In [11]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import os
import pandas as pd
from typing import Optional, Tuple
import numpy as np
import json
from tqdm import tqdm

In [12]:
compiled_labels_path = 'compiled_labels.csv'
img_dir = './faces'

if not os.path.exists(compiled_labels_path):
    json_file_path = 'training_raw_data.json'
    with open(json_file_path, 'r') as file:
        data = json.load(file)

    emotion_list = ['adoration', 'affection', 'aggravation', 'agitation', 'agony', 'alarm', 'alienation', 'amazement', 'amusement', 'anger', 'anguish', 'annoyance', 'anxiety', 'apprehension', 'arousal', 'astonishment', 'attraction', 'bitterness', 'bliss', 'caring', 'cheerfulness', 'compassion', 'contempt', 'contentment', 'defeat', 'dejection', 'delight', 'depression', 'desire', 'despair', 'disappointment', 'disgust', 'dislike', 'dismay', 'displeasure', 'distress', 'dread', 'eagerness', 'ecstasy', 'elation', 'embarrassment', 'enjoyment', 'enthrallment', 'enthusiasm', 'envy', 'euphoria', 'exasperation', 'excitement', 'exhilaration', 'fear', 'ferocity', 'fondness', 'fright', 'frustration', 'fury', 'gaiety', 'gladness', 'glee', 'gloom', 'glumness', 'grief', 'grouchiness', 'grumpiness', 'guilt', 'happiness', 'hate', 'homesickness', 'hope', 'hopelessness', 'horror', 'hostility', 'humiliation', 'hurt', 'hysteria', 'infatuation', 'insecurity', 'insult', 'irritation', 'isolation', 'jealousy', 'jolliness', 'joviality', 'joy', 'jubilation', 'liking', 'loathing', 'loneliness', 'longing', 'love', 'lust', 'melancholy', 'misery', 'mortification', 'neglect', 'nervousness', 'optimism', 'outrage', 'panic', 'passion', 'pity', 'pleasure', 'pride', 'rage', 'rapture', 'regret', 'rejection', 'relief', 'remorse', 'resentment', 'revulsion', 'sadness', 'satisfaction', 'scorn', 'sentimentality', 'shame', 'shock', 'sorrow', 'spite', 'suffering', 'surprise', 'sympathy', 'tenderness', 'tenseness', 'terror', 'thrill', 'torment', 'triumph', 'uneasiness', 'unhappiness', 'vengefulness', 'woe', 'worry', 'wrath', 'zeal', 'zest']

    label_to_idx = {emotion: idx for idx, emotion in enumerate(emotion_list)}

    labels = []

    for image in data:
        labels.append(label_to_idx[image['label']])

    df = pd.DataFrame()
    df['Label'] = labels
    df['Name'] = [str(i) + ".jpg" for i in range(len(df))]

    valid_images = []
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Checking image paths"):
        img_path = os.path.join(img_dir, row['Name'])
        if os.path.exists(img_path):
            valid_images.append(idx)

    df = df.iloc[valid_images].reset_index(drop=True)
    df.to_csv(compiled_labels_path, index=False)
else:
    df = pd.read_csv(compiled_labels_path)

