# TCR-GEX Joint Analysis - EAE Dataset

This notebook performs joint analysis of T-cell receptor (TCR) and gene expression (GEX) data to predict tissue localization using the EAE_allTcells.csv dataset.

**Created:** Fri Aug 8 14:24:07 2025  
**Author:** a4945  
**Updated:** For EAE dataset analysis

## Overview
This script predicts tissue localization (CN vs SP) using:
- DeepTCR embeddings (Temb_0 to Temb_95)
- Gene expression features (CD4, CD8a, CD8b1, NKG7, Foxp3, etc.)
- T-cell receptor distance (TCRdist_MOG)
- Cell type information (modified_cell_type)
- Clone size information (clone_id_size)

## Train/Test Split
- **Test set:** mouse_id '5_3' and '5_4'
- **Training set:** All other mouse_ids

## Output Management
All results will be saved to a folder named with the test name you provide.


In [None]:
#%% Test Name Input and Output Directory Setup
import os
from datetime import datetime

def setup_test_environment():
    """
    Function to input test name and create output directory
    """
    # Get test name from user input
    test_name = input("Enter test name (e.g., 'EAE_analysis_2025'): ").strip()
    
    if not test_name:
        # Generate default name with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        test_name = f"EAE_analysis_{timestamp}"
        print(f"No name provided, using default: {test_name}")
    
    # Create output directory
    output_dir = f"results_{test_name}"
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Test name: {test_name}")
    print(f"Output directory: {output_dir}")
    print(f"All results will be saved to: {os.path.abspath(output_dir)}")
    
    return test_name, output_dir

# Setup test environment
TEST_NAME, OUTPUT_DIR = setup_test_environment()


In [None]:
# -*- coding: utf-8 -*-
"""
Created on Fri Aug  8 14:24:07 2025

@author: a4945
Updated for EAE dataset analysis
"""

import warnings
warnings.filterwarnings("ignore")
import os
os.environ["SCIPY_ARRAY_API"] = "1"

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from sklearn.model_selection import KFold
torch.manual_seed(455)

import numpy as np
import random
import pandas as pd
import matplotlib.pyplot as plt

np.random.seed(455)

import pandas as pd
from sklearn.preprocessing import LabelEncoder

# Load EAE dataset
print("Loading EAE_allTcells.csv...")
df_all_features = pd.read_csv('EAE_allTcells.csv')

print(f"Dataset shape: {df_all_features.shape}")
print(f"Columns: {list(df_all_features.columns)}")
print(f"Tissue distribution: {df_all_features['tissue'].value_counts()}")
print(f"Mouse ID distribution: {df_all_features['mouse_id'].value_counts()}")
print(f"Cell type distribution: {df_all_features['modified_cell_type'].value_counts()}")


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
#%% data pre-processing
from imblearn.over_sampling import RandomOverSampler
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

def preprocessing(df_in, target):
    str_cols = df_in.select_dtypes(include=["object", "string", "category"]).columns

    # One-hot encode those, keep numeric columns as-is
    df = pd.get_dummies(df_in, columns = str_cols, dtype="uint8", dummy_na=True)   
    df.columns = df.columns.astype(str)
    
    feature_names = df.columns
    
    # resampling
    ros = RandomOverSampler(random_state=0)
    X_resampled, Y_resampled = ros.fit_resample(df, target)

    return X_resampled, Y_resampled, feature_names


## Model Architecture

Define the TCR classifier neural network model with regularization.


In [None]:
#%% build model
class TCRClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(TCRClassifier, self).__init__()
        self.layer1 = nn.Linear(input_size, 64)
        self.dropout1 = nn.Dropout(0.3)
        self.layer2 = nn.Linear(64, 32)
        self.dropout2 = nn.Dropout(0.2)
        self.layer3 = nn.Linear(32, 16)
        self.dropout3 = nn.Dropout(0.2)
        self.output = nn.Linear(16, num_classes)
        
        # L1 and L2 regularization equivalent
        self.l1_l2_reg = 0.01
        
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = self.dropout1(x)
        x = F.relu(self.layer2(x))
        x = self.dropout2(x)
        x = F.relu(self.layer3(x))
        x = self.dropout3(x)
        x = F.softmax(self.output(x), dim=1)
        return x
    
    def l1_l2_loss(self):
        l1_loss = sum(torch.norm(p, 1) for p in self.parameters())
        l2_loss = sum(torch.norm(p, 2) for p in self.parameters())
        return self.l1_l2_reg * (l1_loss + l2_loss)

