# **Point Scanning Super Resolution (PSSR) - Training**

# **1. Preparation: Set the Runtime type and mount your Google Drive**
---

## **1.1. Set the Runtime type**
---

<font size = 4>Go to **Runtime -> Change the Runtime type**

<font size = 4>**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*

<font size = 4>**Accelator: GPU** *(Graphics processing unit (GPU)*

In [None]:
#Run this cell to check if you have GPU access
%tensorflow_version 1.x

import tensorflow as tf
if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.') 
  print('Did you change your runtime ?') 
  print('If the runtime settings are correct then Google did not allocate GPU to your session')
  print('Expect slow performance. To access GPU try reconnecting later')
else:
  print('You have GPU access')

from tensorflow.python.client import device_lib 
device_lib.list_local_devices()

TensorFlow 1.x selected.
You have GPU access


[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 10765735311409383637, name: "/device:XLA_CPU:0"
 device_type: "XLA_CPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 2317427044891050265
 physical_device_desc: "device: XLA_CPU device", name: "/device:XLA_GPU:0"
 device_type: "XLA_GPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 11729913141592060865
 physical_device_desc: "device: XLA_GPU device", name: "/device:GPU:0"
 device_type: "GPU"
 memory_limit: 15956161332
 locality {
   bus_id: 1
   links {
   }
 }
 incarnation: 6927160916135005224
 physical_device_desc: "device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0"]

## **1.2. Mount your Google Drive**
---
<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. 

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [None]:
# mount user's Google Drive to Google Colab.
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


## **1.3. Install PSSR and dependencies**
---

In [None]:
!pip install czifile

Collecting czifile
  Downloading https://files.pythonhosted.org/packages/37/86/3d0b1829c8c24eb1a4214f098a02442209f80302766203db33c99a4681ec/czifile-2019.7.2-py2.py3-none-any.whl
Installing collected packages: czifile
Successfully installed czifile-2019.7.2


## **1.4. Specify your working folder - need your input**
---

In [None]:
root_path = "gdrive/My Drive/PSSR-master"

# **2. PSSR - Generate training datasets**
---

In [None]:
import sys
sys.path.insert(1, root_path)
from fastai.script import *
from fastai.vision import *
from utils import *
from utils.crappifiers import *
from pathlib import Path
from fastprogress import master_bar, progress_bar
from time import sleep
import torchvision
import shutil
import PIL
import czifile
import glob
from PIL import Image
from skimage.transform import rescale
from skimage import filters
from skimage.util import random_noise
from scipy.ndimage.interpolation import zoom as npzoom
PIL.Image.MAX_IMAGE_PIXELS = 99999999999999

In [None]:
def need_cache_flush(tile_stats, last_stats):
    if last_stats is None: return True
    if tile_stats['fn'] != last_stats['fn']: return True
    return False

def get_tile_puller(tile_stat, crap_func, t_frames, z_frames):
    fn = tile_stat['fn']
    ftype = tile_stat['ftype']
    nz = tile_stat['nz']
    nt = tile_stat['nt']

    half_z = z_frames // 2
    half_t = t_frames // 2

    if ftype == 'czi':
        img_f = czifile.CziFile(fn)
        proc_axes, proc_shape = get_czi_shape_info(img_f)
        img_data = img_f.asarray()
        img_data = img_data.astype(np.float32)

        def czi_get(istat):
            c,z,t,x,y,mi,ma,is_uint8,rmax,all_rmax,all_ma = [istat[fld] for fld in ['c','z','t','x','y','mi','ma','uint8','rmax','all_rmax','all_ma']]
            if is_uint8:
                mi, ma, rmax = 0., 255.0, 255.0
                all_ma, all_rmax = 255.0, 255.0

            t_slice = slice(t-half_t, t+half_t+1) if half_t > 0 else t
            z_slice = slice(z-half_z, z+half_z+1) if half_z > 0 else z
            idx = build_index(
                proc_axes, {
                    'C': c,
                    'T': t_slice,
                    'Z': z_slice,
                    'X': slice(0, x),
                    'Y': slice(0, y)
                })
            img = img_data[idx].copy()
            img /= all_rmax
            if len(img.shape) <= 2: img = img[None]
            return img

        img_get = czi_get
        img_get._to_close = img_f
    else:
        pil_img = PIL.Image.open(fn)
        def pil_get(istat):
            c,z,t,x,y,mi,ma,is_uint8,rmax,all_rmax,all_ma = [istat[fld] for fld in ['c','z','t','x','y','mi','ma','uint8','rmax','all_rmax','all_ma']]
            if half_t > 0: n_start, n_end = t-half_t, t+half_t+1
            elif half_z > 0: n_start, n_end = z-half_z, z+half_z+1
            else: n_start, n_end = 0,1

            if is_uint8:
                mi, ma, rmax = 0., 255.0, 255.0
                all_ma, all_rmax = 255.0, 255.0

            img_array = []
            for ix in range(n_start, n_end):
                pil_img.seek(ix)
                pil_img.load()
                img = np.array(pil_img)
                if len(img.shape) > 2: img = img[:,:,0]
                img_array.append(img.copy())

            img = np.stack(img_array)
            img = img.astype(np.float32)
            img /= all_rmax
            return img

        img_get = pil_get
        img_get._to_close = pil_img


    def puller(istat, tile_folder, crap_folder, close_me=False):
        if close_me:
            img_get._to_close.close()
            return None

        id = istat['index']
        fn = Path(istat['fn'])
        tile_sz = istat['tile_sz']
        c,z,t,x,y,mi,ma,is_uint8,rmax = [istat[fld] for fld in ['c','z','t','x','y','mi','ma','uint8','rmax']]

        raw_data = img_get(istat)
        img_data = (np.iinfo(np.uint8).max * raw_data).astype(np.uint8)

        thresh = np.percentile(img_data, 2)
        thresh_pct = (img_data > thresh).mean() * 0.30

        frame_count = img_data.shape[0]
        mid_frame = frame_count // 2
        crop_img, box = draw_random_tile(img_data[mid_frame], istat['tile_sz'], thresh, thresh_pct)
        crop_img.save(tile_folder/f'{id:06d}_{fn.stem}.tif')
        if crap_func and crap_folder:
            if frame_count > 1:
                crap_data = []
                for i in range(frame_count):
                    frame_img = img_data[i, box[0]:box[2], box[1]:box[3]]
                    crap_frame = crap_func(frame_img)
                    crap_data.append(np.array(crap_frame))
                multi_array = np.stack(crap_data)
                np.save(crap_folder/f'{id:06d}_{fn.stem}.npy', multi_array)
            else:
                crap_img = crap_func(crop_img)
                crap_img.save(crap_folder/f'{id:06d}_{fn.stem}.tif')

        info = dict(istat)
        info['id'] = id
        info['box'] = box
        info['tile_sz'] = tile_sz
        crop_data = np.array(crop_img)
        info['after_mean'] = crop_data.mean()
        info['after_sd'] = crop_data.std()
        info['after_max'] = crop_data.max()
        info['after_min'] = crop_data.min()
        return info

    return puller

def check_info(info, t_frames, z_frames):
    t_space = t_frames // 2
    z_space = z_frames // 2

    z_ok = (info['nz'] >= z_frames) and (info['z'] >= z_space) and (info['z'] < (info['nz']-z_space))
    t_ok = (info['nt'] >= t_frames) and (info['t'] >= t_space) and (info['t'] < (info['nt']-t_space))

    return t_ok and z_ok

## **2.1. Specify your datasource - need your input**
---

In [None]:
out = 'datasets' #dataset folder, Path
info = 'live_mitotrakcer.csv' #path of the metadata csv file, Path
tile = 512 #generated training tile size, int
n_train: 500 #number of train tiles, int
n_valid: 50 #number of validation tiles', int
crap_func = 'new_crap_AG_SP' #crappifier name, str, check utils/crappifiers.py for more details
n_frames = 1 #number of frames, int, 1 if singleframe, >1 if multiframe, 5 for multiframe by default
lr_type = 's' # (s)ingle, (t) multi or (z) multi', string, if multiframe, t if XYT time-lapse, z if XYZ 3D stack
scale = 4 # upsample factor, int
upsample = False # if LR-Bilinear is needed to save to disk, boolean

In [None]:
up = 'up' if upsample else ''
if lr_type not in ['s','t','z']:
    print('lr_type should be s, t or z')
    return 1

if lr_type == 's':
    z_frames, t_frames = 1, 1
elif lr_type == 't':
    z_frames, t_frames = 1, n_frames
elif lr_type == 'z':
    z_frames, t_frames = n_frames, 1

out = ensure_folder(out/f'{lr_type}_{n_frames}_{info.stem}_{crap_func}')
if out.exists(): shutil.rmtree(out)
out.mkdir(parents=True, mode=0o775, exist_ok=True)

crap_func = eval(crap_func)
if not crap_func is None:
    if not callable(crap_func):
        print('crap_func is not callable')
        crap_func = None
    else:
        crap_func = partial(crap_func, scale=scale, upsample=upsample)

info = pd.read_csv(info)
info = info.loc[info.nz >= z_frames]
info = info.loc[info.nt >= t_frames]

In [None]:
tile_infos = []
for mode, n_samples in [('train', n_train),('valid', n_valid)]:
    mode_info = info.loc[info.dsplit == mode]
    categories = list(mode_info.groupby('category'))
    files_by_category  = {c:list(info.groupby('fn')) for c,info in categories}

    for i in range(n_samples):
        category, cat_df = random.choice(categories)
        fn, item_df = random.choice(files_by_category[category])
        legal_choices = [item_info for ix, item_info in item_df.iterrows() if check_info(item_info, t_frames, z_frames)]

        assert(legal_choices)
        item_info = random.choice(legal_choices)
        for tile_sz in tile:
            item_d = dict(item_info)
            item_d['tile_sz'] = tile_sz
            tile_infos.append(item_d)

tile_info_df = pd.DataFrame(tile_infos).reset_index()
print('num tile pulls:', len(tile_infos))
print(tile_info_df.groupby('category').fn.count())

In [None]:
last_stat = None
tile_pull_info = []
tile_puller = None

multi_str = f'_{lr_type}_{n_frames}' if lr_type != 's' else ''
mbar = master_bar(tile_info_df.groupby('fn'))
for fn, tile_stats in mbar:
    for i, tile_stat in progress_bar(list(tile_stats.iterrows()), parent=mbar):
        try:
            mode = tile_stat['dsplit']
            category = tile_stat['category']
            tile_sz = tile_stat['tile_sz']
            tile_folder = ensure_folder(out / f'hr_t_{tile_sz}{multi_str}' / mode / category)
            if crap_func:
                crap_folder = ensure_folder(out / f'lr{up}_t_{tile_sz}{multi_str}' / mode / category)
            else: crap_folder = None

            if need_cache_flush(tile_stat, last_stat):
                if tile_puller:
                    tile_puller(None, None, None, close_me=True)
                last_stat = tile_stat.copy()
                tile_sz = tile_stat['tile_sz']
                tile_puller = get_tile_puller(tile_stat, crap_func, t_frames, z_frames)
            tile_pull_info.append(tile_puller(tile_stat, tile_folder, crap_folder))
        except MemoryError as error:
            # some files are too big to read
            fn = Path(tile_stat['fn'])
            print(f'too big: {fn.stem}')

pd.DataFrame(tile_pull_info).to_csv(out/f'tiles{multi_str}.csv', index = False)