# Driver Fatigue Detection - Model Training & Comparison

This notebook trains and compares 4 models on the YAWDD dataset:
1. **Custom CNN**
2. **ResNet50**
3. **VGG16**
4. **MobileNetV4**

We use PyTorch for training.


In [None]:
# Install timm for MobileNetV4 if not installed
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import numpy as np
import timm
from tqdm import tqdm
import time
import copy

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Data Preparation

In [None]:
# Configuration
DATA_DIR = "D:/FYP_ARMAN/Driver-Fatigue-Detection-Using-Vision-Based-Machine-Learning/SPLIT_DATASET"
BATCH_SIZE = 32
IMG_SIZE = 224
NUM_CLASSES = 3
CLASSES = ['Normal', 'Talking', 'Yawning']

# Data Transforms
data_transforms = {
    'Train': transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'Val': transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'Test': transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# Load Datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x), data_transforms[x])
                  for x in ['Train', 'Val', 'Test']}

dataloaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=(x=='Train'), num_workers=2)
               for x in ['Train', 'Val', 'Test']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['Train', 'Val', 'Test']}
class_names = image_datasets['Train'].classes

print(f"Classes: {class_names}")
print(f"Samples: {dataset_sizes}")

## 2. Define Models

In [None]:
def get_model(model_name, num_classes=3, pretrained=True):
    model = None
    
    if model_name == 'CNN':
        # Simple Custom CNN
        class CustomCNN(nn.Module):
            def __init__(self, num_classes):
                super(CustomCNN, self).__init__()
                self.features = nn.Sequential(
                    nn.Conv2d(3, 32, kernel_size=3, padding=1),
                    nn.ReLU(),
                    nn.MaxPool2d(2, 2),
                    nn.Conv2d(32, 64, kernel_size=3, padding=1),
                    nn.ReLU(),
                    nn.MaxPool2d(2, 2),
                    nn.Conv2d(64, 128, kernel_size=3, padding=1),
                    nn.ReLU(),
                    nn.MaxPool2d(2, 2)
                )
                self.classifier = nn.Sequential(
                    nn.Flatten(),
                    nn.Linear(128 * 28 * 28, 512),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Linear(512, num_classes)
                )

            def forward(self, x):
                x = self.features(x)
                x = self.classifier(x)
                return x
        
        model = CustomCNN(num_classes)

    elif model_name == 'ResNet50':
        model = models.resnet50(pretrained=pretrained)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)

    elif model_name == 'VGG16':
        model = models.vgg16(pretrained=pretrained)
        num_ftrs = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_ftrs, num_classes)

    elif model_name == 'MobileNetV4':
        # Using timm for MobileNetV4. 
        # Note: timm model names might vary slightly, e.g., 'mobilenetv4_conv_small.e15_r224_in1k'
        # We will iterate through a few possible variants if exact match not found or default to a safe one
        try:
            # Try to load a specific variant or search. 
            # 'mobilenetv4_conv_medium.e15_in1k' is a common medium variant.
            # If not available, we assume the user might not have the very latest timm, so we fallback.
            model = timm.create_model('mobilenetv4_conv_medium.e15_in1k', pretrained=pretrained, num_classes=num_classes)
        except Exception as e:
            print(f"Note: MobileNetV4 not found in installed timm version ({e}). Falling back to MobileNetV3 Large.")
            model = models.mobilenet_v3_large(pretrained=pretrained)
            num_ftrs = model.classifier[3].in_features
            model.classifier[3] = nn.Linear(num_ftrs, num_classes)

    return model

# Test model creation
for name in ['CNN', 'ResNet50', 'VGG16', 'MobileNetV4']:
    try:
        m = get_model(name)
        print(f"{name} created successfully.")
    except Exception as e:
        print(f"Failed to create {name}: {e}")

## 3. Training Loop

In [None]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10):
    since = time.time()

    val_acc_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['Train', 'Val']:
            if phase == 'Train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase], desc=f"{phase} Epoch {epoch}"):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'Train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)

                    if phase == 'Train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'Val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            
            if phase == 'Val':
                val_acc_history.append(epoch_acc.item())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    model.load_state_dict(best_model_wts)
    return model, val_acc_history

## 4. Train and Compare Models

In [None]:
models_list = ['CNN', 'ResNet50', 'VGG16', 'MobileNetV4']
history_dict = {}
trained_models = {}

NUM_EPOCHS = 10  # You can adjust this

for model_name in models_list:
    print(f"\nTraining {model_name}...")
    model = get_model(model_name, num_classes=NUM_CLASSES)
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    # Observe that all parameters are being optimized
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
    model, val_hist = train_model(model, dataloaders, criterion, optimizer, num_epochs=NUM_EPOCHS)
    
    history_dict[model_name] = val_hist
    trained_models[model_name] = model
    
    # Save model
    torch.save(model.state_dict(), f"{model_name}_driver_fatigue.pth")

## 5. Comparison & Visualization

In [None]:
plt.figure(figsize=(10, 6))
for name, hist in history_dict.items():
    plt.plot(hist, label=name)

plt.title('Model Validation Accuracy Comparison')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.show()

## 6. Test Evaluation

In [None]:
def evaluate_on_test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

print("Final Test Accuracy:")
for name, model in trained_models.items():
    acc = evaluate_on_test(model, dataloaders['Test'])
    print(f"{name}: {acc:.2f}%")