In [47]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision import transforms, models
from PIL import Image
import os
from tqdm import tqdm
import pandas as pd
from typing import Optional, Tuple
from concurrent.futures import ThreadPoolExecutor
import json

In [48]:
compiled_labels_path = 'compiled_labels.csv'
img_dir = './faces'

if not os.path.exists(compiled_labels_path):
    json_file_path = 'training_raw_data.json'
    with open(json_file_path, 'r') as file:
        data = json.load(file)

    emotion_list = ['adoration', 'affection', 'aggravation', 'agitation', 'agony', 'alarm', 'alienation', 'amazement', 'amusement', 'anger', 'anguish', 'annoyance', 'anxiety', 'apprehension', 'arousal', 'astonishment', 'attraction', 'bitterness', 'bliss', 'caring', 'cheerfulness', 'compassion', 'contempt', 'contentment', 'defeat', 'dejection', 'delight', 'depression', 'desire', 'despair', 'disappointment', 'disgust', 'dislike', 'dismay', 'displeasure', 'distress', 'dread', 'eagerness', 'ecstasy', 'elation', 'embarrassment', 'enjoyment', 'enthrallment', 'enthusiasm', 'envy', 'euphoria', 'exasperation', 'excitement', 'exhilaration', 'fear', 'ferocity', 'fondness', 'fright', 'frustration', 'fury', 'gaiety', 'gladness', 'glee', 'gloom', 'glumness', 'grief', 'grouchiness', 'grumpiness', 'guilt', 'happiness', 'hate', 'homesickness', 'hope', 'hopelessness', 'horror', 'hostility', 'humiliation', 'hurt', 'hysteria', 'infatuation', 'insecurity', 'insult', 'irritation', 'isolation', 'jealousy', 'jolliness', 'joviality', 'joy', 'jubilation', 'liking', 'loathing', 'loneliness', 'longing', 'love', 'lust', 'melancholy', 'misery', 'mortification', 'neglect', 'nervousness', 'optimism', 'outrage', 'panic', 'passion', 'pity', 'pleasure', 'pride', 'rage', 'rapture', 'regret', 'rejection', 'relief', 'remorse', 'resentment', 'revulsion', 'sadness', 'satisfaction', 'scorn', 'sentimentality', 'shame', 'shock', 'sorrow', 'spite', 'suffering', 'surprise', 'sympathy', 'tenderness', 'tenseness', 'terror', 'thrill', 'torment', 'triumph', 'uneasiness', 'unhappiness', 'vengefulness', 'woe', 'worry', 'wrath', 'zeal', 'zest']

    label_to_idx = {emotion: idx for idx, emotion in enumerate(emotion_list)}

    labels = []

    for image in data:
        labels.append(label_to_idx[image['label']])

    df = pd.DataFrame()
    df['Label'] = labels
    df['Name'] = [str(i) + ".jpg" for i in range(len(df))]

    valid_images = []
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Checking image paths"):
        img_path = os.path.join(img_dir, row['Name'])
        if os.path.exists(img_path):
            valid_images.append(idx)

    df = df.iloc[valid_images].reset_index(drop=True)
    df.to_csv(compiled_labels_path, index=False)
else:
    df = pd.read_csv(compiled_labels_path)

In [49]:
class EmotionDataset(Dataset):
    def __init__(self, df: pd.DataFrame, images, transform: Optional[transforms.Compose] = None):
        self.df = df
        self.transform = transform if transform is not None else transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])
        
        self.images = images

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        image = self.images[idx]
        image = Image.fromarray(image)
        
        if self.transform:
            image = self.transform(image)
        
        label = self.df.iloc[idx]['Label']
        return image, label

class EmotionClassifier(nn.Module):
    def __init__(self, num_classes: int = 135, pretrained: bool = True):
        super(EmotionClassifier, self).__init__()
        self.resnet = models.resnet50(pretrained=pretrained)
        
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.resnet(x)

def create_data_loaders(df: pd.DataFrame, 
                       img_dir: str, 
                       batch_size: int = 32,
                       train_split: float = 0.95,
                       num_workers: int = 4) -> Tuple[DataLoader, DataLoader]:
    
    def load_image(idx):
        img_path = os.path.join(img_dir, df.iloc[idx]['Name'])
        try:
            image = Image.open(img_path).convert('RGB')
            image = np.array(image)
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
            image = np.zeros((224, 224, 3), dtype=np.uint8)
        return image

    print("Loading images into memory with multithreading...")
    with ThreadPoolExecutor() as executor:
        images = list(tqdm(executor.map(load_image, range(len(df))), total=len(df)))

    dataset = EmotionDataset(df, images)
    
    train_size = int(train_split * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False
    )
    
    return train_loader, val_loader

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_loader, val_loader = create_data_loaders(
    df=df,
    img_dir='./faces',
    batch_size=128
)

Loading images into memory with multithreading...


 17%|████████████▍                                                             | 87000/516926 [09:45<1:26:41, 82.65it/s]

In [None]:
model = EmotionClassifier(num_classes=135)

model = model.to(device)
# if torch.cuda.device_count() > 1:
#     print(f"Using {torch.cuda.device_count()} GPUs")
#     model = nn.DataParallel(model)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=True)
    for batch_idx, (images, labels) in enumerate(train_pbar):
        images, labels = images.to(device), labels.to(device)
        
        if -1 in labels:
            continue
            
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        avg_loss = running_loss / (batch_idx + 1)
        train_pbar.set_postfix({'loss': f'{avg_loss:.4f}'})
    
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    val_pbar = tqdm(val_loader, desc='Validation', leave=True)
    with torch.no_grad():
        for images, labels in val_pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            val_pbar.set_postfix({
                'loss': f'{val_loss/total:.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })
    
    print(f'\nEpoch {epoch+1}/{num_epochs} - '
            f'Train Loss: {avg_loss:.4f}, '
            f'Val Loss: {val_loss/len(val_loader):.4f}, '
            f'Val Acc: {100.*correct/total:.2f}%\n')



Using 3 GPUs


Epoch 1/10:  12%|██████▊                                                | 402/3230 [06:48<47:54,  1.02s/it, loss=4.0186]


KeyboardInterrupt: 