# Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%reload_ext autoreload

In [2]:
from skimage import io
from skimage import img_as_uint
from glob import glob
import numpy as np
import tifffile

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
import torch
torch.cuda.empty_cache()
#import torch
#torch.cuda.set_per_process_memory_fraction(0.5, 0)

from denoiser import Denoiser
from lsm_utils import normalize_16bit_images, compute_norm_range
import yaml

# Load data 
- SHG samples
- load model config data from yaml file

In [3]:
def sandbox_compute_norm_range(fname, percentiles=(0, 100), sample_r=0.1):
    max_val = []
    min_val = []
    fail_names = []
    try:
        img = img_as_uint(io.imread(fname))
    except Exception as e:
        print(e)
        print(fname)
        fail_names.append(fname)
    max_val.append(np.percentile(img[:512,:512], percentiles[1]))
    min_val.append(np.percentile(img[:512,:512], percentiles[0]))
    max_val: float | np.ndarray = np.percentile(np.array(max_val), 98)
    min_val: float | np.ndarray = np.percentile(np.array(min_val), 2)
    
    return min_val, max_val, fail_names

# vmin, vmax, fail_names = sandbox_compute_norm_range('sample_data/1B_C1.tif', percentiles=(1, 99.5), sample_r=0.05)
vmin, vmax, fail_names = compute_norm_range('sample_data', ext='tif', percentiles=(1, 99.5), sample_r=1)

sample_data\*.tif


100%|██████████| 1/1 [00:00<00:00,  2.63it/s]


In [5]:
config = yaml.load(open("model_config.yaml", "r"), Loader=yaml.FullLoader)
config['dataset'] = 'sample_data'
config['norm-range'] = [int(vmin), int(vmax)]
config['threads'] = 0

# Create model instance
* Create instance of denoiser with new config data

In [6]:
# Set background screening to false to allow PB522-14-MAX-Fused.tif to be accepted
denoiser = Denoiser(config, screen_bg=False)

Using cache found in C:\Users\lociuser/.cache\torch\hub\mateuszbuda_brain-segmentation-pytorch_master


In [16]:
# Denoiser only works on grayscale images
denoiser.denoise(sampling=True, sample_rate=.10)

In [4]:
vmin, vmax, fail_names = compute_norm_range('sample_data', ext='tif', percentiles=(1, 99.5), sample_r=1)
config = yaml.load(open("model_config.yaml", "r"), Loader=yaml.FullLoader)
config['dataset'] = 'sample_data'
config['norm-range'] = [int(vmin), int(vmax)]
config['threads'] = 0
denoiser = Denoiser(config, screen_bg=False)

sample_data\*.tif


100%|██████████| 1/1 [00:00<00:00,  2.50it/s]
Using cache found in C:\Users\lociuser/.cache\torch\hub\mateuszbuda_brain-segmentation-pytorch_master


In [6]:
import warnings
from skimage import exposure
import torch.nn.functional as F 

p = config['blindspot-rate']
pass_times = int(1/p * config['average-factor'])
iterations = int(np.ceil(pass_times/50))
model = denoiser.backbone
device = next(model.parameters()).device
fname = "sample_data/PB522-14-MAX_Fused.tif"
img_arr = img_as_uint(io.imread(fname))[:512, :512]
img_arr = exposure.rescale_intensity(img_arr, in_range=(config['norm-range'][0], config['norm-range'][1]), out_range=(0, 65535)).astype(int)
img_input = exposure.rescale_intensity(img_arr, in_range=(0, 65535), out_range=(0, 1))
img_tensor = torch.from_numpy(img_input)
img_hyper_tensor = img_tensor.expand([50, 1, img_tensor.shape[0], img_tensor.shape[1]]).float().to(device)
out_tensor = img_tensor * 0
for i in range(iterations):
    drop_mask = F.dropout(torch.ones(img_hyper_tensor.shape, requires_grad=False).to(device), p=p, inplace=True)*(1-p) # p percent zero, keep
    pad_mask = (1-drop_mask) * torch.ones(img_hyper_tensor.shape, device=device, dtype=torch.float32) * torch.mean(img_hyper_tensor, (2, 3), keepdim=True).expand_as(img_hyper_tensor)
    spotted = torch.mul(img_hyper_tensor, drop_mask) + pad_mask
    prediction = model(spotted)
    prediction = torch.mul(prediction, 1-drop_mask)/p
    out_tensor += torch.mean(prediction, 0).squeeze().cpu()/iterations
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    out_arr = img_as_uint(np.clip(out_tensor.numpy().squeeze(), 0, 1))
    img_name = os.path.basename(fname)
    io.imsave(os.path.join("output-self/sample_data/clean/", img_name), out_arr)
    io.imsave(os.path.join("output-self/sample_data/noisy/", img_name), img_as_uint(img_arr))
print(f'Processed', end='\r')

KeyboardInterrupt: 