In [1]:
import sys
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision.utils import save_image
import argparse
import numpy as np
import matplotlib.pyplot as plt

from bgm import *
from sagan import *
from causal_model import *
import os
import random
import utils

import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Compose, Resize, Normalize

from tqdm.notebook import tqdm
import seaborn as sns

In [2]:
class ImageDataset(Dataset):
    def __init__(self,root_folder,transform, cols = None):
        self.transform=transform
        self.img_folder=root_folder+'img/img_align_celeba/'

        self.image_names=[i for i in os.listdir(self.img_folder) if '.jpg' in i]
        
        self.attr = pd.read_csv(root_folder+'attr.csv').replace(-1,0).sample(frac = 0.01)
        self.image_names = list(self.attr.pop('image_id'))
        if cols is not None:
            self.attr = self.attr[cols]    
        self.num_feat = len(self.attr.columns)
        self.order = list(self.attr.columns)
        
        self.attr = self.attr.values
   
    def __len__(self):
        return len(self.image_names)
 
    def __getitem__(self, index):
        image_path = self.img_folder + self.image_names[index]
        image=Image.open(image_path)
        image=self.transform(image)
        label = torch.tensor(self.attr[index], dtype = torch.float)

        return image, label

def get_train_dataloader(root_folder, img_dim=64, batch_size=32, cols = None):

    transform = Compose([Resize((img_dim, img_dim)),
                        ToTensor(),
                        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    training_data = ImageDataset(root_folder='dataset/celebA/',transform=transform, cols = cols)
    train_dataloader = DataLoader(training_data, batch_size = batch_size, num_workers = 2, 
                                  shuffle = True, prefetch_factor = 4)
    return train_dataloader

In [3]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [4]:
celoss = torch.nn.BCEWithLogitsLoss()

In [5]:
cols = ['Smiling', 'Male', 'High_Cheekbones', 'Mouth_Slightly_Open', 'Narrow_Eyes', 'Chubby']

In [6]:
root_folder = 'sample_data/'

in_channels = 3
fc_size = 2048
latent_dim = 100

img_dim = 64
batch_size = 128

g_conv_dim = 32
enc_dist='gaussian'
enc_arch='resnet'
enc_fc_size=2048
enc_noise_dim=128
dec_dist = 'implicit'
prior = 'linscm'

d_conv_dim = 32
dis_fc_size = 1024

num_label = len(cols)

In [7]:
train_dataloader = get_train_dataloader(root_folder, img_dim=img_dim, 
                                        batch_size=batch_size, cols = cols, 
                                       )

In [8]:
A = torch.zeros((num_label, num_label), device = device)
A[0, 2:6] = 1
A[1, 4] = 1

In [9]:
model = BGM(latent_dim, g_conv_dim, img_dim,
                enc_dist, enc_arch, enc_fc_size, enc_noise_dim, dec_dist,
                prior, num_label, A)

In [10]:
discriminator = BigJointDiscriminator(latent_dim, d_conv_dim, img_dim, dis_fc_size)

In [11]:
A_optimizer = None
prior_optimizer = None

enc_param = model.encoder.parameters()
dec_param = list(model.decoder.parameters())
prior_param = list(model.prior.parameters())

A_optimizer = optim.Adam(prior_param[0:1], lr=1e-3)
prior_optimizer = optim.Adam(prior_param[1:], lr=1e-3, betas=(0, 0.999))

In [12]:
encoder_optimizer = optim.Adam(enc_param, lr=5e-5, betas=(0, 0.999))
decoder_optimizer = optim.Adam(dec_param, lr=5e-5, betas=(0, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0, 0.999))


In [13]:
model = nn.DataParallel(model.to(device))
discriminator = nn.DataParallel(discriminator.to(device))

In [14]:
epochs = 10
d_steps_per_iter = 1
g_steps_per_iter = 1

In [15]:
number_batches = (len(train_dataloader.dataset)//batch_size)+1
number_batches

16

In [None]:
for epoch in tqdm(range(epochs)):
    model.train()
    disc_loss, e_loss, g_loss, label_loss = [], [], [], []
    for batch_idx, (x, label) in tqdm(enumerate(train_dataloader), total = number_batches):
        x = x.to(device)
        sup_flag = label[:, 0] != -1
        if sup_flag.sum() > 0:
            label = label[sup_flag, :].float()
        
        label = label.to(device)
        
        for _ in range(d_steps_per_iter):
            discriminator.zero_grad()
            z = torch.randn(x.size(0), latent_dim, device=x.device)
            z_fake, x_fake, z, _ = model(x, z)
            encoder_score = discriminator(x, z_fake.detach())
            decoder_score = discriminator(x_fake.detach(), z.detach())
            del z_fake
            del x_fake
            
            loss_d = F.softplus(decoder_score).mean() + F.softplus(-encoder_score).mean()
            loss_d.backward()
            D_optimizer.step()
            disc_loss.append(loss_d.item())
        
        for _ in range(g_steps_per_iter):
            z = torch.randn(x.size(0), latent_dim, device=x.device)
            z_fake, x_fake, z, z_fake_mean = model(x, z)
            model.zero_grad()
            encoder_score = discriminator(x, z_fake)
            loss_encoder = encoder_score.mean()
            if sup_flag.sum() > 0:
                label_z = z_fake_mean[sup_flag, :num_label]
                sup_loss = celoss(label_z, label)
                label_loss.append(sup_loss.item())
            else:
                sup_loss = torch.zeros([1], device=device)
        
            loss_encoder = loss_encoder + sup_loss * 5
            loss_encoder.backward()
            encoder_optimizer.step()
            prior_optimizer.step()
            e_loss.append(loss_encoder.item())
            
            model.zero_grad()
            z = torch.randn(x.size(0), latent_dim, device=x.device)
            z_fake, x_fake, z, z_fake_mean = model(x, z)
            decoder_score = discriminator(x_fake, z)
            r_decoder = torch.exp(decoder_score.detach())
            s_decoder = r_decoder.clamp(0.5, 2)
            loss_decoder = -(s_decoder * decoder_score).mean()
            
            loss_decoder.backward()
            decoder_optimizer.step()
            model.module.prior.set_zero_grad()
            A_optimizer.step()
            prior_optimizer.step()
            g_loss.append(loss_decoder.item())
    print(f"[{epoch+1}/{epochs}] Encoder Loss : {sum(e_loss)/number_batches:>.5f} Gen Loss : {sum(g_loss)/number_batches:>.5f} \
    Disc Loss : {sum(disc_loss)/number_batches:>.5f} Label Loss : {sum(label_loss)/number_batches:>.5f} ")
    if epoch % 1 == 0:
        model.eval()
        t = 10
        for batch_idx, (x, label) in enumerate(train_dataloader):
            with torch.no_grad():
                x = x.to(device)
                x_ = x[:t]
                x_recon = model(x, recon=True)[:10]
                x_recon = (x_recon * 0.5) + 0.5
                    
        
                z = torch.randn(x_.size(0), latent_dim, device=x.device)
                z_fake, x_fake, z, z_fake_mean = model(x_, z)
                print(z_fake_mean[:, :num_label], label[:t])
            break

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

[1/10] Encoder Loss : 4.24643 Gen Loss : 0.18687     Disc Loss : 1.07666 Label Loss : 0.68507 
tensor([[ 0.8600,  0.3843,  0.2070,  0.7493,  0.1355,  0.0981],
        [ 0.4883, -0.1260,  0.9505,  0.9646, -0.0627,  0.0122],
        [ 0.5217,  0.8827,  0.4022,  1.3968,  0.4691, -0.1891],
        [ 1.0147,  0.0334,  0.4042,  1.2837,  0.3766,  0.2906],
        [-0.9183,  0.8653, -1.5007,  0.2716,  0.4320, -0.3325],
        [ 0.7080,  0.5943, -0.3317,  0.8240, -0.2940, -0.1361],
        [-1.9430,  1.7561, -1.1635, -0.3554, -0.1091,  0.4225],
        [-0.9249,  0.4667, -0.7342, -0.5807,  0.5890,  0.0558],
        [ 0.7989,  0.9563, -0.4863,  0.5815,  0.2102, -0.1263],
        [ 1.1079, -0.0544,  0.3481,  0.9085,  0.2052,  0.7795]],
       device='cuda:0') tensor([[1., 1., 1., 1., 0., 0.],
        [1., 0., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [0., 1., 0., 0., 1., 1.],
  

  0%|          | 0/16 [00:00<?, ?it/s]

[2/10] Encoder Loss : 4.14264 Gen Loss : 0.21730     Disc Loss : 1.19750 Label Loss : 0.70576 
tensor([[-1.0161e+00,  2.7200e-01, -1.5046e-01, -7.4688e-02, -2.2919e-03,
          5.5520e-02],
        [ 5.1189e+00,  4.3418e-01,  4.5362e+00, -1.7756e+00,  4.8787e-01,
         -3.4919e+00],
        [ 7.2044e-01,  1.1513e+00,  1.4748e+00, -2.3064e-01,  5.3076e-01,
         -8.3296e-01],
        [ 1.8135e+00, -7.7615e-01,  1.6556e+00,  6.5426e-01,  1.0621e-02,
         -6.4994e-01],
        [ 8.0618e-01, -6.4680e-01,  8.3089e-01,  1.3163e+00, -6.2500e-01,
         -1.1739e-01],
        [ 2.8968e-01,  1.1804e+00, -1.0382e-01,  1.2346e-01,  5.0680e-02,
         -3.1262e-01],
        [ 1.0492e+00,  1.3940e+00,  5.2494e-01,  8.8545e-01, -2.6146e-01,
         -2.2744e-01],
        [ 1.7734e+00, -1.4187e+00,  1.4553e+00,  5.3321e-01, -8.2644e-03,
         -3.1987e-01],
        [ 7.3565e-01, -6.3030e-02,  4.8924e-01,  8.9291e-01,  7.7108e-02,
         -2.8999e-02],
        [ 1.4473e+00, -5.2998e-0

  0%|          | 0/16 [00:00<?, ?it/s]

[3/10] Encoder Loss : 3.82183 Gen Loss : 0.10074     Disc Loss : 1.21287 Label Loss : 0.66905 
tensor([[ 1.2260e+00, -3.4591e-01,  6.1457e-01,  1.6742e+00,  7.4209e-02,
         -9.0248e-02],
        [-3.2370e+00,  2.6131e+00, -8.7674e+00,  3.9810e-01, -3.1995e+00,
         -9.9728e-01],
        [ 2.4139e+00, -1.5654e+00,  2.0846e+00,  2.6250e+00, -2.2547e-01,
         -6.2592e-01],
        [ 4.1034e-01, -2.4784e-01,  7.8447e-02,  1.0769e+00,  9.8816e-02,
         -4.4357e-01],
        [-3.4657e-01,  6.7364e-01, -6.8868e-01,  8.6376e-01, -2.2930e-01,
         -3.0249e-01],
        [ 1.9201e-01,  1.2923e-01, -4.8674e-01,  5.0805e-01, -3.9868e-02,
         -5.4490e-01],
        [ 9.0511e-01, -6.4210e-02,  4.9290e-03,  1.1197e+00, -4.6967e-01,
         -1.8813e-01],
        [ 4.0772e-01, -3.1842e-01,  1.5133e-01,  9.5001e-01, -7.4103e-02,
         -3.9527e-01],
        [ 8.7849e-01, -3.5071e-01, -1.2073e-01,  7.6993e-01,  4.3614e-01,
         -4.4700e-02],
        [-9.7516e-01,  3.3416e+0

  0%|          | 0/16 [00:00<?, ?it/s]

[4/10] Encoder Loss : 3.66187 Gen Loss : 0.16200     Disc Loss : 1.26980 Label Loss : 0.67667 
tensor([[ 1.1343,  0.2473,  1.5052, -2.5445, -0.9021, -1.6407],
        [ 1.4498, -0.3423,  2.0364,  1.8824, -0.0467, -0.4832],
        [ 0.7266,  1.1375, -0.5711,  0.6637, -0.0302, -0.6111],
        [ 1.4939, -0.1968,  0.7627,  1.1310, -0.0804, -0.2769],
        [-0.4338,  1.3162, -1.4035, -0.0771, -0.0714, -0.1919],
        [-0.5759,  0.5586, -0.2856,  0.1496,  0.4907,  0.1516],
        [-0.8280,  2.7111, -1.0338, -0.3785, -0.3046, -0.2859],
        [ 0.9642,  0.3174,  0.5393,  1.5493, -0.1775, -0.2519],
        [ 0.0807,  1.7158, -0.3138,  0.5120, -0.3187, -0.1735],
        [ 1.1582, -0.8346,  0.2641,  0.6549,  0.1486, -0.2787]],
       device='cuda:0') tensor([[1., 1., 1., 1., 0., 0.],
        [1., 0., 1., 1., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 1.],
        [1., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 1., 1., 0., 0., 1.],
  

  0%|          | 0/16 [00:00<?, ?it/s]

[5/10] Encoder Loss : 3.43490 Gen Loss : 0.00792     Disc Loss : 1.36360 Label Loss : 0.65586 
tensor([[ 0.1244,  0.1841, -0.1281,  0.7603, -0.0240,  0.0345],
        [ 0.6855,  1.1112,  0.6382,  1.4382, -0.0507, -0.0919],
        [-0.5487,  1.8626, -0.9505, -0.2376, -0.0763,  0.0409],
        [-0.8343,  0.1555, -1.1085,  0.1765,  0.2572,  0.0307],
        [-0.8048,  1.3992, -1.2198, -0.4744,  0.0277, -0.2790],
        [-0.3278,  1.5428, -0.8090, -0.4564, -0.1221,  0.0921],
        [-1.2499,  1.1823, -1.0495, -0.1409,  0.1810, -0.1134],
        [ 0.7563, -0.5279, -0.0859,  0.8443,  0.1623, -0.0344],
        [-0.0298,  0.6776, -0.5226,  0.4274,  0.1678,  0.0706],
        [-0.6779,  1.6144, -1.0899,  0.5840,  0.0700, -0.1456]],
       device='cuda:0') tensor([[0., 1., 0., 1., 0., 0.],
        [1., 0., 1., 1., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 1., 0., 0.],
  

  0%|          | 0/16 [00:00<?, ?it/s]

[6/10] Encoder Loss : 3.35328 Gen Loss : 0.03361     Disc Loss : 1.34268 Label Loss : 0.63681 
tensor([[ 0.0344, -1.0509, -0.1440,  0.4456, -0.1637, -0.1624],
        [ 0.7107,  0.4812,  0.2131,  1.1211, -0.2657,  0.0074],
        [ 0.9250,  0.5570,  0.2263,  0.7357, -0.1033,  0.2349],
        [ 0.2855,  1.4499, -0.1534,  0.8045, -0.0870,  0.0089],
        [ 2.4006, -1.0347,  1.9122,  1.9323, -0.4362,  0.0255],
        [-1.0977,  0.3528, -1.3901,  0.1838,  0.1040,  0.3569],
        [ 2.4263, -0.7024,  0.9966,  1.8047,  0.4023, -0.4068],
        [ 1.8128,  0.5858,  1.0608,  1.5652,  0.0380, -0.1439],
        [-0.6907, -0.0497, -0.2076, -0.0599,  0.1037, -0.0228],
        [ 0.4735,  1.1661, -0.0740,  0.9254,  0.0884,  0.0302]],
       device='cuda:0') tensor([[0., 0., 1., 1., 0., 0.],
        [0., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [0., 1., 0., 0., 0., 1.],
        [1., 0., 1., 1., 0., 0.],
        [0., 0., 0., 1., 1., 1.],
        [1., 0., 1., 1., 0., 0.],
  

  0%|          | 0/16 [00:00<?, ?it/s]