In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import timm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm

In [2]:
DATA_DIR = '/media/dtsarev/SatSSD/data/face_images'
TRAIN_CSV = os.path.join('/media/dtsarev/SatSSD/data', 'train_split.csv')
VAL_CSV   = os.path.join('/media/dtsarev/SatSSD/data', 'valid_split.csv')
EMOTIONS = ['Admiration', 'Amusement', 'Determination', 'Empathic Pain', 'Excitement', 'Joy']
IMG_SIZE = 224
BATCH_SIZE = 1  # One directory per batch
CHUNK_SIZE = 16  # Process this many images at a time to limit memory
LR = 1e-4
EPOCHS = 10
NUM_WORKERS = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
def pearson_corr(preds, targets):
    """
    Compute mean Pearson correlation across all emotions.
    """
    preds = preds.detach().cpu()
    targets = targets.detach().cpu()
    vx = preds - preds.mean(0)
    vy = targets - targets.mean(0)
    corr = (vx * vy).sum(0) / (torch.sqrt((vx**2).sum(0) * (vy**2).sum(0)) + 1e-8)
    return corr.mean().item()

In [6]:
class FaceEmotionDataset(Dataset):
    """
    Loads all images from each subject directory and corresponding emotion labels.
    Filters out directories with no images.
    """
    def __init__(self, csv_file, data_dir, transform=None):
        df = pd.read_csv(csv_file)
        self.data_dir = data_dir
        self.transform = transform
        self.emotions = EMOTIONS
        self.samples = []  # list of (dir_path, label_tensor, image_files)
        for _, row in df.iterrows():
            dir_name = f"{int(row['Filename']):05d}"
            dir_path = os.path.join(self.data_dir, dir_name)
            if not os.path.isdir(dir_path):
                continue
            files = [f for f in os.listdir(dir_path) if f.lower().endswith('.jpg')]
            if not files:
                continue
            labels = torch.tensor([row[e] for e in self.emotions], dtype=torch.float32)
            self.samples.append((dir_path, labels, files))
        if len(self.samples) < len(df):
            print(f"Warning: {len(df) - len(self.samples)} samples removed due to missing images.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        dir_path, labels, files = self.samples[idx]
        images = []
        for fname in files:
            file_path = os.path.join(dir_path, fname)
            # Open and close file handle properly to avoid too many open files
            with Image.open(file_path) as img_src:
                img = img_src.convert('RGB')
            if self.transform:
                img = self.transform(img)
            images.append(img)
        return images, labels

In [7]:
def collate_fn(batch):
    images_list = [item[0] for item in batch]  # list of list of img tensors
    labels = torch.stack([item[1] for item in batch], dim=0)
    return images_list, labels

In [8]:
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = FaceEmotionDataset(TRAIN_CSV, DATA_DIR, transform=train_transform)
val_dataset   = FaceEmotionDataset(VAL_CSV, DATA_DIR, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, collate_fn=collate_fn)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, collate_fn=collate_fn)



In [9]:
model = timm.create_model('resnet50', pretrained=True, num_classes=len(EMOTIONS))
model = model.to(DEVICE)

In [10]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

In [11]:
def aggregate_predictions(model, images, device):
    """
    Predict on a list of image tensors in chunks, sum then average.
    """
    total = torch.zeros(len(EMOTIONS), device=device)
    count = 0
    for i in range(0, len(images), CHUNK_SIZE):
        chunk = images[i:i+CHUNK_SIZE]
        tensor = torch.stack(chunk, dim=0).to(device)
        with torch.set_grad_enabled(model.training):
            preds = model(tensor)  # shape [chunk_size, num_emotions]
        total += preds.sum(dim=0)
        count += preds.size(0)
        # free memory
        del tensor, preds
        torch.cuda.empty_cache()
    avg = total / count
    return avg

