# Imports

In [12]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
%reload_ext autoreload

In [14]:
from skimage import io
from skimage import img_as_uint
from glob import glob
import numpy as np
import tifffile

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
import torch
torch.cuda.empty_cache()
#import torch
#torch.cuda.set_per_process_memory_fraction(0.5, 0)

from denoiser import Denoiser
from lsm_utils import normalize_16bit_images, compute_norm_range
import yaml

# Load data 
- SHG samples
- load model config data from yaml file

In [15]:
def sandbox_compute_norm_range(fname, percentiles=(0, 100), sample_r=0.1):
    max_val = []
    min_val = []
    fail_names = []
    try:
        img = img_as_uint(io.imread(fname))
    except Exception as e:
        print(e)
        print(fname)
        fail_names.append(fname)
    max_val.append(np.percentile(img[:256,:256], percentiles[1]))
    min_val.append(np.percentile(img[:256,:256], percentiles[0]))
    max_val: float | np.ndarray = np.percentile(np.array(max_val), 98)
    min_val: float | np.ndarray = np.percentile(np.array(min_val), 2)
    
    return min_val, max_val, fail_names

# vmin, vmax, fail_names = sandbox_compute_norm_range('sample_data/PB522-14-MAX_Fused.tif', percentiles=(1, 99.5), sample_r=1)
vmin, vmax, fail_names = compute_norm_range('cropped_sample_data', ext='tif', percentiles=(1, 99.5), sample_r=1)

cropped_sample_data\*.tif


100%|██████████| 1/1 [00:00<00:00, 143.23it/s]


# Create model instance
* Create instance of denoiser with new config data

In [16]:
# vmin, vmax, fail_names = compute_norm_range('sample_data', ext='tif', percentiles=(1, 99.5), sample_r=1)
config = yaml.load(open("model_config.yaml", "r"), Loader=yaml.FullLoader)
config['dataset'] = 'cropped_sample_data'
config['norm-range'] = [int(vmin), int(vmax)]
config['threads'] = 0
denoiser = Denoiser(config, screen_bg=False)

Using cache found in C:\Users\lociuser/.cache\torch\hub\mateuszbuda_brain-segmentation-pytorch_master


# Denoise

Works on one image input without need of config file

In [18]:
import warnings
from skimage import exposure
import torch.nn.functional as F 

average_factor = 50
blindspot_rate = 0.05
pass_times = int(1/blindspot_rate * average_factor)
iterations = int(np.ceil(pass_times/50))
model = denoiser.backbone
device = next(model.parameters()).device
fname = "cropped_sample_data/PB522-14-MAX_Fused.tif"
img_arr = img_as_uint(io.imread(fname))
img_arr = exposure.rescale_intensity(img_arr, in_range=(int(vmin), int(vmax)), out_range=(0, 65535)).astype(int)
img_input = exposure.rescale_intensity(img_arr, in_range=(0, 65535), out_range=(0, 1))
img_tensor = torch.from_numpy(img_input)
img_hyper_tensor = img_tensor.expand([50, 1, img_tensor.shape[0], img_tensor.shape[1]]).float().to(device)
out_tensor = img_tensor * 0

for i in range(iterations):
    drop_mask = F.dropout(torch.ones(img_hyper_tensor.shape, requires_grad=False).to(device), p=blindspot_rate, inplace=True)*(1-blindspot_rate) # p percent zero, keep
    pad_mask = (1-drop_mask) * torch.ones(img_hyper_tensor.shape, device=device, dtype=torch.float32) * torch.mean(img_hyper_tensor, (2, 3), keepdim=True).expand_as(img_hyper_tensor)
    spotted = torch.mul(img_hyper_tensor, drop_mask) + pad_mask
    prediction = model(spotted)
    prediction = torch.mul(prediction, 1-drop_mask)/blindspot_rate
    out_tensor += torch.mean(prediction, 0).squeeze().cpu()/iterations
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    out_arr = img_as_uint(np.clip(out_tensor.detach().numpy().squeeze(), 0, 1))
    img_name = os.path.basename(fname)
    os.makedirs("output-self/cropped_sample_data/clean/", exist_ok=True)
    os.makedirs("output-self/cropped_sample_data/noisy/", exist_ok=True)
    io.imsave(os.path.join("output-self/cropped_sample_data/clean/", img_name), out_arr)
    io.imsave(os.path.join("output-self/cropped_sample_data/noisy/", img_name), img_as_uint(img_arr))
print(f'Processed', end='\r')

Processed