In [None]:
import os, random, torch, glob
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import timm
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from sklearn.model_selection import train_test_split
import seaborn as sns

# ====== Temporal Drowsiness Dataset ======
class TemporalDrowsinessDataset(Dataset):
    """Temporal dataset for blinking detection"""
    def __init__(self, subject_dirs, transform=None, sequence_length=5, stride=2):
        self.transform = transform
        self.sequence_length = sequence_length
        self.stride = stride
        self.samples = []

        print(f"Building temporal dataset with window={sequence_length}, stride={stride}")
        valid_sequences = 0
        skipped_sequences = 0

        for subject_dir in subject_dirs:
            frames_dir = os.path.join(subject_dir, "frames")
            csv_path = os.path.join(frames_dir, "labels.csv")

            if not os.path.exists(csv_path):
                continue

            df = pd.read_csv(csv_path)
            df = df.sort_values('filename')

            for i in range(0, len(df) - sequence_length + 1, stride):
                seq_files = []
                seq_labels = []
                all_valid = True

                for j in range(i, i + sequence_length):
                    img_path = os.path.join(frames_dir, df.iloc[j]['filename'])

                    if os.path.exists(img_path):
                        try:
                            with Image.open(img_path) as img:
                                img.verify()
                            seq_files.append(img_path)
                            seq_labels.append(df.iloc[j]['label'])
                        except:
                            all_valid = False
                            break
                    else:
                        all_valid = False
                        break

                if all_valid and len(seq_files) == sequence_length:
                    # Majority vote for label
                    target_label = 1 if sum(seq_labels) > len(seq_labels) // 2 else 0
                    self.samples.append((seq_files, target_label, seq_labels))
                    valid_sequences += 1
                else:
                    skipped_sequences += 1

        print(f"Created {valid_sequences} valid sequences, skipped {skipped_sequences}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        seq_files, label, _ = self.samples[idx]
        frames = []

        for img_path in seq_files:
            try:
                image = Image.open(img_path).convert("RGB")
                if self.transform:
                    image = self.transform(image)
                frames.append(image)
            except:
                # Fallback to black frame
                if self.transform:
                    black_image = Image.new('RGB', (256, 256), color='black')
                    image = self.transform(black_image)
                else:
                    image = torch.zeros(3, 256, 256)
                frames.append(image)

        frames = torch.stack(frames, dim=0)  # [T, C, H, W]
        return frames, label

# ====== Temporal Blinking Model ======
class TemporalBlinkingModel(nn.Module):
    def __init__(self, backbone_name='mobilevit_s', pretrained=True,
                 num_classes=2, sequence_length=5, hidden_dim=128):
        super().__init__()

        # Spatial feature extractor
        self.backbone = timm.create_model(backbone_name, pretrained=pretrained, num_classes=0)
        feat_dim = self.backbone.num_features

        # Temporal modeling
        self.lstm = nn.LSTM(
            feat_dim, hidden_dim,
            num_layers=2,
            batch_first=True,
            dropout=0.3,
            bidirectional=True
        )

        # Attention mechanism
        self.attention = nn.MultiheadAttention(
            hidden_dim * 2,  # Bidirectional
            num_heads=4,
            dropout=0.2
        )

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, num_classes)
        )

        self.sequence_length = sequence_length

    def extract_features(self, x):
        """Extract spatial features for each frame"""
        batch_size, seq_len, c, h, w = x.size()
        x = x.view(batch_size * seq_len, c, h, w)

        features = self.backbone.forward_features(x)
        features = features.mean(dim=[2, 3])  # Global average pooling
        features = features.view(batch_size, seq_len, -1)

        return features

    def forward(self, x):
        # x shape: [batch_size, sequence_length, channels, height, width]
        features = self.extract_features(x)

        # LSTM for temporal modeling
        lstm_out, (h_n, c_n) = self.lstm(features)

        # Apply attention
        lstm_out = lstm_out.transpose(0, 1)  # [seq_len, batch, features]
        attn_out, attn_weights = self.attention(lstm_out, lstm_out, lstm_out)
        attn_out = attn_out.transpose(0, 1)  # [batch, seq_len, features]

        # Temporal aggregation
        temporal_features = attn_out.mean(dim=1)  # Mean pooling

        # Classification
        output = self.classifier(temporal_features)

        return output

# ====== Training Functions ======
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for frames, labels in tqdm(loader, desc="Training"):
        frames, labels = frames.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(frames)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    return total_loss / len(loader), 100. * correct / total

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for frames, labels in tqdm(loader, desc="Validating"):
            frames, labels = frames.to(device), labels.to(device)

            outputs = model(frames)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return total_loss / len(loader),  100. * correct / total, all_preds, all_labels

