In [5]:
import block
import importlib
import dataloader
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch
import torch.optim as optim
import cv2
import numpy as np
import gc
import loss
import train
import test
import utils

AttributeError: module 'torch' has no attribute 'version'

In [None]:
importlib.reload(dataloader)
importlib.reload(block)
importlib.reload(loss)
importlib.reload(train)
importlib.reload(test)
importlib.reload(utils)

In [None]:
from dataloader import ImageDataset, ImageLoader

dataset = ImageDataset("../data/")
loader = ImageLoader(dataset, batch_size=1)

In [None]:
test_loader, train_loader = ImageLoader(dataset[2000:], batch_size=1), ImageLoader(
    dataset[:2000], batch_size=1
)

**Test Image**

In [None]:
img = loader.dataset[2049]

In [None]:
plt.figure(figsize=(6, 4))
plt.title("512x512x3")
plt.imshow(img[0][0])
plt.show()

In [None]:
prev_params = None
for size in [128, 256]:
    fe = block.ShallowFE(img.shape[1], size)
    num_params = sum(p.numel() for p in fe.parameters())
    if prev_params is not None:
        increase_percent = ((num_params - prev_params) / prev_params) * 100
        print(
            f"Number of parameters for output size {size}: {num_params:,} (+{increase_percent:.2f}%)"
        )
    else:
        print(f"Number of parameters for output size {size}: {num_params:,}")
    prev_params = num_params

del fe
gc.collect()
torch.cuda.empty_cache()

In [None]:
block.ShallowFE(img.shape[1], 256)

In [None]:
TRAIN, TEST, EVAL = False, False, True
if TRAIN:
    for size in [128]:
        fe = block.ShallowFE(img.shape[1], size)
        optimizer = optim.Adam(fe.parameters(), lr=0.00001)
        train.train(fe, train_loader, loss.VGGPerceptualLoss(), optimizer, size, 5, iter=100)
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
if TEST:
    for size in [64, 256]:
        plt.figure(figsize=(32, 6))
        plt.subplot(1, 7, 1)
        plt.title(f"{size}: Original")
        plt.imshow(img[0][0])
        for epoch in range(3, 9):
            plt.subplot(1, 7, epoch - 1)
            plt.title(f"{size}: {"Epoch" + str(epoch - 3) if epoch < 8 else 'Untrained'}")
            fe = block.ShallowFE(img.shape[1], size)
            if epoch >= 0 and epoch < 8:
                fe.load_state_dict(torch.load(f"./models/model({size})_{epoch - 3}.pth"))
            fe.eval()
            #test.test(fe, test_loader)
            plt.imshow(fe(img)[0][0].detach().cpu().numpy())
            torch.cuda.empty_cache()
            gc.collect()
        plt.show()

In [None]:
if EVAL:
    for res in range(2, 5):
        transform = transforms.Compose([utils.DecreaseResolution(3, depth = res)])
        copied = img.clone()
        copied[0] = transform(copied)
        for size in [128]:
            plt.figure(figsize=(32, 6))
            plt.title(f"Degraded scale {res}x")
            plt.subplot(1, 7, 1)
            plt.title(f"{size}: Original")
            plt.imshow(copied[0][0])
            for epoch in range(3, 9):
                plt.subplot(1, 7, epoch - 1)
                plt.title(f"{size}: {"Epoch" + str(epoch - 3) if epoch < 8 else 'Untrained'}")
                fe = block.ShallowFE(copied.shape[1], size)
                
                if epoch >= 0 and epoch < 8:
                    fe.load_state_dict(torch.load(f"./models/model({size})_{epoch - 3}.pth"))
                fe.eval()
 
                #test.test(fe, test_loader)
                
                out = copied.clone()
 
                out = fe(out)

                plt.imshow(out[0][0].detach().cpu().numpy())

                torch.cuda.empty_cache()
                gc.collect()
            plt.show()