In [1]:
import os
import pandas as pd
from PIL import Image
from resnet import *

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, CenterCrop
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm.notebook import tqdm


from bgm import *
from sagan import *
from causal_model import *
from load_data import *

from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report

In [2]:
class ImageDatasetDf(Dataset):
    def __init__(self, root_folder, df, label = None, transform=None, cols = None):
        self.transform=transform
        self.img_folder=root_folder+'img/img_align_celeba/'
        self.df = df.copy()
        self.attr = df.copy()
        self.image_names = self.attr.pop('image_id')
        if cols is not None:
            self.attr = self.attr[cols]    
        self.label = self.df[label].values
        self.attr = self.attr.values

    def __len__(self):
        return len(self.image_names)
 
    def __getitem__(self, index):
        image_path = self.img_folder + self.image_names[index]
        image=Image.open(image_path)
        image=self.transform(image)
        return image, self.label[index], self.attr[index]

def get_dataloader_from_df(root_folder, df, label = None, img_dim=64, batch_size=32, cols = None):
    transform = Compose([Resize((img_dim, img_dim)),
                        ToTensor(),
                        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    data = ImageDatasetDf(root_folder=root_folder, df = df, label = label, transform=transform, cols = cols)
    dataloader = DataLoader(data, batch_size = batch_size, shuffle = True)
    return dataloader

In [3]:
class ImageDataset(Dataset):
    def __init__(self, root_folder, file_name, transform, attr, img_folder = None, label = None):
        self.transform=transform
        self.img_folder=root_folder+img_folder
        
        self.df = pd.read_csv(root_folder+file_name+'.csv').replace(-1,0).reset_index(drop=True)
        self.image_names = self.df.pop('image_id')
        self.attr = self.df[attr].values
        self.label = self.df[label].values
   
    def __len__(self):
        return len(self.image_names)
 
    def __getitem__(self, index):
        image_path = self.img_folder + self.image_names[index]
        image=Image.open(image_path)
        image=self.transform(image)
        label = torch.tensor(self.label[index], dtype=torch.float32)
        return image, label, self.attr[index]

In [4]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [5]:
img_dim = 64
batch_size = 128

root_folder = 'dataset/celebA/'
img_folder = 'img/img_align_celebA/'
file_name = 'dear_train_downsample_smile'

label = 'Smiling'
attr = 'Male'

apply_weight = False

In [6]:
causal_diagram = 'smiling'

if causal_diagram == 'smiling':
    #saved_folder = 'saved_model_downsample_smile_reduce_latent_dim'
    saved_folder = 'saved_model'
    #saved_folder = 'saved_model_downsample_smile_increase_latent_dim'
    cols = ['Smiling', 'Male', 'High_Cheekbones', 'Mouth_Slightly_Open', 'Narrow_Eyes', 'Chubby']
    num_label = len(cols)
    A = torch.zeros((num_label, num_label))
    A[0, 2:6] = 1
    A[1, 4] = 1

In [7]:
in_channels = 3
fc_size = 2048

#latent_dim = 10
latent_dim = 100
#latent_dim = 150


g_conv_dim = 32
enc_dist='gaussian'
enc_arch='resnet'
enc_fc_size=2048
enc_noise_dim=128
dec_dist = 'implicit'
prior = 'linscm'

In [8]:
model = BGM(latent_dim, g_conv_dim, img_dim,
                enc_dist, enc_arch, enc_fc_size, enc_noise_dim, dec_dist,
                prior, num_label, A)
model = nn.DataParallel(model)
checkpoint = torch.load(f'{saved_folder}/bgm', map_location='cpu')
print(checkpoint['epoch'])
model.load_state_dict(checkpoint['model_state_dict'])
model = model.module.to(device)

130


In [9]:
transform = Compose([CenterCrop(128),
                     Resize((img_dim, img_dim)),
                     ToTensor(),
                     Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
data = ImageDataset(root_folder=root_folder, file_name = file_name, transform=transform,
                    img_folder = img_folder,
                    attr = attr, label = label)
train_dataloader = DataLoader(data, batch_size = batch_size,shuffle = True)
testdata = ImageDataset(root_folder='dataset/celebA/', file_name = 'dear_test', transform=transform,
                    img_folder = 'img/img_align_celebA/',
                    attr = attr, label = label)
test_dataloader = DataLoader(testdata, batch_size = batch_size,shuffle = True)

In [10]:
dest_dir = 'synthetic_latent_dataset'
#dest_dir = 'latent_dataset'
train_dir = f'{dest_dir}/train/'
test_dir = f'{dest_dir}/test/'

make_dir = lambda path : os.makedirs(path) if not os.path.exists(path) else 0
make_dir(dest_dir)
make_dir(train_dir)
make_dir(test_dir)

In [11]:
def save_numpy(z, y, attr, num_id, d):
    for idx in range(len(z)): 
        torch.save((z[idx].cpu(), y[idx].cpu(), attr[idx].cpu()), f'{d}{num_id}.pt') 
        num_id = num_id + 1
    return num_id

In [12]:
y_name, attr_name = 'Smiling', 'Male'
final_cols = [y_name, attr_name]

In [13]:
num_id = 0
for x, y, attr in tqdm(train_dataloader):
    x = x.to(device)
    eps = model.encode(x)
    z_label = eps[:, :num_label]
    other_label = eps[:, num_label:]
    z_label = model.prior(z_label)
    z = torch.cat([z_label, other_label], dim = 1)
    num_id = save_numpy(z, y, attr, num_id, train_dir)

  0%|          | 0/854 [00:00<?, ?it/s]

In [14]:
if 'synthetic' in dest_dir:
    file_name = "dear_train_downsample_smile"
    df = pd.read_csv(f"dataset/celebA/{file_name}.csv").replace(-1,0)
    idx_attr = cols.index(attr_name)
    idx_y = cols.index(y_name)

    df_temp = df[(df['Smiling']==0) & (df['Male']==1)].reset_index(drop=True)
    dc1_loader = get_dataloader_from_df(root_folder, df_temp, label=y_name, batch_size = batch_size, cols = attr_name)

    attr_new , y_new = 0, 0
    with torch.no_grad():
        for ii, (x, y, attr) in enumerate(tqdm(dc1_loader)):
            x = x.to(device)
            eps = model.encode(x)
            eps[:, idx_attr] = -2 if attr_new == 0 else 2
            eps[:, idx_y] = -2 if y_new == 0 else 2
            label_z = model.prior(eps[:, :num_label])
            other_z = eps[:, num_label:]
            z = torch.cat([label_z, other_z], dim=1)
            num_id = save_numpy(z, y, attr, num_id, train_dir)
            if ii > 200:
                break

    df_temp = df[(df['Smiling']==1) & (df['Male']==0)].reset_index(drop=True)        
    dc2_loader = get_dataloader_from_df(root_folder, df_temp, label=y_name,batch_size = batch_size, cols=attr_name)
    attr_new , y_new = 1, 1
    with torch.no_grad():
        for ii, (x, y, attr) in enumerate(tqdm(dc2_loader)):
            x = x.to(device)
            eps = model.encode(x)
            eps[:, idx_attr] = -2 if attr_new == 0 else 2
            eps[:, idx_y] = -2 if y_new == 0 else 2
            label_z = model.prior(eps[:, :num_label])
            other_z = eps[:, num_label:]
            z = torch.cat([label_z, other_z], dim=1)
            num_id = save_numpy(z, y, attr, num_id, train_dir)
            if ii > 200:
                break

  0%|          | 0/317 [00:00<?, ?it/s]

  0%|          | 0/399 [00:00<?, ?it/s]

In [15]:
num_id = 0
for x, y, attr in tqdm(test_dataloader):
    x = x.to(device)
    eps = model.encode(x)
    z_label = eps[:, :num_label]
    other_label = eps[:, num_label:]
    z_label = model.prior(z_label)
    z = torch.cat([z_label, other_label], dim = 1)
    num_id = save_numpy(z,y,attr, num_id, test_dir)

  0%|          | 0/159 [00:00<?, ?it/s]