In [1]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from data import SRCNNDataset
import torch

data = np.load("../256dataset_images_small.npy")
data = np.transpose(data, (0, 2, 3, 1))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

plt.figure()
f, ax = plt.subplots(1, 2)

dataset = SRCNNDataset(hr_images=data, scale_factor=3)
train_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8, 0.1, 0.1])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=False)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=16, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)

lr, hr = next(iter(train_dataloader))

ax[0].imshow(lr[10].T)
ax[1].imshow(hr[10].T)



In [2]:
from model import SRCNN
from util import psnr
import torch.backends.cudnn as cudnn

def train(model, train_set, validation_set, epochs, lr):
    loss_function = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    training_loss = []
    validation_loss = []
    psnrs = []
    
    torch.cuda.empty_cache()

    for epoch in range(1, epochs+1):
        print(f"Epoch {epoch} / {epochs}")
        epoch_train_loss = 0
        for index, batch in enumerate(train_set):
            print(f"Batch {index} / {len(train_set)-1}", end="\r")

            features, labels = batch
            features = features.to(device)
            labels = labels.to(device)
            y_pred = model(features)
            loss = loss_function(y_pred, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_train_loss+=loss.cpu().item()
        training_loss.append(epoch_train_loss/len(train_set))

        epoch_validation_loss = 0
        epoch_validation_psnr = 0
        with torch.no_grad():
            for index, batch in enumerate(validation_set):
                print(f"Validation Batch {index} / {len(validation_set)-1}", end="\r")
                features, labels = batch
                features = features.to(device)
                labels = labels.to(device)
                
                y_pred = model(features)
                loss = loss_function(y_pred, labels)

                epoch_validation_loss+=loss.cpu().item()
                epoch_validation_psnr+=psnr(y_pred, labels)
            validation_loss.append(epoch_validation_loss/len(validation_set))
            psnrs.append(epoch_validation_psnr/len(validation_set))
        
        print(f"Training loss: {epoch_train_loss/len(train_set)} | Validation loss: {epoch_validation_loss/len(validation_set)} | PSNR: {epoch_validation_psnr/len(validation_set)}")
    
    plt.figure()
    f, ax = plt.subplots(1, 2)
    ax[0].plot(range(1, epochs+1), training_loss, label="training loss")
    ax[0].plot(range(1, epochs+1), validation_loss, label="validation loss")
    ax[1].plot(range(1, epochs+1), psnrs, label="PSNR")
    ax[0].legend(loc='best')
    plt.show()

model = SRCNN()
model = model.to(device)
train(model, train_dataloader, validation_dataloader, epochs=1000, lr=0.0001)
torch.save(model, "SRCNN_1000epochs.pt")


Epoch 1 / 1000
Training loss: 0.10389182124767572 | Validation loss: 0.010497473918699792 | PSNR: 19.81127400924271
Epoch 2 / 1000
Training loss: 0.009238796136308609 | Validation loss: 0.008873600865315114 | PSNR: 20.547813977612638
Epoch 3 / 1000
Batch 613 / 1115

KeyboardInterrupt: 

In [None]:
# Try to speed up training? If possible
    # Do we have access to strong GPUs...
# Train for 1000+ epochs
# Adjust scale?
# Try sub-images?