def build_model(input_size, num_classes):
    model = TCRClassifier(input_size, num_classes)
    return model


In [None]:
#%%      plot confusion matrix  
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

def plot_confusion(Y_test, class_pred, s, title):
    cm = confusion_matrix(Y_test, class_pred)
    s = df_all_features[target_class].astype('category')
    class_labels = s.cat.categories
    pred_accuracy = (cm[0,0] + cm[1,1]) / np.sum(cm)
    
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels = class_labels)
    disp.plot(cmap='Blues')
    plt.title(str(title) + "  acc:" + str(round(pred_accuracy, 3)))
    
    # Save to output directory
    output_path = os.path.join(OUTPUT_DIR, f"{title}_confusion.png")
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Confusion matrix saved to: {output_path}")
    plt.show()

## plot AUC 
from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt

def plot_ROC(Y_test, test_pred, title):
    # y_true: true binary labels (0 or 1)
    # y_scores: predicted probabilities for class 1 (NOT class labels)
    # e.g. from model.predict_proba(X)[:, 1]
    
    fpr, tpr, thresholds = roc_curve(Y_test, test_pred[:,1])
    auc = roc_auc_score(Y_test, test_pred[:,1])
    
    plt.figure(figsize=(6, 5))
    plt.plot(fpr, tpr, label=f'AUC = {auc:.3f}')
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Random')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(str(title))
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    
    # Save to output directory
    output_path = os.path.join(OUTPUT_DIR, f"{title}_ROC.png")
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"ROC curve saved to: {output_path}")
    plt.show()


## Model Training and Evaluation

Train the TCR classifier model and evaluate performance across different cell types and clonality conditions.


In [None]:
#%% Updated training configuration for EAE dataset
# select features for EAE dataset
input_cat_features = ['modified_cell_type', 'clone_id_size', 'TCRdist_MOG']   #    

# DeepTCR embeddings (Temb_0 to Temb_95)
input_embs = [f'Temb_{s}' for s in range(96)]

# Gene expression features
gene_features = ['gene_Cd4', 'gene_Cd8a', 'gene_Cd8b1', 'gene_Nkg7', 'gene_Foxp3', 
                 'gene_Ikzf2', 'gene_Ctla4', 'gene_Il2ra', 'gene_Ccr6', 'gene_Il22']

epoch_num = 100

# Filter data (remove extreme TCR distances if needed)
df_all_features = df_all_features[df_all_features['TCRdist_MOG'] < 200]  # Less restrictive than original

# Define clonality and cell types for analysis
Bool_list = [True, False]  # Both cloned and non-cloned cells
Cell_types = ['CD4+ T', 'CD8+ T', 'Treg']  # Main cell types in EAE data

target_class = 'tissue'  # Target: CN vs SP

print(f"Filtered dataset shape: {df_all_features.shape}")
print(f"Tissue distribution after filtering: {df_all_features['tissue'].value_counts()}")
print(f"Cell type distribution: {df_all_features['modified_cell_type'].value_counts()}")
print(f"Clonality distribution: {df_all_features['clone_id_size'].value_counts().head()}")

# Save dataset summary to output directory
summary_file = os.path.join(OUTPUT_DIR, f"{TEST_NAME}_dataset_summary.txt")
with open(summary_file, 'w') as f:
    f.write(f"Test Name: {TEST_NAME}\n")
    f.write(f"Dataset shape: {df_all_features.shape}\n")
    f.write(f"Tissue distribution:\n{df_all_features['tissue'].value_counts()}\n")
    f.write(f"Cell type distribution:\n{df_all_features['modified_cell_type'].value_counts()}\n")
    f.write(f"Mouse ID distribution:\n{df_all_features['mouse_id'].value_counts()}\n")
    f.write(f"Clonality distribution:\n{df_all_features['clone_id_size'].value_counts().head()}\n")
print(f"Dataset summary saved to: {summary_file}")


In [None]:
#%% Updated training loop for EAE dataset
# Initialize results tracking
results_summary = []