In [12]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_targets = []

    for images_list, labels in tqdm(loader):
        # batch size is 1: one directory per batch
        images = images_list[0]
        labels_dir = labels.to(device)
        dir_size = len(images)

        optimizer.zero_grad()
        # chunk-wise forward and backward to limit memory
        total_loss = 0.0
        for i in range(0, dir_size, CHUNK_SIZE):
            chunk = images[i:i+CHUNK_SIZE]
            tensor = torch.stack(chunk, dim=0).to(device)
            outputs = model(tensor)  # [chunk_size, num_emotions]
            # repeat label for each frame
            labels_rep = labels_dir.repeat(outputs.size(0), 1)
            loss_chunk = criterion(outputs, labels_rep)
            # weight by fraction of frames
            weight = outputs.size(0) / dir_size
            (loss_chunk * weight).backward()
            total_loss += loss_chunk.item() * weight
            # free memory
            del tensor, outputs, labels_rep
            torch.cuda.empty_cache()

        optimizer.step()
        running_loss += total_loss

        # compute prediction for metrics (no grad)
        with torch.no_grad():
            avg_pred = aggregate_predictions(model, images, device)
        all_preds.append(avg_pred.detach().cpu())
        all_targets.append(labels_dir.detach().cpu())

    epoch_loss = running_loss / len(loader.dataset)
    preds_cat = torch.stack(all_preds)
    targets_cat = torch.stack(all_targets)
    epoch_corr = pearson_corr(preds_cat, targets_cat)
    return epoch_loss, epoch_corr

In [13]:
def eval_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for images_list, labels in tqdm(loader):
            batch_preds = []
            for images in images_list:
                avg_pred = aggregate_predictions(model, images, device)
                batch_preds.append(avg_pred)
            preds = torch.stack(batch_preds, dim=0)
            labels = labels.to(device)

            loss = criterion(preds, labels)
            running_loss += loss.item() * preds.size(0)
            all_preds.append(preds.detach().cpu())
            all_targets.append(labels.detach().cpu())

    epoch_loss = running_loss / len(loader.dataset)
    preds_cat = torch.cat(all_preds)
    targets_cat = torch.cat(all_targets)
    overall_corr = pearson_corr(preds_cat, targets_cat)
    per_emotion_corrs = []
    for i in range(targets_cat.size(1)):
        vx = preds_cat[:, i] - preds_cat[:, i].mean()
        vy = targets_cat[:, i] - targets_cat[:, i].mean()
        corr = (vx * vy).sum() / (torch.sqrt((vx**2).sum() * (vy**2).sum()) + 1e-8)
        per_emotion_corrs.append(corr.item())
    return epoch_loss, overall_corr, per_emotion_corrs

In [14]:
train_losses, train_corrs = [], []
val_losses, val_corrs = [], []
best_val_corr = -1
best_model_wts = None

for epoch in range(1, EPOCHS + 1):
    train_loss, train_corr = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    val_loss, val_corr, val_corr_per_emotion = eval_epoch(model, val_loader, criterion, DEVICE)
    scheduler.step()

    train_losses.append(train_loss)
    train_corrs.append(train_corr)
    val_losses.append(val_loss)
    val_corrs.append(val_corr)

    print(f"Epoch {epoch}/{EPOCHS} | Train Loss: {train_loss:.4f}, Train Corr: {train_corr:.4f} | "
          f"Val Loss: {val_loss:.4f}, Val Corr: {val_corr:.4f}")
    if val_corr > best_val_corr:
        best_val_corr = val_corr
        best_model_wts = model.state_dict().copy()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7953/7953 [41:33<00:00,  3.19it/s]
 79%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                 | 3604/4565 [03:29<00:55, 17.38it/s]Exception ignored in sys.unraisablehook: <built-in function unraisablehook>
Traceback (most recent call last):
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3027, in write
    result = original_write(data, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/ipykernel/iostream.py", line 692, in write
    self.pub_thread.schedule(self._flush)
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/

KeyboardInterrupt: 

In [None]:
model.load_state_dict(best_model_wts)

In [None]:
plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(range(1, EPOCHS+1), train_losses, label='Train Loss')
plt.plot(range(1, EPOCHS+1), val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Loss over Epochs')
plt.legend()

plt.subplot(1,2,2)
plt.plot(range(1, EPOCHS+1), train_corrs, label='Train Pearson')
plt.plot(range(1, EPOCHS+1), val_corrs, label='Val Pearson')
plt.xlabel('Epoch')
plt.ylabel('Pearson Correlation')
plt.title('Correlation over Epochs')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
print('Per-emotion Pearson correlation on validation set:')
for emo, corr in zip(EMOTIONS, val_corr_per_emotion):
    print(f"{emo}: {corr:.4f}")
