# Import Libraries

In [120]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import seaborn as sns
import pandas as pd
import numpy as np
from torchinfo import summary  

# Load dataset

In [121]:
X_train = np.load('X_train.npy')
X_val = np.load('X_val.npy')
y_train = np.load('y_train.npy')
y_val = np.load('y_val.npy')

In [122]:
#Print Dataset for Verification
print(f"X_train.shape: {X_train.shape}")
print(f"X_val.shape: {X_val.shape}")
print(f"y_train.shape: {y_train.shape}")
print(f"y_val.shape: {y_val.shape}")
#Print Class Distribution
print(f"y_train class distribution: {np.bincount(y_train.astype(int))}")
print(f"y_val class distribution: {np.bincount(y_val.astype(int))}")

X_train.shape: (1587, 100, 9)
X_val.shape: (397, 100, 9)
y_train.shape: (1587,)
y_val.shape: (397,)
y_train class distribution: [1505   82]
y_val class distribution: [376  21]


# Apply Class Weight

In [123]:
from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
class_weight_dict = dict(enumerate(class_weights))

# Adjust weight for positive class (reduce by 20%)
class_weight_dict[1] = class_weight_dict[1] * 0.8
print("Adjusted class weights:", class_weight_dict)

# Convert class weights to tensor
class_weights_tensor = torch.FloatTensor([class_weight_dict[0], class_weight_dict[1]])

Adjusted class weights: {0: 0.5272425249169436, 1: 7.741463414634147}


# Dataset + DataLoader

In [124]:
from torch.utils.data import Dataset, DataLoader

class TabularDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y.values if hasattr(y, 'values') else y, dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    
train_ds = TabularDataset(X_train, y_train)
val_ds = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class_weights_tensor = class_weights_tensor.to(device)



# Model

In [125]:
class SAINTTransformer(nn.Module):
    def __init__(self, input_dim, embedding_dim=64, num_heads=4, num_layers=3, dropout=0.1, max_seq_len=100):
        super().__init__()
        self.embedding = nn.Linear(input_dim, embedding_dim)

        # Learnable positional encoding
        self.pos_embedding = nn.Parameter(torch.randn(1, max_seq_len, embedding_dim))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim, 
            nhead=num_heads, 
            dropout=dropout, 
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Attention pooling layer
        self.attention_pool = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )

        self.hidden = nn.Linear(embedding_dim, 32)
        self.relu = nn.ReLU()
        self.classifier = nn.Linear(32, 1)  # Binary classification

    def forward(self, x):
        # x: (batch_size, seq_len, input_dim)
        B, T, _ = x.size()

        x = self.embedding(x)  # (B, T, embedding_dim)

        # Add positional encoding (trim to T)
        x = x + self.pos_embedding[:, :T, :]

        x = self.transformer(x)  # (B, T, embedding_dim)

        # Attention pooling
        attn_weights = self.attention_pool(x)  # (B, T, 1)
        attn_weights = torch.softmax(attn_weights, dim=1)  # (B, T, 1)
        x = torch.sum(x * attn_weights, dim=1)  # (B, embedding_dim)

        x = self.relu(self.hidden(x))  # (B, 32)
        logits = self.classifier(x)    # (B, 1)
        return logits
    
n_features = X_train.shape[2]  # Get number of features from the data
model = SAINTTransformer(input_dim=n_features).to(device)
summary(model, input_size=(64, 100, n_features))

Layer (type:depth-idx)                        Output Shape              Param #
SAINTTransformer                              [64, 1]                   6,400
├─Linear: 1-1                                 [64, 100, 64]             640
├─TransformerEncoder: 1-2                     [64, 100, 64]             --
│    └─ModuleList: 2-1                        --                        --
│    │    └─TransformerEncoderLayer: 3-1      [64, 100, 64]             281,152
│    │    └─TransformerEncoderLayer: 3-2      [64, 100, 64]             281,152
│    │    └─TransformerEncoderLayer: 3-3      [64, 100, 64]             281,152
├─Sequential: 1-3                             [64, 100, 1]              --
│    └─Linear: 2-2                            [64, 100, 128]            8,320
│    └─Tanh: 2-3                              [64, 100, 128]            --
│    └─Linear: 2-4                            [64, 100, 1]              129
├─Linear: 1-4                                 [64, 32]                  