for ind1 in Bool_list:
    # Create clonality filter based on clone_id_size
    if ind1:
        M_sub1 = df_all_features[df_all_features['clone_id_size'] > 1]  # Cloned cells
        clonality_label = "Cloned"
    else:
        M_sub1 = df_all_features[df_all_features['clone_id_size'] == 1]  # Non-cloned cells
        clonality_label = "Non-cloned"
    
    for ind2 in Cell_types:
        M_sub = M_sub1[M_sub1['modified_cell_type'] == ind2]
        
        if len(M_sub) < 10:  # Skip if too few samples
            print(f"Skipping {clonality_label} {ind2}: only {len(M_sub)} samples")
            continue
            
        mouse_id = clonality_label + "_" + ind2        
        features = M_sub[input_cat_features + input_embs + gene_features]      # 
        
        num_classes = df_all_features[target_class].astype('category').value_counts().shape[0]
        target = M_sub[target_class].astype('category').cat.codes
        s = M_sub[target_class]
        
        # Updated test split: mouse_id '5_3' and '5_4' as test group
        test_id = ['5_3', '5_4']
        test_idx = M_sub['mouse_id'].isin(test_id)
        
        if test_idx.sum() < 5:  # Skip if too few test samples
            print(f"Skipping {mouse_id}: only {test_idx.sum()} test samples")
            continue
            
        features_train = features[~test_idx]
        target_train = target[~test_idx]
        X_train, Y_train, _ = preprocessing(features_train, target_train)
        
        features_test = features[test_idx]
        target_test = target[test_idx]
        X_test, Y_test, feature_names = preprocessing(features_test, target_test)
        
        num_features = X_train.shape[1]
        
        print(f"Training {mouse_id}: {X_train.shape[0]} train, {X_test.shape[0]} test samples, {num_features} features")
        
        # Convert to PyTorch tensors
        X_train_tensor = torch.FloatTensor(X_train.values)
        Y_train_tensor = torch.LongTensor(Y_train.values)
        X_test_tensor = torch.FloatTensor(X_test.values)
        Y_test_tensor = torch.LongTensor(Y_test.values)
        
        # Create data loaders
        train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        
        # Build and train model
        model = build_model(num_features, num_classes)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.01)
        
        # Training loop
        model.train()
        for epoch in range(epoch_num):
            for batch_X, batch_Y in train_loader:
                optimizer.zero_grad()
                outputs = model(batch_X)
                loss = criterion(outputs, batch_Y) + model.l1_l2_loss()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
        
        #% test
        model.eval()
        with torch.no_grad():
            test_pred = model(X_test_tensor).numpy()
        class_pred = np.argmax(test_pred, axis=1)
        
        # Calculate metrics
        from sklearn.metrics import accuracy_score, classification_report
        accuracy = accuracy_score(Y_test, class_pred)
        fpr, tpr, _ = roc_curve(Y_test, test_pred[:,1])
        auc = roc_auc_score(Y_test, test_pred[:,1])
        
        # Store results
        results_summary.append({
            'model_name': mouse_id,
            'train_samples': X_train.shape[0],
            'test_samples': X_test.shape[0],
            'num_features': num_features,
            'accuracy': accuracy,
            'auc': auc
        })
        
        print(f"Results for {mouse_id}: Accuracy = {accuracy:.3f}, AUC = {auc:.3f}")
        
        plot_confusion(Y_test, class_pred, s, mouse_id)
        plot_ROC(Y_test, test_pred, mouse_id)
        
        # Save model
        model_path = os.path.join(OUTPUT_DIR, f"{mouse_id}_model.pth")
        torch.save(model.state_dict(), model_path)
        print(f"Model saved to: {model_path}")

# Save results summary
results_df = pd.DataFrame(results_summary)
results_file = os.path.join(OUTPUT_DIR, f"{TEST_NAME}_results_summary.csv")
results_df.to_csv(results_file, index=False)
print(f"Results summary saved to: {results_file}")
print(f"\nAll outputs saved to: {OUTPUT_DIR}")


## Model Saving and Loading

Optional code for saving and loading trained models.


In [None]:
#%%  save model
# torch.save(model.state_dict(), 'TCR_EAE.pth')
           
# model = build_model(num_features, num_classes)
# model.load_state_dict(torch.load('TCR_EAE.pth'))

# Note: Models are automatically saved during training to the output directory
# Each model is saved as: {clonality}_{cell_type}_model.pth


## SHAP Analysis

Model interpretability analysis using SHAP to understand feature importance.


In [None]:
#%% shap explain
# explain all the predictions in the test set

