In [1]:
import pandas as pd
from PIL import Image
from resnet import *

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, CenterCrop
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm.notebook import tqdm

from sklearn.utils.class_weight import compute_class_weight

In [2]:
class ImageDataset(Dataset):
    def __init__(self, root_folder, file_name, transform, attr, img_folder = None, label = None, weights = None):
        self.transform=transform
        self.img_folder=root_folder+img_folder
        
        self.df = pd.read_csv(root_folder+file_name+'.csv').replace(-1,0).reset_index(drop=True)
        self.image_names = self.df.pop('image_id')
        self.attr = self.df[attr].values
        self.label = self.df[label].values
        self.weights = weights if weights is not None else None
   
    def __len__(self):
        return len(self.image_names)
 
    def __getitem__(self, index):
        image_path = self.img_folder + self.image_names[index]
        image=Image.open(image_path)
        image=self.transform(image)
        label = torch.tensor(self.label[index], dtype=torch.float32)
        weight = self.weights[self.label[index]][self.attr[index]] if self.weights is not None else 0
        weight = torch.tensor(weight, dtype=torch.float32)
        return image, label, weight, self.attr[index]

In [3]:
def get_weights(file, label, attr, root, classes = ["00","01", "10", "11"]):
    df = pd.read_csv(f'{root}{file}.csv').replace(-1,0)
    l = []
    for _, row in df.iterrows():
        l.append(str(row[label])+str(row[attr]))
    df['classes'] = pd.Series(l)
    display(df['classes'].value_counts())
    weights = compute_class_weight(class_weight = 'balanced', y = df['classes'], classes = classes)
    weights = {
        0 : {0:weights[0], 1:weights[1]}, 
        1 : {0:weights[2], 1:weights[3]}, 
    }
    return weights

In [4]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [5]:
img_dim = 64
batch_size = 128

#root_folder = 'dataset/celebA/'
#img_folder = 'img/img_align_celebA/'
#file_name = 'dear_train_downsample_smile'

root_folder = 'synthetic_dataset/'
file_name = 'generated_smiling'
img_folder = 'img/'

label = 'Smiling'
attr = 'Male'

apply_weight = False

In [6]:
weights = get_weights(file_name, label, attr = attr, root = root_folder)
weights

10    51066
00    44534
11    42640
01    40556
Name: classes, dtype: int64

{0: {0: 1.0037050343557732, 1: 1.1021550448762205},
 1: {0: 0.8753182156425019, 1: 1.0482879924953095}}

In [7]:
transform = Compose([CenterCrop(128),
                     Resize((img_dim, img_dim)),
                     ToTensor(),
                     Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
data = ImageDataset(root_folder=root_folder, file_name = file_name, transform=transform,
                    img_folder = img_folder,
                    attr = attr, label = label, weights = weights)
train_dataloader = DataLoader(data, batch_size = batch_size,shuffle = True)


testdata = ImageDataset(root_folder='dataset/celebA/', file_name = 'dear_test', 
                        transform=transform,
                        img_folder = 'img/img_align_celebA/',
                        attr = attr, label = label)
test_dataloader = DataLoader(testdata, batch_size = batch_size,shuffle = True)

In [8]:
print("Number of samples : ", len(train_dataloader.dataset))

Number of samples :  178796


In [9]:
model = resnet18(pretrained=False, in_channels=3, fc_size=2048, out_dim=1).to(device)

In [10]:
def element_weighted_loss(y_hat, y, weights):
    criterion = nn.BCEWithLogitsLoss(reduction='none')
    loss = criterion(y_hat, y)
    loss = loss * weights
    return loss.sum() / weights.sum()
    #return loss.mean()

In [11]:
optimizer = optim.Adam(model.parameters())

In [12]:
num = len(train_dataloader.dataset)//batch_size + 1

In [13]:
criterion = nn.BCEWithLogitsLoss()

In [14]:
model.train()
for epoch in range(10):  
    running_loss = 0.0
    for i, data in tqdm(enumerate(train_dataloader), total = num):
        inputs, labels, weights, attr = data
        inputs, labels, weights, attr = inputs.to(device), labels.to(device), weights.to(device), attr.to(device)
        optimizer.zero_grad()
        outputs, _ = model(inputs)
        if apply_weight:
            loss = element_weighted_loss(outputs, labels.unsqueeze(1), weights)
        else:
            loss = criterion(outputs, labels.unsqueeze(1))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'[{epoch + 1}] loss: {running_loss / num:.3f}')
          
print('Finished Training')

  0%|          | 0/1397 [00:00<?, ?it/s]

[1] loss: 0.111


  0%|          | 0/1397 [00:00<?, ?it/s]

[2] loss: 0.071


  0%|          | 0/1397 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [15]:
#class:{attr:}
correct_pred = {0:{0:0,1:0},1:{0:0,1:0}}
total_pred = {0:{0:0,1:0},1:{0:0,1:0}}
model.eval()
with torch.no_grad():
    for data in tqdm(test_dataloader, total=len(test_dataloader.dataset)//batch_size):
        inputs, labels, weights, attrs = data
        inputs, labels, weights, attrs = inputs.to(device), labels.to(device), weights.to(device), attrs.to(device)
        outputs, _ = model(inputs)
        predictions = torch.round(torch.sigmoid(outputs))
        for label, prediction, attr in zip(labels, predictions, attrs):
            if label == prediction:
                correct_pred[label.item()][attr.item()] += 1
            total_pred[label.item()][attr.item()] += 1

  0%|          | 0/158 [00:00<?, ?it/s]

In [16]:
for classname, correct_counts in correct_pred.items():
    for attr_name, correct_count in correct_counts.items():
        accuracy = 100 * float(correct_count) / total_pred[classname][attr_name]
        print(f'Accuracy for class: {classname} , attr: {attr_name}: {accuracy} total: {total_pred[classname][attr_name]}  ')

Accuracy for class: 0 , attr: 0: 75.31296023564065 total: 5432  
Accuracy for class: 0 , attr: 1: 95.43993571715548 total: 4978  
Accuracy for class: 1 , attr: 0: 97.1224140612848 total: 6429  
Accuracy for class: 1 , attr: 1: 79.24583455130079 total: 3421  


## Without weights

Accuracy for class: 0 , attr: 0: 66.46017699115045 total: 1130  
Accuracy for class: 0 , attr: 1: 96.06147934678194 total: 1041  
Accuracy for class: 1 , attr: 0: 98.0295566502463 total: 1218  
Accuracy for class: 1 , attr: 1: 75.41478129713424 total: 663  

## With Weights
Accuracy for class: 0 , attr: 0: 72.7433628318584 total: 1130  
Accuracy for class: 0 , attr: 1: 96.54178674351586 total: 1041  
Accuracy for class: 1 , attr: 0: 97.53694581280789 total: 1218  
Accuracy for class: 1 , attr: 1: 68.92911010558069 total: 663  

## With Synthetic Data
Accuracy for class: 0 , attr: 0: 78.40707964601769 total: 1130  
Accuracy for class: 0 , attr: 1: 96.25360230547551 total: 1041  
Accuracy for class: 1 , attr: 0: 97.1264367816092 total: 1218  
Accuracy for class: 1 , attr: 1: 75.41478129713424 total: 663  