In [None]:
import os
from disentangle.data_loader.evaluation_dloader import EvaluationDloader
from disentangle.data_loader.patch_index_manager import GridAlignement
from disentangle.nets.model_utils import create_model
from nis2pyr.reader import read_nd2file
from disentangle.config_utils import load_config
import nd2
import glob
import torch
import numpy as np


def load_7D(fpath):    
    print(f'Loading from {fpath}')
    with nd2.ND2File(fpath) as nd2file:
        data = read_nd2file(nd2file)
    return data

def get_best_checkpoint(ckpt_dir):
    output = []
    for filename in glob.glob(ckpt_dir + "/*_best.ckpt"):
        output.append(filename)
    assert len(output) == 1, '\n'.join(output)
    return output[0]

In [None]:
# ckpt_dir = "/group/jug/ashesh/training/disentangle/2407/D28-M3-S0-L0/16"
ckpt_dir = "/group/jug/ashesh/training/disentangle/2407/D28-M3-S0-L0/20"
data_dir = '/facility/imganfacusers/Elisa/DIF17/DIF_17_1'
fnames = [x for x in sorted(os.listdir(data_dir)) if x.endswith('0001.nd2')]

datafile = os.path.join(data_dir, fnames[16])
batch_size = 8

In [None]:
data = load_7D(datafile)
data = data[0,0,:,1,...,0]

In [None]:
config = load_config(ckpt_dir)

In [None]:
test_z_idx = 8
test_data= data[test_z_idx:test_z_idx+1].copy()
test_data = test_data.astype(np.float32)
test_data -= config.data.background_values[0]

In [None]:
data.shape

In [None]:
import matplotlib.pyplot as plt
plt.imshow(data[8][1500:2500,1000:1500], vmax=130)

## Mean/Stdev

In [None]:
from disentangle.data_loader.multicrops_dset import l2
def sample_crop(sz):
    t = np.random.randint(0, len(test_data))
    x = np.random.randint(0, test_data.shape[1] - sz)
    y = np.random.randint(0, test_data.shape[2] - sz)
    crop = test_data[t, x:x+sz, y:y+sz]
    return crop

def compute_mean_std():
    mean_inp = []
    std_inp = []
    for _ in range(30000):
        crop = sample_crop(config.data.image_size)
        mean_inp.append(np.mean(crop))
        std_inp.append(np.std(crop))

    output_mean = {}
    output_std = {}
    output_mean['input'] = np.array([np.mean(mean_inp)]).reshape(-1,1,1,1)
    output_std['input'] = np.array([l2(std_inp)]).reshape(-1,1,1,1)
    
    output_mean['target'] = np.tile(output_mean['input'],(1,2,1,1))
    output_std['target'] = np.tile(output_std['input'],(1,2,1,1))
    return output_mean, output_std


In [None]:
mean_dict, std_dict = compute_mean_std()

In [None]:
model = create_model(config, mean_dict.copy(),std_dict.copy())

In [None]:
ckpt_fpath = get_best_checkpoint(ckpt_dir)
print('Loading checkpoint from', ckpt_fpath)
checkpoint = torch.load(ckpt_fpath)

_ = model.load_state_dict(checkpoint['state_dict'], strict=False)
model.eval()
_= model.cuda()
model.set_params_to_same_device_as(torch.Tensor(1).cuda())

print('Loading from epoch', checkpoint['epoch'])

In [None]:
def normalizer(x):
    return (x - mean_dict['input'].squeeze()) / std_dict['input'].squeeze()

In [None]:
inp_patch = test_data[0,1800:1928,1500:1628]
plt.imshow(inp_patch)

In [None]:
model.reset_for_different_output_size(inp_patch.shape[0])
model.mode_pred = True

In [None]:
inp = normalizer(inp_patch)
with torch.no_grad():
    out = model(torch.Tensor(inp[None,None]).cuda())
out[0].shape
plt.imshow(out[0][0,1].cpu().numpy(), vmax=30)

