## 1. Setup & Data Preparation

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim 
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd

In [3]:
data_dir = "UTKFace"
model_path = "models"

In [4]:
# Define the dataset class
class UTKFaceDataset(Dataset):
    def __init__(self, root_dir, transforms=None):
        self.root_dir = root_dir
        self.transform = transforms
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(".jpg")]
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, index):
        image_name = self.image_files[index]
        image_path = os.path.join(self.root_dir, image_name)
        image = Image.open(image_path).convert('RGB')
        
        # Extract age and gender from the filename
        age = int(image_name.split('_')[0])
        gender = int(image_name.split('_')[1])
        
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor([age], dtype=torch.float32), torch.tensor([gender], dtype=torch.long)

# Define the transformation

train_transforms  = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

test_transforms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Split the dataset into training and testing sets
train_dataset, test_dataset = torch.utils.data.random_split(
    UTKFaceDataset(data_dir, transforms=test_transforms),
    [int(0.8 * len(os.listdir(data_dir))), int(0.2 * len(os.listdir(data_dir)))],
    generator=torch.Generator().manual_seed(42)
)

test_dataset.transform = test_transforms
train_dataset.transform = train_transforms
# over sample the training dataset to balance the classes
class_counts = pd.Series([int(f.split('_')[1]) for f in os.listdir(data_dir) if f.endswith('.jpg')]).value_counts()
max_class_count = class_counts.max()
repeat_count = max(1, max_class_count // len(train_dataset))
train_dataset = torch.utils.data.ConcatDataset([
    UTKFaceDataset(data_dir, transforms=train_transforms) for _ in range(repeat_count)
])


# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)



## Deep CNNs

### Gender classification

In [5]:
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride,
                                   padding=padding, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
      
    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

# Build Separable Convolutional Neural Network for gender classification
class GenderClassifierCnn(nn.Module):
    def __init__(self):
        super(GenderClassifierCnn, self).__init__()
        # Depthwise separable convolutional layers
        self.conv1 = DepthwiseSeparableConv(3, 32)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = DepthwiseSeparableConv(32, 64)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = DepthwiseSeparableConv(64, 128)
        self.bn3 = nn.BatchNorm2d(128)
        
        # Fully connected layers
        self.fc1 = nn.Linear(128 * 16 * 16, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 2)
        self.dropout = nn.Dropout(0.5)
        self.pool = nn.MaxPool2d(2, 2)
        
        
    def forward(self, x):
        # Convolutional layers with ReLU activation and max pooling
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        
        # Flatten the tensor
        x = x.view(x.size(0), -1)
        # Fully connected layers with ReLU activation and dropout
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout(x)
        # Output layer
        x = self.fc2(x)
        return x
        



In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [7]:
# Initialize the model, loss function, and optimizer
gender_classifier = GenderClassifierCnn()
gender_classifier.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(gender_classifier.parameters(), lr=0.0001, weight_decay=1e-5)
lr_scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

In [8]:
from tqdm import tqdm

# Evaluate function: Calculate accuracy on the test set
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images in tqdm(test_loader):
            inputs, _, labels = images
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.squeeze()).sum().item()
    
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy:.4f}")
    return accuracy


# Training function
def train_model(model, train_loader, criterion, optimizer, lr_scheduler=None ,num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        accuracy = 0.0
        for images in tqdm(train_loader):
            inputs, _, labels = images
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels.squeeze())
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            accuracy += (outputs.argmax(dim=1) == labels.squeeze()).sum().item()
        if lr_scheduler:
            lr_scheduler.step()
        accuracy /= len(train_loader.dataset)
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")  
        evaluate_model(model, test_loader)


        

In [None]:
# Train the model
train_model(gender_classifier, train_loader, criterion, optimizer, lr_scheduler ,num_epochs=30)

100%|██████████| 652/652 [03:07<00:00,  3.47it/s]


Epoch [1/30], Loss: 0.6298, Accuracy: 0.6516


100%|██████████| 131/131 [00:29<00:00,  4.47it/s]


Test Accuracy: 0.7159


100%|██████████| 652/652 [02:54<00:00,  3.74it/s]


Epoch [2/30], Loss: 0.5524, Accuracy: 0.7080


100%|██████████| 131/131 [00:29<00:00,  4.45it/s]


Test Accuracy: 0.7535


100%|██████████| 652/652 [02:54<00:00,  3.74it/s]


Epoch [3/30], Loss: 0.5112, Accuracy: 0.7381


100%|██████████| 131/131 [00:29<00:00,  4.47it/s]


Test Accuracy: 0.7758


100%|██████████| 652/652 [02:53<00:00,  3.75it/s]


Epoch [4/30], Loss: 0.4851, Accuracy: 0.7572


100%|██████████| 131/131 [00:29<00:00,  4.44it/s]


Test Accuracy: 0.7929


100%|██████████| 652/652 [03:13<00:00,  3.36it/s]


Epoch [5/30], Loss: 0.4581, Accuracy: 0.7767


100%|██████████| 131/131 [00:32<00:00,  4.02it/s]


Test Accuracy: 0.8008


100%|██████████| 652/652 [03:14<00:00,  3.35it/s]


Epoch [6/30], Loss: 0.4376, Accuracy: 0.7924


100%|██████████| 131/131 [00:29<00:00,  4.38it/s]


Test Accuracy: 0.8046


100%|██████████| 652/652 [03:08<00:00,  3.46it/s]


Epoch [7/30], Loss: 0.4134, Accuracy: 0.8039


100%|██████████| 131/131 [00:30<00:00,  4.26it/s]


Test Accuracy: 0.8386


100%|██████████| 652/652 [03:02<00:00,  3.58it/s]


Epoch [8/30], Loss: 0.3905, Accuracy: 0.8184


100%|██████████| 131/131 [00:29<00:00,  4.37it/s]


Test Accuracy: 0.8521


100%|██████████| 652/652 [03:06<00:00,  3.49it/s]


Epoch [9/30], Loss: 0.3666, Accuracy: 0.8351


100%|██████████| 131/131 [00:31<00:00,  4.13it/s]


Test Accuracy: 0.8461


100%|██████████| 652/652 [03:28<00:00,  3.13it/s]


Epoch [10/30], Loss: 0.3426, Accuracy: 0.8473


100%|██████████| 131/131 [00:29<00:00,  4.37it/s]


Test Accuracy: 0.8804


100%|██████████| 652/652 [03:19<00:00,  3.26it/s]


Epoch [11/30], Loss: 0.2848, Accuracy: 0.8852


100%|██████████| 131/131 [00:40<00:00,  3.25it/s]


Test Accuracy: 0.8971


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [12/30], Loss: 0.2723, Accuracy: 0.8903


100%|██████████| 131/131 [00:31<00:00,  4.10it/s]


Test Accuracy: 0.8952


100%|██████████| 652/652 [03:27<00:00,  3.14it/s]


Epoch [13/30], Loss: 0.2621, Accuracy: 0.8944


100%|██████████| 131/131 [00:38<00:00,  3.43it/s]


Test Accuracy: 0.9024


100%|██████████| 652/652 [03:27<00:00,  3.14it/s]


Epoch [14/30], Loss: 0.2598, Accuracy: 0.8981


100%|██████████| 131/131 [00:31<00:00,  4.17it/s]


Test Accuracy: 0.9084


100%|██████████| 652/652 [03:09<00:00,  3.44it/s]


Epoch [15/30], Loss: 0.2540, Accuracy: 0.8996


100%|██████████| 131/131 [00:41<00:00,  3.17it/s]


Test Accuracy: 0.9051


100%|██████████| 652/652 [03:31<00:00,  3.09it/s]


Epoch [16/30], Loss: 0.2438, Accuracy: 0.9054


100%|██████████| 131/131 [00:35<00:00,  3.70it/s]


Test Accuracy: 0.9115


100%|██████████| 652/652 [02:51<00:00,  3.80it/s]


Epoch [17/30], Loss: 0.2369, Accuracy: 0.9064


100%|██████████| 131/131 [00:29<00:00,  4.42it/s]


Test Accuracy: 0.9142


100%|██████████| 652/652 [03:03<00:00,  3.55it/s]


Epoch [18/30], Loss: 0.2315, Accuracy: 0.9092


100%|██████████| 131/131 [00:29<00:00,  4.42it/s]