In [13]:
class EmotionDataset(Dataset):
    def __init__(self, df: pd.DataFrame, img_dir: str, transform: Optional[transforms.Compose] = None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform if transform is not None else transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        img_path = os.path.join(self.img_dir, self.df.iloc[idx]['Name'])
        
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            
            label = self.df.iloc[idx]['Label']
            return image, label
            
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
            return torch.zeros((3, 224, 224)), -1

class EmotionClassifier(nn.Module):
    def __init__(self, num_classes: int = 135, pretrained: bool = True):
        super(EmotionClassifier, self).__init__()
        self.resnet = models.resnet101(pretrained=pretrained)
        
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        for param in self.resnet.parameters():
            param.requires_grad = False
        
        for param in self.resnet.layer4.parameters():
            param.requires_grad = True
        for param in self.resnet.layer3.parameters():
            param.requires_grad = True
        for param in self.resnet.fc.parameters():
            param.requires_grad = True

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.resnet(x)

def create_data_loaders(df: pd.DataFrame, 
                       img_dir: str, 
                       batch_size: int = 32,
                       train_split: float = 0.95,
                       num_workers: int = 4) -> Tuple[DataLoader, DataLoader]:
    dataset = EmotionDataset(df, img_dir)
    
    train_size = int(train_split * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False
    )
    
    return train_loader, val_loader

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device {device}")

train_loader, val_loader = create_data_loaders(
    df=df,
    img_dir='./faces',
    batch_size=128
)

Using device cuda


In [17]:
model = EmotionClassifier(num_classes=135)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

best_acc = 21.0
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    running_loss = 0.0
    
    for batch_idx, (images, labels) in enumerate(train_pbar):
        if -1 in labels:
            continue
            
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss = (running_loss * batch_idx + loss.item()) / (batch_idx + 1)
        train_pbar.set_postfix({'loss': f'{running_loss:.4f}'})
    
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Valid]')
        
        for batch_idx, (images, labels) in enumerate(val_pbar):
            if -1 in labels:
                continue
                
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            val_loss = (val_loss * batch_idx + loss.item()) / (batch_idx + 1)
            val_pbar.set_postfix({
                'loss': f'{val_loss:.4f}',
                'acc': f'{100 * correct / total:.2f}%'
            })
        
        if epoch == 0 or (100 * correct / total) > best_acc:
            best_acc = 100 * correct / total
            with open('best_model.pkl', 'wb') as f:
                torch.save(model.state_dict(), f)
            print(f'Saving best model with accuracy {best_acc:.2f}%')
    
    print(f'Epoch {epoch+1}/{num_epochs}:')
    print(f'Training Loss: {running_loss:.4f}')
    print(f'Validation Loss: {val_loss:.4f}')
    print(f'Validation Accuracy: {100 * correct / total:.2f}%\n')

Epoch 1/20 [Train]: 100%|██████████████████████████████████████████████| 3836/3836 [37:43<00:00,  1.69it/s, loss=3.4608]
Epoch 1/20 [Valid]: 100%|████████████████████████████████████| 202/202 [01:54<00:00,  1.77it/s, loss=3.2453, acc=16.30%]


Saving best model with accuracy 16.30%
Epoch 1/20:
Training Loss: 3.4608
Validation Loss: 3.2453
Validation Accuracy: 16.30%



Epoch 2/20 [Train]: 100%|██████████████████████████████████████████████| 3836/3836 [06:37<00:00,  9.65it/s, loss=3.1862]
Epoch 2/20 [Valid]: 100%|████████████████████████████████████| 202/202 [00:11<00:00, 17.18it/s, loss=3.1383, acc=17.81%]


Saving best model with accuracy 17.81%
Epoch 2/20:
Training Loss: 3.1862
Validation Loss: 3.1383
Validation Accuracy: 17.81%



Epoch 3/20 [Train]: 100%|██████████████████████████████████████████████| 3836/3836 [06:15<00:00, 10.21it/s, loss=3.0518]
Epoch 3/20 [Valid]: 100%|████████████████████████████████████| 202/202 [00:11<00:00, 18.32it/s, loss=3.0263, acc=19.03%]


Saving best model with accuracy 19.03%
Epoch 3/20:
Training Loss: 3.0518
Validation Loss: 3.0263
Validation Accuracy: 19.03%



Epoch 4/20 [Train]: 100%|██████████████████████████████████████████████| 3836/3836 [06:18<00:00, 10.14it/s, loss=2.9415]
Epoch 4/20 [Valid]: 100%|████████████████████████████████████| 202/202 [00:11<00:00, 17.98it/s, loss=2.9747, acc=20.75%]


Saving best model with accuracy 20.75%
Epoch 4/20:
Training Loss: 2.9415
Validation Loss: 2.9747
Validation Accuracy: 20.75%



Epoch 5/20 [Train]: 100%|██████████████████████████████████████████████| 3836/3836 [06:21<00:00, 10.04it/s, loss=2.8359]
Epoch 5/20 [Valid]: 100%|████████████████████████████████████| 202/202 [00:11<00:00, 17.78it/s, loss=2.9575, acc=20.80%]