# import shap
# def shap_eavl(X_train, X_test, features):
    
    # Background (masker) — sample to keep things fast and stable
# rng = np.random.default_rng(0)
# bg_idx = rng.choice(X_train.shape[0], size=min(100, X_train.shape[0]), replace=False)
# background = X_train.iloc[bg_idx]

# Prediction function that includes preprocessing if you want to explain raw X
# Here we already precomputed X_train_s/X_test_s; if you'd rather pass raw X to SHAP,
# define: f = lambda data: model.predict(scaler.transform(data), verbose=0)
# def predict_function(data):
#     data_tensor = torch.FloatTensor(data.values)
#     model.eval()
#     with torch.no_grad():
#         return model(data_tensor).numpy()

# f = predict_function

# Create the explainer (auto picks a fast, gradient-based method for TF/Keras when possible)
# explainer = shap.Explainer(f, shap.maskers.Independent(background))

# Use a manageable slice for speed (e.g., 500 samples)
# sample_idx = rng.choice(X_test.shape[0], size=min(30, X_test.shape[0]), replace=False)
# X_eval = X_test.iloc[sample_idx]

# Compute explanations
# shap_values = explainer(X_eval)  # returns a shap.Explanation

# shap_values.feature_names = feature_names.tolist()

# k = 0  # or np.argmax(model.predict(X_eval), axis=1)[i] for per-sample class
# shap.plots.beeswarm(shap_values[:, :, k], max_display=5)        # class k

# or overall ranking across classes:
# shap.plots.bar(shap_values.abs.mean(axis=2), max_display=20)     # mean|SHAP| over classes

# shap_eavl(X_train, X_test, feature_names)


# TCR-GEX Joint Analysis - EAE Dataset

This notebook performs joint analysis of T-cell receptor (TCR) and gene expression (GEX) data to predict tissue localization using the EAE_allTcells.csv dataset.

**Created:** Fri Aug 8 14:24:07 2025  
**Author:** a4945  
**Updated:** For EAE dataset analysis

## Overview
This script predicts tissue localization (CN vs SP) using:
- DeepTCR embeddings (Temb_0 to Temb_95)
- Gene expression features (CD4, CD8a, CD8b1, NKG7, Foxp3, etc.)
- T-cell receptor distance (TCRdist_MOG)
- Cell type information (modified_cell_type)
- Clone size information (clone_id_size)

## Train/Test Split
- **Test set:** mouse_id '5_3' and '5_4'
- **Training set:** All other mouse_ids


In [None]:
# -*- coding: utf-8 -*-
"""
Created on Fri Aug  8 14:24:07 2025

@author: a4945
Updated for EAE dataset analysis
"""

import warnings
warnings.filterwarnings("ignore")
import os
os.environ["SCIPY_ARRAY_API"] = "1"

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from sklearn.model_selection import KFold
torch.manual_seed(455)

import numpy as np
import random
import pandas as pd
import matplotlib.pyplot as plt

np.random.seed(455)

import pandas as pd
from sklearn.preprocessing import LabelEncoder

# Load EAE dataset
print("Loading EAE_allTcells.csv...")
df_all_features = pd.read_csv('EAE_allTcells.csv')

print(f"Dataset shape: {df_all_features.shape}")
print(f"Columns: {list(df_all_features.columns)}")
print(f"Tissue distribution: {df_all_features['tissue'].value_counts()}")
print(f"Mouse ID distribution: {df_all_features['mouse_id'].value_counts()}")
print(f"Cell type distribution: {df_all_features['modified_cell_type'].value_counts()}")


In [None]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
print(device)


In [None]:
#%% data pre-processing
from imblearn.over_sampling import RandomOverSampler
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

def preprocessing(df_in, target):
    str_cols = df_in.select_dtypes(include=["object", "string", "category"]).columns

    # One-hot encode those, keep numeric columns as-is
    df = pd.get_dummies(df_in, columns = str_cols, dtype="uint8", dummy_na=True)   
    df.columns = df.columns.astype(str)
    
    feature_names = df.columns
    
    # resampling
    ros = RandomOverSampler(random_state=0)
    X_resampled, Y_resampled = ros.fit_resample(df, target)

    return X_resampled, Y_resampled, feature_names


## Model Architecture

Define the TCR classifier neural network model with regularization.