Test Accuracy: 0.9158


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [19/30], Loss: 0.2294, Accuracy: 0.9125


100%|██████████| 131/131 [00:30<00:00,  4.36it/s]


Test Accuracy: 0.9209


100%|██████████| 652/652 [03:06<00:00,  3.49it/s]


Epoch [20/30], Loss: 0.2238, Accuracy: 0.9144


100%|██████████| 131/131 [00:43<00:00,  3.01it/s]


Test Accuracy: 0.9247


100%|██████████| 652/652 [03:37<00:00,  3.00it/s]


Epoch [21/30], Loss: 0.2161, Accuracy: 0.9205


100%|██████████| 131/131 [00:32<00:00,  4.02it/s]


Test Accuracy: 0.9254


100%|██████████| 652/652 [03:00<00:00,  3.60it/s]


Epoch [22/30], Loss: 0.2134, Accuracy: 0.9192


100%|██████████| 131/131 [00:31<00:00,  4.10it/s]


Test Accuracy: 0.9226


100%|██████████| 652/652 [03:11<00:00,  3.40it/s]


Epoch [23/30], Loss: 0.2123, Accuracy: 0.9213


100%|██████████| 131/131 [00:39<00:00,  3.34it/s]


Test Accuracy: 0.9254


100%|██████████| 652/652 [03:21<00:00,  3.24it/s]


Epoch [24/30], Loss: 0.2115, Accuracy: 0.9205


100%|██████████| 131/131 [00:33<00:00,  3.89it/s]


Test Accuracy: 0.9226


100%|██████████| 652/652 [03:12<00:00,  3.38it/s]


Epoch [25/30], Loss: 0.2120, Accuracy: 0.9205


100%|██████████| 131/131 [00:32<00:00,  4.05it/s]


Test Accuracy: 0.9178


100%|██████████| 652/652 [03:11<00:00,  3.41it/s]


Epoch [26/30], Loss: 0.2087, Accuracy: 0.9211


100%|██████████| 131/131 [00:32<00:00,  4.04it/s]


Test Accuracy: 0.9228


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [27/30], Loss: 0.2105, Accuracy: 0.9226


100%|██████████| 131/131 [00:30<00:00,  4.31it/s]


Test Accuracy: 0.9238


100%|██████████| 652/652 [03:06<00:00,  3.49it/s]


Epoch [28/30], Loss: 0.2070, Accuracy: 0.9237


100%|██████████| 131/131 [00:34<00:00,  3.82it/s]


Test Accuracy: 0.9206


100%|██████████| 652/652 [03:09<00:00,  3.44it/s]


Epoch [29/30], Loss: 0.2071, Accuracy: 0.9229


100%|██████████| 131/131 [00:32<00:00,  4.04it/s]


Test Accuracy: 0.9254


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [30/30], Loss: 0.2071, Accuracy: 0.9212


100%|██████████| 131/131 [00:30<00:00,  4.25it/s]

Test Accuracy: 0.9269





In [None]:
# Save the model

if not os.path.exists(model_path):
    os.makedirs(model_path)
torch.save(gender_classifier.state_dict(), os.path.join(model_path, "gender_classifier.pth"))

## Convolutional Block Attention Module (CBAM)

### 1. Channel Attention Module

In [9]:
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveAvgPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        
        return self.sigmoid(out) * x 
    
    

### 2. Spatial Attention Module

In [10]:
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out = torch.max(x, dim=1, keepdim=True).values
        x_cat = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(x_cat)) * x

### 3. CBAM (Channel + Spatial)

In [11]:
class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=8, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, ratio)
        self.sa = SpatialAttention(kernel_size=kernel_size)
    
    def forward(self, x):
        x = self.ca(x)
        x = self.sa(x)
        
        return x

### Channel Attention Gender Classifier

In [12]:
class GenderClassifierWithCA(nn.Module):
    def __init__(self):
        super(GenderClassifierWithCA, self).__init__()
        # Depthwise separable convolutional layers
        self.conv1 = DepthwiseSeparableConv(3, 32)
        self.bn1 = nn.BatchNorm2d(32)
        self.ca1 = ChannelAttention(32)
        
        self.conv2 = DepthwiseSeparableConv(32, 64)
        self.bn2 = nn.BatchNorm2d(64)
        self.ca2 = ChannelAttention(64)
        
        self.conv3 = DepthwiseSeparableConv(64, 128)
        self.bn3 = nn.BatchNorm2d(128)
        self.ca3 = ChannelAttention(128)
        
        # Fully connected layers
        self.fc1 = nn.Linear(128 * 16 * 16, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 2)
        self.dropout = nn.Dropout(0.5)
        self.pool = nn.MaxPool2d(2, 2)
        
        
    def forward(self, x):
        # Convolutional layers with ReLU activation and max pooling
        x = self.pool(self.ca1(F.relu(self.bn1(self.conv1(x)))))
        x = self.pool(self.ca2(F.relu(self.bn2(self.conv2(x)))))
        x = self.pool(self.ca3(F.relu(self.bn3(self.conv3(x)))))
        
        # Flatten the tensor
        x = x.view(x.size(0), -1)
        # Fully connected layers with ReLU activation and dropout
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout(x)
        # Output layer
        x = self.fc2(x)
        return x
        

In [43]:
# Initialize the model, loss function, and optimizer with Channel Attention
CA_classifier = GenderClassifierWithCA()
CA_classifier.to(device)
criterion_CA = nn.CrossEntropyLoss()
optimizer_CA = optim.Adam(CA_classifier.parameters(), lr=0.0001, weight_decay=1e-5)
lr_scheduler_CA = StepLR(optimizer_CA, step_size=10, gamma=0.1)


In [44]:
# Train the model with Channel Attention
train_model(CA_classifier, train_loader, criterion_CA, optimizer_CA, lr_scheduler=lr_scheduler_CA, num_epochs=30)

100%|██████████| 652/652 [03:07<00:00,  3.48it/s]


Epoch [1/30], Loss: 0.6169, Accuracy: 0.6651


100%|██████████| 131/131 [00:30<00:00,  4.29it/s]


Test Accuracy: 0.7317


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [2/30], Loss: 0.5397, Accuracy: 0.7159


100%|██████████| 131/131 [00:30<00:00,  4.23it/s]


Test Accuracy: 0.7492


100%|██████████| 652/652 [03:01<00:00,  3.60it/s]


Epoch [3/30], Loss: 0.5030, Accuracy: 0.7462


100%|██████████| 131/131 [00:30<00:00,  4.26it/s]


Test Accuracy: 0.7837


100%|██████████| 652/652 [03:00<00:00,  3.62it/s]


Epoch [4/30], Loss: 0.4726, Accuracy: 0.7669


100%|██████████| 131/131 [00:30<00:00,  4.25it/s]


Test Accuracy: 0.7864


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [5/30], Loss: 0.4493, Accuracy: 0.7822


100%|██████████| 131/131 [00:30<00:00,  4.24it/s]


Test Accuracy: 0.8149


100%|██████████| 652/652 [02:59<00:00,  3.63it/s]


Epoch [6/30], Loss: 0.4270, Accuracy: 0.7976


100%|██████████| 131/131 [00:30<00:00,  4.25it/s]


Test Accuracy: 0.8341


100%|██████████| 652/652 [03:26<00:00,  3.16it/s]


Epoch [7/30], Loss: 0.4028, Accuracy: 0.8105


100%|██████████| 131/131 [00:42<00:00,  3.05it/s]


Test Accuracy: 0.8466


100%|██████████| 652/652 [03:47<00:00,  2.86it/s]


Epoch [8/30], Loss: 0.3791, Accuracy: 0.8235


100%|██████████| 131/131 [00:42<00:00,  3.05it/s]


Test Accuracy: 0.8629


100%|██████████| 652/652 [03:59<00:00,  2.72it/s]


Epoch [9/30], Loss: 0.3533, Accuracy: 0.8413


100%|██████████| 131/131 [00:43<00:00,  3.01it/s]


Test Accuracy: 0.8756


100%|██████████| 652/652 [03:59<00:00,  2.72it/s]


Epoch [10/30], Loss: 0.3308, Accuracy: 0.8526


100%|██████████| 131/131 [00:43<00:00,  3.03it/s]


Test Accuracy: 0.8943


100%|██████████| 652/652 [03:36<00:00,  3.01it/s]


