# Let's train a classifier

In [1]:
import os
# os.system("pip install pandas")
# os.system("pip install torchvision")
os.system("CUDA_LAUNCH_BLOCKING=1")
import time

import numpy as np
import pandas as pd

import torch
import torch.nn as nn

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.models import resnet34, ResNet34_Weights

import time

import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm


if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

torch.manual_seed(1)

<torch._C.Generator at 0x7f6952f1e490>

In [2]:
# choose experiment
size_train = 86744
experiments = ["FairFace", "CelebA", "CelebA only white", "CelebA augmented"]
exp = 1

# set celeb paths
celeb_attr_path = "datasets/celeba/list_attr_celeba.txt"
celeb_partitions_path = 'datasets/celeba/list_eval_partition.txt'
celeb_race_path = "CelebA/races/races_ff.csv"
celeb_label_dir = "CelebA/labels_split/"
celeb_img_dir = "CelebA/cropped/"
celeb_img_aug_dir = "CelebA/augmented/"
celeb_train_csv = f"train_{size_train}_samples_random.csv" # "train_total.csv"
celeb_train_only_white_csv = f"train_{size_train}_samples_random_white.csv"
celeb_train_aug_csv = f"train_aug_{size_train}_samples.csv"
celeb_val_csv = "val_total.csv"
celeb_test_csv = "test_total.csv"


# set fairface paths
ff_img_dir = "fairface/dataset/fairface-img-margin125-trainval"
ff_label_dir = "fairface/dataset/"
ff_train_csv = "fairface_label_train.csv"
ff_val_csv = "fairface_label_val.csv"


# set hyperparameters
learning_rates = [2e-5, 2e-5, 2e-5, 2e-5]
lr = learning_rates[exp]
num_epochs = 10

# Architecture
feat_size = (256, 256)
bs_train = 128
bs_val = 128
bs_test = 128
device = 'cuda:3'


races = ["Black", "Indian", "Latino", "Middle Eastern", "Southeast Asian", "East Asian", "White"]
ignored_attributes = ["Black_Hair", "Blond_Hair", "Brown_Hair", "Pale_Skin"]

In [3]:
# define datasets
class CelebaDataset(Dataset):
    """Custom Dataset for loading CelebA face images"""

    def __init__(self, csv_path, img_dir, transform=None, ignored_attributes=[]):
    
        df = pd.read_csv(csv_path, index_col=None)
        # print(df.head())
        self.img_dir = img_dir
        self.csv_path = csv_path
        self.img_names = df["Image_Name"].values
        self.races = df["Race"].values
        drop_cols = ["Image_Name", "Race"] + ignored_attributes
        self.y = np.expand_dims(np.array(df["Male"].values), axis=1) #df.drop(drop_cols, axis=1).values #
        self.transform = transform

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_dir,
                                      self.img_names[index]))
        
        if self.transform is not None:
            img = self.transform(img)
        
        label = self.y[index]
        gt_race = self.races[index]
        return img, label, gt_race

    def __len__(self):
        return self.y.shape[0]
    


class FairFaceDataset(Dataset):
    """Custom Dataset for loading FairFace images"""

    def __init__(self, csv_path, img_dir, transform=None):
    
        df = pd.read_csv(csv_path, index_col=None)
        # print(df.head())
        self.img_dir = img_dir
        self.csv_path = csv_path
        self.img_names = df["file"].values
        self.races = df["race"].replace("Latino_Hispanic", "Latino").values
        gender = df["gender"].replace("Male", 1).replace("Female", 0)
        self.y = np.expand_dims(np.array(gender.values), axis=1)
        self.transform = transform

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_dir,
                                      self.img_names[index]))
        
        if self.transform is not None:
            img = self.transform(img)
        
        label = self.y[index]
        gt_race = self.races[index]
        return img, label, gt_race

    def __len__(self):
        return self.y.shape[0]

In [4]:
# create datasets based on current experiment
num_workers = 6
custom_transform = transforms.Compose([transforms.Resize(feat_size),
                                       transforms.ToTensor()])