In [None]:
#%% build model
class TCRClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(TCRClassifier, self).__init__()
        self.layer1 = nn.Linear(input_size, 64)
        self.dropout1 = nn.Dropout(0.3)
        self.layer2 = nn.Linear(64, 32)
        self.dropout2 = nn.Dropout(0.2)
        self.layer3 = nn.Linear(32, 16)
        self.dropout3 = nn.Dropout(0.2)
        self.output = nn.Linear(16, num_classes)
        
        # L1 and L2 regularization equivalent
        self.l1_l2_reg = 0.01
        
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = self.dropout1(x)
        x = F.relu(self.layer2(x))
        x = self.dropout2(x)
        x = F.relu(self.layer3(x))
        x = self.dropout3(x)
        x = F.softmax(self.output(x), dim=1)
        return x
    
    def l1_l2_loss(self):
        l1_loss = sum(torch.norm(p, 1) for p in self.parameters())
        l2_loss = sum(torch.norm(p, 2) for p in self.parameters())
        return self.l1_l2_reg * (l1_loss + l2_loss)

def build_model(input_size, num_classes):
    model = TCRClassifier(input_size, num_classes)
    return model


In [None]:
#%%      plot confusion matrix  
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

def plot_confusion(Y_test, class_pred, s, title):
    cm = confusion_matrix(Y_test, class_pred)
    s = df_all_features[target_class].astype('category')
    class_labels = s.cat.categories
    pred_accuracy = (cm[0,0] + cm[1,1]) / np.sum(cm)
    
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels = class_labels)
    disp.plot(cmap='Blues')
    plt.title(str(title) + "  acc:" + str(round(pred_accuracy, 3)))
    plt.savefig(str(title) + "_confusion.png")
    plt.show()

## plot AUC 
from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt

def plot_ROC(Y_test, test_pred, title):
    # y_true: true binary labels (0 or 1)
    # y_scores: predicted probabilities for class 1 (NOT class labels)
    # e.g. from model.predict_proba(X)[:, 1]
    
    fpr, tpr, thresholds = roc_curve(Y_test, test_pred[:,1])
    auc = roc_auc_score(Y_test, test_pred[:,1])
    
    plt.figure(figsize=(6, 5))
    plt.plot(fpr, tpr, label=f'AUC = {auc:.3f}')
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Random')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(str(title))
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(str(title) + "_ROC.png")
    plt.show()


## Model Training and Evaluation

Train the TCR classifier model and evaluate performance across different cell types and clonality conditions.


In [None]:
#%% Updated training configuration for EAE dataset
# select features for EAE dataset
input_cat_features = ['modified_cell_type', 'clone_id_size', 'TCRdist_MOG']   #    

# DeepTCR embeddings (Temb_0 to Temb_95)
input_embs = [f'Temb_{s}' for s in range(96)]

# Gene expression features
gene_features = ['gene_Cd4', 'gene_Cd8a', 'gene_Cd8b1', 'gene_Nkg7', 'gene_Foxp3', 
                 'gene_Ikzf2', 'gene_Ctla4', 'gene_Il2ra', 'gene_Ccr6', 'gene_Il22']

epoch_num = 100

# Filter data (remove extreme TCR distances if needed)
df_all_features = df_all_features[df_all_features['TCRdist_MOG'] < 200]  # Less restrictive than original

# Define clonality and cell types for analysis
Bool_list = [True, False]  # Both cloned and non-cloned cells
Cell_types = ['CD4+ T', 'CD8+ T', 'Treg']  # Main cell types in EAE data

target_class = 'tissue'  # Target: CN vs SP

print(f"Filtered dataset shape: {df_all_features.shape}")
print(f"Tissue distribution after filtering: {df_all_features['tissue'].value_counts()}")
print(f"Cell type distribution: {df_all_features['modified_cell_type'].value_counts()}")
print(f"Clonality distribution: {df_all_features['clone_id_size'].value_counts().head()}")