Epoch [11/30], Loss: 0.2686, Accuracy: 0.8926


100%|██████████| 131/131 [00:43<00:00,  3.04it/s]


Test Accuracy: 0.9173


100%|██████████| 652/652 [03:56<00:00,  2.75it/s]


Epoch [12/30], Loss: 0.2542, Accuracy: 0.8963


100%|██████████| 131/131 [00:41<00:00,  3.14it/s]


Test Accuracy: 0.9163


100%|██████████| 652/652 [03:56<00:00,  2.75it/s]


Epoch [13/30], Loss: 0.2419, Accuracy: 0.9051


100%|██████████| 131/131 [00:31<00:00,  4.20it/s]


Test Accuracy: 0.9238


100%|██████████| 652/652 [03:36<00:00,  3.02it/s]


Epoch [14/30], Loss: 0.2377, Accuracy: 0.9060


100%|██████████| 131/131 [00:30<00:00,  4.32it/s]


Test Accuracy: 0.9254


100%|██████████| 652/652 [03:33<00:00,  3.06it/s]


Epoch [15/30], Loss: 0.2289, Accuracy: 0.9090


100%|██████████| 131/131 [00:30<00:00,  4.37it/s]


Test Accuracy: 0.9310


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [16/30], Loss: 0.2209, Accuracy: 0.9145


100%|██████████| 131/131 [00:30<00:00,  4.29it/s]


Test Accuracy: 0.9341


100%|██████████| 652/652 [02:58<00:00,  3.64it/s]


Epoch [17/30], Loss: 0.2160, Accuracy: 0.9167


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.9353


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [18/30], Loss: 0.2111, Accuracy: 0.9191


100%|██████████| 131/131 [00:30<00:00,  4.28it/s]


Test Accuracy: 0.9379


100%|██████████| 652/652 [03:52<00:00,  2.81it/s]


Epoch [19/30], Loss: 0.2022, Accuracy: 0.9225


100%|██████████| 131/131 [00:33<00:00,  3.94it/s]


Test Accuracy: 0.9417


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [20/30], Loss: 0.2015, Accuracy: 0.9233


100%|██████████| 131/131 [00:30<00:00,  4.35it/s]


Test Accuracy: 0.9482


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [21/30], Loss: 0.1868, Accuracy: 0.9328


100%|██████████| 131/131 [00:30<00:00,  4.28it/s]


Test Accuracy: 0.9470


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [22/30], Loss: 0.1897, Accuracy: 0.9305


100%|██████████| 131/131 [00:30<00:00,  4.30it/s]


Test Accuracy: 0.9470


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [23/30], Loss: 0.1854, Accuracy: 0.9307


100%|██████████| 131/131 [00:29<00:00,  4.46it/s]


Test Accuracy: 0.9485


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [24/30], Loss: 0.1855, Accuracy: 0.9323


100%|██████████| 131/131 [00:29<00:00,  4.39it/s]


Test Accuracy: 0.9475


100%|██████████| 652/652 [03:35<00:00,  3.03it/s]


Epoch [25/30], Loss: 0.1878, Accuracy: 0.9295


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.9489


100%|██████████| 652/652 [03:31<00:00,  3.08it/s]


Epoch [26/30], Loss: 0.1828, Accuracy: 0.9335


100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Test Accuracy: 0.9494


100%|██████████| 652/652 [03:05<00:00,  3.51it/s]


Epoch [27/30], Loss: 0.1859, Accuracy: 0.9311


100%|██████████| 131/131 [00:39<00:00,  3.32it/s]


Test Accuracy: 0.9509


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [28/30], Loss: 0.1837, Accuracy: 0.9325


100%|██████████| 131/131 [00:30<00:00,  4.29it/s]


Test Accuracy: 0.9501


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [29/30], Loss: 0.1844, Accuracy: 0.9332


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]


Test Accuracy: 0.9509


100%|██████████| 652/652 [03:48<00:00,  2.86it/s]


Epoch [30/30], Loss: 0.1851, Accuracy: 0.9301


100%|██████████| 131/131 [00:41<00:00,  3.19it/s]

Test Accuracy: 0.9497





In [47]:
# Save the model with Channel Attention
torch.save(CA_classifier.state_dict(), os.path.join(model_path, "CA_gender_classifier.pth"))

### Spatial Attention Gender Classifier

In [13]:
# Build a model with Spatial Attention
class GenderClassifierWithSA(nn.Module):
    def __init__(self):
            super(GenderClassifierWithSA, self).__init__()
            # Depthwise separable convolutional layers
            self.conv1 = DepthwiseSeparableConv(3, 32)
            self.bn1 = nn.BatchNorm2d(32)
            self.sa1 = SpatialAttention()
            
            self.conv2 = DepthwiseSeparableConv(32, 64)
            self.bn2 = nn.BatchNorm2d(64)
            self.ca2 = ChannelAttention(64)
            self.sa2 = SpatialAttention()
            
            self.conv3 = DepthwiseSeparableConv(64, 128)
            self.bn3 = nn.BatchNorm2d(128)
            self.ca3 = ChannelAttention(128)
            self.sa3 = SpatialAttention()
            
            # Fully connected layers
            self.fc1 = nn.Linear(128 * 16 * 16, 256)
            self.bn_fc1 = nn.BatchNorm1d(256)
            self.fc2 = nn.Linear(256, 2)
            self.dropout = nn.Dropout(0.5)
            self.pool = nn.MaxPool2d(2, 2)
        
        
    def forward(self, x):
        # Convolutional layers with ReLU activation and max pooling
        x = self.pool(self.sa1(F.relu(self.bn1(self.conv1(x)))))
        x = self.pool(self.sa2(F.relu(self.bn2(self.conv2(x)))))
        x = self.pool(self.sa3(F.relu(self.bn3(self.conv3(x)))))
        
        # Flatten the tensor
        x = x.view(x.size(0), -1)
        # Fully connected layers with ReLU activation and dropout
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout(x)
        # Output layer
        x = self.fc2(x)
        return x
        
        

In [58]:
# Initialize the model, loss function, and optimizer with Spatial Attention
SA_classifier = GenderClassifierWithSA()
SA_classifier.to(device)
criterion_SA = nn.CrossEntropyLoss()
optimizer_SA = optim.Adam(SA_classifier.parameters(), lr=0.0001, weight_decay=1e-5)
lr_scheduler_SA = StepLR(optimizer_SA, step_size=10, gamma=0.1)


In [59]:
# Train the model with Spatial Attention
train_model(SA_classifier, train_loader, criterion_SA, optimizer_SA, lr_scheduler=lr_scheduler_SA, num_epochs=30)

100%|██████████| 652/652 [03:48<00:00,  2.85it/s]


Epoch [1/30], Loss: 0.6174, Accuracy: 0.6631


100%|██████████| 131/131 [00:36<00:00,  3.56it/s]


Test Accuracy: 0.7363


100%|██████████| 652/652 [03:15<00:00,  3.34it/s]


Epoch [2/30], Loss: 0.5446, Accuracy: 0.7141


100%|██████████| 131/131 [00:29<00:00,  4.47it/s]


Test Accuracy: 0.7540


100%|██████████| 652/652 [03:08<00:00,  3.46it/s]


Epoch [3/30], Loss: 0.5073, Accuracy: 0.7434


100%|██████████| 131/131 [00:41<00:00,  3.19it/s]


Test Accuracy: 0.7895


100%|██████████| 652/652 [03:49<00:00,  2.84it/s]


Epoch [4/30], Loss: 0.4773, Accuracy: 0.7655


100%|██████████| 131/131 [00:41<00:00,  3.18it/s]


Test Accuracy: 0.7878


100%|██████████| 652/652 [03:03<00:00,  3.55it/s]


Epoch [5/30], Loss: 0.4481, Accuracy: 0.7804


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]


Test Accuracy: 0.8168


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [6/30], Loss: 0.4250, Accuracy: 0.7983


100%|██████████| 131/131 [00:29<00:00,  4.44it/s]


Test Accuracy: 0.8370


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [7/30], Loss: 0.3978, Accuracy: 0.8143


100%|██████████| 131/131 [00:29<00:00,  4.45it/s]


Test Accuracy: 0.8607


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [8/30], Loss: 0.3712, Accuracy: 0.8305


100%|██████████| 131/131 [00:29<00:00,  4.41it/s]


