In [None]:
import duckdb
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

In [None]:
pd.set_option('display.max_colwidth', None)  
pd.set_option('display.expand_frame_repr', False)  
pd.set_option('display.max_columns', None)

In [None]:
con = duckdb.connect(r"C:\Users\vigne\Desktop\Capstone\datasets\model_train_data.duckdb")
adf=con.execute("select * from allele").fetch_df()
con.close()

# Majority Undersampling
- We have way more benign comapared to pathogenic. 
- undersampling benign

In [None]:
# Balance classes first
pathogenic_df = adf[adf['ClinicalSignificance'] == 1]
benign_df = adf[adf['ClinicalSignificance'] == 0]
benign_sampled = benign_df.sample(n=len(pathogenic_df), random_state=42)
balanced_df = pd.concat([pathogenic_df, benign_sampled])

In [None]:
# Features (exclude IDs and target)
feature_cols_allele = [col for col in balanced_df.columns if col not in ['AlleleID', 'ClinicalSignificance','GeneID']]

X = balanced_df[feature_cols_allele]
y = balanced_df['ClinicalSignificance']

# Split: 70% train, 15% validation, 15% test
trainx, X_temp, trainy, y_temp = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
valx, testx, valy, testy = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp)

In [None]:
trainx.drop(['ref_is_A', 'ref_is_T', 'ref_is_G', 'ref_is_C', 'alt_is_A', 'alt_is_T', 'alt_is_G', 'alt_is_C',
       'chr_11', 'chr_6', 'chr_2', 'chr_20', 'chr_10', 'chr_16', 'chr_22',
       'chr_15', 'chr_1', 'chr_7', 'chr_8', 'chr_14', 'chr_21', 'chr_5',
       'chr_4', 'chr_19', 'chr_3', 'chr_17', 'chr_12', 'chr_18', 'chr_9',
       'chr_13', 'chr_MT', 'chr_Y', 'chr_X'],axis = 1 , inplace = True)

valx.drop(['ref_is_A', 'ref_is_T', 'ref_is_G', 'ref_is_C', 'alt_is_A', 'alt_is_T', 'alt_is_G', 'alt_is_C',
       'chr_11', 'chr_6', 'chr_2', 'chr_20', 'chr_10', 'chr_16', 'chr_22',
       'chr_15', 'chr_1', 'chr_7', 'chr_8', 'chr_14', 'chr_21', 'chr_5',
       'chr_4', 'chr_19', 'chr_3', 'chr_17', 'chr_12', 'chr_18', 'chr_9',
       'chr_13', 'chr_MT', 'chr_Y', 'chr_X'],axis = 1 , inplace = True)

testx.drop(['ref_is_A', 'ref_is_T', 'ref_is_G', 'ref_is_C', 'alt_is_A', 'alt_is_T', 'alt_is_G', 'alt_is_C',
       'chr_11', 'chr_6', 'chr_2', 'chr_20', 'chr_10', 'chr_16', 'chr_22',
       'chr_15', 'chr_1', 'chr_7', 'chr_8', 'chr_14', 'chr_21', 'chr_5',
       'chr_4', 'chr_19', 'chr_3', 'chr_17', 'chr_12', 'chr_18', 'chr_9',
       'chr_13', 'chr_MT', 'chr_Y', 'chr_X'],axis = 1 , inplace = True)

# SENN

## Model

### Conceptizer : Identity

In [None]:
class IdentityConceptizer(nn.Module):
    """
       - Does absolutely nothing conceptually - just adds then removes a dummy dimension
       - Kept only for maintaining proper SENN interface/flow
       - If in the future you want to try other conceptizers, the necessary structure is present
       - Note: Makes reconstruction loss meaningless (always ~0 since recon_x == original input)
   """
    def __init__(self, **kwargs) :
        super().__init__()
    
    def forward(self, x):
        encoded = self.encode(x)
        decoded = self.decode(encoded)
        return encoded, decoded
    
    def encode(self, x):
        return x.unsqueeze(-1)  # (BATCH, FEATURES, 1)
    
    def decode(self, z):
        return z.squeeze(-1) # (BATCH, FEATURES)


