In [None]:
import torch
import torch.nn as nn
from torch.nn import Conv2d
from torch.nn import Dropout
from torch.nn import Upsample
from torch.nn import MaxPool2d
import torch.optim as optim

import torch.nn.functional as F
from torch.utils import data
from torchvision.utils import save_image

import sys
import random
from datetime import datetime
import matplotlib.pyplot as plt

# My own files
from gridifier import gridify
from data_generator import Dataset

In [None]:
class Unet(nn.Module):
    def __init__(self, in_channels, out_channels, depth=[8,16,32,64], dropout=0.5):
        super(Unet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.depth = depth
        self.dropout = dropout

        self.pool = MaxPool2d((2,2))

        #Convolutions on the way down
        self.convd1_0 = Conv2d (in_channels, depth[0], (3,3), padding=(1,1), padding_mode="reflect")
        self.convd1_1 = Conv2d (depth[0], depth[0], (3,3), padding=(1,1), padding_mode="reflect")

        self.convd2_0 = Conv2d (depth[0],  depth[1], (3,3), padding=(1,1), padding_mode="reflect")
        self.convd2_1 = Conv2d (depth[1],  depth[1], (3,3), padding=(1,1), padding_mode="reflect")

        self.convd3_0 = Conv2d (depth[1], depth[2], (3,3), padding=(1,1), padding_mode="reflect")
        self.convd3_1 = Conv2d (depth[2], depth[2], (3,3), padding=(1,1), padding_mode="reflect")

        self.convd4_0 = Conv2d (depth[2], depth[3], (3,3), padding=(1,1), padding_mode="reflect")
        self.convd4_1 = Conv2d (depth[3], depth[3], (3,3), padding=(1,1), padding_mode="reflect")

        # Convolutions on the way up
        self.convu1_0 = Conv2d (depth[3], depth[2], (3,3), padding=(1,1), padding_mode="reflect")
        self.convu1_1 = Conv2d (depth[2], depth[2], (3,3), padding=(1,1), padding_mode="reflect")

        self.convu2_0 = Conv2d (depth[2], depth[1], (3,3), padding=(1,1), padding_mode="reflect")
        self.convu2_1 = Conv2d (depth[1], depth[1], (3,3), padding=(1,1), padding_mode="reflect")

        self.up_conv1 = Conv2d (depth[3], depth[2], (1,1), padding=(0,0), padding_mode="reflect")
        self.up_conv2 = Conv2d (depth[2], depth[1], (1,1), padding=(0,0), padding_mode="reflect")
        self.up_conv3 = Conv2d (depth[1], depth[0], (1,1), padding=(0,0), padding_mode="reflect")

        # "Horizontal" convolutions
        self.convu3_0 = Conv2d (depth[1], depth[0], (3,3), padding=(1,1), padding_mode="reflect")
        self.convu3_1 = Conv2d (depth[0], depth[0], (3,3), padding=(1,1), padding_mode="reflect")
        self.convu3_2 = Conv2d (depth[0], 3, (1,1), padding_mode="reflect")

        # Finding the image
        #self.remover_conv1 = Conv2d (in_channels+1, in_channels+1, (3,3), padding=(2,2), padding_mode="reflect", dilation=2)
        self.remover_conv1 = Conv2d (in_channels+1, in_channels+1, (3,3), padding=(1,1), padding_mode="reflect", dilation=1)
        self.remover_conv2 = Conv2d (in_channels+1, out_channels, (3,3), padding=(1,1), padding_mode="reflect")

        
        self.up_samp1 = Upsample (scale_factor=(2,2), mode='bilinear', align_corners=False)

        self.drop = Dropout (dropout)

        #self.up_conv1 = Upsample(size=(1,depth[3],73,73), mode='nearest')
        #self.up_conv2 = Upsample(size=(1,depth[2],146,146), mode='nearest')
        #self.up_conv3 = Upsample(size=(1,depth[1],292,292), mode='nearest')




    def forward (self, input_data):
        # Downsampling
        conved1_0 = F.elu (self.convd1_0 (input_data))
        conved1_1 = F.elu (self.convd1_1 (conved1_0))
        pooled1 = self.pool (conved1_1)
        #print (pooled1.size ())

        conved2_0 = F.elu (self.convd2_0 (pooled1))
        conved2_1 = F.elu (self.convd2_1 (conved2_0))
        pooled2 = self.pool (conved2_1)
        #print (pooled2.size ())

        conved3_0 = F.elu (self.convd3_0 (pooled2))
        conved3_1 = F.elu (self.convd3_1 (conved3_0))
        dropped1 = self.drop (conved3_1)
        pooled3 = self.pool (dropped1)
        #print (pooled3.size ())

        conved4_0 = F.elu (self.convd4_0 (pooled3))
        conved4_1 = F.elu (self.convd4_1 (conved4_0))
        dropped2 = self.drop (conved4_1)

        #print ("Minimal size:", conved4.size ())

        # Upsampling
        up_sampled1 = self.up_samp1 (dropped2)
        up_conved1 = torch.cat ((conved3_1, F.elu (self.up_conv1 (up_sampled1))), dim=1)
        conved5_0 = F.elu (self.convu1_0 (up_conved1))
        conved5_1 = F.elu (self.convu1_1 (conved5_0))

        up_sampled2 = self.up_samp1 (conved5_1)
        up_conved2 = torch.cat ((conved2_1, F.elu (self.up_conv2 (up_sampled2))), dim=1)
        conved6_0 = F.elu (self.convu2_0 (up_conved2))
        conved6_1 = F.elu (self.convu2_1 (conved6_0))

        up_sampled3 = self.up_samp1 (conved6_1)
        up_conved3 = torch.cat ((conved1_1, F.elu (self.up_conv3 (up_sampled3))), dim=1)
        conved7_0 = F.elu (self.convu3_0 (up_conved3))
        conved7_1 = F.elu (self.convu3_1 (conved7_0))

        # Final horizontal convolution
        #conved8 = F.selu (torch.tanh(self.convu3_2(conved7_1)))
        #conved9 = F.elu (torch.add (input_data,
        #conved8 = torch.add (input_data, torch.tanh(self.convu3_2(conved7_1)))
        #conved8 = F.elu (torch.add (input_data, torch.tanh(self.convu3_2(conved7_1))))
        conved8 = F.elu (torch.add (input_data, torch.tanh(self.convu3_2(conved7_1))))
        #conved8 = torch.sigmoid (self.convu3_2 (conved7_1))
        #conved8 = torch.sigmoid (torch.add (input_data, torch.tanh(self.convu3_2(conved7_1))))
        
        # Remover convolutions
        #with_input = torch.cat ((input_data, conved8), dim=1)
        #print ("with_input", with_input.size())
        #conved9 = F.selu (self.remover_conv1 (with_input))
        #print ("first", conved9.size())
        #conved10 = torch.sigmoid (self.remover_conv2 (conved9))
        #print ("second", conved10.size())
        
        return conved8

    def save(self, filename):
        args_dict = {
            "in_channels": self.in_channels,
            "out_channels": self.out_channels,
            "depth": self.depth,
            "dropout": self.dropout
        }
        torch.save({
            "state_dict": self.state_dict(),
            "args_dict": args_dict
        }, filename) #model.state_dict(), "temp_best_model.pth")

In [None]:
class Degridifier_model(nn.Module):
    def __init__(self, conv_depth=[5,7,5]):
        super(Degridifier_model, self).__init__()

        self.processor = processor
        self.convd1 = Conv2d(3, conv_depth[0], (5,5), padding=(2,2), padding_mode='same')
        self.convd2 = Conv2d(conv_depth[0], conv_depth[1], (5,5), padding=(2,2), padding_mode='same')
        self.convd3 = Conv2d(conv_depth[1], conv_depth[2], (5,5), padding=(2,2), padding_mode='same')
        self.convd4 = Conv2d(conv_depth[2], 3, (1,1))

        self.drop = Dropout()

    def forward(self, input_data):
        conv1 = F.relu(self.convd1(input_data))
        conv2 = F.relu(self.convd2(conv1))
        conv3 = F.relu(self.convd3(conv2))
        conv4 = torch.sigmoid(self.convd4(conv3))

        return conv4
                       

In [None]:

use_cuda = torch.cuda.is_available() and True
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True


# Initialising the data generator and model
labels_path = "data/nogrid/"
conv_depth = [5,7,5]
batch_size = 8 #works probably
max_imgs = 10000

augments = {"grid_size" : (30,30),
            "grid_intensity" : (-1,1),
            "grid_offset_x" : (0,90),
            "grid_offset_y" : (0,90),
            "crop" : (400,400),
            "hflip" : 0.5,
            "vflip" : 0.5,
            "angle" : 0,
            "shear" : 0,
            "brightness" : (0.75,1.25),
            "pad" : (0,0,0,0),
            "contrast" : (0.5,1)}

dataset = Dataset(max_imgs, labels_path, **augments)
training_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

#model = Degridifier_model(conv_depth).to(processor)
model = Unet (3,3).to(processor)
loss_function = nn.MSELoss().to(processor)
optimiser = optim.Adam(model.parameters(), lr=1e-2)
#print (dataset.__get_all_files__("data/nogrid/"))
#dataset.test()

In [None]:
# Initialising some variables for use in training
batches = float("inf")
time_diff = 0
no_improv = 0
min_loss = float("inf")
epochs = 50
print_stuff = False
loss_list = []

for epoch in range(epochs):
    for i, (batch, labels) in enumerate(training_generator):

        # Keeping track of stuff
        start_time = datetime.now()

        #est_time_left = str(time_diff*(min(batches, dataset.__len__()/batch_size) - i) + (time_diff*min(batches, dataset.__len__()/batch_size)) * (epochs - (epoch+1))).split(".")[0]
        est_time_left = str(time_diff*(min(batches, dataset.__len__()) - i)+time_diff*(epochs-(epoch+1))*min(batches, dataset.__len__()/batch_size)).split(".")[0]
        
        sys.stdout.write("\rEpoch: {0}. Batch: {1}. Min loss: {2:.5f}. Estimated time left: {3}. Best: {4} batches ago.".format(epoch+1, i+1, min_loss, est_time_left, no_improv))

        # Putting data on gpu
        batch = batch.to(processor)
        labels = labels.to(processor)
        
        model.train()
        model.zero_grad()

        out = model(batch)
        loss = loss_function(out, labels)

        loss_list.append(loss.item())
        loss.backward()
        optimiser.step()
        loss
        if loss.item() < min_loss:
            #model.save("temp_best_model.pth")
            #torch.save(model.state_dict(), "temp_best_model.pth")
            min_loss = loss.item()
            no_improv = 0
        else:
            no_improv += 1
        
        # For tracking progress
        end_time = datetime.now()
        time_diff = end_time - start_time
    
        
        #print(out.squeeze(0).size())
        to_save = torch.cat ((out.squeeze(0), labels, batch), dim=0)
        save_image(to_save, "/home/cavtheman/skolearbejde/degridifier/data/artificialgrid/{0}.png".format(i))
        #save_image(batch.squeeze(0), "/home/cavtheman/skolearbejde/degridifier/data/artificialgrid/{0}.png".format(i))
model.save("temp_model.pth")
print ("\nFinished")

In [None]:

# Plotting the loss through the epochs
plt.plot(loss_list)
plt.title("Training loss by batch")
plt.ylabel("Loss")
plt.xlabel("Batch number")
plt.yscale("log")

plt.savefig("loss_log.png")