In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torch.nn.functional as F
from torch.utils import data
from torchvision.utils import save_image

import sys
import random
from datetime import datetime
from itertools import chain
import matplotlib.pyplot as plt
import numpy as np

# My own files
from data_generator import CombinedDataset
from discriminator import Discriminator
from resunet import ResUnet

In [None]:
use_cuda = torch.cuda.is_available()
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic=True


# Initialising the data generator and model
grid_path = "data/train/grid/"
nogrid_path = "data/train/nogrid/"

batch_size = 1
max_imgs = 5000

augments = {"crop" : (400,400),
            "hflip" : 0.5,
            "vflip" : 0.5,
            "angle" : 0,
            "shear" : 0,
            "brightness" : (0.75,1.25),
            "pad" : (0,0,0,0),
            "contrast" : (0.5,2.0)}

dataset = CombinedDataset(max_imgs, grid_path, nogrid_path, **augments)
validation_set = CombinedDataset(max_imgs, grid_path, nogrid_path, **augments, seed=1337)

training_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)
validation_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8)

#loaded_params = torch.load("lstm_1_layer_with_schedule_1024_minloss.pth")
#model = LSTM_model(**loaded_params["args_dict"]).to(processor)
#model.load_state_dict(loaded_params["state_dict"])

nogrid_disc = Discriminator (3).to(processor)
grid_disc = Discriminator (3).to(processor)


loaded_params = torch.load("res_200_epoch_early.pth")
nogrid_gen = ResUnet (**loaded_params["args_dict"]).to(processor)
nogrid_gen.load_state_dict(loaded_params["state_dict"])
grid_gen = ResUnet (3, 3).to(processor)

cycle_loss = nn.L1Loss().to(processor)
id_loss = nn.L1Loss().to(processor)
# Why not BCE?
adv_loss = nn.MSELoss().to(processor)

gen_opt = optim.Adam (chain (nogrid_gen.parameters(), grid_gen.parameters()), lr=1e-3)
nogrid_disc_opt = optim.Adam (nogrid_disc.parameters(), lr=1e-3)
grid_disc_opt = optim.Adam (grid_disc.parameters(), lr=1e-3)
                              
#scheduler = optim.lr_scheduler.StepLR (optimiser, step_size=10, gamma=0.5)


In [None]:
# Initialising some variables for use in training
batches = float("inf")
time_diff = 0
no_improv = 0
min_loss = float("inf")
val_loss = float("inf")
min_val_loss = float("inf")
epochs = 50
early_stop_epoch = 0

gen_losses = []
disc_losses = []

id_losses = []
gan_losses = []
cycle_losses = []

val_loss_list = []

