In [1]:
import sys
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision.utils import save_image
import argparse
import numpy as np
import matplotlib.pyplot as plt

from bgm import *
from sagan import *
from causal_model import *
import os
import random
import utils

import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Compose, Resize, Normalize

from tqdm.notebook import tqdm
import seaborn as sns

In [2]:
class ImageDataset(Dataset):
    def __init__(self,root_folder,transform, cols = None):
        self.transform=transform
        self.img_folder=root_folder+'img/img_align_celeba/'

        self.image_names=[i for i in os.listdir(self.img_folder) if '.jpg' in i]
        
        self.attr = pd.read_csv(root_folder+'attr.csv').replace(-1,0).sample(frac = 0.01)
        self.image_names = list(self.attr.pop('image_id'))
        if cols is not None:
            self.attr = self.attr[cols]    
        self.num_feat = len(self.attr.columns)
        self.order = list(self.attr.columns)
        
        self.attr = self.attr.values
   
    def __len__(self):
        return len(self.image_names)
 
    def __getitem__(self, index):
        image_path = self.img_folder + self.image_names[index]
        image=Image.open(image_path)
        image=self.transform(image)
        label = torch.tensor(self.attr[index], dtype = torch.float)

        return image, label

def get_train_dataloader(root_folder, img_dim=64, batch_size=32, cols = None):

    transform = Compose([Resize((img_dim, img_dim)),
                        ToTensor(),
                        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    training_data = ImageDataset(root_folder='dataset/celebA/',transform=transform, cols = cols)
    train_dataloader = DataLoader(training_data, batch_size = batch_size, num_workers = 2, 
                                  shuffle = True, prefetch_factor = 4)
    return train_dataloader

In [3]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [4]:
celoss = torch.nn.BCEWithLogitsLoss()

In [5]:
cols = ['Smiling', 'Male', 'High_Cheekbones', 'Mouth_Slightly_Open', 'Narrow_Eyes', 'Chubby']

In [6]:
root_folder = 'sample_data/'

in_channels = 3
fc_size = 2048
latent_dim = 100

img_dim = 64
batch_size = 128

g_conv_dim = 32
enc_dist='gaussian'
enc_arch='resnet'
enc_fc_size=2048
enc_noise_dim=128
dec_dist = 'implicit'
prior = 'linscm'

d_conv_dim = 32
dis_fc_size = 1024

num_label = len(cols)

In [7]:
train_dataloader = get_train_dataloader(root_folder, img_dim=img_dim, 
                                        batch_size=batch_size, cols = cols, 
                                       )

In [8]:
A = torch.zeros((num_label, num_label), device = device)
A[0, 2:6] = 1
A[1, 4] = 1

In [9]:
model = BGM(latent_dim, g_conv_dim, img_dim,
                enc_dist, enc_arch, enc_fc_size, enc_noise_dim, dec_dist,
                prior, num_label, A)

In [10]:
discriminator = BigJointDiscriminator(latent_dim, d_conv_dim, img_dim, dis_fc_size)

In [11]:
A_optimizer = None
prior_optimizer = None

enc_param = model.encoder.parameters()
dec_param = list(model.decoder.parameters())
prior_param = list(model.prior.parameters())

A_optimizer = optim.Adam(prior_param[0:1], lr=1e-3)
prior_optimizer = optim.Adam(prior_param[1:], lr=1e-3, betas=(0, 0.999))

In [12]:
encoder_optimizer = optim.Adam(enc_param, lr=5e-5, betas=(0, 0.999))
decoder_optimizer = optim.Adam(dec_param, lr=5e-5, betas=(0, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0, 0.999))


In [13]:
model = nn.DataParallel(model.to(device))
discriminator = nn.DataParallel(discriminator.to(device))

In [14]:
epochs = 1
d_steps_per_iter = 1
g_steps_per_iter = 1

In [15]:
number_batches = (len(train_dataloader.dataset)//batch_size)+1
number_batches

16

In [16]:
for epoch in tqdm(range(epochs)):
    model.train()
    disc_loss, e_loss, g_loss, label_loss = [], [], [], []
    for batch_idx, (x, label) in tqdm(enumerate(train_dataloader), total = number_batches):
        x = x.to(device)
        sup_flag = label[:, 0] != -1
        if sup_flag.sum() > 0:
            label = label[sup_flag, :].float()
        
        label = label.to(device)
        
        for _ in range(d_steps_per_iter):
            discriminator.zero_grad()
            z = torch.randn(x.size(0), latent_dim, device=x.device)
            z_fake, x_fake, z, _ = model(x, z)
            
            
            encoder_score = discriminator(x, z_fake.detach())
            decoder_score = discriminator(x_fake.detach(), z.detach())
            
            del z_fake
            del x_fake
            
            loss_d = F.softplus(decoder_score).mean() + F.softplus(-encoder_score).mean()
            loss_d.backward()
            D_optimizer.step()
            disc_loss.append(loss_d.item())
        
        for _ in range(g_steps_per_iter):
            z = torch.randn(x.size(0), latent_dim, device=x.device)
            z_fake, x_fake, z, z_fake_mean = model(x, z)
            model.zero_grad()
            encoder_score = discriminator(x, z_fake)
            loss_encoder = encoder_score.mean()
            if sup_flag.sum() > 0:
                label_z = z_fake_mean[sup_flag, :num_label]
                sup_loss = celoss(label_z, label)
                label_loss.append(sup_loss.item())
            else:
                sup_loss = torch.zeros([1], device=device)
        
            loss_encoder = loss_encoder + sup_loss * 5
            loss_encoder.backward()
            encoder_optimizer.step()
            prior_optimizer.step()
            e_loss.append(loss_encoder.item())
            
            model.zero_grad()
            z = torch.randn(x.size(0), latent_dim, device=x.device)
            z_fake, x_fake, z, z_fake_mean = model(x, z)
            decoder_score = discriminator(x_fake, z)
            r_decoder = torch.exp(decoder_score.detach())
            s_decoder = r_decoder.clamp(0.5, 2)
            loss_decoder = -(s_decoder * decoder_score).mean()
            
            loss_decoder.backward()
            decoder_optimizer.step()
            model.module.prior.set_zero_grad()
            A_optimizer.step()
            prior_optimizer.step()
            g_loss.append(loss_decoder.item())
    print(f"[{epoch+1}/{epochs}] Encoder Loss : {sum(e_loss)/number_batches:>.5f} Gen Loss : {sum(g_loss)/number_batches:>.5f} \
    Disc Loss : {sum(disc_loss)/number_batches:>.5f} Label Loss : {sum(label_loss)/number_batches:>.5f} ")
    if epoch % 1 == 0:
        model.eval()
        t = 10
        for batch_idx, (x, label) in enumerate(train_dataloader):
            with torch.no_grad():
                x = x.to(device)
                x_ = x[:t]
                x_recon = model(x, recon=True)[:10]
                x_recon = (x_recon * 0.5) + 0.5

                z = torch.randn(x_.size(0), latent_dim, device=x.device)
                z_fake, x_fake, z, z_fake_mean = model(x_, z)
                print(z_fake_mean[:, :num_label], label[:t])
            break

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

[1/1] Encoder Loss : 4.13027 Gen Loss : -0.03233     Disc Loss : 1.31960 Label Loss : 0.68411 
tensor([[ 0.5639, -0.1954,  0.2110,  0.3716, -1.0446, -1.0229],
        [ 0.0669, -0.6117,  0.2885,  0.2823, -1.0709, -0.6989],
        [ 0.5926, -1.0375,  0.4332,  0.4410, -1.2946, -0.9399],
        [ 0.5642, -0.5155,  0.2861,  0.4061, -1.7704, -1.3865],
        [-0.2909, -0.9688, -0.2002,  0.1004, -0.5711, -0.6094],
        [ 0.1822, -0.4133,  0.1151, -0.0359, -0.9634, -0.7432],
        [ 0.3912,  0.0173,  0.5185,  0.3336, -1.1664, -0.9307],
        [-0.4052, -0.6264, -0.4974,  0.0023, -1.0457, -1.0831],
        [-0.2007, -0.7223,  0.1826,  0.3288, -0.9612, -1.0954],
        [-0.0042, -0.7489,  0.3492, -0.0936, -1.1638, -0.9495]],
       device='cuda:0') tensor([[0., 1., 0., 0., 0., 0.],
        [1., 0., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 0., 1., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [1., 0., 1., 1., 0., 0.],
  