Test Accuracy: 0.8749


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [9/30], Loss: 0.3360, Accuracy: 0.8514


100%|██████████| 131/131 [00:29<00:00,  4.45it/s]


Test Accuracy: 0.8945


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [10/30], Loss: 0.3082, Accuracy: 0.8664


100%|██████████| 131/131 [00:29<00:00,  4.46it/s]


Test Accuracy: 0.9202


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [11/30], Loss: 0.2340, Accuracy: 0.9100


100%|██████████| 131/131 [00:30<00:00,  4.36it/s]


Test Accuracy: 0.9389


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [12/30], Loss: 0.2140, Accuracy: 0.9175


100%|██████████| 131/131 [00:29<00:00,  4.46it/s]


Test Accuracy: 0.9434


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [13/30], Loss: 0.2028, Accuracy: 0.9206


100%|██████████| 131/131 [00:29<00:00,  4.44it/s]


Test Accuracy: 0.9480


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [14/30], Loss: 0.1876, Accuracy: 0.9273


100%|██████████| 131/131 [00:29<00:00,  4.42it/s]


Test Accuracy: 0.9523


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [15/30], Loss: 0.1837, Accuracy: 0.9287


100%|██████████| 131/131 [00:30<00:00,  4.31it/s]


Test Accuracy: 0.9580


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [16/30], Loss: 0.1696, Accuracy: 0.9350


100%|██████████| 131/131 [00:29<00:00,  4.38it/s]


Test Accuracy: 0.9600


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [17/30], Loss: 0.1612, Accuracy: 0.9394


100%|██████████| 131/131 [00:30<00:00,  4.36it/s]


Test Accuracy: 0.9626


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [18/30], Loss: 0.1580, Accuracy: 0.9410


100%|██████████| 131/131 [00:29<00:00,  4.43it/s]


Test Accuracy: 0.9669


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [19/30], Loss: 0.1475, Accuracy: 0.9444


100%|██████████| 131/131 [00:29<00:00,  4.37it/s]


Test Accuracy: 0.9691


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [20/30], Loss: 0.1449, Accuracy: 0.9450


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]


Test Accuracy: 0.9679


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [21/30], Loss: 0.1315, Accuracy: 0.9530


100%|██████████| 131/131 [00:29<00:00,  4.45it/s]


Test Accuracy: 0.9703


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [22/30], Loss: 0.1279, Accuracy: 0.9536


100%|██████████| 131/131 [00:30<00:00,  4.35it/s]


Test Accuracy: 0.9715


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [23/30], Loss: 0.1280, Accuracy: 0.9542


100%|██████████| 131/131 [00:29<00:00,  4.46it/s]


Test Accuracy: 0.9724


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [24/30], Loss: 0.1270, Accuracy: 0.9557


100%|██████████| 131/131 [00:30<00:00,  4.35it/s]


Test Accuracy: 0.9727


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [25/30], Loss: 0.1261, Accuracy: 0.9555


100%|██████████| 131/131 [00:30<00:00,  4.36it/s]


Test Accuracy: 0.9729


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [26/30], Loss: 0.1248, Accuracy: 0.9555


100%|██████████| 131/131 [00:29<00:00,  4.41it/s]


Test Accuracy: 0.9731


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [27/30], Loss: 0.1265, Accuracy: 0.9534


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]


Test Accuracy: 0.9734


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [28/30], Loss: 0.1198, Accuracy: 0.9580


100%|██████████| 131/131 [00:29<00:00,  4.43it/s]


Test Accuracy: 0.9741


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [29/30], Loss: 0.1212, Accuracy: 0.9560


100%|██████████| 131/131 [00:29<00:00,  4.39it/s]


Test Accuracy: 0.9739


100%|██████████| 652/652 [02:56<00:00,  3.68it/s]


Epoch [30/30], Loss: 0.1241, Accuracy: 0.9543


100%|██████████| 131/131 [00:29<00:00,  4.39it/s]

Test Accuracy: 0.9748





In [60]:
# Save the model with
torch.save(SA_classifier.state_dict(), os.path.join(model_path, "SA_gender_classifier.pth"))

### CBAM Gender Classifier

In [None]:
# Build a model with CBAM
class GenderClassifierWithCBAM(nn.Module):
    def __init__(self):
        super(GenderClassifierWithCBAM, self).__init__()
        # Depthwise separable convolutional layers
        self.conv1 = DepthwiseSeparableConv(3, 32)
        self.bn1 = nn.BatchNorm2d(32)
        self.cbam1 = CBAM(32)
        self.dropout1 = nn.Dropout2d(0.3)
        
        self.conv2 = DepthwiseSeparableConv(32, 64)
        self.bn2 = nn.BatchNorm2d(64)
        self.cbam2 = CBAM(64)
        self.dropout2 = nn.Dropout2d(0.3)
        
        self.conv3 = DepthwiseSeparableConv(64, 128)
        self.bn3 = nn.BatchNorm2d(128)
        self.cbam3 = CBAM(128)
        self.dropout3 = nn.Dropout2d(0.3)
        
        self.conv4 = DepthwiseSeparableConv(128, 64)
        self.bn4 = nn.BatchNorm2d(64)
        self.cbam4 = CBAM(64)
        self.dropout4 = nn.Dropout2d(0.3)
        
        # Fully connected layers
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 2)
        self.dropout = nn.Dropout(0.5)
        self.pool = nn.MaxPool2d(2, 2)
        
        
    def forward(self, x):
        # Convolutional layers with ReLU activation and max pooling
        x = self.pool(self.cbam1(F.relu(self.bn1(self.conv1(x)))))
        x = self.pool(self.cbam2(F.relu(self.bn2(self.conv2(x)))))
        x = self.pool(self.cbam3(F.relu(self.bn3(self.conv3(x)))))
        x = self.pool(self.cbam4(F.relu(self.bn4(self.conv4(x)))))
        
        # Flatten the tensor
        x = x.view(x.size(0), -1)
        # Fully connected layers with ReLU activation and dropout
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout(x)
        # Output layer
        x = self.fc2(x)
        return x

In [15]:
# Initialize the model, loss function, and optimizer with CBAM
CBAM_classifier = GenderClassifierWithCBAM()
CBAM_classifier.to(device)
criterion_CBAM = nn.CrossEntropyLoss()
optimizer_CBAM = optim.Adam(CBAM_classifier.parameters(), lr=0.0001, weight_decay=1e-5)
lr_scheduler_CBAM = StepLR(optimizer_CBAM, step_size=10, gamma=0.1)



In [17]:
# Train the model with CBAM
train_model(CBAM_classifier, train_loader, criterion_CBAM, optimizer_CBAM, lr_scheduler=lr_scheduler_CBAM, num_epochs=35)

100%|██████████| 652/652 [03:49<00:00,  2.84it/s]


Epoch [1/35], Loss: 0.6004, Accuracy: 0.6771


100%|██████████| 131/131 [00:30<00:00,  4.25it/s]


Test Accuracy: 0.7473


100%|██████████| 652/652 [03:46<00:00,  2.88it/s]


Epoch [2/35], Loss: 0.5373, Accuracy: 0.7207


100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Test Accuracy: 0.7576


100%|██████████| 652/652 [03:35<00:00,  3.03it/s]


Epoch [3/35], Loss: 0.4961, Accuracy: 0.7468


100%|██████████| 131/131 [00:41<00:00,  3.18it/s]


Test Accuracy: 0.7691


100%|██████████| 652/652 [03:10<00:00,  3.43it/s]


Epoch [4/35], Loss: 0.4576, Accuracy: 0.7735


100%|██████████| 131/131 [00:30<00:00,  4.27it/s]


Test Accuracy: 0.8070


100%|██████████| 652/652 [03:11<00:00,  3.41it/s]


Epoch [5/35], Loss: 0.4328, Accuracy: 0.7907


100%|██████████| 131/131 [00:30<00:00,  4.28it/s]


Test Accuracy: 0.8319


100%|██████████| 652/652 [03:00<00:00,  3.60it/s]


Epoch [6/35], Loss: 0.3971, Accuracy: 0.8130


100%|██████████| 131/131 [00:30<00:00,  4.29it/s]


Test Accuracy: 0.8547


100%|██████████| 652/652 [02:59<00:00,  3.64it/s]


Epoch [7/35], Loss: 0.3733, Accuracy: 0.8257