# Criterion (Loss Function) + Optimizer

In [126]:
criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights_tensor[1])  # Use pos_weight for binary classification
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

# Epochs & Record Best Train & Val Metrics

In [127]:
epochs = 20
best_val_auc = 0
best_val_cm = None
best_train_auc = 0
best_train_cm = None
best_train_prec = 0
best_train_rec = 0
best_train_f1 = 0

history = {
    'train_prec': [], 'val_prec': [],
    'train_rec': [], 'val_rec': [],
    'train_f1': [], 'val_f1': [],
    'train_auc': [], 'val_auc': [],
}

loss = {
    'train_loss': [], 'val_loss': [],
}

acc = {
    'train_acc': [], 'val_acc':[],
}

# Define Training Process & Evaluation Metrics

In [128]:
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, auc, roc_curve, precision_recall_curve, average_precision_score

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0
    all_probs = []
    all_preds = []
    all_targets = []

    for X_batch, y_batch in loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs.squeeze(1), y_batch.float())
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * X_batch.size(0)

        probs = torch.sigmoid(outputs).squeeze(dim=1).detach().cpu().numpy()
        preds = (probs > 0.5).astype(int)
        all_probs.extend(probs)
        all_preds.extend(preds)
        all_targets.extend(y_batch.cpu().numpy())

    epoch_loss = running_loss / len(loader.dataset)
    epoch_auc = roc_auc_score(all_targets, all_probs)
    epoch_prec = precision_score(all_targets, all_preds)
    epoch_rec = recall_score(all_targets, all_preds)
    epoch_f1 = f1_score(all_targets, all_preds)
    epoch_acc = accuracy_score(all_targets, all_preds)
    epoch_cm = confusion_matrix(all_targets, all_preds)
    return epoch_loss, epoch_auc, epoch_acc, epoch_prec, epoch_rec, epoch_f1, epoch_cm

def eval_one_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0
    all_probs = []
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for X_batch, y_batch in loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs.squeeze(1), y_batch.float())
            running_loss += loss.item() * X_batch.size(0)

            probs = torch.sigmoid(outputs).squeeze(dim=1).detach().cpu().numpy()
            preds = (probs > 0.5).astype(int)
            all_probs.extend(probs)
            all_preds.extend(preds)
            all_targets.extend(y_batch.cpu().numpy())

    epoch_loss = running_loss / len(loader.dataset)
    epoch_auc = roc_auc_score(all_targets, all_probs)
    epoch_prec = precision_score(all_targets, all_preds)
    epoch_rec = recall_score(all_targets, all_preds)
    epoch_f1 = f1_score(all_targets, all_preds)
    epoch_acc = accuracy_score(all_targets, all_preds)
    epoch_cm = confusion_matrix(all_targets, all_preds)
    return epoch_loss, epoch_auc, epoch_acc, epoch_prec, epoch_rec, epoch_f1, epoch_cm, all_targets, all_probs


# Start Training

