# **Point Scanning Super Resolution (PSSR) - Training**

# **1. Preparation: Set the Runtime type and mount your Google Drive**
---

## **1.1. Set the Runtime type**
---

<font size = 4>Go to **Runtime -> Change the Runtime type**

<font size = 4>**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*

<font size = 4>**Accelator: GPU** *(Graphics processing unit (GPU)*

In [None]:
#Run this cell to check if you have GPU access
%tensorflow_version 2.x

import tensorflow as tf
if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.') 
  print('Did you change your runtime ?') 
  print('If the runtime settings are correct then Google did not allocate GPU to your session')
  print('Expect slow performance. To access GPU try reconnecting later')
else:
  print('You have GPU access')

from tensorflow.python.client import device_lib 
device_lib.list_local_devices()

TensorFlow 2.x selected.
You have GPU access


[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 10765735311409383637, name: "/device:XLA_CPU:0"
 device_type: "XLA_CPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 2317427044891050265
 physical_device_desc: "device: XLA_CPU device", name: "/device:XLA_GPU:0"
 device_type: "XLA_GPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 11729913141592060865
 physical_device_desc: "device: XLA_GPU device", name: "/device:GPU:0"
 device_type: "GPU"
 memory_limit: 15956161332
 locality {
   bus_id: 1
   links {
   }
 }
 incarnation: 6927160916135005224
 physical_device_desc: "device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0"]

## **1.2. Mount your Google Drive**
---
<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. 

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [None]:
# mount user's Google Drive to Google Colab.
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


## **1.3. Install PSSR and dependencies**
---

In [None]:
!pip install czifile
!pip install fastai==1.0.61

Collecting czifile
  Downloading https://files.pythonhosted.org/packages/37/86/3d0b1829c8c24eb1a4214f098a02442209f80302766203db33c99a4681ec/czifile-2019.7.2-py2.py3-none-any.whl
Installing collected packages: czifile
Successfully installed czifile-2019.7.2


## **1.4. Specify your working folder - need your input**
---

In [None]:
root_path = "gdrive/My Drive/PSSR-master"

# **2. PSSR - Get to know your training data**
---

In [None]:
import sys
sys.path.insert(1, root_path)
from fastai import *
from fastai.vision import *
from utils import *
from pathlib import Path
from fastprogress import master_bar, progress_bar
from time import sleep
import shutil
import PIL
import czifile
PIL.Image.MAX_IMAGE_PIXELS = 99999999999999

## **2.1. Specify your datasource - need your input**
---

In [None]:
sources = ['datasources']
output_file = 'live_mitotracker.csv'
only = 'mitotracker'
skip = ''

In [None]:
src_dirs = []
for src in sources:
    sub_fldrs = subfolders(Path(src))
    if skip:  src_dirs += [fldr for fldr in sub_fldrs if fldr.stem not in skip]
    elif only: src_dirs += [fldr for fldr in sub_fldrs if fldr.stem in only]
    else: src_dirs += sub_fldrs

In [None]:
def process_czi(item, category, mode):
#This function only takes the first channel of the czi files
#since those are the only mitotracker channels
    tif_srcs = []
    base_name = item.stem
    with czifile.CziFile(item) as czi_f:
        data = czi_f.asarray()
        axes, shape = get_czi_shape_info(czi_f)
        channels = shape['C']
        depths = shape['Z']
        times = shape['T']
        #times = min(times, 30) #ONLY USE FIRST 30 frames
        x,y = shape['X'], shape['Y']

        mid_depth = depths // 2
        depth_range = range(max(0,mid_depth-2), min(depths, mid_depth+2))
        is_multi = (times > 1) or (depths > 1)

        data = czi_f.asarray()
        all_rmax = data.max()
        all_mi, all_ma = np.percentile(data, [2,99.99])

        dtype = data.dtype
        #for channel in range(channels): #if other channels are needed, use this line
        for channel in range(0,1):
            for z in depth_range:
                for t in range(times):
                    idx = build_index(
                        axes, {
                            'T': t,
                            'C': channel,
                            'Z': z,
                            'X': slice(0, x),
                            'Y': slice(0, y)
                        })
                    img = data[idx]
                    mi, ma = np.percentile(img, [2,99.99])
                    if dtype == np.uint8: rmax = 255.
                    else: rmax = img.max()
                    tif_srcs.append({'fn': item, 'ftype': 'czi', 'multi':int(is_multi), 'category': category, 'dsplit': mode,
                                     'uint8': dtype == np.uint8, 'mi': mi, 'ma': ma, 'rmax': rmax,
                                     'all_rmax': all_rmax, 'all_mi': all_mi, 'all_ma': all_ma,
                                     'mean': img.mean(), 'sd': img.std(),
                                     'nc': channels, 'nz': depths, 'nt': times,
                                     'z': z, 't': t, 'c':channel, 'x': x, 'y': y})
    return tif_srcs

def is_live(item):
    return item.parent.parts[-3] == 'live'

def process_tif(item, category, mode):
    tif_srcs = []
    img = PIL.Image.open(item)
    n_frames = img.n_frames
    x,y = img.size
    is_multi = n_frames > 1
    #n_frames = min(n_frames, 30) #ONLY USE FIRST 30 frames

    data = []
    for n in range(n_frames):
        img.seek(n)
        img.load()
        img_data = np.array(img)
        data.append(img_data)

    data = np.stack(data)
    all_rmax = data.max()
    all_mi, all_ma = np.percentile(data, [2,99.99])

    for n in range(n_frames):
        img_data = data[n]
        dtype = img_data.dtype
        mi, ma = np.percentile(img_data, [2,99.99])
        if dtype == np.uint8: rmax = 255.
        else: rmax = img_data.max()
        if is_live(item):
            t, z = n, 0
            nt, nz = n_frames, 1
        else:
            t, z = 0, n
            nt, nz = 1, n_frames

        tif_srcs.append({'fn': item, 'ftype': 'tif', 'multi':int(is_multi), 'category': category, 'dsplit': mode,
                         'uint8': dtype==np.uint8, 'mi': mi, 'ma': ma, 'rmax': rmax,
                         'all_rmax': all_rmax, 'all_mi': all_mi, 'all_ma': all_ma,
                         'mean': img_data.mean(), 'sd': img_data.std(),
                         'nc': 1, 'nz': nz, 'nt': nt,
                         'z': z, 't': t, 'c':0, 'x': x, 'y': y})
    return tif_srcs

def process_unk(item, category, mode):
    print(f"**** Unknown: {item}")
    return []

def process_item(item, category, mode):
    try:
        if mode == 'test': return []
        else:
            item_map = {
                '.tif': process_tif,
                '.tiff': process_tif,
                '.czi': process_czi,
            }
            map_f = item_map.get(item.suffix, process_unk)
            return map_f(item, category, mode)
    except Exception as ex:
        print(f'err procesing: {item}')
        print(ex)
        return []

def build_tifs(src, mbar=None):
    tif_srcs = []
    for mode in ['train', 'valid', 'test']:
        live = src.parent.parts[-1] == 'live'
        src_dir = src / mode
        category = src.stem
        items = list(src_dir.iterdir()) if src_dir.exists() else []
        if items:
            for p in progress_bar(items, parent=mbar):
                mbar.child.comment = mode
                tif_srcs += process_item(p, category=category, mode=mode)
    return tif_srcs

In [None]:
#pull metadata from datasources
mbar = master_bar(src_dirs)
tif_srcs = []
for src in mbar:
    mbar.write(f'process {src.stem}')
    tif_srcs += build_tifs(src, mbar=mbar)

In [None]:
#save csv to disk
tif_src_df = pd.DataFrame(tif_srcs)
tif_src_df[['category','dsplit','multi','ftype','uint8','mean','sd','all_rmax','all_mi','all_ma','mi','ma','rmax','nc','nz','nt','c','z','t','x','y','fn']].to_csv(output_file, header=True, index=False)
shutil.move(output_file, f'{root_path}/{output_file}')