In [None]:
dset = EvaluationDloader(test_data, normalizer, lambda x: x, config.data.image_size, config.data.image_size//4, GridAlignement.Center)

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm

def get_dset_predictions(model, dset, batch_size, mmse_count=1, num_workers=4):
    model.reset_for_different_output_size(dset[0].shape[0])
    
    dloader = DataLoader(dset, pin_memory=False, num_workers=num_workers, shuffle=False, batch_size=batch_size)
    predictions = []
    predictions_std = []
    with torch.no_grad():
        for inp in tqdm(dloader):
            inp = inp.cuda()
            recon_img_list = []
            for mmse_idx in range(mmse_count):
                imgs, _ = model(inp)
                recon_img_list.append(imgs.cpu()[None])

            samples = torch.cat(recon_img_list, dim=0)
            mmse_imgs = torch.mean(samples, dim=0)
            mmse_std = torch.std(samples, dim=0)
            predictions.append(mmse_imgs.cpu().numpy())
            predictions_std.append(mmse_std.cpu().numpy())
    return np.concatenate(predictions, axis=0), np.concatenate(predictions_std, axis=0)


In [None]:
pred_tiled, pred_std = get_dset_predictions(model, dset, batch_size*10, mmse_count=50, num_workers=4)

In [None]:
from disentangle.analysis.stitch_prediction import stitch_predictions
pred = stitch_predictions(pred_tiled,dset)


In [None]:
from disentangle.analysis.plot_utils import clean_ax
from matplotlib.colors import LogNorm

_,ax = plt.subplots(figsize=(16,8),ncols=4,nrows=2)
ax= ax.reshape(-1,)
t_idx =0
sz = 300
for i in range(len(ax)//2):
    hs = np.random.randint(0, test_data.shape[1] - sz)
    ws = np.random.randint(0, test_data.shape[2] - sz)
    ax[2*i].imshow(test_data[t_idx,hs:hs+sz,ws:ws+sz], vmax=130)
    ax[2*i+1].imshow(pred[t_idx,hs:hs+sz,ws:ws+sz,1])
    ax[2*i].set_title(f'Input, {t_idx,hs,ws}')
    ax[2*i+1].set_title('Puncta Removed')
clean_ax(ax)

## Nature methods plot

In [None]:
# fnames:7, [8,1500:2700,500:2200]
# fnames:8, [7/8/9/10,:1400,:800] => good
# fnames:8, [7,1000:2400,800:2200]
# fnames[13], [9/11,900:1500,1000:2400]
# fnames[14], [9,2400:3800,700:2000]
# fnames[15] [9,2200:4500,500:1500]
# fnames[15] [9,300:2400,1000:2800]
# fnames[15] [9,500:2400,2800:3900]
# fnames[15] [9, 2500:4500,600:1500]
# fnames[16] [9,400:1500,500:1700]

In [None]:
!ls

In [None]:
import os
output_dir =os.path.join('/group/jug/ashesh/naturemethods/puncta/', os.path.basename(datafile).replace('.nd2',''))
os.makedirs(output_dir, exist_ok=True)
output_dir

In [None]:
plt.imshow(test_data[0,400:1600,1400:2600], vmax=130)

In [None]:
save_to_file = True
hs_region = 400
ws_region = 500
# hs_region = 400
# ws_region = 1400
sz = 1200
inp_region = test_data[0,hs_region:hs_region+sz,ws_region:ws_region+sz]
pred_region = pred[0,hs_region:hs_region+sz,ws_region:ws_region+sz,1]
plt.imshow(inp_region, vmax=130)
if save_to_file:
    fname_prefix = f'z.{test_z_idx}_region.{hs_region}-{ws_region}_sz.{sz}'
    print(fname_prefix)

In [None]:
# np.save(f'inp.npy', inp_region)
# np.save(f'pred.npy', pred_region)

import matplotlib.patches as patches
cropsz = 256

_,ax = plt.subplots(figsize=(16,8),ncols=2,nrows=1)
ax[0].imshow(inp_region, vmax=130)
ax[1].imshow(pred_region)
clean_ax(ax)
hw_arr= [(900, 310),
(628, 313),
(758, 80),
(605, 49),
(424, 815),
(449, 541),
(92, 684),
(587,844)
]
# hw_arr = [
# (35, 50),
#  (591,434),
#  (911,568),
#  (917,395),
#  (127,684),
#  (662,804),
#  (350,179),
#  (72,498),
# ]
for i, loc in enumerate(hw_arr):
    (h_s, w_s) = loc
    rect = patches.Rectangle((w_s, h_s), cropsz, cropsz, linewidth=1, edgecolor='w', facecolor='none', linestyle='--')
    ax[0].add_patch(rect)
    # add a number at the top left of the rectangle
    ax[0].text(w_s, h_s, str(i+1), color='black', fontsize=14)

# adjust the subplot gap
plt.subplots_adjust(wspace=0.02, hspace=0.02)
if save_to_file:    
    fpath = os.path.join(output_dir, f'{fname_prefix}_full_region.png')
    # save with high dpi
    plt.savefig(fpath, dpi=100)
    print('Saved to', fpath)

In [None]:
num_crops = len(hw_arr)
imgsz = 2.1
# hw_arr = [(np.random.randint(0, inp_region.shape[0] - cropsz), np.random.randint(0, inp_region.shape[1] - cropsz)) for _ in range(num_crops)]

_,ax = plt.subplots(figsize=(num_crops*imgsz,2*imgsz), ncols=num_crops, nrows=2)
for i,(h,w) in enumerate(hw_arr):
    print(f'{h},{w}')
    ax[0,i].imshow(inp_region[h:h+cropsz,w:w+cropsz], vmax=130)
    ax[0,i].text(10,30, str(i+1), color='black', fontsize=14)
    ax[1,i].imshow(pred_region[h:h+cropsz,w:w+cropsz])
    clean_ax(ax[:,i])
plt.subplots_adjust(wspace=0.05, hspace=0.05)
if save_to_file:
    fpath = os.path.join(output_dir, f'{fname_prefix}_crops.png')
    plt.savefig(fpath, dpi=100)
    print('Saved to', fpath)

In [None]:
# 591,434
# 911,568
# 917,395
# 127,684
# 662,804
# 350,179
# 72,498