In [1]:
from math import log10, sqrt
import time
import os
from os.path import join
import import_ipynb
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from model import VDSR
from dataset import DF2K
import matplotlib.pyplot as plt
import config
from torchvision import transforms
import data_utils
from data2 import get_training_set, get_validation_set

importing Jupyter notebook from model.ipynb
importing Jupyter notebook from dataset.ipynb
importing Jupyter notebook from config.ipynb
importing Jupyter notebook from data_utils.ipynb
importing Jupyter notebook from data2.ipynb


In [2]:
plt.style.use('ggplot')

In [3]:
input_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomResizedCrop((config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH)),
    transforms.RandomHorizontalFlip(),
]
)

target_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH)),]
)

In [4]:
trainDS = DF2K(
    imagePaths=config.TRAIN_DATASET_PATH, 
    input_transforms=input_transforms, 
    target_transforms=target_transforms,
)
validDS = DF2K(
    imagePaths=config.VALID_DATASET_PATH, 
    input_transforms=input_transforms, 
    target_transforms=target_transforms,
)

print(f"[INFO] found {len(trainDS)} examples in the training set...")
print(f"[INFO] found {len(validDS)} examples in the valid set...")

trainLoader = DataLoader(trainDS, shuffle=True,
    batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY,
    num_workers=0)
validLoader = DataLoader(validDS, shuffle=False,
    batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY,
    num_workers=0)

[INFO] found 3450 examples in the training set...
[INFO] found 100 examples in the valid set...


In [5]:
vdsr = VDSR().to(config.DEVICE)

if os.path.exists('./output/last_model.pth'):
    unet.load_state_dict(torch.load('./output/last_model.pth')['model_state_dict'])
    
lossFunc = nn.MSELoss()

H = {"train_loss": [], "test_loss": []}

In [6]:
for epochs, lr in config.NUM_EPOCHS:
    opt = optim.Adam(vdsr.parameters(), lr=lr, weight_decay=config.WEIGHT_DECAY)

    for epoch in range(epochs):
        start_time = time.time()
        vdsr.train()
        epoch_loss = 0.0
        epoch_acc = 0.0
        print("Epoch = {}, lr = {}".format(epoch, opt.param_groups[0]["lr"]))

        for iteration, (input, target) in enumerate(trainLoader, 1):
            input, target = input.to(config.DEVICE), target.to(config.DEVICE)
            opt.zero_grad()
            output = vdsr(input)
            loss = lossFunc(output, target)
            epoch_loss += loss.item()
            loss.backward()
            nn.utils.clip_grad_norm_(vdsr.parameters(), config.CLIP/lr)
            opt.step()

            print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(
                epoch, iteration, len(trainLoader), loss.item())
                 )
        print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(
            epoch, epoch_loss / len(trainLoader))
             )

        elapsed_time = time.time() - start_time
        train_time += elapsed_time

        print("===> {:.2f} seconds to train this epoch".format(elapsed_time))

        start_time = time.time()

        with torch.no_grad():
            net.eval()

            for (input, target) in validLoader:
                input, target = input.to(config.DEVICE), target.to(config.DEVICE)
                prediction = vdsr(input)
                mse = lossFunc(prediction, target)
                psnr = 10 * log10(1.0 / mse.item())

                print("===> Avg. PSNR: {:.4f} dB".format(
                    avg_psnr / len(validLoader))
                     )
            if epoch % 10 == 0:
                model_out_path = "output/model_epoch_{}.pth".format(epoch)
                torch.save(vdsr, model_out_path)

        print("===> average training time per epoch: {:.2f} seconds".format(train_time/epochs))
        print("===> average validation time per epoch: {:.2f} seconds".format(validate_time/epochs))
        print("===> training time: {:.2f} seconds".format(train_time))
        print("===> validation time: {:.2f} seconds".format(validate_time))
        print("===> total training time: {:.2f} seconds".format(train_time+validate_time))  

Epoch = 0, lr = 0.001
===> Epoch[0](1/432): Loss: 1.0871
===> Epoch[0](2/432): Loss: 67.2682
===> Epoch[0](3/432): Loss: 0.3537
===> Epoch[0](4/432): Loss: 0.0333
===> Epoch[0](5/432): Loss: 0.0331


KeyboardInterrupt: 