In [1]:
import os
import torch
import numpy as np
import random

from torch.utils.data import DataLoader
from utils.inference_and_explain import Dataset_SNP_XAI, inference_and_explain
from exp.exp_main import Exp_Main
from models.treeMoE import Model

### Step 1. env settings

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Using device:", device)

fix_seed = 2025
random.seed(fix_seed)
np.random.seed(fix_seed)
torch.manual_seed(fix_seed)

class Args:
    def __init__(self):
        self.device = device
        self.dim_input = 437
        self.num_classes = 2
        # Here
        self.max_depth = 2
        self.hidden_dim_expert = 16
        self.alpha_fs = 0.6568998058780943
        self.beta_fs = 0.2523567583654948
        self.use_gating_mlp = 1
        self.gating_mlp_hidden = 16
        self.initial_temp = 1.0
        self.final_temp = 0.2
        self.anneal_epochs = 20
        self.learning_rate = 0.003

args = Args()

[INFO] Using device: cuda


### Step2. Get load model & Data loader (Function)

In [3]:
def load_model(checkpoint_path):
    model = Model(args).to(device)
    if os.path.exists(checkpoint_path):
        state_dict = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(state_dict)
        print(f"[INFO] Model loaded from {checkpoint_path}")
    else:
        print(f"[WARNING] checkpoint not found at {checkpoint_path}; using untrained model.")
    return model

In [4]:
def prepare_test_loader(stop_loss, stock_name):
    test_dataset = Dataset_SNP_XAI(
        args=args,
        root_path='./dataset',  
        data_path='FULL.csv',    
        flag='test',
        scale=True,
        stop_loss=stop_loss,
        stock_name=stock_name
    )
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    print(f"[INFO] Test set size: {len(test_dataset)}")
    return test_loader


In [5]:
model = load_model(checkpoint_path="./checkpoints/treeMoE_FULL_md2_hde16_al6_be2_mlpTrue_Exp_0/checkpoint.pth")
test_loader_0 = prepare_test_loader(stop_loss = 0 , stock_name = "AAPL")
#test_loader_2 = prepare_test_loader(stop_loss = 2 , stock_name = "AAPL")
#test_loader_3 = prepare_test_loader(stop_loss = 3 , stock_name = "AAPL")
#test_loader_4 = prepare_test_loader(stop_loss = 4 , stock_name = "AAPL")
#test_loader_5 = prepare_test_loader(stop_loss = 5 , stock_name = "AAPL")

[INFO] Model loaded from ./checkpoints/treeMoE_FULL_md2_hde16_al6_be2_mlpTrue_Exp_0/checkpoint.pth
[INFO] Test set size: 279


In [10]:
acc, mcc, Node_level_routing_stat, Leaf_feature_selection, Global_importance, Leaf_influence = inference_and_explain(
    args,
    model,
    test_loader_0,
    gating_mode='soft',          # or 'soft'
    temperature=0.1,
    top_k_leaf=5,
    bottom_k_leaf=5,
    global_importance_mode='both',  # 'none'/'phi_only'/'gamma_weight'/'both'
    top_k_global=5,
    bottom_k_global=5,
    leaf_influence_mode='gamma_weight'  # 'none' or 'gamma_weight'
)


===== Test Accuracy (soft gating) = 0.5556 =====

[Confusion Matrix (soft)]
[[ 33  91]
 [ 33 122]]

[Classification Report]
              precision    recall  f1-score   support

           0     0.5000    0.2661    0.3474       124
           1     0.5728    0.7871    0.6630       155

    accuracy                         0.5556       279
   macro avg     0.5364    0.5266    0.5052       279
weighted avg     0.5404    0.5556    0.5227       279

MCC: 0.0622

=== Node-level Routing Statistics (Soft Gating) ===
[Node] root | coverage=279.000 | label0=124.000, label1=155.000
[Node] root_L | coverage=82.580 | label0=37.658, label1=44.922
[Leaf] root_L_L | coverage=13.335 | label0=5.652, label1=7.683, acc=0.5700, mcc=0.0563
[Leaf] root_L_R | coverage=69.245 | label0=32.005, label1=37.240, acc=0.5319, mcc=0.0458
[Node] root_R | coverage=196.420 | label0=86.342, label1=110.078
[Leaf] root_R_L | coverage=76.075 | label0=32.225, label1=43.849, acc=0.5664, mcc=0.0754
[Leaf] root_R_R | coverage

In [11]:
print("acc :", acc)
print("mcc :", mcc)


acc : 0.5555555555555556
mcc : 0.062235355197987896


In [12]:
for stat in Node_level_routing_stat:
    print(stat)