In [129]:
for epoch in range(epochs):
        train_loss, train_auc, train_acc, train_prec, train_rec, train_f1, train_cm = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_auc, val_acc, val_prec, val_rec, val_f1, val_cm, all_targets, all_probs = eval_one_epoch(model, val_loader, criterion, device)

        scheduler.step(val_loss)
        print(f"Current LR: {optimizer.param_groups[0]['lr']}")
        history['train_prec'].append(train_prec)
        history['val_prec'].append(val_prec)

        history['train_rec'].append(train_rec)
        history['val_rec'].append(val_rec)

        history['train_f1'].append(train_f1)
        history['val_f1'].append(val_f1)

        history['train_auc'].append(train_auc)
        history['val_auc'].append(val_auc)

        loss['train_loss'].append(train_loss)
        loss['val_loss'].append(val_loss)

        acc['train_acc'].append(train_acc)
        acc['val_acc'].append(val_acc)

        # Track best validation
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_val_cm = val_cm
            best_val_prec = val_prec
            best_val_rec = val_rec
            best_val_f1 = val_f1
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_saint_model.pth')

        # Track best training
        if train_auc > best_train_auc:
            best_train_auc = train_auc
            best_train_cm = train_cm
            best_train_prec = train_prec
            best_train_rec = train_rec
            best_train_f1 = train_f1
            best_train_acc = train_acc

        print(f"Epoch {epoch+1}/{epochs} - "
              f"Train Loss: {train_loss:.4f}, AUC: {train_auc:.4f}, Acc: {train_acc:.4f}, Prec: {train_prec:.4f}, Rec: {train_rec:.4f}, F1: {train_f1:.4f}")
        print(f"Val Loss: {val_loss:.4f}, AUC: {val_auc:.4f}, Acc: {val_acc:.4f}, Prec: {val_prec:.4f}, Rec: {val_rec:.4f}, F1: {val_f1:.4f}")

Current LR: 0.001
Epoch 1/20 - Train Loss: 0.6893, AUC: 0.7875, Acc: 0.8198, Prec: 0.1304, Rec: 0.4390, F1: 0.2011
Val Loss: 0.5893, AUC: 0.8499, Acc: 0.7355, Prec: 0.1557, Rec: 0.9048, F1: 0.2657
Current LR: 0.001
Epoch 2/20 - Train Loss: 0.6021, AUC: 0.8196, Acc: 0.7839, Prec: 0.1645, Rec: 0.7805, F1: 0.2718
Val Loss: 0.6280, AUC: 0.7898, Acc: 0.7531, Prec: 0.1532, Rec: 0.8095, F1: 0.2576
Current LR: 0.001
Epoch 3/20 - Train Loss: 0.5977, AUC: 0.8196, Acc: 0.7687, Prec: 0.1550, Rec: 0.7805, F1: 0.2586
Val Loss: 0.6353, AUC: 0.7950, Acc: 0.8060, Prec: 0.1585, Rec: 0.6190, F1: 0.2524
Current LR: 0.001
Epoch 4/20 - Train Loss: 0.5500, AUC: 0.8525, Acc: 0.8173, Prec: 0.1941, Rec: 0.8049, F1: 0.3128
Val Loss: 0.5642, AUC: 0.8345, Acc: 0.7708, Prec: 0.1759, Rec: 0.9048, F1: 0.2946
Current LR: 0.001
Epoch 5/20 - Train Loss: 0.5541, AUC: 0.8278, Acc: 0.7845, Prec: 0.1734, Rec: 0.8415, F1: 0.2875
Val Loss: 0.5739, AUC: 0.8267, Acc: 0.7607, Prec: 0.1696, Rec: 0.9048, F1: 0.2857
Current LR: 0.0

# Plot Metrics (AUC-ROC, Precision, Recall, F1)

In [130]:
import matplotlib.pyplot as plt

