In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torch.nn.functional as F
from torch.utils import data
from torchvision.utils import save_image

import sys
import random
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np

# My own files
from gridifier import gridify
from data_generator import ArtificialDataset
from unet import Unet
from resunet import ResUnet

In [None]:

use_cuda = torch.cuda.is_available()
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic=True


# Initialising the data generator and model
labels_path = "data/train/nogrid/"
val_path = "data/val/nogrid/" 

batch_size = 4
max_imgs = 10000

augments = {"grid_size" : (30,90),
            "grid_intensity" : (-1,1),
            "grid_offset_x" : (0,90),
            "grid_offset_y" : (0,90),
            "crop" : (400,400),
            "hflip" : 0.5,
            "vflip" : 0.5,
            "angle" : 0,
            "shear" : 0,
            "brightness" : (0.75,1.25),
            "pad" : (0,0,0,0),
            "contrast" : (0.5,1.25)}

dataset = ArtificialDataset(max_imgs, labels_path, **augments)
validation_set = ArtificialDataset(max_imgs, val_path, **augments, seed=1337)

training_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)
validation_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8)


#model = Unet (3,3).to(processor)
model = ResUnet (3,3).to(processor)

loss_function = nn.L1Loss().to(processor)
#loss_function = nn.MSELoss().to(processor)
optimiser = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimiser, step_size=10, gamma=0.5)
#print (dataset.__get_all_files__("data/nogrid/"))
#dataset.test()

In [None]:
# Initialising some variables for use in training
batches = float("inf")
time_diff = 0
no_improv = 0
min_loss = float("inf")
val_loss = float("inf")
min_val_loss = float("inf")
epochs = 200
early_stop_epoch = 0
loss_list = []
val_loss_list = []

for epoch in range(epochs):
    for i, (batch, labels) in enumerate(training_generator):

        # Keeping track of stuff
        start_time = datetime.now()
        est_time_left = str(time_diff*(min(batches, dataset.__len__()) - i)+time_diff*(epochs-(epoch+1))*min(batches, dataset.__len__()/batch_size)).split(".")[0]
        
        sys.stdout.write("\rEpoch: {0}. Batch: {1}. Min loss: {2:.5f}. Time left: {3}. Best: {4} batches ago. Val loss: {5:.5f}".format(epoch+1, i+1, min_loss, est_time_left, no_improv, val_loss))

        # Putting data on gpu
        batch = batch.to(processor)
        labels = labels.to(processor)
        
        model.train()
        model.zero_grad()

        out = model(batch)
        loss = loss_function(out, labels)

        loss_list.append(loss.item())
        loss.backward()
        optimiser.step()
        
        if loss.item() < min_loss:
            #model.save("temp_best_model.pth")
            #torch.save(model.state_dict(), "temp_best_model.pth")
            min_loss = loss.item()
            no_improv = 0
        else:
            no_improv += 1
        
        # For tracking progress
        end_time = datetime.now()
        time_diff = end_time - start_time
    
        
        #if batch.size()[0] == 1:
        #    to_save = torch.cat ((labels, out, batch), dim=2)
        #else:
        #    to_save = torch.cat ((labels, out.squeeze(0), batch), dim=2)
        #    #print(to_save.size())
        #save_image(to_save, "data/artificialgrid/{0}.png".format(i))

    val_loss = 0
    model.eval()
    for batch, labels in validation_generator:
        batch = batch.to(processor)
        labels = labels.to(processor)
        
        
        out = model(batch)
        loss = loss_function(out, labels)
        val_loss += loss.item()
    val_loss = val_loss/len(validation_generator)
    if val_loss < min_val_loss:
        min_val_loss = val_loss
        early_stop_epoch = epoch
        model.save("early_stop_model.pth")
    val_loss_list.append(val_loss)
    scheduler.step()
    
model.save("temp_model.pth")
print ("\nFinished. Early stop was at:", early_stop_epoch)

In [None]:
# Plotting the loss through the epochs
train_placings = np.linspace(0,epoch+1,len(loss_list))
val_placings = np.arange (1, epoch+2)

plt.plot(train_placings, loss_list, label="Training")
plt.plot(val_placings, val_loss_list, label="Validation")

plt.title("Loss during training")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.yscale("log")
plt.legend()

plt.savefig("loss_log.png")