In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torch.nn.functional as F
from torch.utils import data
from torchvision.utils import save_image

import sys
import random
from datetime import datetime
from itertools import chain
import matplotlib.pyplot as plt
import numpy as np

# My own files
from data_generator import CombinedDataset, ArtificialDataset
from discriminator import *
from resunet import ResUnet

In [None]:
use_cuda = torch.cuda.is_available() and True
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True
#torch.backends.cudnn.deterministic=True


# Initialising the data generator and model
grid_path = "data/train/grid/"
nogrid_path = "data/train/nogrid/"

batch_size = 1
max_imgs = 5000

augments = {"crop" : (400,400),
            "hflip" : 0.5,
            "vflip" : 0.5,
            "angle" : 0,
            "shear" : 0,
            "brightness" : (0.75,1.25),
            "pad" : (0,0,0,0),
            "contrast" : (0.5,2.0)}

artificial_augments = {"grid_size" : (30,90),
                       "grid_intensity" : (-1,1),
                       "grid_offset_x" : (0,90),
                       "grid_offset_y" : (0,90),
                       "crop" : (400,400),
                       "hflip" : 0.5,
                       "vflip" : 0.5,
                       "angle" : 0,
                       "shear" : 0,
                       "brightness" : (0.75,1.25),
                       "pad" : (0,0,0,0),
                       "contrast" : (0.5,2.0)}

#dataset = CombinedDataset(max_imgs, grid_path, nogrid_path, **augments)
#validation_set = CombinedDataset(max_imgs, grid_path, nogrid_path, **augments, seed=1337)
dataset = ArtificialDataset(max_imgs, nogrid_path, **artificial_augments)
validation_set = ArtificialDataset(max_imgs, nogrid_path, **artificial_augments, seed=1337)

training_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)
validation_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8)

#loaded_params = torch.load("lstm_1_layer_with_schedule_1024_minloss.pth")
#model = LSTM_model(**loaded_params["args_dict"]).to(processor)
#model.load_state_dict(loaded_params["state_dict"])

nogrid_disc = Discriminator (3).to(processor)
grid_disc = Discriminator (3).to(processor)

loaded_params = torch.load("saved/resunet/clamp_final_l2.pth")
nogrid_gen = ResUnet (**loaded_params["args_dict"]).to(processor)
nogrid_gen.load_state_dict(loaded_params["state_dict"])
#nogrid_gen = ResUnet (3, 3).to(processor)

grid_gen = ResUnet (**loaded_params["args_dict"]).to(processor)
grid_gen.load_state_dict(loaded_params["state_dict"])
grid_gen = ResUnet (3, 3).to(processor)

cycle_loss = nn.L1Loss().to(processor)
id_loss = nn.L1Loss().to(processor)
#adv_loss = nn.MSELoss().to(processor)
adv_loss = nn.BCELoss().to(processor)
#adv_loss = nn.BCEWithLogitsLoss().to(processor)


gen_opt = optim.Adam (chain (nogrid_gen.parameters(), grid_gen.parameters()), lr=1e-4)
#grid_gen_opt = optim.Adam (grid_gen.parameters(), lr=1e-4)
#nogrid_gen_opt = optim.Adam (nogrid_gen.parameters(), lr =1e-4)
nogrid_disc_opt = optim.Adam (nogrid_disc.parameters(), lr=1e-4)
grid_disc_opt = optim.Adam (grid_disc.parameters(), lr=1e-4)
                              
#scheduler = optim.lr_scheduler.StepLR (optimiser, step_size=10, gamma=0.5)


In [None]:
# Initialising some variables for use in training
batches = float("inf")
time_diff = 0
no_improv = 0
min_loss = float("inf")
val_loss = float("inf")
min_val_loss = float("inf")
epochs = 5
early_stop_epoch = 0

grid_gen_losses = []
nogrid_gen_losses = []

real_grid_disc_losses = []
real_nogrid_disc_losses = []

fake_grid_disc_losses = []
fake_nogrid_disc_losses = []

grid_disc_losses = []
nogrid_disc_losses = []

val_grid_gen_losses = []
val_nogrid_gen_losses = []

val_real_grid_disc_losses = []
val_real_nogrid_disc_losses = []

val_fake_grid_disc_losses = []
val_fake_nogrid_disc_losses = []

val_grid_disc_losses = []
val_nogrid_disc_losses = []

id_losses = []
gan_losses = []
cycle_losses = []

gen_losses = []
val_gen_losses = []

id_loss_mult = 5
gan_loss_mult = 5
cycle_loss_mult = 1

