## 1. Setup & Data Preparation

In [23]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim 
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision import models
from PIL import Image
from tqdm import tqdm
import pandas as pd

In [3]:
data_dir = "UTKFace"
model_path = "models"

In [4]:
# Define the dataset class
class UTKFaceDataset(Dataset):
    def __init__(self, root_dir, transforms=None):
        self.root_dir = root_dir
        self.transform = transforms
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(".jpg")]
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, index):
        image_name = self.image_files[index]
        image_path = os.path.join(self.root_dir, image_name)
        image = Image.open(image_path).convert('RGB')
        
        # Extract age and gender from the filename
        age = int(image_name.split('_')[0])
        gender = int(image_name.split('_')[1])
        
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor([age], dtype=torch.float32), torch.tensor([gender], dtype=torch.long)

# Define the transformation

train_transforms  = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

test_transforms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Split the dataset into training and testing sets
train_dataset, test_dataset = torch.utils.data.random_split(
    UTKFaceDataset(data_dir, transforms=test_transforms),
    [int(0.8 * len(os.listdir(data_dir))), int(0.2 * len(os.listdir(data_dir)))],
    generator=torch.Generator().manual_seed(42)
)

test_dataset.transform = test_transforms
train_dataset.transform = train_transforms
# over sample the training dataset to balance the classes
class_counts = pd.Series([int(f.split('_')[1]) for f in os.listdir(data_dir) if f.endswith('.jpg')]).value_counts()
max_class_count = class_counts.max()
repeat_count = max(1, max_class_count // len(train_dataset))
train_dataset = torch.utils.data.ConcatDataset([
    UTKFaceDataset(data_dir, transforms=train_transforms) for _ in range(repeat_count)
])


# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)



### Gender distribution in training and testing dataset

- ### 0: Male
- ### 1: Female

In [5]:
from tqdm import tqdm
training_female_count = 0
training_male_count = 0

for i in tqdm(range(len(train_dataset))):
    label = train_dataset[i][2]
    if torch.equal(label, torch.tensor([1])):
        training_female_count += 1
    elif torch.equal(label, torch.tensor([0])):
        training_male_count += 1



100%|██████████| 20855/20855 [01:59<00:00, 174.10it/s]


In [6]:
testing_female_count = 0
testing_male_count = 0

for i in tqdm(range(len(test_dataset))):
    label = test_dataset[i][2]
    if torch.equal(label, torch.tensor([1])):
        testing_female_count += 1 
    elif torch.equal(label, torch.tensor([0])):
        testing_male_count += 1

100%|██████████| 4171/4171 [00:23<00:00, 179.86it/s]


In [7]:
class_counts = {
    "Female": training_female_count,
    "Male": training_male_count
}
print("Class counts in training set:")
print(class_counts)
print("Total number of training images:", len(train_dataset))

print("__" * 50)
test_class_counts = {
    "Female": testing_female_count,
    "Male": testing_male_count
}
print("Class counts in test set:")
print(test_class_counts)
print("Total number of test images:", len(test_dataset))

Class counts in training set:
{'Female': 10528, 'Male': 10327}
Total number of training images: 20855
____________________________________________________________________________________________________
Class counts in test set:
{'Female': 2113, 'Male': 2058}
Total number of test images: 4171


## Training and testing models

### Defined training and evaluating functions

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [6]:
# Evaluate function: Calculate accuracy on the test set
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images in tqdm(test_loader):
            inputs, _, labels = images
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.squeeze()).sum().item()
    
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy:.4f}")
    return accuracy


# Training function
def train_model(model, train_loader, criterion, optimizer, lr_scheduler=None ,num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        accuracy = 0.0
        for images in tqdm(train_loader):
            inputs, _, labels = images
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels.squeeze())
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            accuracy += (outputs.argmax(dim=1) == labels.squeeze()).sum().item()
        if lr_scheduler:
            lr_scheduler.step()
        accuracy /= len(train_loader.dataset)
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")  
        evaluate_model(model, test_loader)


        

