In [1]:
import pandas as pd
from PIL import Image

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, CenterCrop

from sklearn.utils.class_weight import compute_class_weight

import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm

In [2]:
class ImageDataset(Dataset):
    def __init__(self, root_folder, file_name, transform, attr, label = None, weights = None):
        self.transform=transform
        self.img_folder=root_folder+'img/img_align_celeba/'
        
        self.df = pd.read_csv(root_folder+file_name+'.csv').replace(-1,0).reset_index(drop=True)
        self.image_names = self.df.pop('image_id')
        self.attr = self.df[attr].values
        self.label = self.df[label].values
        self.weights = weights if weights is not None else None
   
    def __len__(self):
        return len(self.image_names)
 
    def __getitem__(self, index):
        image_path = self.img_folder + self.image_names[index]
        image=Image.open(image_path)
        image=self.transform(image)
        label = torch.tensor(self.label[index], dtype=torch.float32)
        weight = self.weights[self.label[index]][self.attr[index]] if self.weights is not None else 0
        weight = torch.tensor(weight, dtype=torch.float32)
        return image, label, weight, self.attr[index]

In [3]:
def get_weights(file, label, label_value, attr, classes = [0,1]):
    df = pd.read_csv(f'dataset/celebA/{file}.csv').replace(-1,0)
    df = df[df[label] == label_value].reset_index(drop=True)
    display(df[attr].value_counts())
    weights = compute_class_weight(class_weight = 'balanced', y = df[attr], classes = classes)
    weights = {classname : weights[i] for i, classname in enumerate(classes)}
    return weights

In [4]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [5]:
root_folder = 'dataset/celebA/'
img_dim = 64
batch_size = 16
file_name = 'down_train_downsaple'
label = 'Smiling'
attr = 'Male'

In [6]:
weights = {0:{}, 1:{}}
weights[0] = get_weights(file_name, label, 0, attr)
weights[1] = get_weights(file_name, label, 1, attr)
weights

1    8002
0     865
Name: Male, dtype: int64

0    10284
1      548
Name: Male, dtype: int64

{0: {0: 5.125433526011561, 1: 0.5540489877530618},
 1: {0: 0.5266433294437962, 1: 9.883211678832117}}

