In [None]:
## train for 50 epochs on the combined dataset, then fine-tune on CK+ for 15 epochs

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset, Dataset
from torchvision import datasets, transforms


img_transforms = transforms.Compose([
    transforms.Resize((48, 48)),
    transforms.RandomResizedCrop((40, 40)),
    transforms.Grayscale(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

class GradientReversalFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambda_val):
        ctx.lambda_val = lambda_val
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambda_val, None

class GradientReversalLayer(nn.Module):
    def __init__(self, lambda_val=1.0):
        super(GradientReversalLayer, self).__init__()
        self.lambda_val = lambda_val

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambda_val)

class VGG19EmotionCNN(nn.Module):
    def __init__(self, num_classes, lambda_val=1.0):
        super(VGG19EmotionCNN, self).__init__()
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv_block3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv_block4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv_block5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc1 = nn.Linear(512, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, num_classes)
        self.dropout = nn.Dropout(0.1)
        self.grl = GradientReversalLayer(lambda_val)
        self.domain_classifier = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.conv_block4(x)
        x = self.conv_block5(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(nn.ReLU()(self.fc1(x)))
        x = self.dropout(nn.ReLU()(self.fc2(x)))
        emotion_output = self.fc3(x)
        domain_output = self.domain_classifier(self.grl(x))
        return emotion_output, domain_output


class DataWrapper(Dataset):
    def __init__(self, data, domain):
        self.data = data
        self.domain = domain

    def __getitem__(self, idx):
        x, y = self.data[idx]
        return x, y, self.domain

    def __len__(self):
        return len(self.data)

data_a = DataWrapper(
    datasets.ImageFolder(root='./datasets/fer2013', transform=img_transforms),
    domain=0
)

data_b = DataWrapper(
    datasets.ImageFolder(root='./datasets/raf-db/train', transform=img_transforms),
    domain=1
)

combined_data = ConcatDataset([data_a, data_b])

data_loader = DataLoader(
    combined_data,
    batch_size=32,
    shuffle=True,
)

num_classes = len(data_b.data.classes)
network = VGG19EmotionCNN(num_classes)
loss_fn1 = nn.CrossEntropyLoss()
loss_fn2 = nn.BCEWithLogitsLoss()

optimizer_config = optim.SGD(network.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001, nesterov=True)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer_config, mode='min', factor=0.5, patience=4)

epochs = 50
device_type = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
network = network.to(device_type)

for epoch in range(epochs):
    network.train()
    total_cycle_loss = 0

    for batch in data_loader:
        imgs, lbls, doms = batch
        
        imgs = imgs.to(device_type)
        lbls = lbls.to(device_type)
        doms = doms.float().unsqueeze(1).to(device_type)
       
        optimizer_config.zero_grad()

        emo_out, dom_out = network(imgs)
  
        emo_loss = loss_fn1(emo_out, lbls)
        dom_loss = loss_fn2(dom_out, doms)
    
        lambda_factor = 0.05
        combined_loss = emo_loss + lambda_factor * dom_loss

        combined_loss.backward()
        optimizer_config.step()
 
        total_cycle_loss += combined_loss.item()
    curr_lr = optimizer_config.param_groups[0]['lr']
    lr_scheduler.step(total_cycle_loss / len(data_loader))
    print(f"Epoch {epoch+1}/{epochs}, lr:{curr_lr:.4f}\tLoss: {total_cycle_loss / len(data_loader):.4f}")

save_file = './modelsAll/DANN32_50.pth'
torch.save({
    'model_state_dict': network.state_dict(),
    'optimizer_state_dict': optimizer_config.state_dict(),
    'scheduler_state_dict': lr_scheduler.state_dict(),
}, save_file)

print("Model DANN32_50 saved")

load_file = './modelsAll/DANN32_50.pth'
loaded_data = torch.load(load_file, map_location=device_type)
network.load_state_dict(loaded_data['model_state_dict'])
network = network.to(device_type)
network.train()

new_data = datasets.ImageFolder(root='./datasets/CK+48', transform=img_transforms)
new_loader = DataLoader(new_data, batch_size=32, shuffle=True)

optimizer_config = optim.SGD(network.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0001, nesterov=True)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer_config, mode='min', factor=0.5, patience=4)

loss_fn1 = nn.CrossEntropyLoss()

fine_tune_epochs = 15
for epoch in range(fine_tune_epochs):
    network.train()
    total_tune_loss = 0

    for batch in new_loader:
        imgs, lbls = batch
        imgs = imgs.to(device_type)
        lbls = lbls.to(device_type)

        optimizer_config.zero_grad()

        emo_out, _ = network(imgs)

        loss = loss_fn1(emo_out, lbls)

        loss.backward()
        optimizer_config.step()

        total_tune_loss += loss.item()
    
    curr_lr = optimizer_config.param_groups[0]['lr']
    lr_scheduler.step(total_tune_loss / len(new_loader))
    print(f"Epoch {epoch+1}/{fine_tune_epochs}, lr:{curr_lr:.4f}, Loss: {total_tune_loss / len(new_loader):.4f}")

fine_tune_save = './modelsAll/DANN32.2_fine_tuned.pth'
torch.save({
    'model_state_dict': network.state_dict(),
    'optimizer_state_dict': optimizer_config.state_dict(),
    'scheduler_state_dict': lr_scheduler.state_dict(),
}, fine_tune_save)

print("Model DANN32.2_fine_tuned.pth saved")
