In [None]:
import duckdb
import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
pd.set_option('display.max_colwidth', None)  
pd.set_option('display.expand_frame_repr', False)  
pd.set_option('display.max_columns', None)

In [None]:
con = duckdb.connect(r"C:\Users\vigne\Desktop\Capstone\datasets\model_train_data.duckdb")
adf=con.execute("select * from allele").fetch_df()
con.close()

# Majority Undersampling
- We have way more benign comapared to pathogenic. 
- undersampling benign

In [None]:
# Balance classes first
pathogenic_df = adf[adf['ClinicalSignificance'] == 1]
benign_df = adf[adf['ClinicalSignificance'] == 0]
benign_sampled = benign_df.sample(n=len(pathogenic_df), random_state=42)
balanced_df = pd.concat([pathogenic_df, benign_sampled])

In [None]:
# Features (exclude IDs and target)
feature_cols_allele = [col for col in balanced_df.columns if col not in ['AlleleID', 'ClinicalSignificance','GeneID']]

X = balanced_df[feature_cols_allele]
y = balanced_df['ClinicalSignificance']

# Split: 70% train, 15% validation, 15% test
trainx, X_temp, trainy, y_temp = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
valx, testx, valy, testy = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp)

# SENN

## Model

### Conceptizer : Identity

In [None]:
class IdentityConceptizer(nn.Module):
    """
       - Does absolutely nothing conceptually - just adds then removes a dummy dimension
       - Kept only for maintaining proper SENN interface/flow
       - If in the future you want to try other conceptizers, the necessary structure is present
       - Note: Makes reconstruction loss meaningless (always ~0 since recon_x == original input)
   """
    def __init__(self, **kwargs) :
        super().__init__()
    
    def forward(self, x):
        encoded = self.encode(x)
        decoded = self.decode(encoded)
        return encoded, decoded
    
    def encode(self, x):
        return x.unsqueeze(-1)  # (BATCH, FEATURES, 1)
    
    def decode(self, z):
        return z.squeeze(-1) # (BATCH, FEATURES)


In [None]:
class LinearParameterizer(nn.Module):
    """
        - Hidden layers by default: 128, 64, 32 -> achieved 93% test accuracy
        - Custom hidden_sizes can be provided for experimentation
        - Takes raw input features (not concepts so -> called with x or concepts.squeeze) since IdentityConceptizer makes them equivalent
    """
    def __init__(self, num_features, num_concepts, num_classes, hidden_sizes=None, dropout=0.3):
        super().__init__()
        self.num_concepts = num_concepts
        self.num_classes = num_classes
        
        # Default hidden sizes if not provided
        if hidden_sizes is None:
            hidden_sizes = [num_features, 128, 64, 32, num_concepts * num_classes]
        else:
            hidden_sizes = [num_features] + list(hidden_sizes) + [num_concepts * num_classes]
        
        layers = []
        for h, h_next in zip(hidden_sizes[:-1], hidden_sizes[1:]):
            layers.append(nn.Linear(h, h_next))
            if h_next != hidden_sizes[-1]:  
                layers.append(nn.Dropout(dropout))
                layers.append(nn.ReLU())
        
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        output = self.layers(x)
        return output.view(x.size(0), self.num_concepts, self.num_classes)


In [None]:
class SumAggregator(nn.Module):
    '''
        - Aggregates concepts and relevances using weighted sum (batch matrix multiplication)
        - Applies log_softmax for final class probability distribution
    '''
    def __init__(self, num_classes, **kwargs):
        super().__init__()
        self.num_classes = num_classes
    
    def forward(self, concepts, relevances):
        # concepts: (BATCH, NUM_CONCEPTS, 1)
        # relevances: (BATCH, NUM_CONCEPTS, NUM_CLASSES)
        aggregated = torch.bmm(relevances.permute(0, 2, 1), concepts).squeeze(-1)
        return F.log_softmax(aggregated, dim=1)


In [None]:
class SENN(nn.Module):
    ''' 
        - With IdentityConceptizer: recon_x is identical to original input (reconstruction loss = 0)
        - Returns: predictions, explanations=(concepts, relevances), reconstruction
        - Explanations show which concepts are relevant for each class prediction
    '''
    def __init__(self, conceptizer, parameterizer, aggregator):
        super().__init__()
        self.conceptizer = conceptizer
        self.parameterizer = parameterizer
        self.aggregator = aggregator
    
    def forward(self, x):
        # recon_x is same as original data when using identity conceptizer.
        concepts, recon_x = self.conceptizer(x)
        relevances = self.parameterizer(x)
        predictions = self.aggregator(concepts, relevances)
        explanations = (concepts, relevances)
        return predictions, explanations, recon_x

## Training Functions

### Create dataloaders
- batch size 64

In [None]:
def create_data_loader(X, y, batch_size=64, shuffle=True):
    """Convert pandas DataFrame to PyTorch DataLoader"""
    X_tensor = torch.FloatTensor(X.values)
    y_tensor = torch.LongTensor(y.values)
    dataset = TensorDataset(X_tensor, y_tensor)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

### Training loop
- for 1 epoch

In [None]:
def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        predictions, explanations, recon_x = model(data)
        
        # Main classification loss
        loss = criterion(predictions, target)
        
        # Reconstruction loss : is 0 since recon_x == data. 
        # this line is  kept as boilerplate code incase of trying out different conceptizers.
        recon_loss = F.mse_loss(recon_x, data)
        total_loss_val = loss + 0.01 * recon_loss  
        
        total_loss_val.backward()
        optimizer.step()
        
        total_loss += total_loss_val.item()
        pred = predictions.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
    
    return total_loss / len(train_loader), correct / total