# ====== Visualization ======
def plot_training_history(train_losses, val_losses, train_accs, val_accs):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    ax1.plot(train_losses, label='Train Loss')
    ax1.plot(val_losses, label='Val Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True)

    ax2.plot(train_accs, label='Train Acc')
    ax2.plot(val_accs, label='Val Acc')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.savefig('blinking_training_history.png', dpi=150)
    plt.show()

def plot_confusion_matrix(y_true, y_pred, class_names=['Not Blinking', 'Blinking']):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix - Blinking Detection')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig('blinking_confusion_matrix.png', dpi=150)
    plt.show()

# ====== Main Training Script ======
def main():
    # Configuration
    SEQUENCE_LENGTH = 5
    STRIDE = 2
    BATCH_SIZE = 8
    NUM_EPOCHS = 50
    LEARNING_RATE = 1e-4

    # Device selection (MPS for Mac, CUDA for NVIDIA, CPU fallback)
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"Using CUDA: {torch.cuda.get_device_name(0)}")
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
        print("Using Apple Metal Performance Shaders (MPS)")
    else:
        device = torch.device('cpu')
        print("Using CPU")

    # Data transforms
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_val = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Dataset paths
    drowsy_base = "/Users/shikharsrivastava/Desktop/Thesis/Thesis/MultiTask/Drowsiness/S5"

    # Get subject directories
    subjects = [d for d in os.listdir(drowsy_base)
                if os.path.isdir(os.path.join(drowsy_base, d)) and d.isdigit()]
    subjects = sorted(subjects, key=int)

    print(f"\nFound {len(subjects)} subjects: {subjects}")

    # Subject-wise split (70-15-15)
    random.seed(42)
    random.shuffle(subjects)

    n_train = int(len(subjects) * 0.7)
    n_val = int(len(subjects) * 0.15)

    train_subjects = subjects[:n_train]
    val_subjects = subjects[n_train:n_train + n_val]
    test_subjects = subjects[n_train + n_val:]

    print(f"\nSplit - Train: {train_subjects}, Val: {val_subjects}, Test: {test_subjects}")

    # Create datasets
    train_dirs = [os.path.join(drowsy_base, s) for s in train_subjects]
    val_dirs = [os.path.join(drowsy_base, s) for s in val_subjects]
    test_dirs = [os.path.join(drowsy_base, s) for s in test_subjects]

    print("\nCreating datasets...")
    train_dataset = TemporalDrowsinessDataset(train_dirs, transform_train, SEQUENCE_LENGTH, STRIDE)
    val_dataset = TemporalDrowsinessDataset(val_dirs, transform_val, SEQUENCE_LENGTH, STRIDE)
    test_dataset = TemporalDrowsinessDataset(test_dirs, transform_val, SEQUENCE_LENGTH, STRIDE)

    print(f"\nDataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    # Model, loss, optimizer
    model = TemporalBlinkingModel(
        backbone_name='mobilevit_s',
        pretrained=True,
        num_classes=2,
        sequence_length=SEQUENCE_LENGTH
    ).to(device)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nTotal parameters: {total_params:,}")

    # Weighted loss for class imbalance
    criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0]).to(device))
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)

    # Training loop
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    best_val_acc = 0

    print("\nStarting training...")
    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")

        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)

        # Validate
        val_loss, val_acc, _, _ = validate(model, val_loader, criterion, device)

        # Step scheduler
        scheduler.step()

        # Save metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_blinking_model.pth')
            print(f"Best model saved with Val Acc: {val_acc:.2f}%")

    # Plot training history
    plot_training_history(train_losses, val_losses, train_accs, val_accs)

    # Test evaluation
    print("\n=== Testing Best Model ===")
    model.load_state_dict(torch.load('best_blinking_model.pth'))
    test_loss, test_acc, test_preds, test_labels = validate(model, test_loader, criterion, device)

    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.2f}%")

    # Classification report
    print("\nClassification Report:")
    print(classification_report(test_labels, test_preds,
                              target_names=['Not Blinking', 'Blinking']))

    # Confusion matrix
    plot_confusion_matrix(test_labels, test_preds)

    print("\n=== Training Complete ===")
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
    print(f"Test Accuracy: {test_acc:.2f}%")
    print(f"Model saved as: best_blinking_model.pth")

if __name__ == "__main__":
    main()