In [10]:
import os
import torch
import numpy as np
import random

from torch.utils.data import DataLoader
from utils.inference_and_explain import Dataset_SNP_XAI, inference_and_explain
from utils.result_to_df import leaf_selection_to_dataframe, leaf_influence_to_dataframe, global_importance_to_dataframes, node_level_stats_to_dataframe

from exp.exp_main import Exp_Main
from models.treeMoE import Model

### Step 1. env settings

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Using device:", device)

fix_seed = 2025
random.seed(fix_seed)
np.random.seed(fix_seed)
torch.manual_seed(fix_seed)

class Args:
    def __init__(self):
        self.device = device
        self.dim_input = 437
        self.num_classes = 2
        # Here
        self.max_depth = 2
        self.hidden_dim_expert = 16
        self.alpha_fs = 0.6568998058780943
        self.beta_fs = 0.2523567583654948
        self.use_gating_mlp = 1
        self.gating_mlp_hidden = 16
        self.initial_temp = 1.0
        self.final_temp = 0.2
        self.anneal_epochs = 20
        self.learning_rate = 0.003

args = Args()

[INFO] Using device: cuda


### Step2. Get load model & Data loader (Function)

In [3]:
def load_model(checkpoint_path):
    model = Model(args).to(device)
    if os.path.exists(checkpoint_path):
        state_dict = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(state_dict)
        print(f"[INFO] Model loaded from {checkpoint_path}")
    else:
        print(f"[WARNING] checkpoint not found at {checkpoint_path}; using untrained model.")
    return model

In [4]:
def prepare_test_loader(stop_loss, stock_name):
    test_dataset = Dataset_SNP_XAI(
        args=args,
        root_path='./dataset',  
        data_path='FULL.csv',    
        flag='test',
        scale=True,
        stop_loss=stop_loss,
        stock_name=stock_name
    )
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    print(f"[INFO] Test set size: {len(test_dataset)}")
    return test_loader


In [5]:
model = load_model(checkpoint_path="./checkpoints/treeMoE_FULL_md2_hde16_al6_be2_mlpTrue_Exp_0/checkpoint.pth")
test_loader_0 = prepare_test_loader(stop_loss = 0 , stock_name = "AAPL")
#test_loader_2 = prepare_test_loader(stop_loss = 2 , stock_name = "AAPL")
#test_loader_3 = prepare_test_loader(stop_loss = 3 , stock_name = "AAPL")
#test_loader_4 = prepare_test_loader(stop_loss = 4 , stock_name = "AAPL")
#test_loader_5 = prepare_test_loader(stop_loss = 5 , stock_name = "AAPL")

[INFO] Model loaded from ./checkpoints/treeMoE_FULL_md2_hde16_al6_be2_mlpTrue_Exp_0/checkpoint.pth
[INFO] Test set size: 279


In [6]:
acc, mcc, Node_level_routing_stat, Leaf_feature_selection, Global_importance, Leaf_influence = inference_and_explain(
    args,
    model,
    test_loader_0,
    gating_mode='soft',          # or 'soft'
    temperature=0.1,
    top_k_leaf=5,
    bottom_k_leaf=5,
    global_importance_mode='both',  # 'none'/'phi_only'/'gamma_weight'/'both'
    top_k_global=5,
    bottom_k_global=5,
    leaf_influence_mode='gamma_weight'  # 'none' or 'gamma_weight'
)


===== Test Accuracy (soft gating) = 0.5556 =====

[Confusion Matrix (soft)]
[[ 33  91]
 [ 33 122]]

[Classification Report]
              precision    recall  f1-score   support

           0     0.5000    0.2661    0.3474       124
           1     0.5728    0.7871    0.6630       155

    accuracy                         0.5556       279
   macro avg     0.5364    0.5266    0.5052       279
weighted avg     0.5404    0.5556    0.5227       279

MCC: 0.0622

=== Node-level Routing Statistics (Soft Gating) ===
[Node] root | coverage=279.000 | label0=124.000, label1=155.000
[Node] root_L | coverage=82.580 | label0=37.658, label1=44.922
[Leaf] root_L_L | coverage=13.335 | label0=5.652, label1=7.683, acc=0.5700, mcc=0.0563
[Leaf] root_L_R | coverage=69.245 | label0=32.005, label1=37.240, acc=0.5319, mcc=0.0458
[Node] root_R | coverage=196.420 | label0=86.342, label1=110.078
[Leaf] root_R_L | coverage=76.075 | label0=32.225, label1=43.849, acc=0.5664, mcc=0.0754
[Leaf] root_R_R | coverage

In [7]:
print("acc :", acc)
print("mcc :", mcc)

acc : 0.5555555555555556
mcc : 0.062235355197987896


In [8]:
if 'gamma_weight' in Leaf_influence:
    data_gamma = Leaf_influence['gamma_weight']

In [11]:
node = node_level_stats_to_dataframe(Node_level_routing_stat)
df_leaves = leaf_selection_to_dataframe(Leaf_feature_selection)
df_cov, df_imp = global_importance_to_dataframes(Global_importance, feature_names=None)
df_leaf_infl = leaf_influence_to_dataframe(data_gamma)

In [16]:
import pandas as pd
df = pd.read_csv("./dataset/FULL.csv")

In [26]:
pd.DataFrame({"col":list(df.head().drop(["Date","Y","Y_2","Y_3","Y_4","Y_5","Stock"], axis = 1).columns)}).to_csv("./dataset/col.csv", index = False)

In [27]:
col = pd.read_csv("./dataset/col.csv")

In [31]:
df_cov, df_imp = global_importance_to_dataframes(Global_importance, feature_names=list(col["col"]))


In [22]:
df_imp

Unnamed: 0,type,feature_name,importance
0,phi_only,SMA_2,0.271761
1,phi_only,EMA_2,0.407247
2,phi_only,ROC_2,0.729316
3,phi_only,RSI_2,0.887865
4,phi_only,WR_2,0.831899
...,...,...,...
869,gamma_weight,CloseStd_28,0.479874
870,gamma_weight,Range_29,0.206074
871,gamma_weight,CloseStd_29,0.223853
872,gamma_weight,Range_30,0.701853