### Pretrained resnet model

  self.classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4, 4096), 
            nn.ReLU(),
            nn.BatchNorm1d(4096),
            nn.Dropout(0.5),
            nn.Linear(4096, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

In [None]:
from torchvision.models import resnet50

class ResNet50_f(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        base_model = resnet50(pretrained=True)
        self.features = nn.Sequential(*list(base_model.children())[:-1])  # remove final fc

        self.classifier = nn.Sequential(
            nn.Linear(2048, 4096),
            nn.ReLU(),
            nn.BatchNorm1d(4096),
            nn.Dropout(0.5),

            nn.Linear(4096, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.5),
            
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # flatten
        x = self.classifier(x)
        return x


In [8]:
resnet50_model = ResNet50_f(num_classes=2).to(device)

# Freeze the feature extractor layers
for param in resnet50_model.features.parameters():
    param.requires_grad = False
    

# Define loss function and optimizer
trainable_params = [param for param in resnet50_model.parameters() if param.requires_grad]

optimizer = optim.Adam(trainable_params, lr=0.0001)
criterion = nn.CrossEntropyLoss()
# Learning rate scheduler
lr_scheduler = StepLR(optimizer, step_size=10, gamma=0.1)




In [14]:
# Train the resnet50 model
train_model(resnet50_model, train_loader, criterion, optimizer, lr_scheduler, num_epochs=35)

100%|██████████| 652/652 [03:12<00:00,  3.39it/s]


Epoch [1/35], Loss: 0.4492, Accuracy: 0.7833


100%|██████████| 131/131 [00:39<00:00,  3.35it/s]


Test Accuracy: 0.8243


100%|██████████| 652/652 [03:04<00:00,  3.53it/s]


Epoch [2/35], Loss: 0.3914, Accuracy: 0.8226


100%|██████████| 131/131 [00:35<00:00,  3.71it/s]


Test Accuracy: 0.8108


100%|██████████| 652/652 [03:01<00:00,  3.59it/s]


Epoch [3/35], Loss: 0.3822, Accuracy: 0.8243


100%|██████████| 131/131 [00:35<00:00,  3.71it/s]


Test Accuracy: 0.8454


100%|██████████| 652/652 [03:02<00:00,  3.57it/s]


Epoch [4/35], Loss: 0.3709, Accuracy: 0.8320


100%|██████████| 131/131 [00:34<00:00,  3.80it/s]


Test Accuracy: 0.8442


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [5/35], Loss: 0.3625, Accuracy: 0.8337


100%|██████████| 131/131 [00:34<00:00,  3.83it/s]


Test Accuracy: 0.8410


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [6/35], Loss: 0.3576, Accuracy: 0.8368


100%|██████████| 131/131 [00:34<00:00,  3.84it/s]


Test Accuracy: 0.8564


100%|██████████| 652/652 [03:02<00:00,  3.57it/s]


Epoch [7/35], Loss: 0.3496, Accuracy: 0.8403


100%|██████████| 131/131 [00:34<00:00,  3.82it/s]


Test Accuracy: 0.8571


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [8/35], Loss: 0.3392, Accuracy: 0.8471


100%|██████████| 131/131 [00:34<00:00,  3.79it/s]


Test Accuracy: 0.8602


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [9/35], Loss: 0.3345, Accuracy: 0.8529


100%|██████████| 131/131 [00:33<00:00,  3.85it/s]


Test Accuracy: 0.8607


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [10/35], Loss: 0.3323, Accuracy: 0.8519


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.8645


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [11/35], Loss: 0.3090, Accuracy: 0.8669


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.8693


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [12/35], Loss: 0.3063, Accuracy: 0.8643


100%|██████████| 131/131 [00:33<00:00,  3.85it/s]


Test Accuracy: 0.8715


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [13/35], Loss: 0.3080, Accuracy: 0.8660


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.8727


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [14/35], Loss: 0.3072, Accuracy: 0.8664


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.8729


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [15/35], Loss: 0.3026, Accuracy: 0.8711


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.8717


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [16/35], Loss: 0.2998, Accuracy: 0.8702


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.8739


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [17/35], Loss: 0.2970, Accuracy: 0.8699


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.8765


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [18/35], Loss: 0.2982, Accuracy: 0.8709


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.8760


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [19/35], Loss: 0.2969, Accuracy: 0.8738


100%|██████████| 131/131 [00:33<00:00,  3.88it/s]


Test Accuracy: 0.8770


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [20/35], Loss: 0.2956, Accuracy: 0.8719


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.8780


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [21/35], Loss: 0.2938, Accuracy: 0.8733


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.8789


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [22/35], Loss: 0.2917, Accuracy: 0.8752


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.8782


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [23/35], Loss: 0.2923, Accuracy: 0.8751


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.8777


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [24/35], Loss: 0.2947, Accuracy: 0.8749


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.8780


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [25/35], Loss: 0.2917, Accuracy: 0.8752


100%|██████████| 131/131 [00:33<00:00,  3.88it/s]


Test Accuracy: 0.8782


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [26/35], Loss: 0.2901, Accuracy: 0.8746


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.8784


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [27/35], Loss: 0.2935, Accuracy: 0.8717


100%|██████████| 131/131 [00:34<00:00,  3.84it/s]


Test Accuracy: 0.8784


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [28/35], Loss: 0.2953, Accuracy: 0.8729


100%|██████████| 131/131 [00:33<00:00,  3.85it/s]


Test Accuracy: 0.8775


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [29/35], Loss: 0.2920, Accuracy: 0.8764


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.8787


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [30/35], Loss: 0.2939, Accuracy: 0.8749


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.8770


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [31/35], Loss: 0.2927, Accuracy: 0.8758


100%|██████████| 131/131 [00:34<00:00,  3.84it/s]


Test Accuracy: 0.8772


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [32/35], Loss: 0.2899, Accuracy: 0.8758


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.8772


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [33/35], Loss: 0.2911, Accuracy: 0.8754


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.8770


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [34/35], Loss: 0.2930, Accuracy: 0.8739


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.8772


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [35/35], Loss: 0.2939, Accuracy: 0.8750


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]

Test Accuracy: 0.8775





In [15]:
# Save the resnet50 model
torch.save(resnet50_model.state_dict(), os.path.join(model_path, "resnet50_f_gender_classifier.pth"))

### Pretrained VGG_f model

In [9]:
class VGG_f(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        base_model = models.vgg16(pretrained=True)
        self.features = base_model.features  # Convolutional part

        self.classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4, 4096), 
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


In [10]:
# Initialize the VGG model
vgg_model = VGG_f(num_classes=2).to(device)
# Freeze the feature extractor layers
for param in vgg_model.features.parameters():
    param.requires_grad = False
    
# Define loss function and optimizer
trainable_params = [param for param in vgg_model.parameters() if param.requires_grad]

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(trainable_params, lr=0.0001)
lr_scheduler = StepLR(optimizer, step_size=10, gamma=0.1)



In [18]:
# Train the vgg model
train_model(vgg_model, train_loader, criterion, optimizer, lr_scheduler, num_epochs=35)

100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [1/35], Loss: 0.3917, Accuracy: 0.8244


100%|██████████| 131/131 [00:33<00:00,  3.88it/s]


Test Accuracy: 0.8593


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [2/35], Loss: 0.3334, Accuracy: 0.8546


100%|██████████| 131/131 [00:33<00:00,  3.88it/s]


Test Accuracy: 0.8590


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [3/35], Loss: 0.3064, Accuracy: 0.8657


100%|██████████| 131/131 [00:33<00:00,  3.88it/s]


Test Accuracy: 0.8866


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [4/35], Loss: 0.2882, Accuracy: 0.8771


100%|██████████| 131/131 [00:33<00:00,  3.88it/s]


Test Accuracy: 0.9043


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [5/35], Loss: 0.2637, Accuracy: 0.8869


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9178


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [6/35], Loss: 0.2393, Accuracy: 0.8998


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9293


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [7/35], Loss: 0.2136, Accuracy: 0.9125


100%|██████████| 131/131 [00:34<00:00,  3.74it/s]


Test Accuracy: 0.9381


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [8/35], Loss: 0.1880, Accuracy: 0.9237


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9384


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [9/35], Loss: 0.1686, Accuracy: 0.9333


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9597


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [10/35], Loss: 0.1495, Accuracy: 0.9398


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9609


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [11/35], Loss: 0.1075, Accuracy: 0.9599


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9679


100%|██████████| 652/652 [02:56<00:00,  3.68it/s]


Epoch [12/35], Loss: 0.0923, Accuracy: 0.9656


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9731


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [13/35], Loss: 0.0834, Accuracy: 0.9684


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9753


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [14/35], Loss: 0.0719, Accuracy: 0.9735


100%|██████████| 131/131 [00:33<00:00,  3.89it/s]


Test Accuracy: 0.9734


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [15/35], Loss: 0.0739, Accuracy: 0.9738


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9734


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [16/35], Loss: 0.0669, Accuracy: 0.9754


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9787


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [17/35], Loss: 0.0592, Accuracy: 0.9793


100%|██████████| 131/131 [00:33<00:00,  3.89it/s]


Test Accuracy: 0.9801


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [18/35], Loss: 0.0555, Accuracy: 0.9802


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9827


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [19/35], Loss: 0.0511, Accuracy: 0.9807


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9832


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [20/35], Loss: 0.0469, Accuracy: 0.9839


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9847


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [21/35], Loss: 0.0453, Accuracy: 0.9838


100%|██████████| 131/131 [00:33<00:00,  3.88it/s]


Test Accuracy: 0.9859


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [22/35], Loss: 0.0424, Accuracy: 0.9844


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9859


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [23/35], Loss: 0.0438, Accuracy: 0.9843


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9854


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [24/35], Loss: 0.0443, Accuracy: 0.9847


100%|██████████| 131/131 [00:35<00:00,  3.69it/s]


Test Accuracy: 0.9851


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [25/35], Loss: 0.0433, Accuracy: 0.9846


100%|██████████| 131/131 [00:34<00:00,  3.79it/s]


Test Accuracy: 0.9863


100%|██████████| 652/652 [03:01<00:00,  3.58it/s]


Epoch [26/35], Loss: 0.0396, Accuracy: 0.9857


100%|██████████| 131/131 [00:37<00:00,  3.52it/s]


Test Accuracy: 0.9863


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [27/35], Loss: 0.0412, Accuracy: 0.9863


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.9866


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [28/35], Loss: 0.0413, Accuracy: 0.9862


100%|██████████| 131/131 [00:34<00:00,  3.83it/s]


Test Accuracy: 0.9871


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [29/35], Loss: 0.0404, Accuracy: 0.9850


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9875


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [30/35], Loss: 0.0397, Accuracy: 0.9865


100%|██████████| 131/131 [00:33<00:00,  3.85it/s]


Test Accuracy: 0.9875


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [31/35], Loss: 0.0396, Accuracy: 0.9864


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9880


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [32/35], Loss: 0.0395, Accuracy: 0.9868


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9883


100%|██████████| 652/652 [02:56<00:00,  3.68it/s]


Epoch [33/35], Loss: 0.0393, Accuracy: 0.9868


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9883


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [34/35], Loss: 0.0384, Accuracy: 0.9860


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9883


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [35/35], Loss: 0.0390, Accuracy: 0.9866


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]

Test Accuracy: 0.9885





In [19]:
# Save the vgg model
torch.save(vgg_model.state_dict(), os.path.join(model_path, "vgg_f_gender_classifier.pth"))

## Pretrained models with some modules added

In [11]:
# Define the Channel Attention module
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveAvgPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        
        return self.sigmoid(out) * x 

# Define the Spatial Attention module
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out = torch.max(x, dim=1, keepdim=True).values
        x_cat = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(x_cat)) * x
    
