In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.models import resnet50
from torch.optim import Adam
import torch.nn as nn
import pandas as pd
import os
from PIL import Image
import numpy as np

In [2]:
# Load the CSV file into a DataFrame
csv_file = './partitioned.csv'  
attributes_df = pd.read_csv(csv_file)
attributes_df

Unnamed: 0,image_id,Gender,partition
0,039088.jpg,Female,0
1,030894.jpg,Male,0
2,045279.jpg,Female,0
3,016399.jpg,Female,0
4,013654.jpg,Male,0
...,...,...,...
49995,035413.jpg,Male,2
49996,013543.jpg,Female,2
49997,010990.jpg,Female,2
49998,027439.jpg,Female,2


In [4]:
def getImagePath(image_id):
    return os.path.join('img_align_celeba',image_id)

df = pd.read_csv("partitioned_multi_attr.csv")
train_df = df[df['partition'] == 0]
val_df = df[df['partition'] == 1]
test_df = df[df['partition'] == 2]

df_labels = df.set_index('image_id')
file_paths = df['image_id'].apply(getImagePath).tolist()


# Define the transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [5]:
class CelebADataset(Dataset):
    def __init__(self, file_paths, file_to_label, transform=None):
        self.file_paths = file_paths
        self.file_to_label = file_to_label
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_name = self.file_paths[idx]
        image = Image.open(img_name).convert('RGB')
        label = self.file_to_label[os.path.basename(img_name)][0]  # Only the "Male" label

        if self.transform:
            image = self.transform(image)

        return image, label

In [8]:
# Create separate file path and label mappings for each dataset partition
train_file_paths = train_df['image_id'].apply(getImagePath).tolist()
val_file_paths = val_df['image_id'].apply(getImagePath).tolist()
test_file_paths = test_df['image_id'].apply(getImagePath).tolist()

train_filename_to_label = {filename: labels.values for filename, labels in train_df.set_index('image_id').iterrows()}
val_filename_to_label = {filename: labels.values for filename, labels in val_df.set_index('image_id').iterrows()}
test_filename_to_label = {filename: labels.values for filename, labels in test_df.set_index('image_id').iterrows()}

# Initialize the datasets for each partition
train_dataset = CelebADataset(train_file_paths, train_filename_to_label, transform=transform)
val_dataset = CelebADataset(val_file_paths, val_filename_to_label, transform=transform)
test_dataset = CelebADataset(test_file_paths, test_filename_to_label, transform=transform)

# Create data loaders for each dataset partition
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

In [9]:
# Load a pre-trained ResNet model
model = resnet50(pretrained=True)

# Modify the model for binary classification
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)  # Output one value for binary classification



In [10]:
# Move the model to the GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
device

device(type='cuda', index=0)

In [11]:
# Define the loss function and optimizer
criterion = nn.BCEWithLogitsLoss() 
optimizer = Adam(model.parameters(), lr=0.001)

In [18]:
from tqdm import tqdm

# Early stopping parameters
patience = 5  # How many epochs to wait after last time validation loss improved.
best_loss = np.Inf
epochs_no_improve = 0
early_stop = False

num_epochs = 100 
for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0
    train_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    for batch_idx, (inputs, labels) in train_bar:
        inputs, labels = inputs.to(device), labels.to(device).float()
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = outputs.view(-1)  
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        train_bar.set_description(f'Epoch {epoch+1}/{num_epochs} [Train]')

    # Validation phase
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    val_bar = tqdm(val_loader, total=len(val_loader))
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device).long()  
            outputs = model(inputs)
            outputs = outputs.view(-1)  
            loss = criterion(outputs, labels.float())  
            val_loss += loss.item()
    
            # Apply threshold to get predictions
            predicted = outputs > 0.5  
            predicted = predicted.long()  
    
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            val_bar.set_description(f'Epoch {epoch+1}/{num_epochs} [Val]')
    
    # Calculate average losses
    train_loss = running_loss / len(train_loader)
    val_loss = val_loss / len(val_loader)
    
    # Print training/validation statistics 
    print(f'Epoch: {epoch+1} \tTraining Loss: {train_loss:.6f} \tValidation Loss: {val_loss:.6f}')
    print(f'Validation Accuracy: {100 * correct / total}%')

    # Save model if validation loss has decreased
    if val_loss < best_loss:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
        best_loss,
        val_loss))
        torch.save(model.state_dict(), 'gender_classification_model.pth')
        best_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print('Early stopping')
            early_stop = True
            break

    if early_stop:
        print("Stopped early due to no improvement in validation loss")
        break

print('Training complete')