In [None]:
class LinearParameterizer(nn.Module):
    """
        - Hidden layers by default: 128, 64, 32 -> achieved 93% test accuracy
        - Custom hidden_sizes can be provided for experimentation
        - Takes raw input features (not concepts so -> called with x or concepts.squeeze) since IdentityConceptizer makes them equivalent
    """
    def __init__(self, num_features, num_concepts, num_classes, hidden_sizes=None, dropout=0.3):
        super().__init__()
        self.num_concepts = num_concepts
        self.num_classes = num_classes
        
        # Default hidden sizes if not provided
        if hidden_sizes is None:
            hidden_sizes = [num_features, 128, 64, 32, num_concepts * num_classes]
        else:
            hidden_sizes = [num_features] + list(hidden_sizes) + [num_concepts * num_classes]
        
        layers = []
        for h, h_next in zip(hidden_sizes[:-1], hidden_sizes[1:]):
            layers.append(nn.Linear(h, h_next))
            if h_next != hidden_sizes[-1]:  
                layers.append(nn.Dropout(dropout))
                layers.append(nn.ReLU())
        
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        output = self.layers(x)
        return output.view(x.size(0), self.num_concepts, self.num_classes)


In [None]:
class SumAggregator(nn.Module):
    '''
        - Aggregates concepts and relevances using weighted sum (batch matrix multiplication)
        - Applies log_softmax for final class probability distribution
    '''
    def __init__(self, num_classes, **kwargs):
        super().__init__()
        self.num_classes = num_classes
    
    def forward(self, concepts, relevances):
        # concepts: (BATCH, NUM_CONCEPTS, 1)
        # relevances: (BATCH, NUM_CONCEPTS, NUM_CLASSES)
        aggregated = torch.bmm(relevances.permute(0, 2, 1), concepts).squeeze(-1)
        return F.log_softmax(aggregated, dim=1)


In [None]:
class SENN(nn.Module):
    ''' 
        - With IdentityConceptizer: recon_x is identical to original input (reconstruction loss = 0)
        - Returns: predictions, explanations=(concepts, relevances), reconstruction
        - Explanations show which concepts are relevant for each class prediction
    '''
    def __init__(self, conceptizer, parameterizer, aggregator):
        super().__init__()
        self.conceptizer = conceptizer
        self.parameterizer = parameterizer
        self.aggregator = aggregator
    
    def forward(self, x):
        # recon_x is same as original data when using identity conceptizer.
        concepts, recon_x = self.conceptizer(x)
        relevances = self.parameterizer(x)
        predictions = self.aggregator(concepts, relevances)
        explanations = (concepts, relevances)
        return predictions, explanations, recon_x

## Training Functions

### Create dataloaders
- batch size 64

In [None]:
def create_data_loader(X, y, batch_size=64, shuffle=True):
    """Convert pandas DataFrame to PyTorch DataLoader"""
    X_tensor = torch.FloatTensor(X.values)
    y_tensor = torch.LongTensor(y.values)
    dataset = TensorDataset(X_tensor, y_tensor)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

### Training loop
- for 1 epoch

In [None]:
def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        predictions, explanations, recon_x = model(data)
        
        # Main classification loss
        loss = criterion(predictions, target)
        
        # Reconstruction loss : is 0 since recon_x == data. 
        # this line is  kept as boilerplate code incase of trying out different conceptizers.
        recon_loss = F.mse_loss(recon_x, data)
        total_loss_val = loss + 0.01 * recon_loss  
        
        total_loss_val.backward()
        optimizer.step()
        
        total_loss += total_loss_val.item()
        pred = predictions.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
    
    return total_loss / len(train_loader), correct / total

### Model evaluation