# Define the CBAM module
class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=8, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_planes, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)
        
    def forward(self, x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x

### Pretrained resnet with CBAM modules

In [17]:
class GhostModule(nn.Module):
    def __init__(self, in_channels, out_channels, ratio=2, dw_size=3):
        super(GhostModule, self).__init__()
        self.intrinsic_channels = out_channels // ratio
        self.primary_conv = nn.Conv2d(in_channels, self.intrinsic_channels, 1, bias=False)
        self.ghost_conv = nn.Conv2d(self.intrinsic_channels, self.intrinsic_channels, dw_size, padding=dw_size//2, groups=self.intrinsic_channels, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        intrinsic = self.primary_conv(x)
        ghost = self.ghost_conv(intrinsic)
        out = torch.cat([intrinsic, ghost], dim=1)
        out = self.bn(out)
        return self.relu(out)

In [None]:
# Build the modified ResNet model with CBAM
class ResNetCBAM(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNetCBAM, self).__init__()
        base_model = resnet50(pretrained=True)
        self.features = nn.Sequential(*list(base_model.children())[:-1])  # remove final fc
        
        self.cbam = CBAM(2048, ratio=8, kernel_size=7)
        self.ghost_module = GhostModule(2048, 512, ratio=2, dw_size=3)
        
        self.classifier = nn.Sequential(
            nn.Linear(2048, 4096),
            nn.ReLU(),
            nn.BatchNorm1d(4096),
            nn.Dropout(0.5),

            nn.Linear(4096, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.5),
            
            nn.Linear(1024, num_classes)
        )
        
    
        
    def forward(self, x):
        x = self.features(x)  # [batch, 2048, 1, 1]
        x = self.cbam(x)      # [batch, 2048, 1, 1]
        x = self.ghost_module(x)  # [batch, 512, 1, 1]
        x = x.view(x.size(0), -1) # flatten to [batch, 512]
        x = self.classifier(x)
        return x
        

In [19]:
# # Initialize the modified model
cbam_resnet = ResNetCBAM(num_classes=2)
cbam_resnet.to(device)

# Freeze the convolutional layers
for param in cbam_resnet.parameters():
    param.requires_grad = False

# Unfreeze CBAM and fc layers
for param in cbam_resnet.cbam.parameters():
    param.requires_grad = True
for param in cbam_resnet.ghost_module.parameters():
    param.requires_grad = True
for param in cbam_resnet.classifier.parameters():
    param.requires_grad = True

# Optimizer with only trainable parameters
trainable_params = [p for p in cbam_resnet.parameters() if p.requires_grad]
cbam_optimizer = optim.Adam(trainable_params, lr=0.0001, weight_decay=1e-4)

# Loss function
criterion = nn.CrossEntropyLoss()

# Learning rate scheduler
cbam_lr_scheduler = StepLR(cbam_optimizer, step_size=10, gamma=0.1)




In [24]:
# Train the modified model
train_model(cbam_resnet, train_loader, criterion, cbam_optimizer, cbam_lr_scheduler, num_epochs=35)

100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [1/35], Loss: 0.4308, Accuracy: 0.7960


100%|██████████| 131/131 [00:35<00:00,  3.70it/s]


Test Accuracy: 0.8327


100%|██████████| 652/652 [02:54<00:00,  3.74it/s]


Epoch [2/35], Loss: 0.3833, Accuracy: 0.8248


100%|██████████| 131/131 [00:33<00:00,  3.90it/s]


Test Accuracy: 0.8195


100%|██████████| 652/652 [02:53<00:00,  3.75it/s]


Epoch [3/35], Loss: 0.3676, Accuracy: 0.8300


100%|██████████| 131/131 [00:33<00:00,  3.90it/s]


Test Accuracy: 0.8413


100%|██████████| 652/652 [02:53<00:00,  3.76it/s]


Epoch [4/35], Loss: 0.3530, Accuracy: 0.8412


100%|██████████| 131/131 [00:33<00:00,  3.91it/s]


Test Accuracy: 0.8576


100%|██████████| 652/652 [02:53<00:00,  3.75it/s]


Epoch [5/35], Loss: 0.3466, Accuracy: 0.8438


100%|██████████| 131/131 [00:33<00:00,  3.91it/s]


Test Accuracy: 0.8619


100%|██████████| 652/652 [02:53<00:00,  3.75it/s]


Epoch [6/35], Loss: 0.3389, Accuracy: 0.8477


100%|██████████| 131/131 [00:33<00:00,  3.89it/s]


Test Accuracy: 0.8681


100%|██████████| 652/652 [02:53<00:00,  3.75it/s]


Epoch [7/35], Loss: 0.3275, Accuracy: 0.8539


100%|██████████| 131/131 [00:33<00:00,  3.89it/s]


Test Accuracy: 0.8339


100%|██████████| 652/652 [02:53<00:00,  3.75it/s]


Epoch [8/35], Loss: 0.3189, Accuracy: 0.8599


100%|██████████| 131/131 [00:33<00:00,  3.90it/s]


Test Accuracy: 0.8727


100%|██████████| 652/652 [02:53<00:00,  3.75it/s]


Epoch [9/35], Loss: 0.3103, Accuracy: 0.8634


100%|██████████| 131/131 [00:33<00:00,  3.94it/s]


Test Accuracy: 0.8758


100%|██████████| 652/652 [02:51<00:00,  3.79it/s]


Epoch [10/35], Loss: 0.2998, Accuracy: 0.8678


100%|██████████| 131/131 [00:33<00:00,  3.93it/s]


Test Accuracy: 0.8868


100%|██████████| 652/652 [02:54<00:00,  3.74it/s]


Epoch [11/35], Loss: 0.2722, Accuracy: 0.8825


100%|██████████| 131/131 [00:33<00:00,  3.90it/s]


Test Accuracy: 0.8928


100%|██████████| 652/652 [02:53<00:00,  3.77it/s]


Epoch [12/35], Loss: 0.2642, Accuracy: 0.8871


100%|██████████| 131/131 [00:33<00:00,  3.95it/s]


Test Accuracy: 0.8974


100%|██████████| 652/652 [02:51<00:00,  3.79it/s]


Epoch [13/35], Loss: 0.2578, Accuracy: 0.8897


100%|██████████| 131/131 [00:33<00:00,  3.94it/s]


Test Accuracy: 0.8967


100%|██████████| 652/652 [02:52<00:00,  3.79it/s]


Epoch [14/35], Loss: 0.2533, Accuracy: 0.8915


100%|██████████| 131/131 [00:33<00:00,  3.94it/s]


Test Accuracy: 0.8998


100%|██████████| 652/652 [02:52<00:00,  3.79it/s]


Epoch [15/35], Loss: 0.2508, Accuracy: 0.8927


100%|██████████| 131/131 [00:33<00:00,  3.94it/s]


Test Accuracy: 0.9024


100%|██████████| 652/652 [02:51<00:00,  3.80it/s]


Epoch [16/35], Loss: 0.2498, Accuracy: 0.8939


100%|██████████| 131/131 [00:33<00:00,  3.95it/s]


Test Accuracy: 0.9010


100%|██████████| 652/652 [02:52<00:00,  3.79it/s]


Epoch [17/35], Loss: 0.2454, Accuracy: 0.8969


100%|██████████| 131/131 [00:33<00:00,  3.95it/s]


Test Accuracy: 0.9029


100%|██████████| 652/652 [02:51<00:00,  3.80it/s]


Epoch [18/35], Loss: 0.2465, Accuracy: 0.8951


100%|██████████| 131/131 [00:33<00:00,  3.92it/s]


Test Accuracy: 0.9053


100%|██████████| 652/652 [02:52<00:00,  3.79it/s]


Epoch [19/35], Loss: 0.2383, Accuracy: 0.8992


100%|██████████| 131/131 [00:33<00:00,  3.95it/s]


Test Accuracy: 0.9072


100%|██████████| 652/652 [02:52<00:00,  3.79it/s]


Epoch [20/35], Loss: 0.2425, Accuracy: 0.8975


100%|██████████| 131/131 [00:33<00:00,  3.96it/s]


Test Accuracy: 0.9079


100%|██████████| 652/652 [02:52<00:00,  3.78it/s]


Epoch [21/35], Loss: 0.2357, Accuracy: 0.9005


100%|██████████| 131/131 [00:33<00:00,  3.93it/s]


Test Accuracy: 0.9072


100%|██████████| 652/652 [02:52<00:00,  3.78it/s]


Epoch [22/35], Loss: 0.2350, Accuracy: 0.9026


100%|██████████| 131/131 [00:33<00:00,  3.95it/s]


Test Accuracy: 0.9084


100%|██████████| 652/652 [02:51<00:00,  3.81it/s]


Epoch [23/35], Loss: 0.2320, Accuracy: 0.9008


100%|██████████| 131/131 [00:33<00:00,  3.94it/s]


Test Accuracy: 0.9096


100%|██████████| 652/652 [02:51<00:00,  3.79it/s]


Epoch [24/35], Loss: 0.2307, Accuracy: 0.9020


100%|██████████| 131/131 [00:33<00:00,  3.93it/s]


Test Accuracy: 0.9103


100%|██████████| 652/652 [02:52<00:00,  3.78it/s]


Epoch [25/35], Loss: 0.2302, Accuracy: 0.9015


100%|██████████| 131/131 [00:33<00:00,  3.94it/s]


Test Accuracy: 0.9089


100%|██████████| 652/652 [02:52<00:00,  3.78it/s]


Epoch [26/35], Loss: 0.2308, Accuracy: 0.9034


100%|██████████| 131/131 [00:33<00:00,  3.96it/s]


Test Accuracy: 0.9096


100%|██████████| 652/652 [02:52<00:00,  3.78it/s]


Epoch [27/35], Loss: 0.2315, Accuracy: 0.9033


100%|██████████| 131/131 [00:33<00:00,  3.94it/s]


Test Accuracy: 0.9089


100%|██████████| 652/652 [02:51<00:00,  3.79it/s]


Epoch [28/35], Loss: 0.2286, Accuracy: 0.9029


100%|██████████| 131/131 [00:33<00:00,  3.94it/s]


Test Accuracy: 0.9096


100%|██████████| 652/652 [02:51<00:00,  3.80it/s]


Epoch [29/35], Loss: 0.2327, Accuracy: 0.9040


100%|██████████| 131/131 [00:33<00:00,  3.94it/s]


Test Accuracy: 0.9096


100%|██████████| 652/652 [02:53<00:00,  3.76it/s]


Epoch [30/35], Loss: 0.2344, Accuracy: 0.9000


100%|██████████| 131/131 [00:33<00:00,  3.92it/s]


Test Accuracy: 0.9099


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [31/35], Loss: 0.2309, Accuracy: 0.9044


100%|██████████| 131/131 [00:34<00:00,  3.80it/s]


Test Accuracy: 0.9099


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [32/35], Loss: 0.2285, Accuracy: 0.9060


100%|██████████| 131/131 [00:33<00:00,  3.90it/s]


Test Accuracy: 0.9099


100%|██████████| 652/652 [02:51<00:00,  3.79it/s]


Epoch [33/35], Loss: 0.2281, Accuracy: 0.9052


100%|██████████| 131/131 [00:33<00:00,  3.95it/s]


Test Accuracy: 0.9101


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [34/35], Loss: 0.2299, Accuracy: 0.9034


100%|██████████| 131/131 [00:33<00:00,  3.95it/s]


Test Accuracy: 0.9099


100%|██████████| 652/652 [02:51<00:00,  3.80it/s]


Epoch [35/35], Loss: 0.2292, Accuracy: 0.9040


100%|██████████| 131/131 [00:33<00:00,  3.96it/s]

Test Accuracy: 0.9099





In [25]:
torch.save(cbam_resnet.state_dict(), os.path.join(model_path, "cbam_resnet18_gender_classifier.pth"))

### Pretrained VGG with CBAM 

In [46]:
# Build the modified VGG model with CBAM and GhostNet
class VGGCBAM(nn.Module):
    def __init__(self, num_classes=2):
        super(VGGCBAM, self).__init__()
        base_model = models.vgg16(pretrained=True)
        self.features = base_model.features  # Convolutional part
        
        self.cbam = CBAM(512, ratio=8, kernel_size=7)
    
        self.classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4, 4096), 
            nn.ReLU(),
            nn.BatchNorm1d(4096),
            nn.Dropout(0.5),
            nn.Linear(4096, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)  # [batch, 512, 4, 4]
        x = self.cbam(x)      # [batch, 512, 4, 4]
        x = x.view(x.size(0), -1) # flatten to [batch, 512 * 4 * 4]
        x = self.classifier(x)
        return x

In [47]:
# Initialize the VGGGhostCBAM model
vgg_cbam = VGGCBAM(num_classes=2)
vgg_cbam.to(device)
# Freeze the convolutional layers
for param in vgg_cbam.parameters():
    param.requires_grad = False
# Unfreeze CBAM and fc layers
for param in vgg_cbam.cbam.parameters():
    param.requires_grad = True
for param in vgg_cbam.classifier.parameters():
    param.requires_grad = True

# Optimizer with only trainable parameters
trainable_params = [p for p in vgg_cbam.parameters() if p.requires_grad]
vgg_cbam_optimizer = optim.Adam(trainable_params, lr=0.0001, weight_decay=1e-4)
# Loss function
criterion = nn.CrossEntropyLoss()
# Learning rate scheduler
vgg_cbam_lr_scheduler = StepLR(vgg_cbam_optimizer, step_size=10, gamma=0.1)




In [51]:
# Train the vgg modified model
train_model(vgg_cbam, train_loader, criterion, vgg_cbam_optimizer, vgg_cbam_lr_scheduler, num_epochs=35)


100%|██████████| 652/652 [03:06<00:00,  3.49it/s]


Epoch [1/35], Loss: 0.4407, Accuracy: 0.8094


100%|██████████| 131/131 [00:35<00:00,  3.66it/s]


Test Accuracy: 0.8655


100%|██████████| 652/652 [03:04<00:00,  3.54it/s]


Epoch [2/35], Loss: 0.3309, Accuracy: 0.8538


100%|██████████| 131/131 [00:36<00:00,  3.56it/s]


Test Accuracy: 0.8832


100%|██████████| 652/652 [03:07<00:00,  3.47it/s]


Epoch [3/35], Loss: 0.3035, Accuracy: 0.8681


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.8995


100%|██████████| 652/652 [03:14<00:00,  3.36it/s]


Epoch [4/35], Loss: 0.2798, Accuracy: 0.8813


100%|██████████| 131/131 [00:40<00:00,  3.21it/s]


Test Accuracy: 0.9072


100%|██████████| 652/652 [03:25<00:00,  3.18it/s]


Epoch [5/35], Loss: 0.2639, Accuracy: 0.8888


100%|██████████| 131/131 [00:35<00:00,  3.71it/s]


Test Accuracy: 0.9190


100%|██████████| 652/652 [03:01<00:00,  3.59it/s]


Epoch [6/35], Loss: 0.2463, Accuracy: 0.8991


100%|██████████| 131/131 [00:34<00:00,  3.83it/s]


Test Accuracy: 0.9173


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [7/35], Loss: 0.2243, Accuracy: 0.9088


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.9290


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [8/35], Loss: 0.2021, Accuracy: 0.9194


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.9381


100%|██████████| 652/652 [03:01<00:00,  3.60it/s]


Epoch [9/35], Loss: 0.1835, Accuracy: 0.9261


100%|██████████| 131/131 [00:34<00:00,  3.84it/s]


Test Accuracy: 0.9475


100%|██████████| 652/652 [03:00<00:00,  3.62it/s]


Epoch [10/35], Loss: 0.1686, Accuracy: 0.9335


100%|██████████| 131/131 [00:33<00:00,  3.85it/s]


Test Accuracy: 0.9583


100%|██████████| 652/652 [03:00<00:00,  3.62it/s]


Epoch [11/35], Loss: 0.1101, Accuracy: 0.9585


100%|██████████| 131/131 [00:34<00:00,  3.84it/s]


Test Accuracy: 0.9688


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [12/35], Loss: 0.0957, Accuracy: 0.9645


100%|██████████| 131/131 [00:34<00:00,  3.82it/s]


Test Accuracy: 0.9741


100%|██████████| 652/652 [03:12<00:00,  3.39it/s]


Epoch [13/35], Loss: 0.0852, Accuracy: 0.9686


100%|██████████| 131/131 [00:40<00:00,  3.21it/s]


Test Accuracy: 0.9767


100%|██████████| 652/652 [03:23<00:00,  3.20it/s]


Epoch [14/35], Loss: 0.0789, Accuracy: 0.9702


100%|██████████| 131/131 [00:37<00:00,  3.47it/s]


Test Accuracy: 0.9796


100%|██████████| 652/652 [03:37<00:00,  3.00it/s]


Epoch [15/35], Loss: 0.0673, Accuracy: 0.9747


100%|██████████| 131/131 [00:35<00:00,  3.68it/s]


Test Accuracy: 0.9835


100%|██████████| 652/652 [03:36<00:00,  3.02it/s]


Epoch [16/35], Loss: 0.0608, Accuracy: 0.9783


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.9815


100%|██████████| 652/652 [03:05<00:00,  3.51it/s]


Epoch [17/35], Loss: 0.0576, Accuracy: 0.9790


100%|██████████| 131/131 [00:33<00:00,  3.85it/s]


Test Accuracy: 0.9825


100%|██████████| 652/652 [03:09<00:00,  3.43it/s]


Epoch [18/35], Loss: 0.0534, Accuracy: 0.9812


100%|██████████| 131/131 [00:37<00:00,  3.48it/s]


Test Accuracy: 0.9849


100%|██████████| 652/652 [03:10<00:00,  3.43it/s]


Epoch [19/35], Loss: 0.0538, Accuracy: 0.9820


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.9863


100%|██████████| 652/652 [03:12<00:00,  3.39it/s]


Epoch [20/35], Loss: 0.0479, Accuracy: 0.9830


100%|██████████| 131/131 [00:35<00:00,  3.67it/s]


Test Accuracy: 0.9866


100%|██████████| 652/652 [03:14<00:00,  3.35it/s]


Epoch [21/35], Loss: 0.0385, Accuracy: 0.9868


100%|██████████| 131/131 [00:34<00:00,  3.77it/s]


Test Accuracy: 0.9871


100%|██████████| 652/652 [03:17<00:00,  3.30it/s]


Epoch [22/35], Loss: 0.0394, Accuracy: 0.9865


100%|██████████| 131/131 [00:35<00:00,  3.70it/s]


Test Accuracy: 0.9875


100%|██████████| 652/652 [03:06<00:00,  3.50it/s]


Epoch [23/35], Loss: 0.0386, Accuracy: 0.9868


100%|██████████| 131/131 [00:34<00:00,  3.83it/s]


Test Accuracy: 0.9878


100%|██████████| 652/652 [03:04<00:00,  3.53it/s]


Epoch [24/35], Loss: 0.0377, Accuracy: 0.9867


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.9880


100%|██████████| 652/652 [03:06<00:00,  3.50it/s]


Epoch [25/35], Loss: 0.0365, Accuracy: 0.9878


100%|██████████| 131/131 [00:33<00:00,  3.88it/s]


Test Accuracy: 0.9883


100%|██████████| 652/652 [03:06<00:00,  3.49it/s]


Epoch [26/35], Loss: 0.0366, Accuracy: 0.9870


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9883


100%|██████████| 652/652 [03:14<00:00,  3.35it/s]


Epoch [27/35], Loss: 0.0357, Accuracy: 0.9873


100%|██████████| 131/131 [00:34<00:00,  3.84it/s]


Test Accuracy: 0.9887


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [28/35], Loss: 0.0355, Accuracy: 0.9876


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9885


100%|██████████| 652/652 [03:05<00:00,  3.51it/s]


Epoch [29/35], Loss: 0.0350, Accuracy: 0.9879


100%|██████████| 131/131 [00:35<00:00,  3.67it/s]


Test Accuracy: 0.9892


100%|██████████| 652/652 [03:00<00:00,  3.60it/s]


Epoch [30/35], Loss: 0.0375, Accuracy: 0.9868


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9887


100%|██████████| 652/652 [03:00<00:00,  3.62it/s]


Epoch [31/35], Loss: 0.0331, Accuracy: 0.9892


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9890


100%|██████████| 652/652 [03:00<00:00,  3.62it/s]


Epoch [32/35], Loss: 0.0359, Accuracy: 0.9877


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.9890


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [33/35], Loss: 0.0337, Accuracy: 0.9885


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9890


100%|██████████| 652/652 [02:59<00:00,  3.62it/s]


Epoch [34/35], Loss: 0.0360, Accuracy: 0.9880


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9890


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [35/35], Loss: 0.0345, Accuracy: 0.9880


100%|██████████| 131/131 [00:33<00:00,  3.85it/s]

Test Accuracy: 0.9890





In [52]:
# Save the vgg modified model
torch.save(vgg_cbam.state_dict(), os.path.join(model_path, "vgg_ghost_cbam_gender_classifier.pth"))

### Pretrained VGG with SE and SA Blocks

In [53]:
# Define the SEBlock module
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEBlock, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1) # Calculate the average weight of each channel
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        b, c, _, _ = x.size() # Batch size and number of channels
        y = self.global_avg_pool(x).view(b, c) # [B, C]
        y = self.fc(y).view(b, c, 1, 1) # [B, C, 1, 1]
        
        return x * y.expand_as(x)  # Scaled channel wise
        