In [None]:
#%% Updated training loop for EAE dataset
for ind1 in Bool_list:
    # Create clonality filter based on clone_id_size
    if ind1:
        M_sub1 = df_all_features[df_all_features['clone_id_size'] > 1]  # Cloned cells
        clonality_label = "Cloned"
    else:
        M_sub1 = df_all_features[df_all_features['clone_id_size'] == 1]  # Non-cloned cells
        clonality_label = "Non-cloned"
    
    for ind2 in Cell_types:
        M_sub = M_sub1[M_sub1['modified_cell_type'] == ind2]
        
        if len(M_sub) < 10:  # Skip if too few samples
            print(f"Skipping {clonality_label} {ind2}: only {len(M_sub)} samples")
            continue
            
        mouse_id = clonality_label + "_" + ind2        
        features = M_sub[input_cat_features + input_embs + gene_features]      # 
        
        num_classes = df_all_features[target_class].astype('category').value_counts().shape[0]
        target = M_sub[target_class].astype('category').cat.codes
        s = M_sub[target_class]
        
        # Updated test split: mouse_id '5_3' and '5_4' as test group
        test_id = ['5_3', '5_4']
        test_idx = M_sub['mouse_id'].isin(test_id)
        
        if test_idx.sum() < 5:  # Skip if too few test samples
            print(f"Skipping {mouse_id}: only {test_idx.sum()} test samples")
            continue
            
        features_train = features[~test_idx]
        target_train = target[~test_idx]
        X_train, Y_train, _ = preprocessing(features_train, target_train)
        
        features_test = features[test_idx]
        target_test = target[test_idx]
        X_test, Y_test, feature_names = preprocessing(features_test, target_test)
        
        num_features = X_train.shape[1]
        
        print(f"Training {mouse_id}: {X_train.shape[0]} train, {X_test.shape[0]} test samples, {num_features} features")
        
        # Convert to PyTorch tensors
        X_train_tensor = torch.FloatTensor(X_train.values)
        Y_train_tensor = torch.LongTensor(Y_train.values)
        X_test_tensor = torch.FloatTensor(X_test.values)
        Y_test_tensor = torch.LongTensor(Y_test.values)
        
        # Create data loaders
        train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        
        # Build and train model
        model = build_model(num_features, num_classes)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.01)
        
        # Training loop
        model.train()
        for epoch in range(epoch_num):
            for batch_X, batch_Y in train_loader:
                optimizer.zero_grad()
                outputs = model(batch_X)
                loss = criterion(outputs, batch_Y) + model.l1_l2_loss()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
        
        #% test
        model.eval()
        with torch.no_grad():
            test_pred = model(X_test_tensor).numpy()
        class_pred = np.argmax(test_pred, axis=1)
        
        plot_confusion(Y_test, class_pred, s, mouse_id)
        plot_ROC(Y_test, test_pred, mouse_id)


## Model Saving and Loading

Optional code for saving and loading trained models.


In [None]:
#%%  save model
# torch.save(model.state_dict(), 'TCR_EAE.pth')
           
# model = build_model(num_features, num_classes)
# model.load_state_dict(torch.load('TCR_EAE.pth'))


## SHAP Analysis

Model interpretability analysis using SHAP to understand feature importance.


In [None]:
#%% shap explain
# explain all the predictions in the test set

# import shap
# def shap_eavl(X_train, X_test, features):
    
    # Background (masker) — sample to keep things fast and stable
# rng = np.random.default_rng(0)
# bg_idx = rng.choice(X_train.shape[0], size=min(100, X_train.shape[0]), replace=False)
# background = X_train.iloc[bg_idx]

# Prediction function that includes preprocessing if you want to explain raw X
# Here we already precomputed X_train_s/X_test_s; if you'd rather pass raw X to SHAP,
# define: f = lambda data: model.predict(scaler.transform(data), verbose=0)
# def predict_function(data):
#     data_tensor = torch.FloatTensor(data.values)
#     model.eval()
#     with torch.no_grad():
#         return model(data_tensor).numpy()

# f = predict_function

# Create the explainer (auto picks a fast, gradient-based method for TF/Keras when possible)
# explainer = shap.Explainer(f, shap.maskers.Independent(background))

# Use a manageable slice for speed (e.g., 500 samples)
# sample_idx = rng.choice(X_test.shape[0], size=min(30, X_test.shape[0]), replace=False)
# X_eval = X_test.iloc[sample_idx]

# Compute explanations
# shap_values = explainer(X_eval)  # returns a shap.Explanation

# shap_values.feature_names = feature_names.tolist()

# k = 0  # or np.argmax(model.predict(X_eval), axis=1)[i] for per-sample class
# shap.plots.beeswarm(shap_values[:, :, k], max_display=5)        # class k

# or overall ranking across classes:
# shap.plots.bar(shap_values.abs.mean(axis=2), max_display=20)     # mean|SHAP| over classes

# shap_eavl(X_train, X_test, feature_names)
