In [0]:
#!/usr/bin/python

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision.datasets as dsets

#import encoder as enc
import generator as gen
import discriminator as disc
import STL10GrayColor as STLGray
import utils as utls
import losses
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import lab2rgb

In [9]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#data
# transform = transforms.Compose([transforms.Resize(128)])
# 
# # Load STL10 dataset
# stl10_trainset = STLGray.STL10GrayColor(root="./data",
#                               split='train',
#                               download=True,
#                               transform=transform)

transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5)),
])

stl10_dtset_c = dsets.STL10(root="./data",
                          download=True,
                          split='train',
                          transform=transform)
transform = transforms.Compose([
    transforms.Resize(128),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, ),
                         (0.5, )),
])
stl10_dtset_g = dsets.STL10(root="./data",
                          download=True,
                          split='train',
                          transform=transform)

#TODO
#train+unlabeled in split

#########################
# Test TODO:
# update in the same time the encoder and the generator
# reduce the learning rate after n epochs
# 

Files already downloaded and verified
Files already downloaded and verified


In [0]:
# Parameters
batch_size = 25
# z_dim = 256
params_loader = {
    'batch_size': batch_size,
    'shuffle': False
}

train_loader_c = DataLoader(stl10_dtset_c, **params_loader)
train_loader_g = DataLoader(stl10_dtset_g, **params_loader)


In [12]:
netG = gen.GeneratorSeg(color_ch=3)
netD = disc.SADiscriminator(in_dim=3)

# TODO init layers of the generator in the class
netD.apply(utls.xavier_init_weights)

netG.to(device)
netD.to(device)

# parameters given in the original paper
lr_g = 0.0001
lr_d = 0.0004

betas = (0., 0.9)

optimizer_g = Adam(netG.parameters(), lr=lr_g, betas=betas)
optimizer_d = Adam(netD.parameters(), lr=lr_d, betas=betas)

print(netG)
print(netD)


[Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)]
GeneratorSeg(
  (convert_bw_to_rgb): Sequential(
    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(3, 3, kern

In [0]:
n_epochs = 50
wass_loss = False

def disc_hinge_loss(netD, real_data, fake_data):
    # Train with real
    d_out_real = netD(real_data)
    
    # Train with fake
    d_out_fake = netD(fake_data)
    
    # adversial hinge loss
    d_loss_real = nn.ReLU()(1.0 - d_out_real).mean()
    d_loss_fake = nn.ReLU()(1.0 + d_out_fake).mean()
    d_loss = d_loss_real + d_loss_fake
    
    return d_loss

def gen_hinge_loss(netD, fake_data):
    loss = -netD(fake_data).mean()
    
    return loss

for epoch in range(n_epochs):
    print("epoch :", epoch)

    # for idx, (img_g, img_c) in enumerate(train_loader):
    for idx, ((img_c, _), (img_g, _)) in enumerate(zip(train_loader_c, train_loader_g)):
        img_g = img_g.to(device)
        img_c = img_c.to(device)

        # The last batch hasn't the same batch size so skip it
        bs, *_ = img_g.shape
        if bs != batch_size:
            continue

        #######################
        # Train Discriminator #
        #######################

        # Create fake colors
        fakes = netG(img_g)
        
        if wass_loss:
            d_loss = losses.dis_loss(netD, img_c, fakes.detach())
        else:
            d_loss = disc_hinge_loss(netD, img_c, fakes.detach())
            
        m_d_loss = d_loss.item()
        
        # Backward and optimize
        netD.zero_grad()
        d_loss.backward()
        optimizer_d.step()
        
        # Release the gpu memory
        del d_loss
        
        #######################
        # Train Discriminator #
        #######################
        
        if wass_loss:
            g_loss = losses.gen_loss(netD, fakes)
        else:
            g_loss = gen_hinge_loss(netD, fakes)
        
        # Backward and optimize
        netG.zero_grad()
        g_loss.backward()
        optimizer_g.step()
        
        m_g_loss = g_loss.item()
        

        print(f"Epoch [{epoch}/{n_epochs}], "
              f"iter[{idx}/{len(train_loader)}], "
              f"d_out_real: {m_d_loss}, "
              f"g_out_fake: {m_g_loss}")
        
        if idx % 100 == 0:
            
            # grayscale = torch.squeeze(img_g.detach())
            # img_display = utls.convert_lab2rgb(grayscale,
            #                                    fakes.detach())
            vutils.save_image(fakes.detach(),
                              f'./_{epoch}_epoch_{idx}.png',
                              normalize=True)
            
        # Release the gpu memory
        del fakes, g_loss
            
        torch.cuda.empty_cache()

epoch : 0
Epoch [0/50], iter[0/200], d_out_real: 1.9555391073226929, g_out_fake: -0.3093000650405884
Epoch [0/50], iter[1/200], d_out_real: 1.3869214057922363, g_out_fake: 0.08069433271884918
Epoch [0/50], iter[2/200], d_out_real: 1.2135690450668335, g_out_fake: -1.0384349822998047
Epoch [0/50], iter[3/200], d_out_real: 2.369509220123291, g_out_fake: 0.19177043437957764
Epoch [0/50], iter[4/200], d_out_real: 1.0079423189163208, g_out_fake: 0.7593481540679932
Epoch [0/50], iter[5/200], d_out_real: 0.4191974997520447, g_out_fake: 5.093146324157715
Epoch [0/50], iter[6/200], d_out_real: 0.444720059633255, g_out_fake: 0.8791082501411438
Epoch [0/50], iter[7/200], d_out_real: 0.4378979504108429, g_out_fake: 4.45481538772583
Epoch [0/50], iter[8/200], d_out_real: 0.037082016468048096, g_out_fake: 2.01496958732605
Epoch [0/50], iter[9/200], d_out_real: 0.0, g_out_fake: 1.5156484842300415
Epoch [0/50], iter[10/200], d_out_real: 0.03327735885977745, g_out_fake: 2.298039674758911
Epoch [0/50], i