In [5]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms, datasets, models
from PIL import Image
from tqdm.notebook import tqdm

# HYPERPARAMETERS
BATCH_SIZE = 8
IMAGE_SIZE = (64, 64)
DEVICE = torch.device("cpu")

# FER-2013 Paths
fer_train_path = os.path.join("FER-2013", "data", "train")
fer_test_path  = os.path.join("FER-2013", "data", "test")

# DAiSEE Paths
daisee_train_csv = os.path.join("DAiSEE", "Labels", "TrainLabels.csv")
daisee_test_csv  = os.path.join("DAiSEE", "Labels", "TestLabels.csv")

daisee_train_dir = os.path.join("DAiSEE", "frames_out", "Train")
daisee_test_dir  = os.path.join("DAiSEE", "frames_out", "Test")
fer_map = {3: 3, 6: 3, 4: 2, 5: 1, 0: 0, 1: 0, 2: 1}

In [6]:
global_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class DaiseeDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        row = self.annotations.iloc[index]
        clip_id = str(row['ClipID']).replace('.avi', '')
        folder_path = os.path.join(self.root_dir, clip_id)
        frames = []
        for i in range(1, 6):
            try:
                img_path = os.path.join(folder_path, f"frame_{i:02d}.jpg")
                image = Image.open(img_path).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                frames.append(image)
            except:
                frames.append(torch.zeros(3, IMAGE_SIZE[0], IMAGE_SIZE[1]))
        return torch.stack(frames), int(row['Engagement'])

class FER2013Dataset(datasets.ImageFolder):
    def __init__(self, root, transform=None, label_mapping=None):
        super().__init__(root, transform=transform)
        self.label_mapping = label_mapping

    def __getitem__(self, index):
        path, target = self.samples[index]
        image = self.loader(path)
        if self.transform:
            image = self.transform(image)
        frames = torch.stack([image] * 5) 
        
        final_label = target
        if self.label_mapping:
            final_label = self.label_mapping.get(target, 0)
        return frames, final_label


class ResNetModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.backbone.fc = nn.Identity()
        self.fc = nn.Linear(512, 4)

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        feats = self.backbone(x)
        feats = feats.view(B, T, -1).mean(dim=1)
        return self.fc(feats)

In [7]:
print("--- Loading Datasets ---")

# DAiSEE
daisee_train = DaiseeDataset(daisee_train_csv, daisee_train_dir, global_transform)
daisee_test = DaiseeDataset(daisee_test_csv, daisee_test_dir, global_transform)

# FER-2013
if os.path.exists(fer_train_path):
    print(f"Loading FER Train from: {fer_train_path}")
    fer_train = FER2013Dataset(fer_train_path, global_transform, fer_map)
    train_set = ConcatDataset([daisee_train, fer_train])
else:
    print("Folder name error. Using DAiSEE only.")
    train_set = daisee_train

if os.path.exists(fer_test_path):
    print(f"Loading FER Test from: {fer_test_path}")
    fer_test = FER2013Dataset(fer_test_path, global_transform, fer_map)
    test_set = ConcatDataset([daisee_test, fer_test])
else:
    test_set = daisee_test

# 3. Create Loaders
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader  = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Training Samples: {len(train_set)} | Test Samples: {len(test_set)}")

--- Loading Datasets ---
Loading FER Train from: FER-2013\data\train
Loading FER Test from: FER-2013\data\test
Training Samples: 34067 | Test Samples: 8962


In [8]:
model = ResNetModel().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("\nStarting Training & Evaluation...")

for epoch in range(10):
    # --- TRAIN LOOP ---
    model.train()
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]")
    
    for inputs, labels in loop:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())

    # --- VALIDATION LOOP 
    model.eval() 
    correct = 0
    total = 0
    test_loss = 0.0

    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc=f"Epoch {epoch+1} [Test]"):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc = 100 * correct / total
    avg_loss = test_loss / len(test_loader)
    
    print(f"END OF EPOCH {epoch+1} -> Test Accuracy: {acc:.2f}% | Test Loss: {avg_loss:.4f}")
    
    # Save Model
    torch.save(model.state_dict(), f"focus_tracker_model_epoch_{epoch+1}.pth")


Starting Training & Evaluation...


Epoch 1 [Train]:   0%|          | 0/4259 [00:00<?, ?it/s]

Epoch 1 [Test]:   0%|          | 0/1121 [00:00<?, ?it/s]

END OF EPOCH 1 -> Test Accuracy: 45.78% | Test Loss: 1.3760


Epoch 2 [Train]:   0%|          | 0/4259 [00:00<?, ?it/s]

Epoch 2 [Test]:   0%|          | 0/1121 [00:00<?, ?it/s]

END OF EPOCH 2 -> Test Accuracy: 53.21% | Test Loss: 1.0604


Epoch 3 [Train]:   0%|          | 0/4259 [00:00<?, ?it/s]

Epoch 3 [Test]:   0%|          | 0/1121 [00:00<?, ?it/s]

END OF EPOCH 3 -> Test Accuracy: 58.34% | Test Loss: 1.0401


Epoch 4 [Train]:   0%|          | 0/4259 [00:00<?, ?it/s]

Epoch 4 [Test]:   0%|          | 0/1121 [00:00<?, ?it/s]

END OF EPOCH 4 -> Test Accuracy: 59.31% | Test Loss: 0.9239


Epoch 5 [Train]:   0%|          | 0/4259 [00:00<?, ?it/s]

Epoch 5 [Test]:   0%|          | 0/1121 [00:00<?, ?it/s]

END OF EPOCH 5 -> Test Accuracy: 62.12% | Test Loss: 0.8974


Epoch 6 [Train]:   0%|          | 0/4259 [00:00<?, ?it/s]

Epoch 6 [Test]:   0%|          | 0/1121 [00:00<?, ?it/s]

END OF EPOCH 6 -> Test Accuracy: 61.44% | Test Loss: 0.9115


Epoch 7 [Train]:   0%|          | 0/4259 [00:00<?, ?it/s]

Epoch 7 [Test]:   0%|          | 0/1121 [00:00<?, ?it/s]

END OF EPOCH 7 -> Test Accuracy: 60.62% | Test Loss: 0.9836


Epoch 8 [Train]:   0%|          | 0/4259 [00:00<?, ?it/s]

Epoch 8 [Test]:   0%|          | 0/1121 [00:00<?, ?it/s]

END OF EPOCH 8 -> Test Accuracy: 61.41% | Test Loss: 1.0180


Epoch 9 [Train]:   0%|          | 0/4259 [00:00<?, ?it/s]

Epoch 9 [Test]:   0%|          | 0/1121 [00:00<?, ?it/s]

END OF EPOCH 9 -> Test Accuracy: 60.99% | Test Loss: 1.1185


Epoch 10 [Train]:   0%|          | 0/4259 [00:00<?, ?it/s]

Epoch 10 [Test]:   0%|          | 0/1121 [00:00<?, ?it/s]

END OF EPOCH 10 -> Test Accuracy: 61.68% | Test Loss: 1.1014