for epoch in range(epochs):
    for i, batch in enumerate(training_generator):

        # Keeping track of stuff
        start_time = datetime.now()
        est_time_left = str(time_diff*(min(batches, dataset.__len__()) - i)+time_diff*(epochs-(epoch+1))*min(batches, dataset.__len__()/batch_size)).split(".")[0]
        
        sys.stdout.write("\rEpoch: {0}. Batch: {1}. Min loss: {2:.5f}. Time left: {3}. Best: {4} batches ago. Val loss: {5:.5f}".format(epoch+1, i+1, min_loss, est_time_left, no_improv, val_loss))

        # Putting data on gpu
        real_grid = batch["grid"].to(processor)
        real_nogrid = batch["nogrid"].to(processor)

        real_label = torch.full((batch_size, 1), 1, device=processor, dtype=torch.float32)
        fake_label = torch.full((batch_size, 1), 0, device=processor, dtype=torch.float32)
        
        nogrid_disc.train()
        nogrid_gen.train()
        grid_disc.train()
        grid_gen.train()

                
        gen_opt.zero_grad()
        #grid_gen_opt.zero_grad()
        #nogrid_gen_opt.zero_grad()

        # Training the generators

        # Identity loss
        id_grid = grid_gen (real_grid)
        id_grid_loss = id_loss (id_grid, real_grid) * id_loss_mult

        id_nogrid = grid_gen (real_nogrid)
        id_nogrid_loss = id_loss (id_nogrid, real_nogrid) * id_loss_mult

        
        # GAN loss
        fake_grid = grid_gen (real_nogrid)
        fake_grid_pred = grid_disc (fake_grid)
        gan_loss_grid = adv_loss (fake_grid_pred, real_label) * gan_loss_mult
        
        fake_nogrid = nogrid_gen(real_grid)
        fake_nogrid_pred = nogrid_disc (fake_nogrid)
        gan_loss_nogrid = adv_loss (fake_nogrid_pred, real_label) * gan_loss_mult

        
        # Cycle loss
        cycled_grid = grid_gen (fake_nogrid)
        cycle_loss_grid = cycle_loss (cycled_grid, real_grid) * cycle_loss_mult

        cycled_nogrid = nogrid_gen (fake_grid)
        cycle_loss_nogrid = cycle_loss (cycled_nogrid, real_nogrid) * cycle_loss_mult
        
        gen_loss = id_grid_loss + id_nogrid_loss + gan_loss_grid + gan_loss_nogrid + cycle_loss_grid + cycle_loss_nogrid
        #gen_loss = id_nogrid_loss + gan_loss_nogrid + cycle_loss_nogrid
        #grid_gen_loss = id_grid_loss + gan_loss_grid + cycle_loss_grid

        gen_loss.backward()
        gen_opt.step()
        #grid_gen_opt.step()
        
        #grid_gen_loss.backward()
        #nogrid_gen_opt.step()

        # Training the discriminators

        # Grid discriminator
        grid_disc_opt.zero_grad()

        real_grid_disc = grid_disc (real_grid)
        real_grid_disc_loss = adv_loss (real_grid_disc, real_label)

        fake_grid_disc = grid_disc (fake_grid.detach())
        fake_grid_disc_loss = adv_loss (fake_grid_disc, fake_label)

        grid_disc_loss = (real_grid_disc_loss + fake_grid_disc_loss)/2 # Is /2 necessary?
        
        grid_disc_loss.backward()
        grid_disc_opt.step()
        

        # Nogrid discriminator
        grid_disc_opt.zero_grad()

        real_nogrid_disc = nogrid_disc (real_nogrid)
        real_nogrid_disc_loss = adv_loss (real_nogrid_disc, real_label)

        fake_nogrid_disc = nogrid_disc (fake_nogrid.detach())
        fake_nogrid_disc_loss = adv_loss (fake_nogrid_disc, fake_label)

        nogrid_disc_loss = (real_nogrid_disc_loss + fake_nogrid_disc_loss)/2 # Is /2 necessary?

        nogrid_disc_loss.backward()
        nogrid_disc_opt.step()

        grid_gen_losses.append((id_grid_loss+gan_loss_grid+cycle_loss_grid).item())
        nogrid_gen_losses.append((id_nogrid_loss+gan_loss_nogrid+cycle_loss_nogrid).item())
        
        real_grid_disc_losses.append(real_grid_disc_loss.item())
        real_nogrid_disc_losses.append(real_nogrid_disc_loss.item())
        
        nogrid_disc_losses.append(nogrid_disc_loss.item())
        grid_disc_losses.append(grid_disc_loss.item())

        fake_grid_disc_losses.append(fake_grid_disc_loss.item())
        fake_nogrid_disc_losses.append(fake_nogrid_disc_loss.item())

        id_losses.append((id_grid_loss + id_nogrid_loss).item())
        gan_losses.append((gan_loss_grid + gan_loss_nogrid).item())
        cycle_losses.append((cycle_loss_grid + cycle_loss_nogrid).item())
        gen_losses.append(gen_loss.item())
        
        if i < 1:
            if real_grid.size()[0] == 1:
                #to_save = out
                to_save = torch.cat ((real_grid, fake_nogrid, real_nogrid, fake_grid), dim=2)
                #to_save = torch.cat ((labels, out), dim=2)
            else:
                #to_save = out.squeeze(0)
                to_save = torch.cat ((real_grid, fake_nogrid.squeeze(0), real_nogrid, fake_grid.squeeze(0)), dim=2)
                #to_save = torch.cat ((labels, out.squeeze(0), batch), dim=2)
                #to_save = torch.cat ((labels, out.squeeze(0)), dim=2)
                #print(to_save.size())
            save_image(to_save, "data/training_temp/test.png")
            
        if gen_loss.item() < min_loss:
            #model.save("temp_best_model.pth")
            #torch.save(model.state_dict(), "temp_best_model.pth")
            min_loss = gen_loss.item()
            no_improv = 0
        else:
            no_improv += 1
    
        # For tracking progress
        end_time = datetime.now()
        time_diff = end_time - start_time

    val_gen_loss = 0
    val_grid_gen_loss = 0
    val_nogrid_gen_loss = 0
    
    val_real_grid_disc_loss = 0
    val_real_nogrid_disc_loss = 0
    
    val_fake_grid_disc_loss = 0
    val_fake_nogrid_disc_loss = 0

    with torch.no_grad():
        nogrid_disc.eval()
        nogrid_gen.eval()
        grid_disc.eval()
        grid_gen.eval()
        for i, batch in enumerate (validation_generator):
            # Putting data on gpu
            real_grid = batch["grid"].to(processor)
            real_nogrid = batch["nogrid"].to(processor)
            
            real_label = torch.full((batch_size, 1), 1, device=processor, dtype=torch.float32)
            fake_label = torch.full((batch_size, 1), 0, device=processor, dtype=torch.float32)
            
            # Identity loss
            id_grid = grid_gen (real_grid)
            id_grid_loss = id_loss (id_grid, real_grid)
            
            id_nogrid = grid_gen (real_nogrid)
            id_nogrid_loss = id_loss (id_nogrid, real_nogrid)
            
            
            # GAN loss
            fake_grid = grid_gen (real_nogrid)
            fake_grid_pred = grid_disc (fake_grid)
            gan_loss_grid = adv_loss (fake_grid_pred, real_label)
            
            fake_nogrid = nogrid_gen(real_grid)
            fake_nogrid_pred = nogrid_disc (fake_nogrid)
            gan_loss_nogrid = adv_loss (fake_nogrid_pred, real_label)
            
            
            # Cycle loss
            cycled_grid = grid_gen (fake_nogrid)
            cycle_loss_grid = cycle_loss (cycled_grid, real_grid)
            
            cycled_nogrid = nogrid_gen (fake_grid)
            cycle_loss_nogrid = cycle_loss (cycled_nogrid, real_nogrid)
            
            gen_loss = id_grid_loss + id_nogrid_loss + gan_loss_grid + gan_loss_nogrid + cycle_loss_grid + cycle_loss_nogrid
            
            val_grid_gen_loss += (id_grid_loss+gan_loss_grid+cycle_loss_grid).item()
            val_nogrid_gen_loss += (id_nogrid_loss+gan_loss_nogrid+cycle_loss_nogrid).item()
            
            val_real_grid_disc_loss += real_grid_disc_loss.item()
            val_real_nogrid_disc_loss += real_nogrid_disc_loss.item()

            val_fake_grid_disc_loss += fake_grid_disc_loss.item()
            val_fake_nogrid_disc_loss += fake_nogrid_disc_loss.item()
            
            val_gen_loss += gen_loss.item()

            if i < 10:
                if real_grid.size()[0] == 1:
                    #to_save = out
                    to_save = torch.cat ((real_grid, fake_nogrid, real_nogrid, fake_grid), dim=2)
                    #to_save = torch.cat ((labels, out), dim=2)
                else:
                    #to_save = out.squeeze(0)
                    to_save = torch.cat ((real_grid, fake_nogrid.squeeze(0), real_nogrid, fake_grid.squeeze(0)), dim=2)
                    #to_save = torch.cat ((labels, out.squeeze(0), batch), dim=2)
                    #to_save = torch.cat ((labels, out.squeeze(0)), dim=2)
                    #print(to_save.size())
                save_image(to_save, "data/training_temp/{}_{}.png".format(epoch+1, i))
        
    val_loss = val_loss/len(validation_generator)
    val_grid_gen_loss = val_grid_gen_loss/len(validation_generator)
    val_nogrid_gen_loss = val_nogrid_gen_loss/len(validation_generator)
        
    val_real_grid_disc_loss = val_real_grid_disc_loss/len(validation_generator)
    val_real_nogrid_disc_loss = val_real_nogrid_disc_loss/len(validation_generator)
    
    val_fake_grid_disc_loss = val_fake_grid_disc_loss/len(validation_generator)
    val_fake_nogrid_disc_loss = val_fake_nogrid_disc_loss/len(validation_generator)
    
    if val_nogrid_gen_loss < min_val_loss:
        min_val_loss = val_nogrid_gen_loss
        early_stop_epoch = epoch + 1
        grid_gen.save("early_stop_grid_gen.pth")
        nogrid_gen.save("early_stop_nogrid_gen.pth")
        grid_disc.save("early_stop_grid_disc.pth")
        nogrid_disc.save("early_stop_nogrid_disc.pth")
        
    val_gen_losses.append(val_gen_loss)
    val_grid_gen_losses.append(val_grid_gen_loss)
    val_nogrid_gen_losses.append(val_nogrid_gen_loss)
    
    val_real_grid_disc_losses.append(val_real_grid_disc_loss)
    val_real_nogrid_disc_losses.append(val_real_nogrid_disc_loss)
    
    val_fake_grid_disc_losses.append(val_fake_grid_disc_loss)
    val_fake_nogrid_disc_losses.append(val_fake_nogrid_disc_loss)