In [7]:
transform = Compose([CenterCrop(128),
                     Resize((img_dim, img_dim)),
                     ToTensor(),
                     Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
data = ImageDataset(root_folder=root_folder, file_name = file_name, transform=transform,
                    attr = attr, label = label, weights = weights)
train_dataloader = DataLoader(data, batch_size = batch_size,shuffle = True)
testdata = ImageDataset(root_folder=root_folder, file_name = 'down_test', transform=transform,
                    attr = attr, label = label)
test_dataloader = DataLoader(testdata, batch_size = batch_size,shuffle = True)

In [8]:
print("Number of samples : ", len(train_dataloader.dataset))

Number of samples :  19699


In [9]:
class AttributeWeight(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(64, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(32, 16, 5)
        self.fc1 = nn.Linear(5408, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) 
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.sigmoid(self.fc3(x))
        return x

In [10]:
def element_weighted_loss(y_hat, y, weights):
    criterion = nn.BCELoss(reduction='none')
    loss = criterion(y_hat, y)
    loss = loss * weights
    return loss.sum() / weights.sum()
    #return loss.mean()

In [11]:
model = AttributeWeight().to(device)

In [12]:
import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [13]:
num = len(train_dataloader.dataset)//batch_size + 1

In [14]:
for epoch in range(20):  
    running_loss = 0.0
    for i, data in tqdm(enumerate(train_dataloader), total = num):
        inputs, labels, weights, attr = data
        inputs, labels, weights, attr = inputs.to(device), labels.to(device), weights.to(device), attr.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = element_weighted_loss(outputs, labels.unsqueeze(1), weights)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'[{epoch + 1}] loss: {running_loss / num:.3f}')
          
print('Finished Training')

  0%|          | 0/1232 [00:00<?, ?it/s]

[1] loss: 3.685


  0%|          | 0/1232 [00:00<?, ?it/s]

[2] loss: 2.232


  0%|          | 0/1232 [00:00<?, ?it/s]

[3] loss: 1.986


  0%|          | 0/1232 [00:00<?, ?it/s]

[4] loss: 1.751


  0%|          | 0/1232 [00:00<?, ?it/s]

[5] loss: 1.553


  0%|          | 0/1232 [00:00<?, ?it/s]

[6] loss: 1.454


  0%|          | 0/1232 [00:00<?, ?it/s]

[7] loss: 1.320


  0%|          | 0/1232 [00:00<?, ?it/s]

[8] loss: 1.144


  0%|          | 0/1232 [00:00<?, ?it/s]

[9] loss: 0.981


  0%|          | 0/1232 [00:00<?, ?it/s]

[10] loss: 0.889


  0%|          | 0/1232 [00:00<?, ?it/s]

[11] loss: 0.786


  0%|          | 0/1232 [00:00<?, ?it/s]

[12] loss: 0.609


  0%|          | 0/1232 [00:00<?, ?it/s]

[13] loss: 0.483


  0%|          | 0/1232 [00:00<?, ?it/s]

[14] loss: 0.497


  0%|          | 0/1232 [00:00<?, ?it/s]

[15] loss: 0.334


  0%|          | 0/1232 [00:00<?, ?it/s]

[16] loss: 0.341


  0%|          | 0/1232 [00:00<?, ?it/s]

[17] loss: 0.333


  0%|          | 0/1232 [00:00<?, ?it/s]

[18] loss: 0.262


  0%|          | 0/1232 [00:00<?, ?it/s]

[19] loss: 0.197


  0%|          | 0/1232 [00:00<?, ?it/s]

[20] loss: 0.225
Finished Training


In [15]:
#class:{attr:}
correct_pred = {0:{0:0,1:0},1:{0:0,1:0}}
total_pred = {0:{0:0,1:0},1:{0:0,1:0}}

with torch.no_grad():
    for data in tqdm(test_dataloader, total=len(test_dataloader.dataset)//batch_size):
        inputs, labels, weights, attrs = data
        inputs, labels, weights, attrs = inputs.to(device), labels.to(device), weights.to(device), attrs.to(device)
        outputs = model(inputs)
        predictions = torch.round(outputs)
        for label, prediction, attr in zip(labels, predictions, attrs):
            if label == prediction:
                correct_pred[label.item()][attr.item()] += 1
            total_pred[label.item()][attr.item()] += 1

  0%|          | 0/253 [00:00<?, ?it/s]

In [16]:
for classname, correct_counts in correct_pred.items():
    for attr_name, correct_count in correct_counts.items():
        accuracy = 100 * float(correct_count) / total_pred[classname][attr_name]
        print(f'Accuracy for class: {classname} , attr: {attr_name}: {accuracy} total: {total_pred[classname][attr_name]}  ')

Accuracy for class: 0 , attr: 0: 72.7433628318584 total: 1130  
Accuracy for class: 0 , attr: 1: 96.54178674351586 total: 1041  
Accuracy for class: 1 , attr: 0: 97.53694581280789 total: 1218  
Accuracy for class: 1 , attr: 1: 68.92911010558069 total: 663  


## Without weights

Accuracy for class: 0 , attr: 0: 66.46017699115045 total: 1130  
Accuracy for class: 0 , attr: 1: 96.06147934678194 total: 1041  
Accuracy for class: 1 , attr: 0: 98.0295566502463 total: 1218  
Accuracy for class: 1 , attr: 1: 75.41478129713424 total: 663  

## With Weights
Accuracy for class: 0 , attr: 0: 72.7433628318584 total: 1130  
Accuracy for class: 0 , attr: 1: 96.54178674351586 total: 1041  
Accuracy for class: 1 , attr: 0: 97.53694581280789 total: 1218  
Accuracy for class: 1 , attr: 1: 68.92911010558069 total: 663  