In [1]:
! rm -rf ./runs/
from dataset import SRDDataset, Dataset
from model import BasicSRModel, BasicSRModelSkip
from trainer import SRTrainer
from tester import Tester

import torch.utils.data
import torch
import torch.nn as nn

import random
from os import listdir
from os.path import isfile, join

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Cuda: Version {torch.version.cuda}; Using {device} device')

Cuda: Version 11.3; Using cuda device


In [3]:
#Prepare data
train_images = [file for file in listdir("./train") if isfile(join("./train",file))]
random.shuffle(train_images)
train_split = 0.9

train_paths =  train_images[:int(train_split*len(train_images))]
valid_paths =  train_images[int(train_split*len(train_images)):]
tests_paths = [file for file in listdir("./eval") if isfile(join("./eval",file))]

train_dataset = SRDDataset("./train", train_paths, device, type=Dataset.TRAIN)
valid_dataset = SRDDataset("./train", valid_paths, device, type=Dataset.VALID)
tests_dataset = SRDDataset("./eval",  tests_paths, 'cpu',  type=Dataset.VALID) #Run CPU

train_dataloader = torch.utils.data.dataloader.DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=True)
valid_dataloader = torch.utils.data.dataloader.DataLoader(valid_dataset,batch_size=1,shuffle=False,num_workers=0)
tests_dataloader = torch.utils.data.dataloader.DataLoader(tests_dataset,batch_size=1,shuffle=False,num_workers=0,pin_memory=True) #Run CPU

In [4]:
model = BasicSRModel().to(device)
srtrainer = SRTrainer(model=model,loss_fn=nn.L1Loss(), lr=1e-4)
srtrainer.train(train_dataloader,valid_dataloader,num_epochs=31)

100%|██████████| 67/67 [00:10<00:00,  6.54it/s]
100%|██████████| 31/31 [00:02<00:00, 10.54it/s]


Epoch 1/31 Training Loss: 0.218 Valid (L1): 0.094 Valid (PSNR): 18.760  Valid (SSIM): 0.565


100%|██████████| 67/67 [00:05<00:00, 13.15it/s]
 58%|█████▊    | 18/31 [00:01<00:01, 10.33it/s]


KeyboardInterrupt: 

In [None]:
model = BasicSRModel().to(device)
model.load_state_dict(torch.load("./model/model30"))
model.eval()
tester = Tester(model)
model.to('cpu')
tester.test(tests_dataloader)

In [None]:
model = BasicSRModelSkip().to(device)
srtrainer = SRTrainer(model=model,loss_fn=nn.L1Loss(), lr=1e-4)
srtrainer.train(train_dataloader,valid_dataloader,num_epochs=31,model_name="skip_model")

In [None]:
model = BasicSRModelSkip().to(device)
model.load_state_dict(torch.load("./model/skip_model30"))
model.eval()
tester = Tester(model)
model.to('cpu')
tester.test(tests_dataloader,model_name="skip_model")