100%|██████████| 131/131 [00:30<00:00,  4.30it/s]


Test Accuracy: 0.8614


100%|██████████| 652/652 [02:59<00:00,  3.64it/s]


Epoch [8/35], Loss: 0.3412, Accuracy: 0.8448


100%|██████████| 131/131 [00:30<00:00,  4.29it/s]


Test Accuracy: 0.8840


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [9/35], Loss: 0.3074, Accuracy: 0.8651


100%|██████████| 131/131 [00:30<00:00,  4.28it/s]


Test Accuracy: 0.8940


100%|██████████| 652/652 [02:59<00:00,  3.64it/s]


Epoch [10/35], Loss: 0.2787, Accuracy: 0.8787


100%|██████████| 131/131 [00:30<00:00,  4.29it/s]


Test Accuracy: 0.9082


100%|██████████| 652/652 [02:59<00:00,  3.64it/s]


Epoch [11/35], Loss: 0.2008, Accuracy: 0.9211


100%|██████████| 131/131 [00:31<00:00,  4.12it/s]


Test Accuracy: 0.9446


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [12/35], Loss: 0.1779, Accuracy: 0.9298


100%|██████████| 131/131 [00:30<00:00,  4.28it/s]


Test Accuracy: 0.9492


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [13/35], Loss: 0.1673, Accuracy: 0.9370


100%|██████████| 131/131 [00:30<00:00,  4.31it/s]


Test Accuracy: 0.9535


100%|██████████| 652/652 [03:09<00:00,  3.43it/s]


Epoch [14/35], Loss: 0.1615, Accuracy: 0.9384


100%|██████████| 131/131 [00:40<00:00,  3.21it/s]


Test Accuracy: 0.9578


100%|██████████| 652/652 [03:12<00:00,  3.38it/s]


Epoch [15/35], Loss: 0.1485, Accuracy: 0.9452


100%|██████████| 131/131 [00:30<00:00,  4.33it/s]


Test Accuracy: 0.9612


100%|██████████| 652/652 [03:32<00:00,  3.07it/s]


Epoch [16/35], Loss: 0.1451, Accuracy: 0.9450


100%|██████████| 131/131 [00:31<00:00,  4.22it/s]


Test Accuracy: 0.9636


100%|██████████| 652/652 [03:46<00:00,  2.88it/s]


Epoch [17/35], Loss: 0.1344, Accuracy: 0.9511


100%|██████████| 131/131 [00:42<00:00,  3.11it/s]


Test Accuracy: 0.9667


100%|██████████| 652/652 [03:49<00:00,  2.85it/s]


Epoch [18/35], Loss: 0.1262, Accuracy: 0.9534


100%|██████████| 131/131 [00:43<00:00,  3.03it/s]


Test Accuracy: 0.9700


100%|██████████| 652/652 [03:07<00:00,  3.48it/s]


Epoch [19/35], Loss: 0.1200, Accuracy: 0.9551


100%|██████████| 131/131 [00:34<00:00,  3.77it/s]


Test Accuracy: 0.9710


100%|██████████| 652/652 [03:07<00:00,  3.48it/s]


Epoch [20/35], Loss: 0.1143, Accuracy: 0.9584


100%|██████████| 131/131 [00:31<00:00,  4.12it/s]


Test Accuracy: 0.9751


100%|██████████| 652/652 [03:08<00:00,  3.46it/s]


Epoch [21/35], Loss: 0.1032, Accuracy: 0.9626


100%|██████████| 131/131 [00:31<00:00,  4.17it/s]


Test Accuracy: 0.9767


100%|██████████| 652/652 [03:19<00:00,  3.27it/s]


Epoch [22/35], Loss: 0.1039, Accuracy: 0.9630


100%|██████████| 131/131 [00:30<00:00,  4.25it/s]


Test Accuracy: 0.9765


100%|██████████| 652/652 [03:52<00:00,  2.80it/s]


Epoch [23/35], Loss: 0.1002, Accuracy: 0.9640


100%|██████████| 131/131 [00:38<00:00,  3.41it/s]


Test Accuracy: 0.9765


100%|██████████| 652/652 [03:10<00:00,  3.43it/s]


Epoch [24/35], Loss: 0.0990, Accuracy: 0.9655


100%|██████████| 131/131 [00:30<00:00,  4.30it/s]


Test Accuracy: 0.9777


100%|██████████| 652/652 [03:08<00:00,  3.46it/s]


Epoch [25/35], Loss: 0.0988, Accuracy: 0.9645


100%|██████████| 131/131 [00:30<00:00,  4.25it/s]


Test Accuracy: 0.9779


100%|██████████| 652/652 [03:19<00:00,  3.27it/s]


Epoch [26/35], Loss: 0.0956, Accuracy: 0.9661


100%|██████████| 131/131 [00:34<00:00,  3.78it/s]


Test Accuracy: 0.9789


100%|██████████| 652/652 [03:15<00:00,  3.34it/s]


Epoch [27/35], Loss: 0.0951, Accuracy: 0.9651


100%|██████████| 131/131 [00:31<00:00,  4.13it/s]


Test Accuracy: 0.9789


100%|██████████| 652/652 [03:07<00:00,  3.47it/s]


Epoch [28/35], Loss: 0.0976, Accuracy: 0.9660


100%|██████████| 131/131 [00:31<00:00,  4.11it/s]


Test Accuracy: 0.9799


100%|██████████| 652/652 [03:23<00:00,  3.21it/s]


Epoch [29/35], Loss: 0.0958, Accuracy: 0.9657


100%|██████████| 131/131 [00:31<00:00,  4.11it/s]


Test Accuracy: 0.9794


100%|██████████| 652/652 [03:08<00:00,  3.46it/s]


Epoch [30/35], Loss: 0.0946, Accuracy: 0.9664


100%|██████████| 131/131 [00:30<00:00,  4.24it/s]


Test Accuracy: 0.9801


100%|██████████| 652/652 [03:01<00:00,  3.59it/s]


Epoch [31/35], Loss: 0.0920, Accuracy: 0.9676


100%|██████████| 131/131 [00:30<00:00,  4.27it/s]


Test Accuracy: 0.9801


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [32/35], Loss: 0.0933, Accuracy: 0.9677


100%|██████████| 131/131 [00:31<00:00,  4.16it/s]


Test Accuracy: 0.9799


100%|██████████| 652/652 [03:14<00:00,  3.35it/s]


Epoch [33/35], Loss: 0.0924, Accuracy: 0.9672


100%|██████████| 131/131 [00:41<00:00,  3.14it/s]


Test Accuracy: 0.9799


100%|██████████| 652/652 [03:17<00:00,  3.31it/s]


Epoch [34/35], Loss: 0.0955, Accuracy: 0.9668


100%|██████████| 131/131 [00:29<00:00,  4.39it/s]


Test Accuracy: 0.9803


100%|██████████| 652/652 [02:59<00:00,  3.63it/s]


Epoch [35/35], Loss: 0.0950, Accuracy: 0.9657


100%|██████████| 131/131 [00:30<00:00,  4.28it/s]

Test Accuracy: 0.9806





In [18]:
# Save the model with CBAM
torch.save(CBAM_classifier.state_dict(), os.path.join(model_path, "CBAM_gender_classifier.pth"))

## Squeeze-and-Excitation (SE) Block

In [25]:
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEBlock, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1) # Calculate the average weight of each channel
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        b, c, _, _ = x.size() # Batch size and number of channels
        y = self.global_avg_pool(x).view(b, c) # [B, C]
        y = self.fc(y).view(b, c, 1, 1) # [B, C, 1, 1]
        
        return x * y.expand_as(x)  # Scaled channel wise
        


In [26]:
class GenderClassifierWithSE(nn.Module):
    def __init__(self):
        super(GenderClassifierWithSE, self).__init__()
        self.conv1 = DepthwiseSeparableConv(3, 32)
        self.bn1 = nn.BatchNorm2d(32)
        self.se1 = SEBlock(32)
        
        self.conv2 = DepthwiseSeparableConv(32, 64)
        self.bn2 = nn.BatchNorm2d(64)
        self.se2 = SEBlock(64)
        
        self.conv3 = DepthwiseSeparableConv(64, 128)
        self.bn3 = nn.BatchNorm2d(128)
        self.se3 = SEBlock(128)
        
        # Fully connected layers
        self.fc1 = nn.Linear(128 * 16 * 16, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 2)
        self.dropout = nn.Dropout(0.5)
        self.pool = nn.MaxPool2d(2, 2)
        
        
    def forward(self, x):
        # Convolutional layers with ReLU activation and max pooling
        x = self.pool(self.se1(F.relu(self.bn1(self.conv1(x)))))
        x = self.pool(self.se2(F.relu(self.bn2(self.conv2(x)))))
        x = self.pool(self.se3(F.relu(self.bn3(self.conv3(x)))))
        
        # Flatten the tensor
        x = x.view(x.size(0), -1)
        # Fully connected layers with ReLU activation and dropout
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout(x)
        # Output layer
        x = self.fc2(x)
        return x
        


