# **Point Scanning Super Resolution (PSSR) - Training**

# **1. Preparation: Set the Runtime type and mount your Google Drive**
---

## **1.1. Set the Runtime type**
---

<font size = 4>Go to **Runtime -> Change the Runtime type**

<font size = 4>**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*

<font size = 4>**Accelator: GPU** *(Graphics processing unit (GPU)*

In [None]:
#Run this cell to check if you have GPU access
%tensorflow_version 1.x

import tensorflow as tf
if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.') 
  print('Did you change your runtime ?') 
  print('If the runtime settings are correct then Google did not allocate GPU to your session')
  print('Expect slow performance. To access GPU try reconnecting later')
else:
  print('You have GPU access')

from tensorflow.python.client import device_lib 
device_lib.list_local_devices()

TensorFlow 1.x selected.
You have GPU access


[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 10765735311409383637, name: "/device:XLA_CPU:0"
 device_type: "XLA_CPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 2317427044891050265
 physical_device_desc: "device: XLA_CPU device", name: "/device:XLA_GPU:0"
 device_type: "XLA_GPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 11729913141592060865
 physical_device_desc: "device: XLA_GPU device", name: "/device:GPU:0"
 device_type: "GPU"
 memory_limit: 15956161332
 locality {
   bus_id: 1
   links {
   }
 }
 incarnation: 6927160916135005224
 physical_device_desc: "device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0"]

## **1.2. Mount your Google Drive**
---
<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. 

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [None]:
# mount user's Google Drive to Google Colab.
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


## **1.3. Install PSSR and dependencies**
---

In [None]:
!pip install czifile

Collecting czifile
  Downloading https://files.pythonhosted.org/packages/37/86/3d0b1829c8c24eb1a4214f098a02442209f80302766203db33c99a4681ec/czifile-2019.7.2-py2.py3-none-any.whl
Installing collected packages: czifile
Successfully installed czifile-2019.7.2


## **1.4. Specify your working folder - need your input**
---

In [None]:
root_path = "gdrive/My Drive/PSSR-master"

# **2. PSSR Training**
---

In [None]:
import sys
sys.path.insert(1, root_path)
from fastai.script import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.distributed import *
from fastai.vision.models.unet import DynamicUnet
from fastai.vision.models import resnet18, resnet34, resnet50
from skimage.util import random_noise
from skimage import filters
from utils import *
from utils.resnet import *
from utils.utils import unet_image_from_tiles_blend

In [None]:
torch.backends.cudnn.benchmark = True
torch.cuda.set_device(0)

## **2.1. Hyper-parameter configuration - need your input**
---

In [None]:
#basic adjustable hyper-parameters - configure as needed
datasetname = 't_5_mito'
tile_sz = 512
n_frames = 5
lr_type = 't' #'s' or 't' or 'z'
bs = 8
size = 512
lr = 4e-4
cycles = 10

In [None]:
#advanced adjustable hyper-parameters - keep default numbers in most cases

#data augmentation related
cutout = False
norm = True
mode = 'L'

#network architecture
arch = 'wnresnet34'
attn = True
blur = True
final_blur = True
bottle = True
last_cross = True

#fitting related
l1_loss = False
lr_start = None
load_name = None
freeze = False
wd = 1e-3
save_name = None

In [None]:
# TO-DO: Export configuration as .json files

## **2.2. Prepare Databunch**
---

In [None]:
def get_src(x_data, y_data, n_frames=1, mode='L'):
    def map_to_hr(x):
        return y_data/x.relative_to(x_data).with_suffix('.tif')

    if n_frames == 1:
        src = (ImageImageList
                .from_folder(x_data, convert_mode=mode)
                .split_by_folder()
                .label_from_func(map_to_hr, convert_mode=mode))
    else:
        src = (MultiImageImageList
                .from_folder(x_data, extensions=['.npy'])
                .split_by_folder()
                .label_from_func(map_to_hr, convert_mode=mode))
    return src

In [None]:
def get_data(bs, size, x_data, y_data,
             n_frames=1,
             max_rotate=10.,
             min_zoom=1., max_zoom=1.1,
             use_cutout=False,
             use_noise=False,
             scale=4,
             xtra_tfms=None,
             gauss_sigma=(0.4,0.7),
             pscale=(5,30),
             mode='L',
             norm=False,
             **kwargs):
    src = get_src(x_data, y_data, n_frames=n_frames, mode=mode)

    x_tfms, y_tfms = get_xy_transforms(
                          max_rotate=max_rotate,
                          min_zoom=min_zoom, max_zoom=max_zoom,
                          use_cutout=use_cutout,
                          use_noise=use_noise,
                          gauss_sigma=gauss_sigma,
                          pscale=pscale,
                          xtra_tfms = xtra_tfms)
    x_size = size // scale
    data = (src
            .transform(x_tfms, size=x_size)
            .transform_y(y_tfms, size=size)
            .databunch(bs=bs, **kwargs))
    if norm:
        print('normalizing x and y data')
        data = data.normalize(do_y=True)
    #data.c = 3 #why?
    return data

In [None]:
datasets = Path(root_path)/'datasets'
dataset = datasets/datasetname
if tile_sz is None:
    hr_tifs = dataset/f'hr'
    lr_tifs = dataset/f'lr'
else:
    multi_str = f'_{lr_type}_{n_frames}' if lr_type != 's' else ''
    hr_tifs = dataset/f'hr_t_{tile_sz:d}{multi_str}'
    lr_tifs = dataset/f'lr_t_{tile_sz:d}{multi_str}'

data = get_data(bs, size, lr_tifs, hr_tifs, n_frames=n_frames, max_zoom=4.,
                use_cutout=cutout, mode=mode, norm=norm)
print('bs:', bs, 'size: ', size)



normalizing x and y data
bs: 8 size:  512


## **2.3. Set up Learner**
--- 

In [None]:
# Set up the learner
if save_name is None: 
    save_name = f'{datasetname}_{cycles}epochs'
pickle_models = Path(root_path)/'stats/models'
pth_models = Path(root_path)/'models'
if l1_loss: loss = F.l1_loss
else: loss = F.mse_loss
print('loss: ', loss)

callback_fns = []
callback_fns.append(partial(SaveModelCallback, name=f'{save_name}_best_{size}'))

wnres_args = {
    'blur': blur,
    'blur_final': final_blur,
    'bottle': bottle,
    'self_attention': attn,
    'last_cross': True
}
arch = eval(arch)
learn = wnres_unet_learner(data, arch, in_c=n_frames, wnres_args=wnres_args,
                          path=Path('.'), loss_func=loss, metrics=sr_metrics,
                          model_dir=pth_models, callback_fns=callback_fns, wd=wd)

loss:  <function mse_loss at 0x7f908239bd90>


In [None]:
if load_name:
    learn = learn.load(f'{load_name}')
    print(f'loaded {load_name}')

if freeze: learn.freeze()

if not lr_start is None: lr = slice(lr_start, lr)
else: lr = slice(None, lr, None)

## **2.4. Training**
--- 

In [None]:
learn.fit_one_cycle(cycles, lr)

epoch,train_loss,valid_loss,ssim,psnr,norm_ssim,norm_psnr,time
0,2.112518,1.314248,0.053371,18.152414,0.193791,24.315308,00:19
1,1.524127,0.893038,0.144292,19.878559,0.19853,21.536158,00:06
2,1.105469,0.525178,0.201754,22.321411,0.441906,24.159216,00:06
3,0.854487,0.361783,0.293551,24.011105,0.524344,25.495962,00:06
4,0.691903,0.285678,0.261414,25.020697,0.537696,26.280737,00:06
5,0.580606,0.226929,0.246743,25.951757,0.557221,27.009121,00:12
6,0.499706,0.197493,0.248876,26.523712,0.575726,27.447292,00:06
7,0.437946,0.173635,0.235162,27.033894,0.587575,27.961914,00:08
8,0.390096,0.153997,0.221161,27.552311,0.600354,28.364492,00:06
9,0.3533,0.139935,0.227021,27.955494,0.617168,28.735907,00:10


Better model found at epoch 0 with valid_loss value: 1.3142478466033936.
Better model found at epoch 1 with valid_loss value: 0.8930379748344421.
Better model found at epoch 2 with valid_loss value: 0.5251778364181519.
Better model found at epoch 3 with valid_loss value: 0.36178290843963623.
Better model found at epoch 4 with valid_loss value: 0.2856775224208832.
Better model found at epoch 5 with valid_loss value: 0.22692862153053284.
Better model found at epoch 6 with valid_loss value: 0.19749264419078827.
Better model found at epoch 7 with valid_loss value: 0.1736353635787964.
Better model found at epoch 8 with valid_loss value: 0.15399673581123352.
Better model found at epoch 9 with valid_loss value: 0.1399351805448532.


## **2.5. Export trained PSSR model**
--- 

In [None]:
learn.save(save_name)
print(f'saved: {save_name}')
learn.export(pickle_models/f'{save_name}_{size}.pkl')
print('exported')

saved: t_5_mito_10epochs
exported