def plot_metrics(history, filename):
    # Get final epoch metrics
    metrics = ['Precision', 'Recall', 'F1-score', 'AUC']
    train_scores = [
        history['train_prec'][-1],
        history['train_rec'][-1],
        history['train_f1'][-1],
        history['train_auc'][-1]
    ]
    val_scores = [
        history['val_prec'][-1],
        history['val_rec'][-1],
        history['val_f1'][-1],
        history['val_auc'][-1]
    ]

    x = np.arange(len(metrics))  # Metric categories
    width = 0.35  # Bar width

    plt.figure(figsize=(8, 6))
    plt.bar(x - width/2, train_scores, width, label='Train', color='skyblue')
    plt.bar(x + width/2, val_scores, width, label='Validation', color='lightcoral')

    plt.ylabel('Score')
    plt.ylim(0, 1.05)
    plt.title('Final Training vs Validation Metrics')
    plt.xticks(x, metrics)
    plt.legend()
    plt.grid(True, axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

plot_metrics(history, 'training_validation_metrics.png')

# Plot Loss & Accuracy

In [131]:
def plot_loss_and_accuracy(loss_history, acc_history, loss_filename, acc_filename):
    epochs = range(1, len(loss_history['train_loss']) + 1)

    # Plot Loss
    plt.figure(figsize=(8,6))
    plt.plot(epochs, loss_history['train_loss'], 'b-', label='Train Loss')
    plt.plot(epochs, loss_history['val_loss'], 'r--', label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(loss_filename)
    plt.close()

    # Plot Accuracy
    plt.figure(figsize=(8,6))
    plt.plot(epochs, acc_history['train_acc'], 'g-', label='Train Accuracy')
    plt.plot(epochs, acc_history['val_acc'], 'm--', label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy over Epochs')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(acc_filename)
    plt.close()
    
plot_loss_and_accuracy(loss, acc, 'training_validation_loss.png', 'training_validation_acc.png')


# Plot ROC Curve

In [132]:
def plot_roc_curve(y_true, y_scores, filename):
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)
    
    plt.figure(figsize=(8,6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate (Recall)')
    plt.title('AUC-ROC Curve (Test Set)')
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

plot_roc_curve(all_targets, all_probs, 'val_roc_curve.png')

# Plot Precision-Recall Curve

In [133]:
def plot_precision_recall_curve(y_true, y_scores, filename):
    precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
    avg_precision = average_precision_score(y_true, y_scores)
    
    plt.figure(figsize=(8,6))
    plt.plot(recall, precision, color='blue', lw=2, label=f'Precision-Recall curve (AP = {avg_precision:.4f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve (Validation)')
    plt.legend(loc="lower left")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()
    
plot_precision_recall_curve(all_targets, all_probs, 'val_precision_recall_curve.png')


# Print Best Evaluation Metrics (Training & Validation)

In [134]:
def plot_and_save_cm(cm, title, filename):
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
    plt.title(title)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()
# Final metric printouts
# print(f"\nBest Training Accuracy: {best_train_acc:.4f}")
print(f"\nBest Training AUC: {best_train_auc:.4f}")
print(f"Best Training Precision: {best_train_prec:.4f}")
print(f"Best Training Recall: {best_train_rec:.4f}")
print(f"Best Training F1-score: {best_train_f1:.4f}")
print("Best Training Confusion Matrix:")
print(best_train_cm)
plot_and_save_cm(best_train_cm, 'Best Training Confusion Matrix', 'best_train_cm.png')

# print(f"\nBest Validation Accuracy: {best_val_acc:.4f}")
print(f"\nBest Validation AUC: {best_val_auc:.4f}")
print(f"Best Validation Precision: {best_val_prec:.4f}")
print(f"Best Validation Recall: {best_val_rec:.4f}")
print(f"Best Validation F1-score: {best_val_f1:.4f}")
print("Best Validation Confusion Matrix:")
print(best_val_cm)
plot_and_save_cm(best_val_cm, 'Best Validation Confusion Matrix', 'best_val_cm.png')


Best Training AUC: 0.8525
Best Training Precision: 0.1941
Best Training Recall: 0.8049
Best Training F1-score: 0.3128
Best Training Confusion Matrix:
[[1231  274]
 [  16   66]]

Best Validation AUC: 0.8601
Best Validation Precision: 0.1759
Best Validation Recall: 0.9048
Best Validation F1-score: 0.2946
Best Validation Confusion Matrix:
[[287  89]
 [  2  19]]