# training dataset
if experiments[exp].startswith("CelebA"):

    if experiments[exp].endswith("augmented"):
        train_csv = celeb_train_aug_csv
        train_img_dir = celeb_img_aug_dir
    
    else:
        if "only white" in experiments[exp]:
            train_csv = celeb_train_only_white_csv
        else:
            train_csv = celeb_train_csv
        train_img_dir = celeb_img_dir

    train_dataset = CelebaDataset(csv_path=celeb_label_dir + train_csv,
                                img_dir=train_img_dir,
                                transform=custom_transform,
                                ignored_attributes=ignored_attributes)

if experiments[exp].startswith("FairFace"):
    train_dataset = FairFaceDataset(csv_path=ff_label_dir + ff_train_csv,
                                    img_dir=ff_img_dir,
                                    transform=custom_transform)


# validation dataset
val_dataset = FairFaceDataset(csv_path=ff_label_dir + ff_val_csv,
                                img_dir=ff_img_dir,
                                transform=custom_transform)

# val_dataset = CelebaDataset(csv_path=celeb_label_dir + celeb_val_csv,
#                             img_dir=celeb_img_dir,
#                             transform=custom_transform,
#                             ignored_attributes=ignored_attributes)


# test datasets
test_dataset_celeb = CelebaDataset(csv_path=celeb_label_dir + celeb_test_csv,
                            img_dir=celeb_img_dir,
                            transform=custom_transform,
                            ignored_attributes=ignored_attributes)

test_dataset_ff = val_dataset


# create dataloaders on these datasets
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=bs_train,
                          shuffle=True,
                          num_workers=num_workers)

val_loader = DataLoader(dataset=val_dataset,
                          batch_size=bs_val,
                          shuffle=False,
                          num_workers=num_workers)

test_loader_celeb = DataLoader(dataset=test_dataset_celeb,
                          batch_size=bs_test,
                          shuffle=False,
                          num_workers=num_workers)

test_loader_ff = DataLoader(dataset=test_dataset_ff,
                          batch_size=bs_test,
                          shuffle=False,
                          num_workers=num_workers)

In [5]:
# build model, define loss and create optimizer
model = resnet34(weights=ResNet34_Weights.DEFAULT)
model.to(device)
num_attr_predicted = train_dataset.y.shape[1]
fc_layer = nn.Linear(1000, num_attr_predicted, device=device)
sigmoid = nn.Sigmoid()
bin_ce = nn.BCELoss()
params = list(model.parameters()) + list(fc_layer.parameters())
optimizer = torch.optim.Adam(params, lr=lr)

In [6]:
# define evaluation procedure
def evaluate_metrics(model, data_loader, device, show_tqdm=False):

    correct_predictions = np.zeros(len(races))
    true_pos = np.zeros(len(races))
    true_neg = np.zeros(len(races))
    positive_preds = np.zeros(len(races))
    positive_targets = np.zeros(len(races))
    num_examples = np.zeros(len(races))
    total_examples = len(data_loader.dataset) 

    # total_it = int(np.ceil(total_examples / data_loader.batch_size))
    for _, (features, targets, gt_races) in tqdm(enumerate(data_loader), total=len(data_loader), desc="Evaluating", disable=not show_tqdm):

        features = features.to(device)
        probas = sigmoid(fc_layer(model(features)))
        prediction = (probas >= 0.5).cpu().numpy()
        targets = targets.numpy()

        # prepape annotated races for metric split afterwards
        gt_races = np.array([races.index(race) for race in gt_races])
        gt_races = np.expand_dims(gt_races, axis=1)
        gt_races = np.broadcast_to(gt_races, prediction.shape)

        # collect the necessary data split by annotated race
        for j in range(len(races)):
            correct_preds = (gt_races == j) & (prediction == targets)
            true_pos[j] += (correct_preds & (prediction == 1)).sum()
            true_neg[j] += (correct_preds & (prediction == 0)).sum()
            correct_predictions[j] += correct_preds.sum()
            positive_targets[j] += ((gt_races == j) & (targets == 1)).sum()
            positive_preds[j] += np.where(gt_races == j, prediction, 0).sum()
            num_examples[j] += (gt_races == j).sum()

    # calculate and return metrics    
    zero = 1e-10
    print("Race distribution:", num_examples/targets.shape[1], "Total:", total_examples)

    total_accuracy = correct_predictions.sum() / num_examples.sum()
    accuracies = correct_predictions / (num_examples + zero)
    accs_out = [f"{a:.2%}" for a in accuracies]
    max_acc_disparity = np.log(max(accuracies)/min(accuracies))

    total_precision = true_pos.sum() / (positive_preds.sum() + zero)
    precisions = [f"{p:.2%}" for p in true_pos / (positive_preds + zero)]

    total_recall = true_pos.sum() / (positive_targets.sum() + zero)
    recalls = [f"{r:.2%}" for r in true_pos / (positive_targets + zero)]
    return total_accuracy, accs_out, max_acc_disparity, total_precision, precisions, total_recall, recalls