In [27]:
# Initialize the models with SE 
SE_classifier = GenderClassifierWithSE()
SE_classifier.to(device)
criterion_SE = nn.CrossEntropyLoss()
optimizer_SE = optim.Adam(SE_classifier.parameters(), lr=0.0001, weight_decay=1e-5)
lr_scheduler_SE = StepLR(optimizer_SE, step_size=10, gamma=0.1)


In [28]:
# Train the model with SE
train_model(SE_classifier, train_loader, criterion_SE, optimizer_SE, lr_scheduler=lr_scheduler_SE, num_epochs=35)

100%|██████████| 652/652 [03:47<00:00,  2.86it/s]


Epoch [1/35], Loss: 0.6229, Accuracy: 0.6580


100%|██████████| 131/131 [00:39<00:00,  3.28it/s]


Test Accuracy: 0.7181


100%|██████████| 652/652 [03:54<00:00,  2.79it/s]


Epoch [2/35], Loss: 0.5515, Accuracy: 0.7075


100%|██████████| 131/131 [00:40<00:00,  3.20it/s]


Test Accuracy: 0.7365


100%|██████████| 652/652 [03:19<00:00,  3.26it/s]


Epoch [3/35], Loss: 0.5123, Accuracy: 0.7399


100%|██████████| 131/131 [00:30<00:00,  4.25it/s]


Test Accuracy: 0.7703


100%|██████████| 652/652 [02:59<00:00,  3.63it/s]


Epoch [4/35], Loss: 0.4852, Accuracy: 0.7560


100%|██████████| 131/131 [00:31<00:00,  4.17it/s]


Test Accuracy: 0.7588


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [5/35], Loss: 0.4617, Accuracy: 0.7743


100%|██████████| 131/131 [00:31<00:00,  4.19it/s]


Test Accuracy: 0.8012


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [6/35], Loss: 0.4407, Accuracy: 0.7915


100%|██████████| 131/131 [00:30<00:00,  4.23it/s]


Test Accuracy: 0.8214


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [7/35], Loss: 0.4188, Accuracy: 0.8016


100%|██████████| 131/131 [00:31<00:00,  4.18it/s]


Test Accuracy: 0.8334


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [8/35], Loss: 0.3987, Accuracy: 0.8146


100%|██████████| 131/131 [00:31<00:00,  4.10it/s]


Test Accuracy: 0.8504


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [9/35], Loss: 0.3818, Accuracy: 0.8253


100%|██████████| 131/131 [00:31<00:00,  4.17it/s]


Test Accuracy: 0.8523


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [10/35], Loss: 0.3577, Accuracy: 0.8359


100%|██████████| 131/131 [00:31<00:00,  4.20it/s]


Test Accuracy: 0.8806


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [11/35], Loss: 0.3018, Accuracy: 0.8752


100%|██████████| 131/131 [00:31<00:00,  4.19it/s]


Test Accuracy: 0.8933


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [12/35], Loss: 0.2909, Accuracy: 0.8796


100%|██████████| 131/131 [00:31<00:00,  4.14it/s]


Test Accuracy: 0.8986


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [13/35], Loss: 0.2840, Accuracy: 0.8811


100%|██████████| 131/131 [00:31<00:00,  4.21it/s]


Test Accuracy: 0.9007


100%|██████████| 652/652 [03:00<00:00,  3.62it/s]


Epoch [14/35], Loss: 0.2773, Accuracy: 0.8824


100%|██████████| 131/131 [00:30<00:00,  4.25it/s]


Test Accuracy: 0.9017


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [15/35], Loss: 0.2743, Accuracy: 0.8877


100%|██████████| 131/131 [00:31<00:00,  4.19it/s]


Test Accuracy: 0.9084


100%|██████████| 652/652 [02:56<00:00,  3.68it/s]


Epoch [16/35], Loss: 0.2660, Accuracy: 0.8909


100%|██████████| 131/131 [00:31<00:00,  4.17it/s]


Test Accuracy: 0.9087


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [17/35], Loss: 0.2602, Accuracy: 0.8926


100%|██████████| 131/131 [00:31<00:00,  4.20it/s]


Test Accuracy: 0.9135


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [18/35], Loss: 0.2557, Accuracy: 0.8980


100%|██████████| 131/131 [00:31<00:00,  4.16it/s]


Test Accuracy: 0.9182


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [19/35], Loss: 0.2475, Accuracy: 0.9038


100%|██████████| 131/131 [00:30<00:00,  4.25it/s]


Test Accuracy: 0.9235


100%|██████████| 652/652 [02:57<00:00,  3.66it/s]


Epoch [20/35], Loss: 0.2425, Accuracy: 0.9034


100%|██████████| 131/131 [00:31<00:00,  4.16it/s]


Test Accuracy: 0.9259


100%|██████████| 652/652 [02:59<00:00,  3.63it/s]


Epoch [21/35], Loss: 0.2369, Accuracy: 0.9071


100%|██████████| 131/131 [00:31<00:00,  4.15it/s]


Test Accuracy: 0.9271


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [22/35], Loss: 0.2348, Accuracy: 0.9084


100%|██████████| 131/131 [00:31<00:00,  4.22it/s]


Test Accuracy: 0.9281


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [23/35], Loss: 0.2346, Accuracy: 0.9092


100%|██████████| 131/131 [00:31<00:00,  4.21it/s]


Test Accuracy: 0.9278


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [24/35], Loss: 0.2334, Accuracy: 0.9095


100%|██████████| 131/131 [00:31<00:00,  4.18it/s]


Test Accuracy: 0.9295


100%|██████████| 652/652 [02:58<00:00,  3.64it/s]


Epoch [25/35], Loss: 0.2328, Accuracy: 0.9099


100%|██████████| 131/131 [00:31<00:00,  4.15it/s]


Test Accuracy: 0.9290


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [26/35], Loss: 0.2310, Accuracy: 0.9108


100%|██████████| 131/131 [00:31<00:00,  4.19it/s]


Test Accuracy: 0.9295


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [27/35], Loss: 0.2321, Accuracy: 0.9101


100%|██████████| 131/131 [00:31<00:00,  4.21it/s]


Test Accuracy: 0.9298


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [28/35], Loss: 0.2308, Accuracy: 0.9099


100%|██████████| 131/131 [00:31<00:00,  4.15it/s]


Test Accuracy: 0.9302


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [29/35], Loss: 0.2297, Accuracy: 0.9105


100%|██████████| 131/131 [00:31<00:00,  4.20it/s]


Test Accuracy: 0.9302


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [30/35], Loss: 0.2285, Accuracy: 0.9101


100%|██████████| 131/131 [00:31<00:00,  4.22it/s]


Test Accuracy: 0.9314


100%|██████████| 652/652 [03:00<00:00,  3.62it/s]


Epoch [31/35], Loss: 0.2298, Accuracy: 0.9103


100%|██████████| 131/131 [00:31<00:00,  4.16it/s]


Test Accuracy: 0.9310


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [32/35], Loss: 0.2287, Accuracy: 0.9108


100%|██████████| 131/131 [00:31<00:00,  4.19it/s]


Test Accuracy: 0.9310


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [33/35], Loss: 0.2283, Accuracy: 0.9127


100%|██████████| 131/131 [00:31<00:00,  4.19it/s]


Test Accuracy: 0.9307


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [34/35], Loss: 0.2291, Accuracy: 0.9124


100%|██████████| 131/131 [00:30<00:00,  4.26it/s]


Test Accuracy: 0.9312


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [35/35], Loss: 0.2258, Accuracy: 0.9141


100%|██████████| 131/131 [00:30<00:00,  4.27it/s]

Test Accuracy: 0.9312