In [None]:
def evaluate(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            predictions, explanations, recon_x = model(data)
            
            loss = criterion(predictions, target)
            total_loss += loss.item()
            
            pred = predictions.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
            
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    accuracy = correct / total
    avg_loss = total_loss / len(test_loader)
    
    return avg_loss, accuracy, all_preds, all_targets

### Main Training
- Dropout : 0.3
- Criterion : NLL
- Optimizer : Adam
- scheduler :
    - reduce
    - 30 epochs
    - gamma = 0.5
- patience counter added if the model plateaus

In [None]:
def train_senn_model(trainx, trainy, valx, valy, 
                     num_epochs=100, batch_size=64, learning_rate=0.001):
    """
    Complete SENN training pipeline
    
    Parameters:
    -----------
    trainx, trainy : Training data 
    valx, valy : Validation data  
    testx, testy : Test data
    """
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Data dimensions
    num_features = trainx.shape[1]
    num_concepts = num_features  # Each feature is a concept
    num_classes = 2  # Binary classification
    
    print(f"Number of features: {num_features}")
    print(f"Number of concepts: {num_concepts}")
    print(f"Number of classes: {num_classes}")
    
    # Create data loaders
    train_loader = create_data_loader(trainx, trainy, batch_size, shuffle=True)
    val_loader = create_data_loader(valx, valy, batch_size, shuffle=False)
    test_loader = create_data_loader(testx, testy, batch_size, shuffle=False)
    
    # Initialize SENN components
    conceptizer = IdentityConceptizer()
    parameterizer = LinearParameterizer(
        num_features=num_features,
        num_concepts=num_concepts, 
        num_classes=num_classes,
        hidden_sizes=[128, 64, 32],  # change if testing
        dropout=0.3
    )
    aggregator = SumAggregator(num_classes=num_classes)
    
    # Create SENN model
    model = SENN(conceptizer, parameterizer, aggregator)
    model.to(device)
    
    # Loss and optimizer
    criterion = nn.NLLLoss()  # For log_softmax output
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
    
    # Training loop
    train_losses, train_accs = [], []
    val_losses, val_accs = [], []
    
    best_val_acc = 0.0
    patience = 10
    patience_counter = 0
    
    print("Starting training...")
    
    for epoch in tqdm(range(num_epochs), desc="Training Progress"):

        # Train
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        
        # Validate
        val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion, device)
        
        # Update learning rate
        scheduler.step()
        
        # Track metrics
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), r'C:\Users\vigne\Desktop\Capstone\modelsbest_senn_model.pth')
        else:
            patience_counter += 1
        
        # Print progress
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}]')
            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
            print('-' * 50)
        
        # Early stopping
        if patience_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break   

    # Plot training curves
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(val_accs, label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    return model

In [None]:
# Train
model = train_senn_model(trainx, trainy, valx, valy ,num_epochs=50)

# OUTPUT Feature importance 

In [None]:
def get_feature_importance(model, feature_names, device, sample_input=None, class_names=("benign", "pathogenic")):
    """Extract feature importance from SENN model"""
    model.eval()

    if sample_input is None:
        sample_input = torch.randn(1, len(feature_names)).to(device)

    with torch.no_grad():
        predictions, explanations, recon_x = model(sample_input)
        concepts, relevances = explanations
        # relevances: [batch, num_concepts, num_classes]

        # Global average (across batch and classes)
        global_avg = relevances.mean(dim=(0, 2)).cpu().numpy()

        # Per-class average (across batch only)
        per_class_avg = relevances.mean(dim=0).cpu().numpy()  # [num_concepts, num_classes]

    # Build dictionary
    importance_dict = {}
    for i, name in enumerate(feature_names):
        importance_dict[name] = {
            "global": float(global_avg[i]),
            class_names[0]: float(per_class_avg[i, 0]),  # benign
            class_names[1]: float(per_class_avg[i, 1])   # pathogenic
        }

    return importance_dict


In [None]:
def print_feature_importance_rankings(feature_importance):
    """Print top 10 features for global, benign, pathogenic"""
    
    # --- Global ---
    sorted_global = sorted(feature_importance.items(),
                           key=lambda x: abs(x[1]["global"]),
                           reverse=True)
    print("\nTop 10 Most Important Features (Global):")
    for i, (feature, scores) in enumerate(sorted_global[:10]):
        print(f"{i+1:2d}. {feature:<40} | global: {scores['global']:>8.4f}")
    
    # --- Benign ---
    sorted_benign = sorted(feature_importance.items(),
                           key=lambda x: abs(x[1]["benign"]),  # Fixed: using "benign" key
                           reverse=True)
    print("\nTop 10 Most Important Features (Benign):")
    for i, (feature, scores) in enumerate(sorted_benign[:10]):
        print(f"{i+1:2d}. {feature:<40} | benign: {scores['benign']:>8.4f}")
    
    # --- Pathogenic ---
    sorted_pathogenic = sorted(feature_importance.items(),
                               key=lambda x: abs(x[1]["pathogenic"]),  # Fixed: using "pathogenic" key
                               reverse=True)
    print("\nTop 10 Most Important Features (Pathogenic):")
    for i, (feature, scores) in enumerate(sorted_pathogenic[:10]):
        print(f"{i+1:2d}. {feature:<40} | pathogenic: {scores['pathogenic']:>8.4f}")



In [None]:
def group_concepts(feature_importance, feature_names):
    """Group features into concepts"""
    concept_groups = {
        'genomic_location': ['chr_', 'is_genomic', 'is_mitochondrial'],
        'sequence_change': ['ref_is_', 'alt_is_'], 
        'gene_context': ['has_VariantGeneRelation_'],
        'molecular_consequence': ['has_MC_'],
        'data_source': ['has_Origin_']
    }
    
    concept_importance = {}
    detailed_contributions = {}
    
    valid_features = {f: feature_importance[f] for f in feature_names if f in feature_importance}
    
    for concept_name, feature_prefixes in concept_groups.items():
        concept_features = {}
        total_importance = 0
        
        for feature_name, scores in valid_features.items():
            for prefix in feature_prefixes:
                if feature_name.startswith(prefix):
                    concept_features[feature_name] = scores
                    total_importance += abs(scores["global"])
                    break
        
        concept_importance[concept_name] = total_importance
        detailed_contributions[concept_name] = concept_features
    
    return concept_importance, detailed_contributions


In [None]:
def print_concept_importance(concept_importance, detailed_contributions):
    """Print concept-level importance"""
    print("\n" + "="*60)
    print("CONCEPT-LEVEL IMPORTANCE ANALYSIS")
    print("="*60)
    
    print("\nConcept-Level Importance (Global):")
    for concept, importance in sorted(concept_importance.items(), key=lambda x: x[1], reverse=True):
        print(f"{concept:<25}: {importance:>8.4f}")
    
    print("\nDetailed Feature Contributions by Concept:")
    for concept_name, features in detailed_contributions.items():
        if features:
            print(f"\n{concept_name.upper()}:")
            sorted_features = sorted(features.items(),
                                     key=lambda x: abs(x[1]["global"]),
                                     reverse=True)
            for feature, scores in sorted_features[:5]:  # Top 5 per concept
                print(f"  {feature:<30}: "
                      f"global={scores['global']:>6.3f}, "
                      f"benign={scores['benign']:>6.3f}, "  # Fixed: using "benign" key
                      f"pathogenic={scores['pathogenic']:>6.3f}")  # Fixed: using "pathogenic" key


In [None]:
def final_result(model, testx, testy, feature_names, batch_size=64):
    """
    Complete SENN analysis: testing, feature importance, and concept grouping

    Args:
        model: Trained SENN model 
        testx, testy: Test data
        feature_names: List of feature names (e.g., trainx.columns.tolist())
        batch_size: Batch size for testing

    Returns:
        test_results, feature_importance, concept_importance, detailed_contributions
    """

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load best model
    model.load_state_dict(torch.load(r'C:\Users\vigne\Desktop\Capstone\models\best_senn_model.pth'))

    # Test evaluation
    test_loader = create_data_loader(testx, testy, batch_size, shuffle=False)
    criterion = nn.NLLLoss()

    test_loss, test_acc, test_preds, test_targets = evaluate(model, test_loader, criterion, device)

    print(f"\nFinal Test Results:")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.4f}")
    print("\nDetailed Classification Report:")
    print(classification_report(test_targets, test_preds, target_names=['Benign', 'Pathogenic']))

    # Get feature importance
    feature_importance = get_feature_importance(model, feature_names, device)

    # Print individual feature importance
    print_feature_importance_rankings(feature_importance)

    # Group into concepts
    concept_importance, detailed_contributions = group_concepts(feature_importance, feature_names)

    # Print concept-level results
    print_concept_importance(concept_importance, detailed_contributions)

    test_results = (test_loss, test_acc, test_preds, test_targets)
    return test_results, feature_importance, concept_importance, detailed_contributions

In [None]:
test_results, feature_importance, concept_importance, detailed_contributions = final_result(
    model, testx, testy, trainx.columns.tolist()
)