In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.models import resnet50
from torch.optim import Adam
import torch.nn as nn
import pandas as pd
import os
from PIL import Image
import numpy as np

In [2]:
# Load the CSV file into a DataFrame
csv_file = './list_attr_celeba.csv'  # Update this to the correct path
attributes_df = pd.read_csv(csv_file)
attributes_df

Unnamed: 0,image_id,5_o_Clock_Shadow,Arched_Eyebrows,Attractive,Bags_Under_Eyes,Bald,Bangs,Big_Lips,Big_Nose,Black_Hair,...,Sideburns,Smiling,Straight_Hair,Wavy_Hair,Wearing_Earrings,Wearing_Hat,Wearing_Lipstick,Wearing_Necklace,Wearing_Necktie,Young
0,000001.jpg,-1,1,1,-1,-1,-1,-1,-1,-1,...,-1,1,1,-1,1,-1,1,-1,-1,1
1,000002.jpg,-1,-1,-1,1,-1,-1,-1,1,-1,...,-1,1,-1,-1,-1,-1,-1,-1,-1,1
2,000003.jpg,-1,-1,-1,-1,-1,-1,1,-1,-1,...,-1,-1,-1,1,-1,-1,-1,-1,-1,1
3,000004.jpg,-1,-1,1,-1,-1,-1,-1,-1,-1,...,-1,-1,1,-1,1,-1,1,1,-1,1
4,000005.jpg,-1,1,1,-1,-1,-1,1,-1,-1,...,-1,-1,-1,-1,-1,-1,1,-1,-1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
202594,202595.jpg,-1,-1,1,-1,-1,-1,1,-1,-1,...,-1,-1,-1,-1,-1,-1,1,-1,-1,1
202595,202596.jpg,-1,-1,-1,-1,-1,1,1,-1,-1,...,-1,1,1,-1,-1,-1,-1,-1,-1,1
202596,202597.jpg,-1,-1,-1,-1,-1,-1,-1,-1,1,...,-1,1,-1,-1,-1,-1,-1,-1,-1,1
202597,202598.jpg,-1,1,1,-1,-1,-1,1,-1,1,...,-1,1,-1,1,1,-1,1,-1,-1,1


In [3]:
# Map filenames to labels (assuming gender is the attribute for classification)
# Let's say 'Male' is represented as 1 and 'Female' as -1 in the CSV
filename_to_label = {row['image_id']: 1 if row['Male'] == 1 else 0 for index, row in attributes_df.iterrows()}
filename_to_label

{'000001.jpg': 0,
 '000002.jpg': 0,
 '000003.jpg': 1,
 '000004.jpg': 0,
 '000005.jpg': 0,
 '000006.jpg': 0,
 '000007.jpg': 1,
 '000008.jpg': 1,
 '000009.jpg': 0,
 '000010.jpg': 0,
 '000011.jpg': 0,
 '000012.jpg': 1,
 '000013.jpg': 1,
 '000014.jpg': 0,
 '000015.jpg': 1,
 '000016.jpg': 1,
 '000017.jpg': 0,
 '000018.jpg': 0,
 '000019.jpg': 0,
 '000020.jpg': 1,
 '000021.jpg': 1,
 '000022.jpg': 0,
 '000023.jpg': 1,
 '000024.jpg': 0,
 '000025.jpg': 1,
 '000026.jpg': 0,
 '000027.jpg': 0,
 '000028.jpg': 0,
 '000029.jpg': 0,
 '000030.jpg': 1,
 '000031.jpg': 0,
 '000032.jpg': 1,
 '000033.jpg': 1,
 '000034.jpg': 0,
 '000035.jpg': 0,
 '000036.jpg': 1,
 '000037.jpg': 1,
 '000038.jpg': 1,
 '000039.jpg': 0,
 '000040.jpg': 0,
 '000041.jpg': 1,
 '000042.jpg': 0,
 '000043.jpg': 0,
 '000044.jpg': 0,
 '000045.jpg': 0,
 '000046.jpg': 0,
 '000047.jpg': 0,
 '000048.jpg': 1,
 '000049.jpg': 1,
 '000050.jpg': 1,
 '000051.jpg': 1,
 '000052.jpg': 1,
 '000053.jpg': 1,
 '000054.jpg': 0,
 '000055.jpg': 1,
 '000056.j

In [4]:
class CelebADataset(Dataset):
    def __init__(self, file_paths, file_to_label, transform=None):
        self.file_paths = file_paths
        self.file_to_label = file_to_label
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_name = self.file_paths[idx]
        image = Image.open(img_name).convert('RGB')
        label = self.file_to_label[os.path.basename(img_name)]

        if self.transform:
            image = self.transform(image)

        return image, label

In [5]:
# Define the transformations for the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to the input size expected by ResNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalization parameters for ResNet
])