def get_elapsed_time(start_time):
    elapsed = int(time.time() - start_time)
    m, s = divmod(elapsed, 60)
    h, m = divmod(m, 60)
    return f"{h}:{m:02d}:{s:02d}"

In [7]:
# Training loop
start_time = time.time()

print(f"Initiating experiment '{experiments[exp]}' with a lr of {lr} and {size_train} samples on device {device}")


for epoch in range(num_epochs):
    
    model.train()
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {(epoch+1):02d}/{num_epochs:02d}")
    for _, (features, targets, _) in pbar:
        
        features = features.to(device)
        targets = targets.float().to(device)
            
        # forward and backward pass
        model_output = model(features)
        logits = sigmoid(fc_layer(model_output))
        loss = bin_ce(logits, targets)
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        optimizer.zero_grad()
        
        loss.backward()
        
        # update model params 
        optimizer.step()
        
        # if batch_idx == 0:
        #     break 

    model.eval()
    with torch.set_grad_enabled(False): # save memory during inference
        acc_total, accs, max_acc_disp, prec_total, precs, rec_total, recs = evaluate_metrics(model, val_loader, device)
        print(f"Evaluation epoch {(epoch+1):02d}/{num_epochs:02d}:")
        print(f"Total accuracy: {acc_total:.2%}\t| Accuracies:\t{accs} | Max disparity: {max_acc_disp:.4f}")
        print(f"Total precision: {prec_total:.2%}\t| Precisions:\t{precs}")
        print(f"Total recall: {rec_total:.2%}\t| Recalls:\t{recs}\n")
    
print(f"Total Training Time: {get_elapsed_time(start_time)}")

Initiating experiment 'CelebA' with a lr of 2e-05 and 86744 samples on device cuda:3


Epoch 01/10: 100%|██████████| 678/678 [03:59<00:00,  2.83it/s, loss=0.1032]


Race distribution: [1556. 1516. 1623. 1209. 1415. 1550. 2085.] Total: 10954
Evaluation epoch 01/10:
Total accuracy: 82.31%	| Accuracies:	['75.06%', '81.33%', '84.60%', '87.84%', '80.71%', '80.06%', '86.19%'] | Max disparity: 0.1572
Total precision: 81.62%	| Precisions:	['71.21%', '77.39%', '81.39%', '90.51%', '83.67%', '82.05%', '85.82%']
Total recall: 85.88%	| Recalls:	['86.36%', '88.18%', '88.78%', '91.51%', '78.10%', '77.09%', '89.04%']



Epoch 02/10: 100%|██████████| 678/678 [03:55<00:00,  2.88it/s, loss=0.0151]


Race distribution: [1556. 1516. 1623. 1209. 1415. 1550. 2085.] Total: 10954
Evaluation epoch 02/10:
Total accuracy: 80.76%	| Accuracies:	['74.55%', '79.35%', '81.82%', '86.02%', '79.01%', '80.13%', '84.17%'] | Max disparity: 0.1431
Total precision: 77.87%	| Precisions:	['70.13%', '74.12%', '75.94%', '86.84%', '77.93%', '78.77%', '81.33%']
Total recall: 88.86%	| Recalls:	['87.86%', '89.77%', '91.93%', '93.36%', '83.13%', '82.63%', '91.62%']



Epoch 03/10: 100%|██████████| 678/678 [03:55<00:00,  2.88it/s, loss=0.0714]


Race distribution: [1556. 1516. 1623. 1209. 1415. 1550. 2085.] Total: 10954
Evaluation epoch 03/10:
Total accuracy: 82.20%	| Accuracies:	['75.64%', '81.86%', '84.97%', '87.10%', '79.79%', '79.48%', '86.00%'] | Max disparity: 0.1410
Total precision: 81.68%	| Precisions:	['73.60%', '78.45%', '81.44%', '89.91%', '81.84%', '80.64%', '85.05%']
Total recall: 85.51%	| Recalls:	['81.98%', '87.52%', '89.66%', '91.02%', '78.50%', '77.73%', '89.75%']