In [29]:
# Save the model with SE
torch.save(SE_classifier.state_dict(), os.path.join(model_path, "SE_Gender_classifier.pth"))

### SE Block with Spatial Attention

In [30]:
class GenderClassifierWithSAandSE(nn.Module):
    def __init__(self):
        super(GenderClassifierWithSAandSE, self).__init__()
        # Depthwise separable convolutional layers
        self.conv1 = DepthwiseSeparableConv(3, 32)
        self.bn1 = nn.BatchNorm2d(32)
        self.se1 = SEBlock(32)
        self.sa1 = SpatialAttention()

        
        self.conv2 = DepthwiseSeparableConv(32, 64)
        self.bn2 = nn.BatchNorm2d(64)
        self.se2 = SEBlock(64)
        self.sa2 = SpatialAttention()
        
        
        self.conv3 = DepthwiseSeparableConv(64, 128)
        self.bn3 = nn.BatchNorm2d(128)
        self.se3 = SEBlock(128)
        self.sa3 = SpatialAttention()
        
        
        # Fully connected layers
        self.fc1 = nn.Linear(128 * 16 * 16, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 2)
        self.dropout = nn.Dropout(0.5)
        self.pool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        # Convolutional layers with ReLU activation and max pooling
        x = self.pool(self.sa1(self.se1(F.relu(self.bn1(self.conv1(x))))))
        x = self.pool(self.sa2(self.se2(F.relu(self.bn2(self.conv2(x))))))
        x = self.pool(self.sa3(self.se3(F.relu(self.bn3(self.conv3(x))))))
        
        # Flatten the tensor
        x = x.view(x.size(0), -1)
        # Fully connected layers with ReLU activation and dropout
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout(x)
        # Output layer
        x = self.fc2(x)
        return x

In [31]:
# Initialize the model with SA and SE
SA_SE_classifier = GenderClassifierWithSAandSE()
SA_SE_classifier.to(device)
criterion_SA_SE = nn.CrossEntropyLoss()
optimizer_SA_SE = optim.Adam(SA_SE_classifier.parameters(), lr=0.0001, weight_decay=1e-5)
lr_scheduler_SA_SE = StepLR(optimizer_SA_SE, step_size=10, gamma=0.1)

In [32]:
# Train the model with SA and SE 
train_model(SA_SE_classifier, train_loader, criterion_SA_SE, optimizer_SA_SE, lr_scheduler=lr_scheduler_SA_SE, num_epochs=35)

100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [1/35], Loss: 0.6151, Accuracy: 0.6616


100%|██████████| 131/131 [00:29<00:00,  4.45it/s]


Test Accuracy: 0.7356


100%|██████████| 652/652 [03:01<00:00,  3.60it/s]


Epoch [2/35], Loss: 0.5446, Accuracy: 0.7187


100%|██████████| 131/131 [00:30<00:00,  4.35it/s]


Test Accuracy: 0.7567


100%|██████████| 652/652 [02:59<00:00,  3.63it/s]


Epoch [3/35], Loss: 0.5011, Accuracy: 0.7450


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]


Test Accuracy: 0.7833


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [4/35], Loss: 0.4699, Accuracy: 0.7664


100%|██████████| 131/131 [00:29<00:00,  4.38it/s]


Test Accuracy: 0.8044


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [5/35], Loss: 0.4400, Accuracy: 0.7873


100%|██████████| 131/131 [00:29<00:00,  4.37it/s]


Test Accuracy: 0.8235


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [6/35], Loss: 0.4085, Accuracy: 0.8069


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.8497


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [7/35], Loss: 0.3763, Accuracy: 0.8245


100%|██████████| 131/131 [00:29<00:00,  4.37it/s]


Test Accuracy: 0.8789


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [8/35], Loss: 0.3437, Accuracy: 0.8475


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]


Test Accuracy: 0.8955


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [9/35], Loss: 0.3085, Accuracy: 0.8658


100%|██████████| 131/131 [00:29<00:00,  4.39it/s]


Test Accuracy: 0.9254


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [10/35], Loss: 0.2695, Accuracy: 0.8878


100%|██████████| 131/131 [00:29<00:00,  4.41it/s]


Test Accuracy: 0.9271


100%|██████████| 652/652 [02:57<00:00,  3.66it/s]


Epoch [11/35], Loss: 0.1971, Accuracy: 0.9256


100%|██████████| 131/131 [00:29<00:00,  4.38it/s]


Test Accuracy: 0.9523


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [12/35], Loss: 0.1717, Accuracy: 0.9356


100%|██████████| 131/131 [00:29<00:00,  4.38it/s]


Test Accuracy: 0.9571


100%|██████████| 652/652 [02:57<00:00,  3.66it/s]


Epoch [13/35], Loss: 0.1624, Accuracy: 0.9381


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]


Test Accuracy: 0.9607


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [14/35], Loss: 0.1542, Accuracy: 0.9436


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]


Test Accuracy: 0.9640


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [15/35], Loss: 0.1437, Accuracy: 0.9475


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]


Test Accuracy: 0.9667


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [16/35], Loss: 0.1365, Accuracy: 0.9496


100%|██████████| 131/131 [00:29<00:00,  4.39it/s]


Test Accuracy: 0.9700


100%|██████████| 652/652 [02:57<00:00,  3.66it/s]


Epoch [17/35], Loss: 0.1328, Accuracy: 0.9502


100%|██████████| 131/131 [00:29<00:00,  4.38it/s]


Test Accuracy: 0.9734


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [18/35], Loss: 0.1225, Accuracy: 0.9566


100%|██████████| 131/131 [00:29<00:00,  4.38it/s]


Test Accuracy: 0.9770


100%|██████████| 652/652 [02:58<00:00,  3.65it/s]


Epoch [19/35], Loss: 0.1133, Accuracy: 0.9600


100%|██████████| 131/131 [00:29<00:00,  4.38it/s]


Test Accuracy: 0.9801


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [20/35], Loss: 0.1097, Accuracy: 0.9616


100%|██████████| 131/131 [00:29<00:00,  4.41it/s]


Test Accuracy: 0.9820


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [21/35], Loss: 0.1029, Accuracy: 0.9663


100%|██████████| 131/131 [00:29<00:00,  4.42it/s]


Test Accuracy: 0.9823


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [22/35], Loss: 0.0982, Accuracy: 0.9672


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]


Test Accuracy: 0.9820


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [23/35], Loss: 0.0946, Accuracy: 0.9689


100%|██████████| 131/131 [00:29<00:00,  4.39it/s]


Test Accuracy: 0.9837


100%|██████████| 652/652 [02:58<00:00,  3.66it/s]


Epoch [24/35], Loss: 0.0949, Accuracy: 0.9687


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]


Test Accuracy: 0.9842


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [25/35], Loss: 0.0932, Accuracy: 0.9683


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]


Test Accuracy: 0.9837


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [26/35], Loss: 0.0942, Accuracy: 0.9682


100%|██████████| 131/131 [00:29<00:00,  4.38it/s]


Test Accuracy: 0.9832


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [27/35], Loss: 0.0906, Accuracy: 0.9692


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]


Test Accuracy: 0.9835


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [28/35], Loss: 0.0904, Accuracy: 0.9698


100%|██████████| 131/131 [00:29<00:00,  4.38it/s]


Test Accuracy: 0.9842


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [29/35], Loss: 0.0911, Accuracy: 0.9703


100%|██████████| 131/131 [00:29<00:00,  4.41it/s]


Test Accuracy: 0.9835


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [30/35], Loss: 0.0936, Accuracy: 0.9693


100%|██████████| 131/131 [00:29<00:00,  4.39it/s]


Test Accuracy: 0.9837


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [31/35], Loss: 0.0898, Accuracy: 0.9704


100%|██████████| 131/131 [00:29<00:00,  4.44it/s]


Test Accuracy: 0.9844


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [32/35], Loss: 0.0895, Accuracy: 0.9705


100%|██████████| 131/131 [00:29<00:00,  4.43it/s]


Test Accuracy: 0.9844


100%|██████████| 652/652 [02:57<00:00,  3.68it/s]


Epoch [33/35], Loss: 0.0887, Accuracy: 0.9700


100%|██████████| 131/131 [00:29<00:00,  4.42it/s]


Test Accuracy: 0.9844


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [34/35], Loss: 0.0890, Accuracy: 0.9708