In [6]:
# Load the CelebA dataset from a single directory
image_directory = './processed_img'  # Replace with your dataset path
image_paths = [os.path.join(image_directory, img) for img in os.listdir(image_directory)]

In [7]:
# Create the dataset
celeba_dataset = CelebADataset(image_paths, filename_to_label, transform=transform)

# Now you can create DataLoaders for training and validation
train_size = int(0.8 * len(celeba_dataset))
val_size = len(celeba_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(celeba_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [8]:
# Load a pre-trained ResNet model
model = resnet50(pretrained=True)

# Modify the model for binary classification
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # 2 classes: Male/Female



In [9]:
# Move the model to the GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
device

device(type='cuda', index=0)

In [12]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)

AttributeError: module 'torch.nn' has no attribute 'BinaryCrossEntropyLoss'

In [11]:
# Early stopping parameters
patience = 5  # How many epochs to wait after last time validation loss improved.
best_loss = np.Inf
epochs_no_improve = 0
early_stop = False

num_epochs = 100  # You can adjust the number of epochs
for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
        # Print stats after each batch
        if (batch_idx + 1) % 100 == 0:  # Adjust the modulo number based on your batch size
            print(f'Epoch: {epoch+1}, Batch: {batch_idx+1}, Loss: {running_loss / (batch_idx+1):.4f}', flush=True)

    # Validation phase
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    # Calculate average losses
    train_loss = running_loss / len(train_loader)
    val_loss = val_loss / len(val_loader)
    
    # Print training/validation statistics 
    print(f'Epoch: {epoch+1} \tTraining Loss: {train_loss:.6f} \tValidation Loss: {val_loss:.6f}')
    print(f'Validation Accuracy: {100 * correct / total}%')

    # Save model if validation loss has decreased
    if val_loss < best_loss:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
        best_loss,
        val_loss))
        torch.save(model.state_dict(), 'gender_classification_model.pth')
        best_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print('Early stopping')
            early_stop = True
            break

    if early_stop:
        print("Stopped early due to no improvement in validation loss")
        break

print('Training complete')

Epoch: 1, Batch: 100, Loss: 0.3195
Epoch: 1, Batch: 200, Loss: 0.2720
Epoch: 1, Batch: 300, Loss: 0.2413
Epoch: 1, Batch: 400, Loss: 0.2263
Epoch: 1, Batch: 500, Loss: 0.2124
Epoch: 1, Batch: 600, Loss: 0.2063
Epoch: 1, Batch: 700, Loss: 0.1995
Epoch: 1, Batch: 800, Loss: 0.1910
Epoch: 1, Batch: 900, Loss: 0.1865
Epoch: 1, Batch: 1000, Loss: 0.1811
Epoch: 1, Batch: 1100, Loss: 0.1748
Epoch: 1, Batch: 1200, Loss: 0.1721
Epoch: 1, Batch: 1300, Loss: 0.1694
Epoch: 1, Batch: 1400, Loss: 0.1666
Epoch: 1, Batch: 1500, Loss: 0.1632
Epoch: 1, Batch: 1600, Loss: 0.1604
Epoch: 1, Batch: 1700, Loss: 0.1584
Epoch: 1, Batch: 1800, Loss: 0.1567
Epoch: 1, Batch: 1900, Loss: 0.1562
Epoch: 1, Batch: 2000, Loss: 0.1541
Epoch: 1, Batch: 2100, Loss: 0.1528
Epoch: 1, Batch: 2200, Loss: 0.1512
Epoch: 1, Batch: 2300, Loss: 0.1495
Epoch: 1, Batch: 2400, Loss: 0.1481
Epoch: 1, Batch: 2500, Loss: 0.1463
Epoch: 1, Batch: 2600, Loss: 0.1452
Epoch: 1, Batch: 2700, Loss: 0.1435
Epoch: 1, Batch: 2800, Loss: 0.1422
E

KeyboardInterrupt: 