grid_gen.save("final_grid_gen.pth")
nogrid_gen.save("final_nogrid_gen.pth")
grid_disc.save("final_grid_disc.pth")
nogrid_disc.save("final_nogrid_disc.pth")
print ("\nFinished. Early stop was at:", early_stop_epoch)

In [None]:
# Plotting the loss through the epochs
train_placings = np.linspace(0,epoch+1,len(gen_losses))
val_placings = np.arange (1, epoch+2)

grid_fig, grid_ax = plt.subplots(2, figsize=(16,16))

grid_ax[0].plot(train_placings, grid_gen_losses, label="Generator")
grid_ax[1].plot(train_placings, real_grid_disc_losses, label="Real discriminator")
grid_ax[1].plot(train_placings, grid_disc_losses, label="Combined")
grid_ax[1].plot(train_placings, fake_grid_disc_losses, label="Fake discriminator")

grid_ax[0].plot(val_placings, val_grid_gen_losses, label="Generator (validation)")
grid_ax[1].plot(val_placings, val_real_grid_disc_losses, label="Real discriminator (validation)")
grid_ax[1].plot(val_placings, val_fake_grid_disc_losses, label="Fake discriminator (validation)")

grid_ax[0].set_title("Grid generator loss during training")
grid_ax[1].set_title("Grid discriminator loss during training")
grid_ax[0].set_ylabel("Loss")
grid_ax[0].set_xlabel("Epoch")
grid_ax[1].set_ylabel("Loss")
grid_ax[1].set_xlabel("Epoch")
#ax[0].yscale("log")

