In [None]:
import torch
import torch.nn as nn
from torch.nn import Conv2d
from torch.nn import Dropout
import torch.nn.functional as F
from torch.utils import data
from torchvision.utils import save_image

import sys
import random
from datetime import datetime

# My own files
from gridifier import gridify
from data_generator import Dataset

In [None]:
class Degridifier_model(nn.Module):
    def __init__(self, conv_depth=[5,7,5]):
        super(Degridifier_model, self).__init__()

        self.processor = processor
        self.convd1 = Conv2d(3, conv_depth[0], (5,5), padding=(2,2), padding_mode='same')
        self.convd2 = Conv2d(conv_depth[0], conv_depth[1], (5,5), padding=(2,2), padding_mode='same')
        self.convd3 = Conv2d(conv_depth[1], conv_depth[2], (5,5), padding=(2,2), padding_mode='same')
        self.convd4 = Conv2d(conv_depth[2], 3, (1,1))

        self.drop = Dropout()

    def forward(self, input_data):
        conv1 = F.relu(self.convd1(input_data))
        conv2 = F.relu(self.convd2(conv1))
        conv3 = F.relu(self.convd3(conv2))
        conv4 = torch.sigmoid(self.convd4(conv3))

        return conv4
                       

In [None]:

use_cuda = torch.cuda.is_available() and False
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True


# Initialising the data generator and model
labels_path = "gridless/{}.png"
conv_depth = [5,7,5]
batch_size = 1

augments = {"grid_size" : (15,90),
            "grid_intensity" : (-1,1),
            "grid_offset_x" : (0,90),
            "grid_offset_y" : (0,90),
            "hflip" : 0.5,
            "vflip" : 0.5,
            "angle" : 0,
            "shear" : 0,
            "brightness" : (0.75,1.25),
            "pad" : (0,0,0,0),
            "contrast" : (0.5,2)}

dataset = Dataset(range(30), labels_path, **augments)
training_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

model = Degridifier_model(conv_depth).to(processor)
loss_function = nn.BCELoss().to(processor)

In [None]:
# Initialising some variables for use in training
batches = float("inf")
time_diff = 0
no_improv = 0
min_loss = float("inf")
epochs = 1
print_stuff = False

for epoch in range(epochs):
    for i, (batch, labels) in enumerate(training_generator):

        # Keeping track of stuff
        start_time = datetime.now()

        #est_time_left = str(time_diff*(min(batches, dataset.__len__()/batch_size) - i) + (time_diff*min(batches, dataset.__len__()/batch_size)) * (epochs - (epoch+1))).split(".")[0]
        est_time_left = str(time_diff*(min(batches, dataset.__len__()) - i)+time_diff*(epochs-(epoch+1))*min(batches, dataset.__len__()/batch_size)).split(".")[0]
        
        sys.stdout.write("\rEpoch: {0}. Batch: {1}. Min loss: {2:.5f}. Estimated time left: {3}. Best: {4} batches ago.".format(epoch+1, i+1, min_loss, est_time_left, no_improv))

        # Putting data on gpu
        batch = batch.to(processor)
        labels = labels.to(processor)

        model.train()
        model.zero_grad()

        out = model(batch)
        loss = loss_function(out, labels)

        if loss.item() < min_loss:
            #model.save("temp_best_model.pth")
            #torch.save(model.state_dict(), "temp_best_model.pth")
            min_loss = loss.item()
            no_improv = 0
        else:
            no_improv += 1
        
        # For tracking progress
        end_time = datetime.now()
        time_diff = end_time - start_time
    
        #print(i)
        #print(batch.squeeze(0).size())
        save_image(out.squeeze(0), "model_gridless/{0}.png".format(i))