In [54]:
# Build the modified VGG model with Squeeze-and-Excitation (SE) and SA modules
class VGGSEandSA(nn.Module):
    def __init__(self, num_classes=2):
        super(VGGSEandSA, self).__init__()
        base_model = models.vgg16(pretrained=True)
        self.features = base_model.features  # Convolutional part
        
        self.se_block = SEBlock(512, reduction=16)
        self.spatial_attention = SpatialAttention(kernel_size=7)

        self.classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4, 4096), 
            nn.ReLU(),
            nn.BatchNorm1d(4096),
            nn.Dropout(0.5),
            nn.Linear(4096, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)  # [batch, 512, 4, 4]
        x = self.se_block(x)  # [batch, 512, 4, 4]
        x = self.spatial_attention(x)  # [batch, 512, 4, 4]
        x = x.view(x.size(0), -1) # flatten to [batch, 512 * 4 * 4]
        x = self.classifier(x)
        return x

In [55]:
# Initialize the modified VGGSEandSA
vgg_se_sa = VGGSEandSA(num_classes=2).to(device)

# Freeze the convolutional layers
for param in vgg_se_sa.parameters():
    param.requires_grad = False

# Unfreeze SE Block, Spatial Attention Block and fc layers
for param in vgg_se_sa.se_block.parameters():
    param.requires_grad = True

