In [1]:
import os
import pandas as pd
from PIL import Image
from resnet import *

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, CenterCrop
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm.notebook import tqdm
from torch.nn.functional import normalize

from bgm import *
from sagan import *
from causal_model import *
from load_data import *

from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report

In [2]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [3]:
class TrainImageDataset(Dataset):
    def __init__(self, folder):
        self.folder= [folder+i for i in os.listdir(folder) if '.pt' in i]
   
    def __len__(self):
        return len(self.folder)
 
    def __getitem__(self, index):
        (z,y,attr) = torch.load(self.folder[index])
        return (z,y,attr)

In [4]:
latent_dim = 100
batch_size = 64
cols = ['Smiling', 'Male', 'High_Cheekbones', 'Mouth_Slightly_Open', 'Narrow_Eyes', 'Chubby']
dest_dir = 'synthetic_latent_dataset'
train_folder = f'{dest_dir}/train/'
test_folder = f'{dest_dir}/test/'

train_data = TrainImageDataset(train_folder)
train_dataloader = DataLoader(train_data, batch_size = batch_size,shuffle = True)

test_data = TrainImageDataset(test_folder)
test_dataloader = DataLoader(test_data, batch_size = batch_size,shuffle = True)

In [5]:
for z, y, attr in train_dataloader:
    print(z.shape)
    break

torch.Size([64, 100])


In [6]:
class LatentModel(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LatentModel(latent_dim).to(device)

In [7]:
def element_weighted_loss(y_hat, y, weights):
    criterion = nn.BCEWithLogitsLoss(reduction='none')
    loss = criterion(y_hat, y)
    loss = loss * weights
    return loss.sum() / weights.sum()
    #return loss.mean()

In [8]:
criterion = nn.BCEWithLogitsLoss()
num = len(train_dataloader.dataset)//batch_size + 1
optimizer = optim.Adam(model.parameters())

In [9]:
model.train()
for epoch in range(2):  
    running_loss = 0.0
    for i, data in tqdm(enumerate(train_dataloader), total = num):
        z, y, attr = data
        z, y, attr = z.to(device), y.to(device), attr.to(device)
        optimizer.zero_grad()
        outputs = model(z)
        loss = criterion(outputs, y.unsqueeze(1))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'[{epoch + 1}] loss: {running_loss / num:.3f}')
          
print('Finished Training')

  0%|          | 0/2515 [00:00<?, ?it/s]

[1] loss: 0.097


  0%|          | 0/2515 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [10]:
#class:{attr:}
correct_pred = {0:{0:0,1:0},1:{0:0,1:0}}
total_pred = {0:{0:0,1:0},1:{0:0,1:0}}
model.eval()
label_l = []
pred_l = []
with torch.no_grad():
    for data in tqdm(test_dataloader, total=len(test_dataloader.dataset)//batch_size):
        z, y, attr = data
        z, y, attr = z.to(device), y.to(device), attr.to(device)
        outputs= model(z)
        predictions = torch.round(torch.sigmoid(outputs))
        for label, prediction, attr in zip(y, predictions, attr):
            if label == prediction:
                correct_pred[label.item()][attr.item()] += 1
            total_pred[label.item()][attr.item()] += 1
            label_l.append(label.item())
            pred_l.append(prediction.item())

  0%|          | 0/316 [00:00<?, ?it/s]

In [11]:
for classname, correct_counts in correct_pred.items():
    for attr_name, correct_count in correct_counts.items():
        accuracy = 100 * float(correct_count) / total_pred[classname][attr_name] if total_pred[classname][attr_name] > 0 else 0
        print(f'Accuracy for class: {classname} , attr: {attr_name}: {accuracy} total: {total_pred[classname][attr_name]}  ')

Accuracy for class: 0 , attr: 0: 82.67673048600884 total: 5432  
Accuracy for class: 0 , attr: 1: 94.1944556046605 total: 4978  
Accuracy for class: 1 , attr: 0: 93.77819256494011 total: 6429  
Accuracy for class: 1 , attr: 1: 78.98275358082432 total: 3421  
