In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
from torchvision import datasets, models, transforms

# Custom function to calculate bias metrics
def calculate_bias_metrics(y_true, y_pred, sensitive_attr):
    bias_metrics = {}
    unique_attrs = np.unique(sensitive_attr)
    for attr in unique_attrs:
        indices = sensitive_attr == attr
        bias_metrics[attr] = {
            'accuracy': accuracy_score(y_true[indices], y_pred[indices]),
            'f1_score': f1_score(y_true[indices], y_pred[indices], average='weighted')
        }
    return bias_metrics

# Load data
def load_data(batch_size=32):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    train_dataset = datasets.FakeData(transform=transform)
    test_dataset = datasets.FakeData(transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

# Train model
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25):
    best_model_wts = model.state_dict()
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()
    
    model.load_state_dict(best_model_wts)
    return model

# Evaluate model
def evaluate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    all_sensitive_attrs = []
    
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        # Simulated sensitive attribute, in practice replace with actual sensitive attribute
        all_sensitive_attrs.extend(np.random.randint(0, 10, size=labels.size(0)))
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_sensitive_attrs = np.array(all_sensitive_attrs)
    
    overall_accuracy = accuracy_score(all_labels, all_preds)
    overall_f1 = f1_score(all_labels, all_preds, average='weighted')
    bias_metrics = calculate_bias_metrics(all_labels, all_preds, all_sensitive_attrs)
    
    return overall_accuracy, overall_f1, bias_metrics

# Print metrics
def print_metrics(model_name, overall_accuracy, overall_f1, bias_metrics):
    print(f"Model: {model_name}")
    print(f"Overall Accuracy: {overall_accuracy:.4f}")
    print(f"Overall F1 Score: {overall_f1:.4f}")
    print("Bias Metrics by Age Group:")
    for age_group, metrics in bias_metrics.items():
        print(f"  Age Group {age_group}: Accuracy: {metrics['accuracy']:.4f}, F1 Score: {metrics['f1_score']:.4f}")

# Main execution
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = load_data(batch_size=32)

model_names = ['resnet18', 'vgg16', 'densenet121', 'mobilenet_v2', 'alexnet']
models = {
    'resnet18': models.resnet18(pretrained=True),
    'vgg16': models.vgg16(pretrained=True),
    'densenet121': models.densenet121(pretrained=True),
    'mobilenet_v2': models.mobilenet_v2(pretrained=True),
    'alexnet': models.alexnet(pretrained=True)
}

trained_models = []
for model_name in model_names:
    model = models[model_name]
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    print(f"Training {model_name}...")
    trained_model = train_model(model, {'train': train_loader, 'val': test_loader}, criterion, optimizer, num_epochs=5)
    trained_models.append(trained_model)
    torch.save(trained_model.state_dict(), f'model_{model_name}.pth')

for model, model_name in zip(trained_models, model_names):
    print(f"Evaluating {model_name}...")
    overall_accuracy, overall_f1, bias_metrics = evaluate_model(model, test_loader)
    print_metrics(model_name, overall_accuracy, overall_f1, bias_metrics)



Training resnet18...
Epoch 0/4
----------
train Loss: 3.1497 Acc: 0.0850
val Loss: 2.2987 Acc: 0.1850
Epoch 1/4
----------
train Loss: 1.9929 Acc: 0.3000
val Loss: 3.3991 Acc: 0.2070
Epoch 2/4
----------
train Loss: 1.3647 Acc: 0.5410
val Loss: 2.7661 Acc: 0.2540
Epoch 3/4
----------
train Loss: 0.6595 Acc: 0.7850
val Loss: 4.0931 Acc: 0.3420
Epoch 4/4
----------
train Loss: 0.3928 Acc: 0.8590
val Loss: 3.2349 Acc: 0.4500
Training vgg16...
Epoch 0/4
----------
train Loss: 5.9717 Acc: 0.0940
val Loss: 2.3452 Acc: 0.0900
Epoch 1/4
----------
train Loss: 2.3744 Acc: 0.1110
val Loss: 2.3353 Acc: 0.1210
Epoch 2/4
----------
train Loss: 2.3687 Acc: 0.1020
val Loss: 2.3293 Acc: 0.1270
Epoch 3/4
----------
train Loss: 2.3301 Acc: 0.1170
val Loss: 2.3224 Acc: 0.0930
Epoch 4/4
----------
train Loss: 2.3306 Acc: 0.1230
val Loss: 2.3208 Acc: 0.1210
Training densenet121...
Epoch 0/4
----------
train Loss: 3.2686 Acc: 0.0910
val Loss: 3.0399 Acc: 0.1290
Epoch 1/4
----------
train Loss: 2.1578 Acc: 0