for param in vgg_se_sa.spatial_attention.parameters():
    param.requires_grad = True 
    
for param in vgg_se_sa.classifier.parameters():
    param.requires_grad = True
    
# Optimizer with only trainable parameters
trainable_params = [p for p in vgg_se_sa.parameters() if p.requires_grad]
vgg_se_sa_optimizer = optim.Adam(trainable_params, lr=0.0001, weight_decay=1e-4)
# Loss function
criterion = nn.CrossEntropyLoss()
# Learning rate scheduler
vgg_se_sa_lr_scheduler = StepLR(vgg_se_sa_optimizer, step_size=5, gamma=0.2)



In [56]:
# Train the vgg_se_sa model 
train_model(vgg_se_sa, train_loader, criterion, vgg_se_sa_optimizer, vgg_se_sa_lr_scheduler, num_epochs=35)

100%|██████████| 652/652 [03:16<00:00,  3.32it/s]


Epoch [1/35], Loss: 0.4386, Accuracy: 0.8072


100%|██████████| 131/131 [00:36<00:00,  3.62it/s]


Test Accuracy: 0.8732


100%|██████████| 652/652 [03:15<00:00,  3.34it/s]


Epoch [2/35], Loss: 0.3411, Accuracy: 0.8491


100%|██████████| 131/131 [00:34<00:00,  3.79it/s]