grid_ax[0].legend()
grid_ax[1].legend()
grid_fig.savefig("grid_loss_log.png")
grid_fig.show()

nogrid_fig, nogrid_ax = plt.subplots(2, figsize=(16,16))

nogrid_ax[0].plot(train_placings, nogrid_gen_losses, label="Generator")
nogrid_ax[1].plot(train_placings, real_nogrid_disc_losses, label="Real discriminator")
nogrid_ax[1].plot(train_placings, nogrid_disc_losses, label="Combined")
nogrid_ax[1].plot(train_placings, fake_nogrid_disc_losses, label="Fake discriminator")

nogrid_ax[0].plot(val_placings, val_nogrid_gen_losses, label="Generator (validation)")
nogrid_ax[1].plot(val_placings, val_real_nogrid_disc_losses, label="Real discriminator (validation)")
nogrid_ax[1].plot(val_placings, val_fake_nogrid_disc_losses, label="Fake discriminator (validation)")

nogrid_ax[0].set_title("Nogrid generator loss during training")
nogrid_ax[1].set_title("Nogrid discriminator loss during training")
nogrid_ax[0].set_ylabel("Loss")
nogrid_ax[0].set_xlabel("Epoch")
nogrid_ax[1].set_ylabel("Loss")
nogrid_ax[1].set_xlabel("Epoch")
#ax[0].yscale("log")

nogrid_ax[0].legend()
nogrid_ax[1].legend()
nogrid_fig.savefig("nogrid_loss_log.png")
nogrid_fig.show()