In [1]:
import pandas as pd
from PIL import Image

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, CenterCrop
from resnet import *

from sklearn.utils.class_weight import compute_class_weight

import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm

In [2]:
class ImageDataset(Dataset):
    def __init__(self, root_folder, file_name, transform, attr, label = None, weights = None):
        self.transform=transform
        self.img_folder=root_folder+'img/img_align_celeba/'
        
        self.df = pd.read_csv(root_folder+file_name+'.csv').replace(-1,0).reset_index(drop=True)
        self.image_names = self.df.pop('image_id')
        self.attr = self.df[attr].values
        self.label = self.df[label].values
        self.weights = weights if weights is not None else None
   
    def __len__(self):
        return len(self.image_names)
 
    def __getitem__(self, index):
        image_path = self.img_folder + self.image_names[index]
        image=Image.open(image_path)
        image=self.transform(image)
        label = torch.tensor(self.label[index], dtype=torch.float32)
        weight = self.weights[self.label[index]][self.attr[index]] if self.weights is not None else 0
        weight = torch.tensor(weight, dtype=torch.float32)
        return image, label, weight, self.attr[index]

In [3]:
def get_weights(file, label, attr, classes = ["00","01", "10", "11"]):
    df = pd.read_csv(f'dataset/celebA/{file}.csv').replace(-1,0)
    l = []
    for _, row in df.iterrows():
        l.append(str(row[label])+str(row[attr]))
    df['classes'] = pd.Series(l)
    display(df['classes'].value_counts())
    weights = compute_class_weight(class_weight = 'balanced', y = df['classes'], classes = classes)
    weights = {
        0 : {0:weights[0], 1:weights[1]}, 
        1 : {0:weights[2], 1:weights[3]}, 
    }
    return weights

In [4]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [5]:
root_folder = 'dataset/celebA/'
img_dim = 64
batch_size = 32
file_name = 'down_train_downsaple_minority_group'
attr  = 'Male'
label = 'Young'

In [6]:
weights = get_weights(file_name, label, attr = attr)
weights

10    16590
01     4986
11     1700
00      468
Name: classes, dtype: int64

{0: {0: 12.683760683760683, 1: 1.1905334937825913},
 1: {0: 0.35780590717299576, 1: 3.491764705882353}}

In [7]:
transform = Compose([Resize((img_dim, img_dim)),
                     ToTensor(),
                     Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
data = ImageDataset(root_folder=root_folder, file_name = file_name, transform=transform,
                    attr = attr, label = label, weights = weights)
train_dataloader = DataLoader(data, batch_size = batch_size,shuffle = True)
testdata = ImageDataset(root_folder=root_folder, file_name = 'down_test', transform=transform,
                    attr = attr, label = label)
test_dataloader = DataLoader(testdata, batch_size = batch_size,shuffle = True)

In [8]:
print("Number of samples : ", len(train_dataloader.dataset))

Number of samples :  23744


In [9]:
model = resnet18(pretrained=False, in_channels=3, fc_size=2048, out_dim=1).to(device)

In [10]:
def element_weighted_loss(y_hat, y, weights):
    criterion = nn.BCEWithLogitsLoss(reduction='none')
    loss = criterion(y_hat, y)
    #loss = loss * weights
    #return loss.sum() / weights.sum()
    return loss.mean()

In [11]:
optimizer = optim.Adam(model.parameters())

In [12]:
num = len(train_dataloader.dataset)//batch_size + 1

In [13]:
model.train()
for epoch in range(5): 
    running_loss = 0.0
    for i, data in tqdm(enumerate(train_dataloader), total = num):
        inputs, labels, weights, attr = data
        inputs, labels, weights, attr = inputs.to(device), labels.to(device), weights.to(device), attr.to(device)
    
        optimizer.zero_grad()
        outputs,_ = model(inputs)
        loss = element_weighted_loss(outputs, labels.unsqueeze(1), weights)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'[{epoch + 1}] Train loss: {running_loss / num:.3f}')
          
print('Finished Training')

  0%|          | 0/743 [00:00<?, ?it/s]

[1] Train loss: 0.351


  0%|          | 0/743 [00:00<?, ?it/s]

[2] Train loss: 0.262


  0%|          | 0/743 [00:00<?, ?it/s]

[3] Train loss: 0.251


  0%|          | 0/743 [00:00<?, ?it/s]

[4] Train loss: 0.231


  0%|          | 0/743 [00:00<?, ?it/s]

[5] Train loss: 0.222
Finished Training


In [14]:
#class:{attr:}
correct_pred = {0:{0:0,1:0},1:{0:0,1:0}}
total_pred = {0:{0:0,1:0},1:{0:0,1:0}}

with torch.no_grad():
    model.eval()
    for data in tqdm(test_dataloader, total=len(test_dataloader.dataset)//batch_size):
        inputs, labels, weights, attrs = data
        inputs, labels, weights, attrs = inputs.to(device), labels.to(device), weights.to(device), attrs.to(device)
        outputs, _ = model(inputs)
        predictions = torch.round(torch.sigmoid(outputs))
        for label, prediction, attr in zip(labels, predictions, attrs):
            if label == prediction:
                correct_pred[label.item()][attr.item()] += 1
            total_pred[label.item()][attr.item()] += 1

  0%|          | 0/126 [00:00<?, ?it/s]

In [None]:
for classname, correct_counts in correct_pred.items():
    for attr_name, correct_count in correct_counts.items():
        accuracy = 100 * float(correct_count) / total_pred[classname][attr_name]
        print(f'Accuracy for class: {classname} , attr: {attr_name}: {accuracy} total: {total_pred[classname][attr_name]}  ')

## Without weights

Accuracy for class: 0 , attr: 0: 12.269938650306749 total: 326  
Accuracy for class: 0 , attr: 1: 56.426332288401255 total: 638  
Accuracy for class: 1 , attr: 0: 97.52720079129574 total: 2022  
Accuracy for class: 1 , attr: 1: 91.18198874296435 total: 1066  

## With Weights
Accuracy for class: 0 , attr: 0: 7.975460122699387 total: 326  
Accuracy for class: 0 , attr: 1: 49.529780564263326 total: 638  
Accuracy for class: 1 , attr: 0: 99.06033630069238 total: 2022  
Accuracy for class: 1 , attr: 1: 94.74671669793621 total: 1066  