Test Accuracy: 0.8859


100%|██████████| 652/652 [03:06<00:00,  3.50it/s]


Epoch [3/35], Loss: 0.3077, Accuracy: 0.8666


100%|██████████| 131/131 [00:35<00:00,  3.67it/s]


Test Accuracy: 0.8945


100%|██████████| 652/652 [03:18<00:00,  3.28it/s]


Epoch [4/35], Loss: 0.2839, Accuracy: 0.8787


100%|██████████| 131/131 [00:36<00:00,  3.54it/s]


Test Accuracy: 0.8945


100%|██████████| 652/652 [03:08<00:00,  3.46it/s]


Epoch [5/35], Loss: 0.2631, Accuracy: 0.8897


100%|██████████| 131/131 [00:35<00:00,  3.74it/s]


Test Accuracy: 0.9084


100%|██████████| 652/652 [03:34<00:00,  3.05it/s]


Epoch [6/35], Loss: 0.2091, Accuracy: 0.9154


100%|██████████| 131/131 [00:35<00:00,  3.64it/s]


Test Accuracy: 0.9329


100%|██████████| 652/652 [03:23<00:00,  3.20it/s]


Epoch [7/35], Loss: 0.1826, Accuracy: 0.9255


100%|██████████| 131/131 [00:34<00:00,  3.76it/s]