Epoch 04/10: 100%|██████████| 678/678 [03:56<00:00,  2.87it/s, loss=0.0184]


Race distribution: [1556. 1516. 1623. 1209. 1415. 1550. 2085.] Total: 10954
Evaluation epoch 04/10:
Total accuracy: 82.26%	| Accuracies:	['75.84%', '81.79%', '84.78%', '87.26%', '79.72%', '80.84%', '85.32%'] | Max disparity: 0.1404
Total precision: 81.18%	| Precisions:	['73.06%', '77.57%', '80.81%', '89.94%', '80.85%', '81.83%', '84.00%']
Total recall: 86.52%	| Recalls:	['83.85%', '89.11%', '90.29%', '91.27%', '79.86%', '79.41%', '89.84%']



Epoch 05/10: 100%|██████████| 678/678 [03:55<00:00,  2.87it/s, loss=0.0128]


Race distribution: [1556. 1516. 1623. 1209. 1415. 1550. 2085.] Total: 10954
Evaluation epoch 05/10:
Total accuracy: 82.36%	| Accuracies:	['76.86%', '82.12%', '84.23%', '88.01%', '80.14%', '80.71%', '84.65%'] | Max disparity: 0.1354
Total precision: 79.51%	| Precisions:	['72.28%', '76.31%', '79.03%', '88.84%', '79.40%', '78.80%', '82.03%']
Total recall: 89.78%	| Recalls:	['89.11%', '92.83%', '92.18%', '93.97%', '83.40%', '84.17%', '91.53%']



Epoch 06/10: 100%|██████████| 678/678 [03:56<00:00,  2.87it/s, loss=0.0699]


Race distribution: [1556. 1516. 1623. 1209. 1415. 1550. 2085.] Total: 10954
Evaluation epoch 06/10:
Total accuracy: 81.82%	| Accuracies:	['77.06%', '81.27%', '83.36%', '87.18%', '80.57%', '79.55%', '84.03%'] | Max disparity: 0.1234
Total precision: 78.67%	| Precisions:	['72.37%', '75.24%', '77.67%', '88.26%', '78.68%', '77.19%', '81.33%']
Total recall: 90.04%	| Recalls:	['89.49%', '92.83%', '92.56%', '93.36%', '85.85%', '84.04%', '91.27%']



Epoch 07/10: 100%|██████████| 678/678 [03:55<00:00,  2.88it/s, loss=0.0025]


Race distribution: [1556. 1516. 1623. 1209. 1415. 1550. 2085.] Total: 10954
Evaluation epoch 07/10:
Total accuracy: 82.78%	| Accuracies:	['76.41%', '82.12%', '85.46%', '87.51%', '81.13%', '81.74%', '85.08%'] | Max disparity: 0.1356
Total precision: 82.71%	| Precisions:	['74.60%', '78.76%', '82.50%', '90.56%', '84.72%', '83.65%', '84.57%']
Total recall: 85.26%	| Recalls:	['81.98%', '87.65%', '89.16%', '90.90%', '77.69%', '79.02%', '88.41%']



Epoch 08/10: 100%|██████████| 678/678 [03:55<00:00,  2.88it/s, loss=0.0001]


Race distribution: [1556. 1516. 1623. 1209. 1415. 1550. 2085.] Total: 10954
Evaluation epoch 08/10:
Total accuracy: 82.11%	| Accuracies:	['75.90%', '80.67%', '84.35%', '86.85%', '81.70%', '81.68%', '83.88%'] | Max disparity: 0.1348
Total precision: 78.79%	| Precisions:	['70.31%', '74.78%', '78.76%', '87.16%', '80.13%', '80.47%', '81.04%']
Total recall: 90.52%	| Recalls:	['91.86%', '92.16%', '93.06%', '94.34%', '86.12%', '83.78%', '91.44%']



Epoch 09/10: 100%|██████████| 678/678 [03:54<00:00,  2.89it/s, loss=0.0008]


