In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import sys
sys.path.append('../')

from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from data.load import get_data, get_patched_data, get_patched_src, subsample
from data.utils import custom_cutout
from model.metrics import psnr, ssim
from model import losses

## Choose Devices

In [None]:
gpu_id = 2
num_cores = 4

In [None]:
torch.cuda.set_device(gpu_id)

To load the patch critic, loading data is necessary. Is there a way to load the model without loading any data (can we save a good critic not as .pth file but as something else)?

## SSIM Loss

In [None]:
ssim_loss = losses.SSIM()

## Metrics

In [None]:
superres_metrics = [losses.mse_loss, psnr, ssim]

# Model

## Initialization

In [None]:
nb_name = 'ssimloss-resnet34-pretrained'
data_pth = Path('/home/alaa/Dropbox/BPHO Staff/USF/')
lr_path = f'EM/training/trainsets/lr/'
hr_path = f'EM/training/trainsets/hr/'
model_path = data_pth/f'EM/models/feat_loss/'

# loading 3 rounds of data
bs_1 = 64
size_1 = 128
db = get_data(data_pth=data_pth, lr_dir=lr_path, hr_dir=hr_path,
             bs=bs_1, in_sz=size_1, out_sz=size_1, max_zoom=6)

In [None]:
arch = models.resnet34
wd = 1e-3
learn = unet_learner(db, arch, wd=wd, 
                         loss_func=ssim_loss,
                         metrics=superres_metrics, 
                         #callback_fns=LossMetrics, 
                         blur=False, 
                         norm_type=NormType.Weight, 
                         model_dir=model_path)
gc.collect()

In [None]:
learn.model_dir = model_path
print(learn.model_dir)

## Load Model (optional)

In [None]:
# learn.model_dir = model_pth
# learn = learn.load(f'{nb_name}.3b')

# Load PSSR
learn = learn.load("emsynth_005_unet.5")

## Training

In [None]:
learn.lr_find()

In [None]:
learn.recorder.plot()

### 1a

In [None]:
lr = # specify LR
learn.fit_one_cycle(1, max_lr=lr)
learn.save(f'{nb_name}.1a', with_opt=False)

In [None]:
learn.show_results(rows=5, figsize=(30,24))

In [None]:
len(learn.data.valid_dl)

### 1b

In [None]:
learn.unfreeze()

In [None]:
learn.lr_find()

In [None]:
learn.recorder.plot()

In [None]:
lr = slice(1e-5, 1e-3)
learn.fit_one_cycle(3, max_lr=lr)
learn.save(f'{nb_name}.1b', with_opt=False)

In [None]:
learn.show_results(rows=5, figsize=(25, 15))

## 2a

In [None]:
# Progressive resizing
bs_2 = 32
size_2 = 256
db = get_data(data_pth=data_pth, lr_dir=lr_path, hr_dir=hr_path,
             bs=bs_2, in_sz=size_2, out_sz=size_2, max_zoom=3)

In [None]:
learn.data = db

In [None]:
learn.freeze()

In [None]:
learn.lr_find()

In [None]:
learn.recorder.plot()

In [None]:
lr = 3e-3
learn.fit_one_cycle(1, max_lr=lr)
learn.save(f'{nb_name}.2a', with_opt=False)

In [None]:
learn.show_results(rows=5, figsize=(25, 25))

## 2b

In [None]:
learn.unfreeze()

In [None]:
lr = slice(1e-5, 1e-3)
learn.fit_one_cycle(3, max_lr=lr)
learn.save(f'{nb_name}.2b', with_opt=False)

In [None]:
learn = learn.load(f'{nb_name}.2b');
learn.validate()

In [None]:
learn.load(f"emsynth_005_unet.5");
learn.validate()

## 3a

learn.lr_find()

In [None]:
learn.recorder.plot()

In [None]:
lr = 1e-3
learn.fit_one_cycle(1, max_lr=lr)
learn.save(f'{nb_name}.3a', with_opt=False)

In [None]:
learn.show_results(rows=5, figsize=(25, 25))

In [None]:
# loading 3 rounds of data
bs_3 = 8
size_3 = 512
db = get_data(data_pth=data_pth, lr_dir=lr_path, hr_dir=hr_path,
             bs=bs_3, in_sz=size_3, out_sz=size_3, max_zoom=2.)

In [None]:
learn.data = db
learn.data

In [None]:
learn.freeze()

In [None]:
learn.lr_find()

In [None]:
learn.recorder.plot()

In [None]:
lr = 1e-3
learn.fit_one_cycle(1, max_lr=lr)
learn.save(f'{nb_name}.3a', with_opt=False)

In [None]:
learn.show_results(rows=5, figsize=(25, int(25*2)))

In [None]:
learn.unfreeze()

In [None]:
learn.lr_find()

In [None]:
learn.recorder.plot()

In [None]:
learn.load(f'{nb_name}.3a');

In [None]:
lr = slice(3e-5, 3e-4)
learn.fit_one_cycle(3, max_lr=lr)
learn.save(f'{nb_name}.3bv2', with_opt=False)

In [None]:
learn.show_results(rows=5, figsize=(25, int(25*2)))

In [None]:
learn.load(f'{nb_name}.3b');

In [None]:
learn.validate()

In [None]:
learn.show_results(rows=5, figsize=(25, int(25*2)))

In [None]:
def patchwise_mse(target, predict, coords):
    num_patches = coords.shape[0]
    for xys in coords:
        loss += mse(predict[xys[0]:xys[1], xys[2]:xys[3]])
    return loss / num_patches