Epoch 1/100 [Train]: 100%|█████████████████████████████████████████████████████████| 1250/1250 [02:12<00:00,  9.41it/s]
Epoch 1/100 [Val]:   0%|                                                                       | 0/157 [00:09<?, ?it/s]

Epoch: 1 	Training Loss: 0.073770 	Validation Loss: 0.094534
Validation Accuracy: 96.82%
Validation loss decreased (inf --> 0.094534).  Saving model ...



  0%|                                                                                         | 0/1250 [00:00<?, ?it/s][A
Epoch 2/100 [Train]:   0%|                                                                    | 0/1250 [00:00<?, ?it/s][A
Epoch 2/100 [Train]:   0%|                                                            | 1/1250 [00:00<02:13,  9.36it/s][A
Epoch 2/100 [Train]:   0%|                                                            | 1/1250 [00:00<02:13,  9.36it/s][A
Epoch 2/100 [Train]:   0%|                                                            | 2/1250 [00:00<02:14,  9.25it/s][A
Epoch 2/100 [Train]:   0%|                                                            | 2/1250 [00:00<02:14,  9.25it/s][A
Epoch 2/100 [Train]:   0%|▏                                                           | 3/1250 [00:00<02:23,  8.67it/s][A
Epoch 2/100 [Train]:   0%|▏                                                           | 3/1250 [00:00<02:23,  8.67it/s][A
Epoch 2/100 [Tr

Epoch: 2 	Training Loss: 0.062172 	Validation Loss: 0.079552
Validation Accuracy: 97.26%
Validation loss decreased (0.094534 --> 0.079552).  Saving model ...


Epoch 3/100 [Train]: 100%|█████████████████████████████████████████████████████████| 1250/1250 [02:03<00:00, 10.10it/s]
Epoch 2/100 [Val]:   0%|                                                                       | 0/157 [02:13<?, ?it/s]
Epoch 3/100 [Val]:   0%|                                                                       | 0/157 [00:08<?, ?it/s]

Epoch: 3 	Training Loss: 0.057271 	Validation Loss: 0.084910
Validation Accuracy: 96.26%



  0%|                                                                                         | 0/1250 [00:00<?, ?it/s][A
Epoch 4/100 [Train]:   0%|                                                                    | 0/1250 [00:00<?, ?it/s][A
Epoch 4/100 [Train]:   0%|                                                                    | 0/1250 [00:00<?, ?it/s][A
Epoch 4/100 [Train]:   0%|                                                            | 2/1250 [00:00<01:59, 10.43it/s][A
Epoch 4/100 [Train]:   0%|                                                            | 2/1250 [00:00<01:59, 10.43it/s][A
Epoch 4/100 [Train]:   0%|                                                            | 2/1250 [00:00<01:59, 10.43it/s][A
Epoch 4/100 [Train]:   0%|▏                                                           | 4/1250 [00:00<02:01, 10.28it/s][A
Epoch 4/100 [Train]:   0%|▏                                                           | 4/1250 [00:00<02:01, 10.28it/s][A
Epoch 4/100 [Tr

Epoch: 4 	Training Loss: 0.048875 	Validation Loss: 0.080313
Validation Accuracy: 96.92%


Epoch 5/100 [Train]: 100%|█████████████████████████████████████████████████████████| 1250/1250 [02:08<00:00,  9.75it/s]
Epoch 4/100 [Val]:   0%|                                                                       | 0/157 [02:17<?, ?it/s]
Epoch 5/100 [Val]:   0%|                                                                       | 0/157 [00:08<?, ?it/s]

Epoch: 5 	Training Loss: 0.042525 	Validation Loss: 0.069374
Validation Accuracy: 97.5%
Validation loss decreased (0.079552 --> 0.069374).  Saving model ...



  0%|                                                                                         | 0/1250 [00:00<?, ?it/s][A
Epoch 6/100 [Train]:   0%|                                                                    | 0/1250 [00:00<?, ?it/s][A
Epoch 6/100 [Train]:   0%|                                                            | 1/1250 [00:00<02:06,  9.84it/s][A
Epoch 6/100 [Train]:   0%|                                                            | 1/1250 [00:00<02:06,  9.84it/s][A
Epoch 6/100 [Train]:   0%|                                                            | 2/1250 [00:00<02:08,  9.73it/s][A
Epoch 6/100 [Train]:   0%|                                                            | 2/1250 [00:00<02:08,  9.73it/s][A
Epoch 6/100 [Train]:   0%|                                                            | 2/1250 [00:00<02:08,  9.73it/s][A
Epoch 6/100 [Train]:   0%|▏                                                           | 4/1250 [00:00<02:05,  9.95it/s][A
Epoch 6/100 [Tr

Epoch: 6 	Training Loss: 0.036201 	Validation Loss: 0.070581
Validation Accuracy: 97.48%


Epoch 7/100 [Train]: 100%|█████████████████████████████████████████████████████████| 1250/1250 [02:02<00:00, 10.23it/s]
Epoch 6/100 [Val]:   0%|                                                                       | 0/157 [02:11<?, ?it/s]
Epoch 7/100 [Val]:   0%|                                                                       | 0/157 [00:08<?, ?it/s]

Epoch: 7 	Training Loss: 0.028818 	Validation Loss: 0.074346
Validation Accuracy: 97.54%



  0%|                                                                                         | 0/1250 [00:00<?, ?it/s][A
Epoch 8/100 [Train]:   0%|                                                                    | 0/1250 [00:00<?, ?it/s][A
Epoch 8/100 [Train]:   0%|                                                                    | 0/1250 [00:00<?, ?it/s][A
Epoch 8/100 [Train]:   0%|                                                            | 2/1250 [00:00<02:01, 10.24it/s][A
Epoch 8/100 [Train]:   0%|                                                            | 2/1250 [00:00<02:01, 10.24it/s][A
Epoch 8/100 [Train]:   0%|                                                            | 2/1250 [00:00<02:01, 10.24it/s][A
Epoch 8/100 [Train]:   0%|▏                                                           | 4/1250 [00:00<02:05,  9.91it/s][A
Epoch 8/100 [Train]:   0%|▏                                                           | 4/1250 [00:00<02:05,  9.91it/s][A
Epoch 8/100 [Tr

Epoch: 8 	Training Loss: 0.027558 	Validation Loss: 0.074181
Validation Accuracy: 97.5%


Epoch 9/100 [Train]: 100%|█████████████████████████████████████████████████████████| 1250/1250 [03:11<00:00,  6.52it/s]
Epoch 8/100 [Val]:   0%|                                                                       | 0/157 [03:21<?, ?it/s]
Epoch 9/100 [Val]:   0%|                                                                       | 0/157 [00:10<?, ?it/s]

Epoch: 9 	Training Loss: 0.023269 	Validation Loss: 0.117772
Validation Accuracy: 95.92%



  0%|                                                                                         | 0/1250 [00:00<?, ?it/s][A
Epoch 10/100 [Train]:   0%|                                                                   | 0/1250 [00:00<?, ?it/s][A
Epoch 10/100 [Train]:   0%|                                                           | 1/1250 [00:00<02:44,  7.62it/s][A
Epoch 10/100 [Train]:   0%|                                                           | 1/1250 [00:00<02:44,  7.62it/s][A
Epoch 10/100 [Train]:   0%|                                                           | 2/1250 [00:00<02:54,  7.17it/s][A
Epoch 10/100 [Train]:   0%|                                                           | 2/1250 [00:00<02:54,  7.17it/s][A
Epoch 10/100 [Train]:   0%|▏                                                          | 3/1250 [00:00<03:00,  6.91it/s][A
Epoch 10/100 [Train]:   0%|▏                                                          | 3/1250 [00:00<03:00,  6.91it/s][A
Epoch 10/100 [T

Epoch: 10 	Training Loss: 0.022861 	Validation Loss: 0.080568
Validation Accuracy: 97.6%
Early stopping
Training complete


In [25]:
from tqdm import tqdm
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def test_model(model, val_loader, criterion, device):
    model.eval()  # Set the model to evaluation mode
    val_loss = 0.0
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc='Evaluating', leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            outputs = outputs.view(-1)
            loss = criterion(outputs, labels.float())  # Calculate loss
            val_loss += loss.item()

            # Apply sigmoid since BCEWithLogitsLoss includes the sigmoid layer
            predictions = torch.sigmoid(outputs) > 0.5
            all_predictions.extend(predictions.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())

    # Calculate average loss
    avg_loss = val_loss / len(val_loader)

    # Convert to binary values and calculate metrics
    all_predictions = [int(pred) for pred in all_predictions]
    all_targets = [int(target) for target in all_targets]

    # Calculate metrics using sklearn
    accuracy = accuracy_score(all_targets, all_predictions)
    precision = precision_score(all_targets, all_predictions)
    recall = recall_score(all_targets, all_predictions)
    f1 = f1_score(all_targets, all_predictions)

    return avg_loss, accuracy, precision, recall, f1

criterion = torch.nn.BCEWithLogitsLoss()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
avg_loss, accuracy, precision, recall, f1 = test_model(model, val_loader, criterion, device)
print(f'Validation Loss: {avg_loss:.4f}')
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1: {f1:.4f}')

                                                                                                                       

Validation Loss: 0.0694
Accuracy: 0.9754
Precision: 0.9707
Recall: 0.9711
F1: 0.9709


