In [2]:
import sys
sys.path.append("..")

from predict import perform_inference
from rfw_loader import create_dataloaders
from train import train, save_model

import torch

In [4]:
EPOCHS = 5
LEARNING_RATE = 0.01
RATIO = 0.8
BATCH_SIZE = 32
DEVICE = 2

In [5]:
device = torch.device(f'cuda:{DEVICE}' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=2)

# Get Train and Validation Loss Curves

In [6]:
import torch
import torchvision.models as models
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F


class MultiHeadResNet(nn.Module):
    def __init__(self, output_dims):
        super(MultiHeadResNet, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        num_features = self.resnet.fc.in_features
        self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1]))
        self.heads = nn.ModuleDict()
        for head, num_classes in output_dims.items():
            self.heads[head] = nn.Linear(num_features, num_classes)

    def forward(self, x):
        features = self.resnet(x).squeeze()
        outputs = {}
        for head, head_module in self.heads.items():
            output_logits = head_module(features)
            outputs[head] = F.softmax(output_logits, dim=1)
        return outputs

In [7]:
lambda_value = 64
data_rate = 1
ROOT = '/media/global_data/fair_neural_compression_data/decoded_rfw'

compressed_image_path = f'{ROOT}/progressive_64x64/qres17m/q_{lambda_value}/{data_rate}'
RFW_LABELS_DIR = "/media/global_data/fair_neural_compression_data/datasets/RFW/clean_metadata/numerical_labels.csv"
train_loader, valid_loader, test_loader = create_dataloaders(
    compressed_image_path, 
    RFW_LABELS_DIR, 
    BATCH_SIZE, 
    RATIO
)



In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np

RACE_LABELS = ['Indian', 'Asian', 'African', 'Caucasian']

def train_numerical_rfw(
        model, 
        num_epochs, 
        lr, 
        train_loader, 
        valid_loader,
        device,
        patience=5  # Number of epochs to wait for improvement in validation loss before stopping
    ):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    
    train_losses = []
    valid_losses = []
    
    best_valid_loss = float('inf')
    no_improvement_count = 0
    
    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs} - Training") as pbar:
            for inputs, targets, races in train_loader:
                inputs, targets = inputs.to(device).float(), targets.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = 0
                for i, head in enumerate(outputs):
                    loss += criterion(outputs[head], targets[:, i].to(torch.int64))
                loss.backward()
                optimizer.step()
                running_train_loss += loss.item() * inputs.size(0)
                avg_train_loss = running_train_loss / ((pbar.n + 1) * len(inputs))  # Compute average training loss
                pbar.set_postfix(train_loss=avg_train_loss)
                pbar.update(1)
        print(f'Epoch {i + 1} train loss : {avg_train_loss}')
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        running_valid_loss = 0.0
        with torch.no_grad():
            with tqdm(total=len(valid_loader), desc=f"Epoch {epoch+1}/{num_epochs} - Validation") as pbar:
                for inputs, targets, races in valid_loader:
                    inputs, targets = inputs.to(device).float(), targets.to(device)
                    outputs = model(inputs)
                    loss = 0
                    for i, head in enumerate(outputs):
                        loss += criterion(outputs[head], targets[:, i].to(torch.int64))
                    running_valid_loss += loss.item() * inputs.size(0)
                    avg_valid_loss = running_valid_loss / ((pbar.n + 1) * len(inputs))  # Compute average validation loss
                    pbar.set_postfix(valid_loss=avg_valid_loss)
                    pbar.update(1)
        print(f'Epoch {epoch + 1} valid loss : {avg_valid_loss}')
        valid_losses.append(avg_valid_loss)
        
        # Check for early stopping
        if avg_valid_loss < best_valid_loss:
            best_valid_loss = avg_valid_loss
            no_improvement_count = 0
        else:
            no_improvement_count += 1
            if no_improvement_count >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    return model, train_losses, valid_losses

In [None]:
output_dims = {
    'skin_type': 6,
    'eye_type': 2,
    'nose_type': 2,
    'lip_type': 2,
    'hair_type': 4,
    'hair_color': 5
}
model = MultiHeadResNet(output_dims).to(device)
num_epochs = 20
lr = 0.01

trained_model, train_losses, valid_losses = train_numerical_rfw(
    model, 
    20, 
    lr, 
    train_loader, 
    valid_loader,
    device,
    patience=5
)