In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
import numpy as np
import matplotlib.pyplot as plt

# ------------------ Load Data ------------------
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

features = torch.load("all_features.pt")  # [N, T, 2048]
labels = torch.load("all_labels.pt")      # [N]

print("Loaded features:", features.shape)
print("Loaded labels:", labels.shape)

BATCH_SIZE = 8
dataset = TensorDataset(features, labels)
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size
train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)


# ------------------ Models ------------------
class AttentionPool(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attn = nn.Linear(hidden_dim * 2, 1)

    def forward(self, lstm_out):
        scores = self.attn(lstm_out)
        weights = torch.softmax(scores, dim=1)
        context = (weights * lstm_out).sum(dim=1)
        return context

class LSTMClassifier(nn.Module):
    def __init__(self, hidden_dim=768, dropout=0.5):
        super().__init__()
        self.lstm = nn.LSTM(2048, hidden_dim, num_layers=2, batch_first=True, bidirectional=True, dropout=dropout)
        self.attn_pool = AttentionPool(hidden_dim)
        self.batchnorm = nn.BatchNorm1d(hidden_dim * 2)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim // 2, 2)
        )

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        attn_vec = self.attn_pool(lstm_out)
        attn_vec = self.batchnorm(attn_vec)
        return self.fc(attn_vec)

class GRUClassifier(nn.Module):
    def __init__(self, hidden_dim=768, dropout=0.5):
        super().__init__()
        self.gru = nn.GRU(2048, hidden_dim, num_layers=2, batch_first=True, bidirectional=True, dropout=dropout)
        self.attn_pool = AttentionPool(hidden_dim)
        self.batchnorm = nn.BatchNorm1d(hidden_dim * 2)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim // 2, 2)
        )

    def forward(self, x):
        gru_out, _ = self.gru(x)
        attn_vec = self.attn_pool(gru_out)
        attn_vec = self.batchnorm(attn_vec)
        return self.fc(attn_vec)


# ------------------ Train Function ------------------
import random

# ------------------ Seed Control ------------------
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if device.type == 'cuda':
        torch.cuda.manual_seed_all(seed)

# ------------------ Train Function ------------------
def train_model(model, model_name, seed=42):
    set_seed(seed)
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    best_acc = 0.0
    best_state = None

    for epoch in range(30):
        model.train()
        running_loss = 0
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(X)
            loss = criterion(out, y)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # ✅ Gradient clipping
            optimizer.step()
            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)

        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for X_val, y_val in val_loader:
                X_val, y_val = X_val.to(device), y_val.to(device)
                outputs = model(X_val)
                _, predicted = torch.max(outputs, 1)
                total += y_val.size(0)
                correct += (predicted == y_val).sum().item()

        acc = correct / total
        scheduler.step()
        print(f"{model_name} Epoch {epoch+1} | Loss: {avg_loss:.4f} | Val Acc: {acc:.4f}")

        if acc > best_acc:
            best_acc = acc
            best_state = model.state_dict()

    torch.save(best_state, f"{model_name}.pth")
    print(f"✅ Saved best {model_name} model with Val Acc: {best_acc:.4f}\n")


# ------------------ Train All Models ------------------
train_model(LSTMClassifier(hidden_dim=768, dropout=0.5), "model1_lstm", seed=42)
train_model(LSTMClassifier(hidden_dim=512, dropout=0.3), "model2_lstm", seed=17)
train_model(GRUClassifier(hidden_dim=768, dropout=0.4), "model3_gru", seed=77)


# ------------------ Ensemble Prediction ------------------
def ensemble_predict(models, x):
    probs = []
    for model in models:
        model.eval()
        with torch.no_grad():
            out = model(x)
            prob = F.softmax(out, dim=1)
            probs.append(prob)
    avg_prob = torch.stack(probs).mean(dim=0)
    return avg_prob.argmax(dim=1)

# ------------------ Load Models for Inference ------------------
model1 = LSTMClassifier(hidden_dim=768, dropout=0.5)
model2 = LSTMClassifier(hidden_dim=512, dropout=0.3)
model3 = GRUClassifier(hidden_dim=768, dropout=0.4)
model1.load_state_dict(torch.load("model1_lstm.pth"))
model2.load_state_dict(torch.load("model2_lstm.pth"))
model3.load_state_dict(torch.load("model3_gru.pth"))

models = [model1.to(device), model2.to(device), model3.to(device)]

# ------------------ Final Evaluation ------------------
correct = 0
total = 0
for x_batch, y_batch in test_loader:
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)
    preds = ensemble_predict(models, x_batch)
    correct += (preds == y_batch).sum().item()
    total += y_batch.size(0)

print(f"\n🎯 Ensemble Test Accuracy: {correct / total:.4f}")


Using device: mps
Loaded features: torch.Size([2000, 32, 2048])
Loaded labels: torch.Size([2000])
model1_lstm Epoch 1 | Loss: 0.5515 | Val Acc: 0.8300
model1_lstm Epoch 2 | Loss: 0.5349 | Val Acc: 0.4833
model1_lstm Epoch 3 | Loss: 0.5384 | Val Acc: 0.8333
model1_lstm Epoch 4 | Loss: 0.5063 | Val Acc: 0.8433
model1_lstm Epoch 5 | Loss: 0.5094 | Val Acc: 0.8400
model1_lstm Epoch 6 | Loss: 0.4878 | Val Acc: 0.8300
model1_lstm Epoch 7 | Loss: 0.4948 | Val Acc: 0.8433
model1_lstm Epoch 8 | Loss: 0.4663 | Val Acc: 0.8233
model1_lstm Epoch 9 | Loss: 0.4759 | Val Acc: 0.8433
model1_lstm Epoch 10 | Loss: 0.4618 | Val Acc: 0.8500
model1_lstm Epoch 11 | Loss: 0.4599 | Val Acc: 0.8467
model1_lstm Epoch 12 | Loss: 0.4562 | Val Acc: 0.8433
model1_lstm Epoch 13 | Loss: 0.4754 | Val Acc: 0.8433
model1_lstm Epoch 14 | Loss: 0.4568 | Val Acc: 0.8300
model1_lstm Epoch 15 | Loss: 0.4657 | Val Acc: 0.8367
model1_lstm Epoch 16 | Loss: 0.4796 | Val Acc: 0.6367
model1_lstm Epoch 17 | Loss: 0.4865 | Val Acc: 