In [0]:
#!/usr/bin/python

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader

import encoder as enc
import generator as gen
import discriminator as disc
import STL10GrayColor as STLGray
import utils as utls
import losses
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import lab2rgb

In [2]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#data
transform = transforms.Compose([transforms.Resize(128)])#,
#                                transforms.ToTensor()])

# Load STL10 dataset
stl10_trainset = STLGray.STL10GrayColor(root="./data",
                              split='train',
                              download=True,
                              transform=transform)

#TODO
#train+unlabeled in split

Files already downloaded and verified


In [0]:
# Parameters
batch_size = 32
z_dim = 128
params_loader = {'batch_size': batch_size,
               'shuffle': False}

train_loader = DataLoader(stl10_trainset, **params_loader)


In [0]:
#demultiplier = dem.Demultiplier()
#demultiplier = demultiplier.to(device)

encoder = enc.Encoder(z_dim=z_dim)
encoder = encoder.to(device)

generator = gen.Generator(z_dim=z_dim)
generator.apply(utls.weights_init)
generator = generator.to(device)

discriminator = disc.Discriminator()
discriminator.apply(utls.weights_init)
discriminator = discriminator.to(device)

optimizer_params = {'lr': 0.0001,
                    'betas':(0.5, 0.999)}

enc_loss = nn.MSELoss()

#optimizer_m = torch.optim.Adam(demultiplier.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_e = torch.optim.Adam(encoder.parameters(), **optimizer_params)
optimizer_g = torch.optim.Adam(generator.parameters(), **optimizer_params)
optimizer_d = torch.optim.Adam(discriminator.parameters(), **optimizer_params)


In [9]:
print(encoder)
print(generator)
print(discriminator)

Encoder(
  (conv_1_to_3): Sequential(
    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (vgg): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inpla

In [0]:
n_epochs = 10


for epoch in range(n_epochs):
    print("epoch :", epoch)
  
    for i, (img_g, img_c) in enumerate(train_loader):
        
        img_g = img_g.to(device)
        img_c = img_c.to(device)

        bs, *_ = img_g.shape
        if bs != batch_size:
            continue


        #######################
        # Train Discriminator #
        #######################
        img_features = encoder(img_g)

        img_colorized = generator(img_features.detach())

        loss_d = losses.dis_loss(discriminator, img_c, img_colorized.detach())

        #bp
        discriminator.zero_grad()
        loss_d.backward()
        optimizer_d.step()
        
        #######################
        # Train Generator #
        #######################
        
        #img_colorized = generator(img_features) #re attach ?
        
        loss_g = losses.gen_loss(discriminator, img_colorized)
        
        #bp
        generator.zero_grad()     
        loss_g.backward()
        optimizer_g.step()
        
        #######################
        # Train Encoder #
        #######################
        
        #TODO BETTER WAY/optimizing img_colorized without detach
        #img_features = encoder(img_g)

        img_colorized = generator(img_features)
        
        loss_e = enc_loss(img_colorized, img_c)
        
        print("loss encoder :", loss_e.item())
        #bp
        encoder.zero_grad()
        loss_e.backward()
        optimizer_e.step()
        
        #printing shit
        if (i%1 == 0) :
            print("iteration ", i, "out of ", len(train_loader.dataset)//batch_size,
                  "\terrD : ", round(loss_d.item(),3), "\terrG : ", round(loss_g.item(),3))
        
        
        if i%100 == 0:
            img_display = utls.convert_lab2rgb(img_g, img_colorized.detach())
            
            vutils.save_image(img_display,
                              f"___epoch_{epoch}.png",
                              nrow=5,
                              normalize=True)
            print(">plotted shit")        
        
    
    

epoch : 0
loss encoder : 240.33721923828125
iteration  0 out of  156 	errD :  -2715.789 	errG :  2469.628
>plotted shit
loss encoder : 132.7208709716797
iteration  1 out of  156 	errD :  -2719.08 	errG :  2471.387
loss encoder : 170.93991088867188
iteration  2 out of  156 	errD :  -2720.473 	errG :  2473.179
loss encoder : 177.8868865966797
iteration  3 out of  156 	errD :  -2722.609 	errG :  2474.993
loss encoder : 156.9547576904297
iteration  4 out of  156 	errD :  -2724.297 	errG :  2476.717
loss encoder : 206.63722229003906
iteration  5 out of  156 	errD :  -2725.703 	errG :  2478.5
loss encoder : 180.01422119140625
iteration  6 out of  156 	errD :  -2727.545 	errG :  2480.249
loss encoder : 260.9034423828125
iteration  7 out of  156 	errD :  -2728.448 	errG :  2482.039
loss encoder : 160.64605712890625
iteration  8 out of  156 	errD :  -2731.073 	errG :  2483.851
loss encoder : 197.50213623046875
iteration  9 out of  156 	errD :  -2733.127 	errG :  2485.646
loss encoder : 156.0454

In [0]:
fig, axs = plt.subplots(2, figsize=(10,10))
fig.subplots_adjust(hspace=0.3)


axs[0].set_title("All Losses")
axs[0].set_xlabel("iterations")
axs[0].set_ylabel("Loss")
axs[0].plot(G_losses,label="G")
axs[0].plot(D_losses,label="D")
axs[0].legend()

axs[1].set_title("After 1000 iterations")
axs[1].set_xlabel("iterations")
axs[1].set_ylabel("Loss")
axs[1].plot(G_losses[1000:],label="G")
axs[1].plot(D_losses[1000:],label="D")
axs[1].legend()