Test Accuracy: 0.9473


100%|██████████| 652/652 [03:15<00:00,  3.33it/s]


Epoch [8/35], Loss: 0.1650, Accuracy: 0.9349


100%|██████████| 131/131 [00:33<00:00,  3.92it/s]


Test Accuracy: 0.9530


100%|██████████| 652/652 [03:29<00:00,  3.11it/s]


Epoch [9/35], Loss: 0.1461, Accuracy: 0.9424


100%|██████████| 131/131 [00:42<00:00,  3.06it/s]


Test Accuracy: 0.9585


100%|██████████| 652/652 [03:44<00:00,  2.90it/s]


Epoch [10/35], Loss: 0.1315, Accuracy: 0.9503


100%|██████████| 131/131 [00:47<00:00,  2.76it/s]


Test Accuracy: 0.9612


100%|██████████| 652/652 [03:26<00:00,  3.16it/s]


Epoch [11/35], Loss: 0.1065, Accuracy: 0.9602


100%|██████████| 131/131 [00:35<00:00,  3.67it/s]


Test Accuracy: 0.9696


100%|██████████| 652/652 [03:07<00:00,  3.48it/s]


Epoch [12/35], Loss: 0.1003, Accuracy: 0.9636


100%|██████████| 131/131 [00:34<00:00,  3.80it/s]


