<a href="https://colab.research.google.com/github/Madhumasa84/BitCamp2023_Cyber_masa/blob/main/Untitled1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns


class MultiModalSyntheticDataset(Dataset):
    def __init__(self, n=800):
        super().__init__()
        self.n = n

        self.eeg = np.random.randn(n, 1, 64, 128).astype(np.float32)

        self.face = np.random.randn(n, 3, 224, 224).astype(np.float32)

        self.phys = np.random.randn(n, 128).astype(np.float32)

        self.landmarks = np.random.randn(n, 68, 2).astype(np.float32)
        # Labels (0 or 1)
        self.labels = np.random.randint(0, 2, size=n)

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return {
            "eeg": torch.tensor(self.eeg[idx]),
            "face": torch.tensor(self.face[idx]),
            "phys": torch.tensor(self.phys[idx]),
            "landmarks": torch.tensor(self.landmarks[idx]),
            "label": torch.tensor(self.labels[idx]).long()
        }


class DenseNet201(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.model = timm.create_model("densenet201", pretrained=False, num_classes=0, in_chans=3)
        self.fc = nn.Linear(self.model.num_features, num_classes)

    def forward(self, x):
        x = self.model(x)
        x = self.fc(x)
        return x

# 2.2 Swin Transformer
class SwinTransformer(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.model = timm.create_model(
            "swin_tiny_patch4_window7_224",
            pretrained=False,
            num_classes=num_classes
        )

    def forward(self, x):
        return self.model(x)

# 2.3 CNN + BiLSTM for physiological signals
class CNNBiLSTM(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv1d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        self.lstm = nn.LSTM(
            input_size=64,
            hidden_size=128,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )

        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.cnn(x)
        x = x.permute(0, 2, 1)
        o, _ = self.lstm(x)
        o = o[:, -1, :]
        return self.fc(o)

# 2.4 GCN for facial landmarks
class SimpleGCN(nn.Module):
    def __init__(self, num_nodes=68, in_features=2, hidden=64, num_classes=2):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.out = nn.Linear(hidden, num_classes)

        # adjacency matrix (fully connected for simplicity)
        self.adj = torch.ones(num_nodes, num_nodes)

    def forward(self, x):
        A = self.adj.to(x.device)
        x = F.relu(self.fc1(torch.einsum("ij,bjf->bif", A, x)))
        x = F.relu(self.fc2(torch.einsum("ij,bjf->bif", A, x)))
        x = x.mean(dim=1)
        return self.out(x)

# 2.5 Temporal Conformer
class ConformerBlock(nn.Module):
    def __init__(self, dim=128, heads=4):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim*4),
            nn.GELU(),
            nn.Linear(dim*4, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        h = self.norm1(x)
        x = x + self.attn(h, h, h)[0]
        h = self.norm2(x)
        x = x + self.ff(h)
        return x

class TemporalConformer(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.fc_in = nn.Linear(1, 128)
        self.block = ConformerBlock(128)
        self.out = nn.Linear(128, num_classes)

    def forward(self, x):
        x = x.unsqueeze(-1)
        x = self.fc_in(x)
        x = self.block(x)
        x = x.mean(dim=1)
        return self.out(x)

--
def train(model, loader, device):
    model.train()
    optim = torch.optim.Adam(model.parameters(), lr=3e-4)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(3):
        total_loss = 0
        for batch in loader:
            x = batch["face"].to(device)          # using face as default input
            y = batch["label"].to(device)

            optim.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optim.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1} Loss: {total_loss/len(loader):.4f}")


def evaluate(model, loader, device, model_name="model"):
    model.eval()
    preds, trues = [], []

    with torch.no_grad():
        for batch in loader:
            x = batch["face"].to(device)
            y = batch["label"].to(device)

            out = model(x)
            p = out.argmax(dim=1)

            preds.extend(p.cpu().numpy())
            trues.extend(y.cpu().numpy())

    acc = accuracy_score(trues, preds)
    print(f"{model_name} Accuracy: {acc*100:.2f}%")

    cm = confusion_matrix(trues, preds)
    plt.figure(figsize=(5,4))
    sns.heatmap(cm, annot=True, cmap="Blues", fmt="d")
    plt.title(f"Confusion Matrix â€“ {model_name}")
    plt.savefig(f"{model_name}_cm.png", dpi=200)
    plt.close()

    print(f"Saved confusion matrix as {model_name}_cm.png")

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Dataset
    ds = MultiModalSyntheticDataset()
    train_loader = DataLoader(ds, batch_size=16, shuffle=True)
    val_loader = DataLoader(ds, batch_size=16)

    # ----- RUN ALL MODELS -----

    print("\n=== DenseNet201 ===")
    m1 = DenseNet201().to(device)
    train(m1, train_loader, device)
    evaluate(m1, val_loader, device, "DenseNet201")

    print("\n=== Swin Transformer ===")
    m2 = SwinTransformer().to(device)
    train(m2, train_loader, device)
    evaluate(m2, val_loader, device, "SwinTransformer")

    print("\n=== CNN-BiLSTM ===")
    m3 = CNNBiLSTM().to(device)
    train(m3, train_loader, device)
    evaluate(m3, val_loader, device, "CNN_BiLSTM")

    print("\n=== GCN ===")
    m4 = SimpleGCN().to(device)
    train(m4, train_loader, device)
    evaluate(m4, val_loader, device, "GCN")

    print("\n=== Conformer ===")
    m5 = TemporalConformer().to(device)
    train(m5, train_loader, device)
    evaluate(m5, val_loader, device, "Conformer")