for epoch in range(epochs):
    for i, batch in enumerate(training_generator):

        # Keeping track of stuff
        start_time = datetime.now()
        est_time_left = str(time_diff*(min(batches, dataset.__len__()) - i)+time_diff*(epochs-(epoch+1))*min(batches, dataset.__len__()/batch_size)).split(".")[0]
        
        sys.stdout.write("\rEpoch: {0}. Batch: {1}. Min loss: {2:.5f}. Time left: {3}. Best: {4} batches ago. Val loss: {5:.5f}".format(epoch+1, i+1, min_loss, est_time_left, no_improv, val_loss))

        # Putting data on gpu
        real_grid = batch["grid"].to(processor)
        real_nogrid = batch["nogrid"].to(processor)

        real_label = torch.full((batch_size, 1), 1, device=processor, dtype=torch.float32)
        fake_label = torch.full((batch_size, 1), 0, device=processor, dtype=torch.float32)
        
        nogrid_disc.train()
        nogrid_gen.train()
        grid_disc.train()
        grid_gen.train()
        
        gen_opt.zero_grad()

        # Training the generators

        # Identity loss
        id_grid = grid_gen (real_grid)
        id_grid_loss = id_loss (id_grid, real_grid)

        id_nogrid = grid_gen (real_nogrid)
        id_nogrid_loss = id_loss (id_nogrid, real_nogrid)

        
        # GAN loss
        fake_grid = grid_gen (real_nogrid)
        fake_grid_pred = grid_disc (fake_grid)
        gan_loss_grid = adv_loss (fake_grid_pred, real_label)

        fake_nogrid = nogrid_gen(real_grid)
        fake_nogrid_pred = nogrid_disc (fake_nogrid)
        gan_loss_nogrid = adv_loss (fake_nogrid_pred, real_label)

        
        # Cycle loss
        cycled_grid = grid_gen (fake_nogrid)
        cycle_loss_grid = cycle_loss (cycled_grid, real_grid)

        cycled_nogrid = nogrid_gen (fake_grid)
        cycle_loss_nogrid = cycle_loss (cycled_nogrid, real_nogrid)
        
        gen_loss = id_grid_loss + id_nogrid_loss + gan_loss_grid + gan_loss_nogrid + cycle_loss_grid + cycle_loss_nogrid

        gen_loss.backward()
        gen_opt.step()

        # Training the discriminators

        # Grid discriminator
        grid_disc_opt.zero_grad()

        real_grid_disc = grid_disc (real_grid)
        real_grid_disc_loss = adv_loss (real_grid_disc, real_label)

        fake_grid_disc = grid_disc (fake_grid.detach())
        fake_grid_disc_loss = adv_loss (fake_grid_disc, fake_label)

        grid_disc_loss = (real_grid_disc_loss + fake_grid_disc_loss)/2 # Is /2 necessary?
        
        grid_disc_loss.backward()
        grid_disc_opt.step()
        

        # Nogrid discriminator
        grid_disc_opt.zero_grad()

        real_nogrid_disc = nogrid_disc (real_nogrid)
        real_nogrid_disc_loss = adv_loss (real_nogrid_disc, real_label)

        fake_nogrid_disc = nogrid_disc (fake_nogrid.detach())
        fake_nogrid_disc_loss = adv_loss (fake_nogrid_disc, fake_label)

        nogrid_disc_loss = (real_nogrid_disc_loss + fake_nogrid_disc_loss)/2 # Is /2 necessary?

        nogrid_disc_loss.backward()
        nogrid_disc_opt.step()

        gen_losses.append(gen_loss.item())
        disc_losses.append((grid_disc_loss + nogrid_disc_loss).item())
        
        id_losses.append((id_grid_loss + id_nogrid_loss).item())
        gan_losses.append((gan_loss_grid + gan_loss_nogrid).item())
        cycle_losses.append((cycle_loss_grid + cycle_loss_nogrid).item())
        
        
        if gen_loss.item() < min_loss:
            #model.save("temp_best_model.pth")
            #torch.save(model.state_dict(), "temp_best_model.pth")
            min_loss = gen_loss.item()
            no_improv = 0
        else:
            no_improv += 1
    
        # For tracking progress
        end_time = datetime.now()
        time_diff = end_time - start_time

    val_loss = 0
    nogrid_disc.eval()
    nogrid_gen.eval()
    grid_disc.eval()
    grid_gen.eval()
    for batch in validation_generator:
        # Putting data on gpu
        real_grid = batch["grid"].to(processor)
        real_nogrid = batch["nogrid"].to(processor)
        
        real_label = torch.full((batch_size, 1), 1, device=processor, dtype=torch.float32)
        fake_label = torch.full((batch_size, 1), 0, device=processor, dtype=torch.float32)
        
        # Identity loss
        id_grid = grid_gen (real_grid)
        id_grid_loss = id_loss (id_grid, real_grid)

        id_nogrid = grid_gen (real_nogrid)
        id_nogrid_loss = id_loss (id_nogrid, real_nogrid)

        
        # GAN loss
        fake_grid = grid_gen (real_nogrid)
        fake_grid_pred = grid_disc (fake_grid)
        gan_loss_grid = adv_loss (fake_grid_pred, real_label)

        fake_nogrid = nogrid_gen(real_grid)
        fake_nogrid_pred = nogrid_disc (fake_nogrid)
        gan_loss_nogrid = adv_loss (fake_nogrid_pred, real_label)

        
        # Cycle loss
        cycled_grid = grid_gen (fake_nogrid)
        cycle_loss_grid = cycle_loss (cycled_grid, real_grid)

        cycled_nogrid = nogrid_gen (fake_grid)
        cycle_loss_nogrid = cycle_loss (cycled_nogrid, real_nogrid)

        gen_loss = id_grid_loss + id_nogrid_loss + gan_loss_grid + gan_loss_nogrid + cycle_loss_grid + cycle_loss_nogrid

        val_loss += gen_loss.item()
        
    val_loss = val_loss/len(validation_generator)
    if val_loss < min_val_loss:
        min_val_loss = val_loss
        early_stop_epoch = epoch
        grid_gen.save("early_stop_grid_gen.pth")
        nogrid_gen.save("early_stop_nogrid_gen.pth")
        grid_disc.save("early_stop_grid_disc.pth")
        nogrid_disc.save("early_stop_nogrid_disc.pth")
        
    val_loss_list.append(val_loss)

grid_gen.save("final_grid_gen.pth")
nogrid_gen.save("final_nogrid_gen.pth")
grid_disc.save("final_grid_disc.pth")
nogrid_disc.save("final_nogrid_disc.pth")
print ("\nFinished. Early stop was at:", early_stop_epoch)

In [None]:
# Plotting the loss through the epochs
train_placings = np.linspace(0,epoch+1,len(gen_losses))
val_placings = np.arange (1, epoch+2)

plt.plot(train_placings, gen_losses, label="Training")
plt.plot(val_placings, val_loss_list, label="Validation")

plt.title("Generator Loss during training")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.yscale("log")
plt.legend()

plt.savefig("gan_loss_log.png")    