In [None]:
# #download dataset
# !wget -N http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
# !tar -xzf VOCtrainval_06-Nov-2007.tar

In [None]:
import os
import random
import torch
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import wandb
# wandb.login()
wandb.init(project="Y-Data-DL-Week4-Super-Resolution-Final")

from networks import network1, network2, network3, network4, network5, network6
from utils import VOC2007Dataset, ssim, show_images

In [None]:
wandb.config.update(dict(batch_size=2, epochs=5, lr=0.01, no_cuda=True, window_size=12,
                                 seed=42, log_interval=10))        # Initialize config

In [None]:
config = wandb.config
use_cuda = not config.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}

random.seed(config.seed)       # python random seed
torch.manual_seed(config.seed) # pytorch random seed
np.random.seed(config.seed) # numpy random seed

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
trainset = VOC2007Dataset(image_set='trainval', transform=transform, sample_slice=[0,100]) # training dataset
trainloader = DataLoader(trainset, batch_size=config.batch_size, shuffle=False, **kwargs)
testset = VOC2007Dataset(image_set='trainval', transform=transform, sample_slice=[-11, -1]) # validation dataset
testloader = DataLoader(testset, batch_size=config.batch_size, shuffle=False, **kwargs)

In [None]:
show_images(trainset[50])

In [None]:
def train_model_y_mid(config, net, train_data, optimizer, epoch):
    net.train()
    train_loss = 0
    for i, batch in tqdm(enumerate(train_data), total=len(train_data)):
        X = batch['X'].to(device)
        y_mid = batch['y_mid'].to(device)
        optimizer.zero_grad()
        output = net(X)
        loss = 1-ssim(output, y_mid, window_size=config.window_size)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    wandb.log({'Train Loss - Model 1': train_loss/len(train_data), 'Epoch': epoch}, commit=False)

def test_model_y_mid(config, net, test_data, epoch):
    net.eval()
    test_loss = 0
    example_images = []
    avg_psnr = 0
    n = len(test_data)
    with torch.no_grad():
        for j, batch in tqdm(enumerate(test_data), total=len(test_data)):
            X = batch['X'].to(device)
            y_mid = batch['y_mid'].to(device)
            output = net(X)
            loss = 1-ssim(output, y_mid, window_size=config.window_size)
            test_loss += loss
            avg_psnr += 10 * np.log10(1/nn.MSELoss(mid_output, y_mid).item())
            example_images.append(wandb.Image(transforms.ToPILImage(mode='RGB')(output[0]), 
                                              caption="Output Reconstruction"))
            example_images.append(wandb.Image(transforms.ToPILImage(mode='RGB')(y_mid[0]), 
                                              caption="Target"))
    wandb.log({'Test Loss - Model 1': test_loss/n, 'Avg PSN Ratio - Model 1': avg_psnr/n,
                "Examples": example_images, 'Epoch': epoch, 
              })

In [None]:
def run_training(network):
    model = network().to(device)
    wandb.watch(model, log="all")
    optimizer = optim.Adam(model.parameters(), lr=config.lr)
    for epoch in range(1, config.epochs + 1):
        train_model_y_mid(config, model, trainloader, optimizer, epoch)
        test_model_y_mid(config, model, testloader, epoch)
    return model

In [None]:
model1 = run_training(network1)

In [None]:
dataiter = iter(trainloader)
test = next(dataiter)
y_mid = model1(test['X'])
img_dct = dict(y_mid=y_mid[0], X=test['X'][0])
show_images(img_dct)

In [None]:
def train_model_y_mid_large(config, net, train_data, optimizer, epoch, model_number=2):
    net.train()
    train_loss = 0
    for i, batch in tqdm(enumerate(train_data), total=len(train_data)):
        X = batch['X'].to(device)
        y_mid = batch['y_mid'].to(device)
        y_large = batch['y_large'].to(device)
        optimizer.zero_grad()
        mid_output, large_output = net(X)
        loss = ((1-ssim(mid_output, y_mid, window_size=config.window_size)) + 
                (1-ssim(large_output, y_large,window_size=config.window_size))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    wandb.log({f'Train Loss - Model {model_number}': train_loss/len(train_data), 'Epoch': epoch}, commit=False)

def test_model_y_mid_large(config, net, test_data, epoch, model_number=2):
    net.eval()
    test_loss = 0
    example_images = []
    avg_psnr = 0
    n = len(test_data)
    loss_fn = ssim
    with torch.no_grad():
        for j, batch in tqdm(enumerate(test_data), total=len(test_data)):
            X = batch['X'].to(device)
            y_mid = batch['y_mid'].to(device)
            y_large = batch['y_large'].to(device)
            mid_output, large_output = net(X)
            loss = ((1-ssim(mid_output, y_mid, window_size=config.window_size)) + 
                    (1-ssim(large_output, y_large, window_size=config.window_size))
            test_loss += loss
            avg_psnr += 10 * np.log10(1/nn.MSELoss(mid_output, y_mid).item())
            example_images.append(wandb.Image(transforms.ToPILImage(mode='RGB')(mid_output[0]), 
                                                caption="Mid Output Reconstruction"))
            example_images.append(wandb.Image(transforms.ToPILImage(mode='RGB')(large_output[0]), 
                                                caption="Large Output Reconstruction"))
            example_images.append(wandb.Image(transforms.ToPILImage(mode='RGB')(y_mid[0]), 
                                                caption="Target"))
    wandb.log({f'Test Loss - Model {model_number}': test_loss/len(test_data), f'Avg PSN Ratio - Model {model_number}': avg_psnr/n,
                "Examples": example_images, 'Epoch': epoch})

In [None]:
def run_training2(network, model_number):
    model = network().to(device)
    wandb.watch(model, log="all")
    optimizer = optim.Adam(model.parameters(), lr=config.lr)
    loss_fn = nn.MSELoss()
    for epoch in range(1, config.epochs + 1):
        train_model_y_mid_large(config, model, trainloader, optimizer, epoch, loss_fn=loss_fn, model_number=model_number)
        test_model_y_mid_large(config, model, testloader, epoch, loss_fn=loss_fn, model_number=model_number)
    return model

In [None]:
model2 = run_training2(convnet2, model_number=2)

In [None]:
def show_test_images(model, test):
    y_mid, y_large = model(test['X'])
    img_dct = dict(y_large=y_large[0], y_mid=y_mid[0], X=test['X'][0])
    show_images(img_dct)

In [None]:
show_test_images(model2, test=test)

In [None]:
model3 = run_training2(resnet, model_number=3)

In [None]:
show_test_images(model3, test=test)

In [None]:
model4 = run_training2(dilation_net, model_number=4)

In [None]:
show_test_images(model4, test=test)

In [None]:
model5 = run_training2(pretrained_net, model_number=5)

In [None]:
show_test_images(model5, test=test)

In [None]:
model6 = run_training2(pixel_shuffle_net, model_number=6)

In [None]:
show_test_images(model6, test=test)