In [34]:
from predict import perform_inference
from celeba_loader import create_dataloaders
from train import train, write_model, save_model

import torch

In [35]:
DEVICE = 1
device = torch.device(f'cuda:{DEVICE}' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=1)

## Train Model

In [36]:
EPOCHS = 5
LEARNING_RATE = 0.01
RATIO = 0.8
BATCH_SIZE = 32
DEVICE = 1

trainloader, testloader = create_dataloaders("./data/celebA/img_align_celeba/", "./data/celebA/attr/list_attr_celeba.txt", BATCH_SIZE, RATIO)

In [37]:
model = train(EPOCHS, LEARNING_RATE, trainloader, DEVICE)

Using Device  cuda:1


Using cache found in /home/rasta/.cache/torch/hub/pytorch_vision_v0.10.0


KeyboardInterrupt: 

In [None]:
from datetime import datetime

def save_model(model, dir_path, model_name):
    current_time = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
    os.makedirs(dir_path, exist_ok=True)
    filename = f"{dir_path}/{model_name}_{current_time}.pth"
    print("Writing Model")
    torch.save(model, filename)
    print("Model Saved")

In [None]:
save_model(model, 'models', 'resnet18')

Writing Model
Model Saved


## Test Model (Gender)

In [None]:
import torch
model = torch.load('models/resnet18_2024-05-01_15:35:06.pth')

In [None]:
EPOCHS = 5
LEARNING_RATE = 0.01
RATIO = 0.8
BATCH_SIZE = 32
DEVICE = 1

trainloader, testloader = create_dataloaders("./data/celebA/img_align_celeba/", "./data/celebA/attr/list_attr_celeba.txt", BATCH_SIZE, RATIO)

In [55]:
attr_path = '/home/rasta/fair-neural-compression-eval/data/celebA/attr/list_attr_celeba.txt'
LINE_PADDING = 2
attr_table = open(attr_path).readlines()[LINE_PADDING:]
column_labels = open(attr_path).readlines()[1]
attr = [row.split() for row in attr_table]

In [None]:
column_labels = open(attr_path).readlines()[1]
column_labels = column_labels.strip().split()

In [None]:
gender_index = column_labels.index('Male')
gender_index

20

In [65]:
from tqdm import tqdm

def save_male_female_predictions(model, model_name, testloader, device, gender_index, prediction_save_dir):
    predictions_female = torch.Tensor().to(device)
    predictions_male = torch.Tensor().to(device)
    labels_male, labels_female = torch.Tensor().to(device), torch.Tensor().to(device)
    
    testloader = tqdm(testloader, desc="Saving Predictions", unit="batch")
    
    for i, data in enumerate(testloader):
        inputs, labels = data

        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)

        predictions = outputs >= 0.5

        gender = labels[:, gender_index]
        predictions_female = torch.cat((predictions_female, predictions[gender == 0]), dim=0)
        predictions_male = torch.cat((predictions_male, predictions[gender == 1]), dim=0)
        
        labels_female = torch.cat((labels_female, labels[gender == 0]), dim=0)
        labels_male = torch.cat((labels_male, labels[gender == 1]), dim=0)

    # Save predictions
    torch.save(predictions_female, f'{prediction_save_dir}/{model_name}_predictions_female.pt')
    torch.save(predictions_male, f'{prediction_save_dir}/{model_name}_predictions_male.pt')
    torch.save(labels_female, f'{prediction_save_dir}/{model_name}_labels_female.pt')
    torch.save(labels_male, f'{prediction_save_dir}/{model_name}_labels_male.pt')
    
    return predictions_male, predictions_female, labels_male, labels_female


In [20]:
MODEL_NAME = 'resnet18_2024-05-01_15:35:06'
PRED_SAVE_DIR = 'results/predictions'

In [66]:
male_pred, female_pred, male_label, female_label = \
    save_male_female_predictions(model, MODEL_NAME, testloader, device, gender_index, PRED_SAVE_DIR)

Saving Predictions: 100%|██████████| 1267/1267 [00:57<00:00, 22.07batch/s]


In [67]:
male_pred.shape

torch.Size([16963, 40])

In [68]:
female_pred.shape

torch.Size([23557, 40])

In [71]:
torch.sum(male_pred == male_label, axis = 0)/male_pred.shape[0]

tensor([0.8155, 0.9466, 0.8095, 0.7191, 0.9590, 0.9650, 0.8648, 0.6943, 0.8570,
        0.9803, 0.9406, 0.8771, 0.8593, 0.8831, 0.9162, 0.9724, 0.9031, 0.9387,
        0.9971, 0.8470, 0.9629, 0.9153, 0.9021, 0.8804, 0.8721, 0.7869, 0.9755,
        0.8359, 0.8883, 0.9975, 0.9044, 0.8937, 0.7704, 0.8800, 0.9825, 0.9780,
        0.9870, 0.9814, 0.8990, 0.8058], device='cuda:1')

In [72]:
torch.sum(female_pred == female_label, axis = 0)/female_pred.shape[0]

tensor([0.9999, 0.7256, 0.8010, 0.9048, 0.9999, 0.9388, 0.7057, 0.8907, 0.8958,
        0.9136, 0.9526, 0.8154, 0.9393, 0.9873, 0.9904, 0.9938, 0.9998, 0.9905,
        0.8390, 0.8695, 0.9704, 0.9207, 1.0000, 0.8871, 0.9976, 0.7260, 0.9528,
        0.6901, 0.9414, 0.9069, 1.0000, 0.9177, 0.8276, 0.7580, 0.7899, 0.9893,
        0.8783, 0.7635, 0.9994, 0.9030], device='cuda:1')