Race distribution: [1556. 1516. 1623. 1209. 1415. 1550. 2085.] Total: 10954
Evaluation epoch 09/10:
Total accuracy: 82.00%	| Accuracies:	['76.22%', '80.21%', '82.99%', '87.76%', '81.98%', '80.65%', '84.51%'] | Max disparity: 0.1410
Total precision: 79.06%	| Precisions:	['71.64%', '74.22%', '77.47%', '88.89%', '81.17%', '78.84%', '81.88%']
Total recall: 89.71%	| Recalls:	['88.86%', '92.16%', '91.93%', '93.48%', '85.03%', '83.91%', '91.44%']



Epoch 10/10: 100%|██████████| 678/678 [03:55<00:00,  2.88it/s, loss=0.0001]


Race distribution: [1556. 1516. 1623. 1209. 1415. 1550. 2085.] Total: 10954
Evaluation epoch 10/10:
Total accuracy: 83.05%	| Accuracies:	['76.54%', '82.78%', '85.27%', '88.01%', '80.64%', '81.29%', '86.43%'] | Max disparity: 0.1396
Total precision: 82.87%	| Precisions:	['73.95%', '79.43%', '82.98%', '90.63%', '83.17%', '83.13%', '86.45%']
Total recall: 85.64%	| Recalls:	['83.85%', '88.18%', '87.89%', '91.64%', '78.64%', '78.64%', '88.68%']

Total Training Time: 0:42:28


In [8]:
# evaluate experiment on test sets
with torch.set_grad_enabled(False): # save memory during inference
    # evaluation CelebA
    acc_total, accs, max_acc_disp, prec_total, precs, rec_total, recs = evaluate_metrics(model, test_loader_celeb, device, show_tqdm=True)
    print(f"\nEvaluation CelebA test set ({experiments[exp]}):")
    print(f"Total accuracy: {acc_total:.2%}\t| Accuracies:\t{accs}")
    print(f"Maximum accuracy disparity: {max_acc_disp:.4f}")
    print(f"Total precision: {prec_total:.2%}\t| Precisions:\t{precs}")
    print(f"Total recall: {rec_total:.2%}\t| Recalls:\t{recs}\n")

    # evaluation FairFace
    acc_total, accs, max_acc_disp, prec_total, precs, rec_total, recs = evaluate_metrics(model, test_loader_ff, device, show_tqdm=True)
    print(f"\nEvaluation FairFace test set ({experiments[exp]}):")
    print(f"Total accuracy: {acc_total:.2%}\t| Accuracies:\t{accs}")
    print(f"Maximum accuracy disparity: {max_acc_disp:.4f}")
    print(f"Total precision: {prec_total:.2%}\t| Precisions:\t{precs}")
    print(f"Total recall: {rec_total:.2%}\t| Recalls:\t{recs}\n")


Evaluating: 100%|██████████| 156/156 [00:32<00:00,  4.82it/s]

Race distribution: [ 1461.   553.  1269.  1538.   311.  1777. 13053.] Total: 19962

Evaluation CelebA test set (CelebA):
Total accuracy: 97.76%	| Accuracies:	['96.65%', '99.28%', '98.90%', '98.37%', '93.89%', '96.17%', '97.94%']
Maximum accuracy disparity: 0.0558
Total precision: 97.34%	| Precisions:	['97.33%', '99.52%', '97.43%', '98.14%', '93.71%', '94.70%', '97.54%']
Total recall: 96.84%	| Recalls:	['96.75%', '98.56%', '98.55%', '98.53%', '93.06%', '93.38%', '96.92%']




Evaluating: 100%|██████████| 86/86 [00:14<00:00,  6.07it/s]


Race distribution: [1556. 1516. 1623. 1209. 1415. 1550. 2085.] Total: 10954

Evaluation FairFace test set (CelebA):
Total accuracy: 83.05%	| Accuracies:	['76.54%', '82.78%', '85.27%', '88.01%', '80.64%', '81.29%', '86.43%']
Maximum accuracy disparity: 0.1396
Total precision: 82.87%	| Precisions:	['73.95%', '79.43%', '82.98%', '90.63%', '83.17%', '83.13%', '86.45%']
Total recall: 85.64%	| Recalls:	['83.85%', '88.18%', '87.89%', '91.64%', '78.64%', '78.64%', '88.68%']