Test Accuracy: 0.9712


100%|██████████| 652/652 [03:09<00:00,  3.44it/s]


Epoch [13/35], Loss: 0.0974, Accuracy: 0.9652


100%|██████████| 131/131 [00:34<00:00,  3.76it/s]


Test Accuracy: 0.9731


100%|██████████| 652/652 [03:03<00:00,  3.55it/s]


Epoch [14/35], Loss: 0.0877, Accuracy: 0.9695


100%|██████████| 131/131 [00:35<00:00,  3.71it/s]


Test Accuracy: 0.9734


100%|██████████| 652/652 [03:07<00:00,  3.49it/s]


Epoch [15/35], Loss: 0.0801, Accuracy: 0.9706


100%|██████████| 131/131 [00:34<00:00,  3.78it/s]


Test Accuracy: 0.9758


100%|██████████| 652/652 [03:16<00:00,  3.32it/s]


Epoch [16/35], Loss: 0.0762, Accuracy: 0.9740


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.9758


100%|██████████| 652/652 [03:01<00:00,  3.60it/s]


Epoch [17/35], Loss: 0.0792, Accuracy: 0.9725


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9758


100%|██████████| 652/652 [03:01<00:00,  3.60it/s]


Epoch [18/35], Loss: 0.0744, Accuracy: 0.9743


100%|██████████| 131/131 [00:33<00:00,  3.85it/s]


Test Accuracy: 0.9770


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [19/35], Loss: 0.0734, Accuracy: 0.9748


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9758


100%|██████████| 652/652 [03:01<00:00,  3.59it/s]


Epoch [20/35], Loss: 0.0685, Accuracy: 0.9769


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.9775


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [21/35], Loss: 0.0726, Accuracy: 0.9755


100%|██████████| 131/131 [00:33<00:00,  3.85it/s]


Test Accuracy: 0.9779


100%|██████████| 652/652 [03:00<00:00,  3.62it/s]


Epoch [22/35], Loss: 0.0717, Accuracy: 0.9752


100%|██████████| 131/131 [00:34<00:00,  3.81it/s]


Test Accuracy: 0.9777


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [23/35], Loss: 0.0708, Accuracy: 0.9758


100%|██████████| 131/131 [00:34<00:00,  3.85it/s]


Test Accuracy: 0.9777


100%|██████████| 652/652 [03:01<00:00,  3.60it/s]


Epoch [24/35], Loss: 0.0662, Accuracy: 0.9784


100%|██████████| 131/131 [00:33<00:00,  3.86it/s]


Test Accuracy: 0.9779


100%|██████████| 652/652 [02:59<00:00,  3.62it/s]


Epoch [25/35], Loss: 0.0703, Accuracy: 0.9764


100%|██████████| 131/131 [00:33<00:00,  3.89it/s]


Test Accuracy: 0.9782


100%|██████████| 652/652 [02:59<00:00,  3.63it/s]


Epoch [26/35], Loss: 0.0680, Accuracy: 0.9762


100%|██████████| 131/131 [00:33<00:00,  3.87it/s]


Test Accuracy: 0.9784


 97%|█████████▋| 632/652 [02:55<00:05,  3.60it/s]


KeyboardInterrupt: 

### The training took too long and the accuracy seemed can not be improved and i interrupted it and save the model

In [57]:
# Save the vgg_se_sa model
torch.save(vgg_se_sa.state_dict(), os.path.join(model_path, "vgg_se_sa_gender_classifier.pth"))