100%|██████████| 131/131 [00:29<00:00,  4.41it/s]


Test Accuracy: 0.9847


100%|██████████| 652/652 [02:57<00:00,  3.67it/s]


Epoch [35/35], Loss: 0.0879, Accuracy: 0.9713


100%|██████████| 131/131 [00:29<00:00,  4.40it/s]

Test Accuracy: 0.9842





In [33]:
# Save the model with SE and SA
torch.save(SA_SE_classifier.state_dict(), os.path.join(model_path, "SA_and_SE_Gender_classifier.pth"))

### SE Block with CBAM

In [34]:
class GenderClassifierWithCBAMandSE(nn.Module):
    def __init__(self):
        super(GenderClassifierWithCBAMandSE, self).__init__()
        # Depthwise separable convolutional layers
        self.conv1 = DepthwiseSeparableConv(3, 32)
        self.bn1 = nn.BatchNorm2d(32)
        self.cbam1 = CBAM(32)
        self.se1 = SEBlock(32)
        
        self.conv2 = DepthwiseSeparableConv(32, 64)
        self.bn2 = nn.BatchNorm2d(64)
        self.cbam2 = CBAM(64)
        self.se2 = SEBlock(64)
        
        self.conv3 = DepthwiseSeparableConv(64, 128)
        self.bn3 = nn.BatchNorm2d(128)
        self.cbam3 = CBAM(128)
        self.se3 = SEBlock(128)
        
        # Fully connected layers
        self.fc1 = nn.Linear(128 * 16 * 16, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 2)
        self.dropout = nn.Dropout(0.5)
        self.pool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        x = self.pool(self.cbam1(self.se1(F.relu(self.bn1(self.conv1(x))))))
        x = self.pool(self.cbam2(self.se2(F.relu(self.bn2(self.conv2(x))))))
        x = self.pool(self.cbam3(self.se3(F.relu(self.bn3(self.conv3(x))))))
        
        # Flatten the tensor
        x = x.view(x.size(0), -1)
        # Fully connected layers with ReLU activation and dropout
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout(x)
        # Output layer
        x = self.fc2(x)
        return x




In [35]:
# Initialize the model with CBAM and SE
CBAM_SE_classifier = GenderClassifierWithCBAMandSE()
CBAM_SE_classifier.to(device)
criterion_CBAM_SE = nn.CrossEntropyLoss()
optimizer_CBAM_SE = optim.Adam(CBAM_SE_classifier.parameters(), lr=0.0001, weight_decay=1e-5)
lr_scheduler_CBAM_SE = StepLR(optimizer_CBAM_SE, step_size=10, gamma=0.1)

In [36]:
# Train the model with CBAM and SE
train_model(CBAM_SE_classifier, train_loader, criterion_CBAM_SE, optimizer_CBAM_SE, lr_scheduler=lr_scheduler_CBAM_SE, num_epochs=35)

100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [1/35], Loss: 0.6236, Accuracy: 0.6515


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.7276


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [2/35], Loss: 0.5634, Accuracy: 0.6993


100%|██████████| 131/131 [00:30<00:00,  4.33it/s]


Test Accuracy: 0.7384


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [3/35], Loss: 0.5198, Accuracy: 0.7319


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.7540


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [4/35], Loss: 0.4911, Accuracy: 0.7490


100%|██████████| 131/131 [00:30<00:00,  4.35it/s]


Test Accuracy: 0.7938


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [5/35], Loss: 0.4606, Accuracy: 0.7772


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.8046


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [6/35], Loss: 0.4284, Accuracy: 0.7947


100%|██████████| 131/131 [00:30<00:00,  4.35it/s]


Test Accuracy: 0.8372


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [7/35], Loss: 0.4009, Accuracy: 0.8094


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.8398


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [8/35], Loss: 0.3644, Accuracy: 0.8318


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.8665


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [9/35], Loss: 0.3338, Accuracy: 0.8509


100%|██████████| 131/131 [00:30<00:00,  4.33it/s]


Test Accuracy: 0.8979


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [10/35], Loss: 0.2915, Accuracy: 0.8768


100%|██████████| 131/131 [00:30<00:00,  4.35it/s]


Test Accuracy: 0.9230


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [11/35], Loss: 0.2133, Accuracy: 0.9138


100%|██████████| 131/131 [00:30<00:00,  4.33it/s]


Test Accuracy: 0.9374


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [12/35], Loss: 0.1933, Accuracy: 0.9239


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.9403


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [13/35], Loss: 0.1779, Accuracy: 0.9318


100%|██████████| 131/131 [00:30<00:00,  4.35it/s]


Test Accuracy: 0.9477


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [14/35], Loss: 0.1731, Accuracy: 0.9329


100%|██████████| 131/131 [00:30<00:00,  4.35it/s]


Test Accuracy: 0.9518


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [15/35], Loss: 0.1599, Accuracy: 0.9394


100%|██████████| 131/131 [00:30<00:00,  4.35it/s]


Test Accuracy: 0.9573


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [16/35], Loss: 0.1546, Accuracy: 0.9409


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.9597


100%|██████████| 652/652 [03:00<00:00,  3.61it/s]


Epoch [17/35], Loss: 0.1426, Accuracy: 0.9476


100%|██████████| 131/131 [00:30<00:00,  4.35it/s]


Test Accuracy: 0.9633


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [18/35], Loss: 0.1380, Accuracy: 0.9482


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.9664


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [19/35], Loss: 0.1271, Accuracy: 0.9540


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.9669


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [20/35], Loss: 0.1228, Accuracy: 0.9549


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.9724


100%|██████████| 652/652 [02:55<00:00,  3.72it/s]


Epoch [21/35], Loss: 0.1088, Accuracy: 0.9627


100%|██████████| 131/131 [00:30<00:00,  4.33it/s]


Test Accuracy: 0.9715


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [22/35], Loss: 0.1121, Accuracy: 0.9613


100%|██████████| 131/131 [00:30<00:00,  4.32it/s]


Test Accuracy: 0.9731


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [23/35], Loss: 0.1074, Accuracy: 0.9618


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.9731


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [24/35], Loss: 0.1052, Accuracy: 0.9632


100%|██████████| 131/131 [00:30<00:00,  4.33it/s]


Test Accuracy: 0.9751


100%|██████████| 652/652 [02:56<00:00,  3.69it/s]


Epoch [25/35], Loss: 0.1085, Accuracy: 0.9627


100%|██████████| 131/131 [00:30<00:00,  4.31it/s]


Test Accuracy: 0.9758


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [26/35], Loss: 0.1083, Accuracy: 0.9618


100%|██████████| 131/131 [00:30<00:00,  4.35it/s]


Test Accuracy: 0.9755


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [27/35], Loss: 0.1080, Accuracy: 0.9606


100%|██████████| 131/131 [00:30<00:00,  4.35it/s]


Test Accuracy: 0.9758


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [28/35], Loss: 0.1021, Accuracy: 0.9652


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.9758


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [29/35], Loss: 0.1023, Accuracy: 0.9635


100%|██████████| 131/131 [00:30<00:00,  4.33it/s]


Test Accuracy: 0.9763


100%|██████████| 652/652 [02:56<00:00,  3.70it/s]


Epoch [30/35], Loss: 0.1014, Accuracy: 0.9641


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.9772


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [31/35], Loss: 0.1036, Accuracy: 0.9648


100%|██████████| 131/131 [00:30<00:00,  4.32it/s]


Test Accuracy: 0.9767


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [32/35], Loss: 0.1012, Accuracy: 0.9655


100%|██████████| 131/131 [00:30<00:00,  4.33it/s]


Test Accuracy: 0.9770


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [33/35], Loss: 0.0969, Accuracy: 0.9669


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.9775


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [34/35], Loss: 0.1001, Accuracy: 0.9651


100%|██████████| 131/131 [00:30<00:00,  4.34it/s]


Test Accuracy: 0.9777


100%|██████████| 652/652 [02:55<00:00,  3.71it/s]


Epoch [35/35], Loss: 0.1032, Accuracy: 0.9653


100%|██████████| 131/131 [00:30<00:00,  4.33it/s]

Test Accuracy: 0.9779





In [37]:
# Save the model with CBAM and SE
torch.save(CBAM_SE_classifier.state_dict(), os.path.join(model_path, "CBAM_and_SE_Gender_classifier.pth"))