# **Point Scanning Super Resolution (PSSR) - Inference**

# **1. Preparation: Set the Runtime type and mount your Google Drive**
---

**1.1 Set the Runtime type**
---

<font size = 4>Go to **Runtime -> Change the Runtime type**

<font size = 4>**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*

<font size = 4>**Accelator: GPU** *(Graphics processing unit (GPU)*

In [1]:
#Run this cell to check if you have GPU access
%tensorflow_version 1.x

import tensorflow as tf
if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.') 
  print('Did you change your runtime ?') 
  print('If the runtime settings are correct then Google did not allocate GPU to your session')
  print('Expect slow performance. To access GPU try reconnecting later')
else:
  print('You have GPU access')

from tensorflow.python.client import device_lib 
device_lib.list_local_devices()

TensorFlow 1.x selected.
You have GPU access


[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 18061946578985755917, name: "/device:XLA_CPU:0"
 device_type: "XLA_CPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 9307723627770481009
 physical_device_desc: "device: XLA_CPU device", name: "/device:XLA_GPU:0"
 device_type: "XLA_GPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 15446412096300437072
 physical_device_desc: "device: XLA_GPU device", name: "/device:GPU:0"
 device_type: "GPU"
 memory_limit: 11330115994
 locality {
   bus_id: 1
   links {
   }
 }
 incarnation: 13295573838020484712
 physical_device_desc: "device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7"]

## **1.2 Mount your Google Drive**
---
<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. 

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [2]:
# mount user's Google Drive to Google Colab.
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


## **1.3 Install PSSR and dependencies**
---

In [3]:
!pip install czifile



## **1.4 Specify your working folder - need your input**
---

In [4]:
root_path = "gdrive/My Drive/PSSR-master"

# **2. PSSR Inference**
---

In [5]:
import sys
sys.path.insert(1, root_path)
from fastai.vision import *
from utils import *
import PIL.Image
import czifile
import imageio
from fastprogress import master_bar, progress_bar 

In [6]:
torch.backends.cudnn.benchmark = True
torch.cuda.set_device(0)

## **2.1 Specify your model and test data - need your input**
---

In [None]:
# Modify accordingly
model_name = 's_1_mito_50_512'
testset_name = 'real-world_mitotracker'
use_tiles = False
mode = 'L' #Param("L or RGBA", str)

## **2.2 Prepare for inference**
---

In [None]:
# Prepare model
test_path = Path(root_path)/'stats'
model_name = model_name
model_dir = test_path/'models'
print(f'{model_name} model is being used.')

In [None]:
# Prepare data
testset_name = testset_name
src_dir = test_path/'LR'/testset_name
out_dir = test_path/'LR-PSSR'/testset_name
out_dir = ensure_folder(out_dir)

## **2.3 Inference**
---

In [7]:
def process_tif(fn, processor, proc_func, out_fn, n_depth=1, n_time=1, mode='L'):
    with PIL.Image.open(fn) as img_tif:
        n_frame = max(n_depth, n_time)
        offset_frames = n_frame // 2

        if n_frame > img_tif.n_frames:
            if img_tif.n_frames == 1:
                times = n_frame
                img_tif = np.array(img_tif)
                data = np.repeat(img_tif[None],5,axis=0).astype(np.float32)
            else:
                return []
        else:
            times = img_tif.n_frames
            img_tifs = []
            for i in range(times):
                img_tif.seek(i)
                img_tif.load()
                img_tifs.append(np.array(img_tif).copy())
            data = np.stack(img_tifs).astype(np.float32)

        data, img_info = img_to_float(data)
        img_tiffs = []
        time_range = list(range(offset_frames, times - offset_frames))
        for t in progress_bar(time_range):
            time_slice = slice(t-offset_frames, t+offset_frames+1)
            img = data[time_slice].copy()
            pred_img = proc_func(img, img_info=img_info, mode=mode)
            pred_img8 = (pred_img * np.iinfo(np.uint8).max).astype(np.uint8)
            img_tiffs.append(pred_img8[None])

        imgs = np.concatenate(img_tiffs)
        if processor!='bilinear':
            fldr_name = f'{out_fn.parent}/{processor}'
        else:
            fldr_name = out_fn.parent.parent.parent/processor/out_fn.parent.stem
        save_name = f'{fn.stem}_{processor}.tif'
        out_fldr = ensure_folder(out_fn.parent/processor)

        if imgs.size < 4e9:
            imageio.mimwrite(out_fldr/save_name, imgs)
        else:
            imageio.mimwrite(out_fldr/save_name, imgs, bigtiff=True)

In [8]:
def process_czi(fn, processor, proc_func, out_fn, n_depth=1, n_time=1, mode='L'):
    stats = []
    with czifile.CziFile(fn) as czi_f:
        proc_axes, proc_shape = get_czi_shape_info(czi_f)
        channels = proc_shape['C']
        depths = proc_shape['Z']
        times = proc_shape['T']
        x, y = proc_shape['X'], proc_shape['Y']

        data = czi_f.asarray().astype(np.float32)
        data, img_info = img_to_float(data)

        if depths < n_depth: return
        if times < n_time: return

        if n_depth > 1: # this is busted
            offset_frames = n_depth // 2
            for c in range(channels):
                for t in range(times):
                    for z in range(offset_frames, depths - offset_frame):
                        depth_slice = slice(z-offset_frames, z+offset_frame+1)
                        idx = build_index(
                            proc_axes, {
                                'T': t,
                                'C': c,
                                'Z': depth_slice,
                                'X': slice(0, x),
                                'Y': slice(0, y)
                        })
                        img = data[idx].copy()
                        tag = f'{c}_{t}_{z+offset_frames}_'

                        save_name = f'{proc_name}_{item.stem}_{tag}'

                        pred_img = proc_func(img, img_info=img_info, mode=mode)
                        pred_img8 = (pred_img * np.iinfo(np.uint8).max).astype(np.uint8)
                        PIL.Image.fromarray(pred_img8).save(out_fn)
        elif n_time > 1:
            offset_frames = n_time // 2
            for c in range(channels):
                for z in range(depths):
                    imgs = []
                    time_range = list(range(offset_frames, times - offset_frames))
                    for t in progress_bar(time_range):
                        time_slice = slice(t-offset_frames, t+offset_frames+1)
                        idx = build_index(
                            proc_axes, {
                                'T': time_slice,
                                'C': c,
                                'Z': z,
                                'X': slice(0, x),
                                'Y': slice(0, y)
                        })
                        img = data[idx].copy()
                        pred_img = proc_func(img, img_info=img_info, mode=mode)
                        pred_img8 = (pred_img * np.iinfo(np.uint8).max).astype(np.uint8)
                        imgs.append(pred_img8[None])

                    all_y = np.concatenate(imgs)
                    if processor!='bilinear':
                        fldr_name = f'{out_fn.parent}/{processor}'
                    else:
                        fldr_name = out_fn.parent.parent.parent/processor/out_fn.parent.stem
                    save_name = f'{fn.stem}_{processor}.tif'
                    if c > 1 or z > 1:
                        fldr_name = fldr_name/f'{c}_{z}'
                    out_fldr = ensure_folder(fldr_name)

                    if all_y.size < 4e9:
                        imageio.mimwrite(out_fldr/save_name, all_y)
                    else:
                        imageio.mimwrite(out_fldr/save_name, all_y, bigtiff=True)
        else:
            imgs = []
            for c in range(channels):
                for z in range(depths):
                    for t in range(times):
                        idx = build_index(
                            proc_axes, {
                                'T': t,
                                'C': c,
                                'Z': z,
                                'X': slice(0, x),
                                'Y': slice(0, y)
                        })
                        img = data[idx].copy()
                        pred_img = proc_func(img, img_info=img_info, mode=mode)
                        pred_img8 = (pred_img * np.iinfo(np.uint8).max).astype(np.uint8)
                        imgs.append(pred_img8[None])
            all_y = np.concatenate(imgs)
            if processor!='bilinear':
                fldr_name = f'{out_fn.parent}/{processor}'
            else:
                fldr_name = out_fn.parent.parent.parent/processor/out_fn.parent.stem
            save_name = f'{fn.stem}_{processor}.tif'
            out_fldr = ensure_folder(fldr_name)

            if all_y.size < 4e9:
                imageio.mimwrite(out_fldr/save_name, all_y)
            else:
                imageio.mimwrite(out_fldr/save_name, all_y, bigtiff=True)


In [9]:
def process_files(src_dir, out_dir, model_dir, processor, mode, use_tiles):
    proc_map = {
        '.tif': process_tif,
        '.czi': process_czi
    }
    proc_func, num_chan = get_named_processor(processor, model_dir, use_tiles)
    src_files = list(src_dir.glob('**/*.czi'))
    src_files += list(src_dir.glob('**/*.tif'))

    for fn in progress_bar(src_files):
        out_fn = out_dir/fn.relative_to(src_dir)
        ensure_folder(out_fn.parent)
        file_proc = proc_map.get(fn.suffix, None)
        if file_proc:
            n_depth = n_time = 1
            if 'z_' in processor: n_depth = num_chan
            if 't_' in processor: n_time = num_chan
            print('File being processed: ', fn)
            file_proc(fn, processor, proc_func, out_fn, n_depth=n_depth, n_time=n_time, mode=mode)

In [17]:
process_files(src_dir, out_dir, model_dir, model_name, mode, use_tiles)
print('All done!')

File being processed:  gdrive/My Drive/PSSR_0625/stats/LR/real-world_mitotracker/realworld_lowres_lowpower_7_LR.tif




File being processed:  gdrive/My Drive/PSSR_0625/stats/LR/real-world_mitotracker/realworld_lowres_lowpower_2_LR.tif


File being processed:  gdrive/My Drive/PSSR_0625/stats/LR/real-world_mitotracker/realworld_lowres_lowpower_10_LR.tif


File being processed:  gdrive/My Drive/PSSR_0625/stats/LR/real-world_mitotracker/realworld_lowres_lowpower_9_LR.tif


File being processed:  gdrive/My Drive/PSSR_0625/stats/LR/real-world_mitotracker/realworld_lowres_lowpower_1_LR.tif


File being processed:  gdrive/My Drive/PSSR_0625/stats/LR/real-world_mitotracker/realworld_lowres_lowpower_8_LR.tif


File being processed:  gdrive/My Drive/PSSR_0625/stats/LR/real-world_mitotracker/realworld_lowres_lowpower_4_LR.tif


File being processed:  gdrive/My Drive/PSSR_0625/stats/LR/real-world_mitotracker/realworld_lowres_lowpower_6_LR.tif


File being processed:  gdrive/My Drive/PSSR_0625/stats/LR/real-world_mitotracker/realworld_lowres_lowpower_5_LR.tif


File being processed:  gdrive/My Drive/PSSR_0625/stats/LR/real-world_mitotracker/realworld_lowres_lowpower_3_LR.tif


All done!