{'node_path': 'root', 'coverage': 279.0, 'coverage_label0': 124.0, 'coverage_label1': 155.0, 'acc_leaf': None, 'mcc_leaf': None, 'is_leaf': False}
{'node_path': 'root_L', 'coverage': 82.57977294921875, 'coverage_label0': 37.657535552978516, 'coverage_label1': 44.922237396240234, 'acc_leaf': None, 'mcc_leaf': None, 'is_leaf': False}
{'node_path': 'root_L_L', 'coverage': 13.33462905883789, 'coverage_label0': 5.652087688446045, 'coverage_label1': 7.682542324066162, 'acc_leaf': 0.5699782196756269, 'mcc_leaf': 0.05633554408682479, 'is_leaf': True}
{'node_path': 'root_L_R', 'coverage': 69.24514770507812, 'coverage_label0': 32.00544738769531, 'coverage_label1': 37.23969650268555, 'acc_leaf': 0.5319451085741977, 'mcc_leaf': 0.04576423537046383, 'is_leaf': True}
{'node_path': 'root_R', 'coverage': 196.42022705078125, 'coverage_label0': 86.34246826171875, 'coverage_label1': 110.0777587890625, 'acc_leaf': None, 'mcc_leaf': None, 'is_leaf': False}
{'node_path': 'root_R_L', 'coverage': 76.074813842

In [13]:
# Leaf_feature_selection
for leaf_sel in Leaf_feature_selection:
    print(leaf_sel)

{'leaf_path': 'root_L_L', 'top_features': [('StoK_2', 2.538090229034424), ('WR_3', 2.4669647216796875), ('WR_2', 2.235210418701172), ('ROC_24', 2.2015724182128906), ('CloseStd_2', 2.102259874343872)], 'bottom_features': [('Volume_EMA_3', -3.0052690505981445), ('BBup_3', -2.972235918045044), ('Volume_EMA_23', -2.6266863346099854), ('StoD_27', -2.553452730178833), ('BBup_15', -2.546664237976074)]}
{'leaf_path': 'root_L_R', 'top_features': [('ROC_10', 3.221736431121826), ('ROC_9', 3.0836308002471924), ('ROC_8', 3.007431983947754), ('ROC_23', 2.963824987411499), ('ROC_27', 2.9246933460235596)], 'bottom_features': [('BBup_2', -3.526888132095337), ('EMA_19', -3.4113845825195312), ('SMA_15', -3.394969940185547), ('SMA_27', -3.3744688034057617), ('BBdn_12', -3.3671507835388184)]}
{'leaf_path': 'root_R_L', 'top_features': [('ROC_17', 2.9934256076812744), ('ROC_26', 2.705493688583374), ('RSI_27', 2.6053860187530518), ('RSI_3', 2.5873496532440186), ('ROC_10', 2.5546252727508545)], 'bottom_feature

In [14]:
# Global_importance
print("Global_importance keys =>", list(Global_importance.keys()))
if 'phi_only' in Global_importance:
    imp = Global_importance['phi_only']['importance']
    cov = Global_importance['phi_only']['coverage_dict']
    print("phi_only => coverage:", cov)
    print("phi_only => importance shape:", imp.shape)

Global_importance keys => ['phi_only', 'gamma_weight']
phi_only => coverage: {'root_L_L': 0.047794368118047714, 'root_L_R': 0.2481904923915863, 'root_R_L': 0.27266958355903625, 'root_R_R': 0.43134555220603943}
phi_only => importance shape: (437,)


In [15]:

if 'gamma_weight' in Global_importance:
    imp_g = Global_importance['gamma_weight']['importance']
    cov_g = Global_importance['gamma_weight']['coverage_dict']
    print("gamma_weight => coverage:", cov_g)
    print("gamma_weight => importance shape:", imp_g.shape)

gamma_weight => coverage: {'root_L_L': 0.047794368118047714, 'root_L_R': 0.2481904923915863, 'root_R_L': 0.27266958355903625, 'root_R_R': 0.43134555220603943}
gamma_weight => importance shape: (437,)


In [16]:
# Leaf_influence
print("Leaf_influence keys =>", list(Leaf_influence.keys()))
if 'gamma_weight' in Leaf_influence:
    data_gamma = Leaf_influence['gamma_weight']
    # data_gamma[i] = {"leaf_path":..., "coverage":..., "top_features":[], "bottom_features":[]}
    for lf_inf in data_gamma:
        print(lf_inf)

Leaf_influence keys => ['gamma_weight']
{'leaf_path': 'root_L_L', 'coverage': 0.04779437280470325, 'top_features': [('CloseStd_2', 1.5669842958450317, 0.8911226391792297, 1.7584384679794312), ('ROC_24', 1.3782854080200195, 0.900390625, 1.530763864517212), ('MACD_Signal', 1.365856409072876, 0.8461136817932129, 1.614270567893982), ('StoK_2', 1.29862380027771, 0.9267693161964417, 1.4012373685836792), ('MACD', 1.218380093574524, 0.8440697193145752, 1.4434590339660645)], 'bottom_features': [('Volume_EMA_3', 0.03131188079714775, 0.04718839377164841, 0.6635504961013794), ('BBup_3', 0.034894563257694244, 0.048696037381887436, 0.7165791392326355), ('Volume_EMA_28', 0.03706947714090347, 0.07471566647291183, 0.49614059925079346), ('Volume_SMA_25', 0.03841418772935867, 0.08139877021312714, 0.47192588448524475), ('SMA_9', 0.04062732309103012, 0.08424815535545349, 0.4822339713573456)]}
{'leaf_path': 'root_L_R', 'coverage': 0.24819049356658826, 'top_features': [('RSI_26', 4.348038673400879, 0.9263532