Saving best model with accuracy 20.80%
Epoch 5/20:
Training Loss: 2.8359
Validation Loss: 2.9575
Validation Accuracy: 20.80%



Epoch 6/20 [Train]: 100%|██████████████████████████████████████████████| 3836/3836 [06:22<00:00, 10.04it/s, loss=2.7287]
Epoch 6/20 [Valid]: 100%|████████████████████████████████████| 202/202 [00:11<00:00, 17.91it/s, loss=2.9569, acc=21.32%]


Saving best model with accuracy 21.32%
Epoch 6/20:
Training Loss: 2.7287
Validation Loss: 2.9569
Validation Accuracy: 21.32%



Epoch 7/20 [Train]: 100%|██████████████████████████████████████████████| 3836/3836 [06:22<00:00, 10.02it/s, loss=2.6184]
Epoch 7/20 [Valid]: 100%|████████████████████████████████████| 202/202 [00:11<00:00, 17.28it/s, loss=2.9659, acc=21.18%]


Epoch 7/20:
Training Loss: 2.6184
Validation Loss: 2.9659
Validation Accuracy: 21.18%



Epoch 8/20 [Train]: 100%|██████████████████████████████████████████████| 3836/3836 [06:23<00:00, 10.01it/s, loss=2.5010]
Epoch 8/20 [Valid]: 100%|████████████████████████████████████| 202/202 [00:11<00:00, 17.78it/s, loss=2.9973, acc=21.71%]


Saving best model with accuracy 21.71%
Epoch 8/20:
Training Loss: 2.5010
Validation Loss: 2.9973
Validation Accuracy: 21.71%



Epoch 9/20 [Train]: 100%|██████████████████████████████████████████████| 3836/3836 [06:26<00:00,  9.94it/s, loss=2.3800]
Epoch 9/20 [Valid]: 100%|████████████████████████████████████| 202/202 [00:11<00:00, 17.39it/s, loss=3.1324, acc=21.31%]


Epoch 9/20:
Training Loss: 2.3800
Validation Loss: 3.1324
Validation Accuracy: 21.31%



Epoch 10/20 [Train]: 100%|█████████████████████████████████████████████| 3836/3836 [06:24<00:00,  9.97it/s, loss=2.2586]
Epoch 10/20 [Valid]: 100%|███████████████████████████████████| 202/202 [00:11<00:00, 17.36it/s, loss=3.1488, acc=21.58%]


Epoch 10/20:
Training Loss: 2.2586
Validation Loss: 3.1488
Validation Accuracy: 21.58%



Epoch 11/20 [Train]: 100%|█████████████████████████████████████████████| 3836/3836 [06:39<00:00,  9.61it/s, loss=2.1417]
Epoch 11/20 [Valid]: 100%|███████████████████████████████████| 202/202 [00:11<00:00, 17.34it/s, loss=3.2386, acc=21.33%]


Epoch 11/20:
Training Loss: 2.1417
Validation Loss: 3.2386
Validation Accuracy: 21.33%



Epoch 12/20 [Train]: 100%|█████████████████████████████████████████████| 3836/3836 [06:25<00:00,  9.95it/s, loss=2.0270]
Epoch 12/20 [Valid]: 100%|███████████████████████████████████| 202/202 [00:11<00:00, 17.53it/s, loss=3.3883, acc=21.06%]


Epoch 12/20:
Training Loss: 2.0270
Validation Loss: 3.3883
Validation Accuracy: 21.06%



Epoch 13/20 [Train]: 100%|█████████████████████████████████████████████| 3836/3836 [06:25<00:00,  9.95it/s, loss=1.9172]
Epoch 13/20 [Valid]: 100%|███████████████████████████████████| 202/202 [00:11<00:00, 17.03it/s, loss=3.4207, acc=20.68%]


Epoch 13/20:
Training Loss: 1.9172
Validation Loss: 3.4207
Validation Accuracy: 20.68%



Epoch 14/20 [Train]:   3%|█▏                                             | 99/3836 [00:10<06:45,  9.21it/s, loss=1.6909]


KeyboardInterrupt: 