In [None]:
ckpt_dir = "/group/jug/ashesh/training/disentangle/2407/D28-M3-S0-L0/9"
datafile = '/facility/imganfacusers/Elisa/DIF17/DIF_17_1/DIF_17_Day_25_A100_1_4_0001.nd2'
batch_size = 8

In [None]:
from disentangle.data_loader.evaluation_dloader import EvaluationDloader
from disentangle.nets.model_utils import create_model
from nis2pyr.reader import read_nd2file
from disentangle.config_utils import load_config
import nd2
import glob
import torch
import numpy as np

def load_7D(fpath):    
    print(f'Loading from {fpath}')
    with nd2.ND2File(fpath) as nd2file:
        data = read_nd2file(nd2file)
    return data

def get_best_checkpoint(ckpt_dir):
    output = []
    for filename in glob.glob(ckpt_dir + "/*_best.ckpt"):
        output.append(filename)
    assert len(output) == 1, '\n'.join(output)
    return output[0]

In [None]:
data = load_7D(datafile)
data = data[0,0,:,1,...,0]

In [None]:
config = load_config(ckpt_dir)

In [None]:
test_data= data[8:12].copy()
test_data = test_data.astype(np.float32)
test_data -= config.data.background_values[0]

In [None]:
import matplotlib.pyplot as plt
plt.imshow(test_data[0], vmax=30)

## Mean/Stdev

In [None]:
from disentangle.data_loader.multicrops_dset import l2
def sample_crop(sz):
    t = np.random.randint(0, len(test_data))
    x = np.random.randint(0, test_data.shape[1] - sz)
    y = np.random.randint(0, test_data.shape[2] - sz)
    crop = test_data[t, x:x+sz, y:y+sz]
    return crop

def compute_mean_std():
    mean_inp = []
    std_inp = []
    for _ in range(30000):
        crop = sample_crop(config.data.image_size)
        mean_inp.append(np.mean(crop))
        std_inp.append(np.std(crop))

    output_mean = {}
    output_std = {}
    output_mean['input'] = np.array([np.mean(mean_inp)]).reshape(-1,1,1,1)
    output_std['input'] = np.array([l2(std_inp)]).reshape(-1,1,1,1)
    
    output_mean['target'] = np.tile(output_mean['input'],(1,2,1,1))
    output_std['target'] = np.tile(output_std['input'],(1,2,1,1))
    return output_mean, output_std


In [None]:
mean_dict, std_dict = compute_mean_std()

In [None]:
model = create_model(config, mean_dict.copy(),std_dict.copy())

In [None]:
ckpt_fpath = get_best_checkpoint(ckpt_dir)
print('Loading checkpoint from', ckpt_fpath)
checkpoint = torch.load(ckpt_fpath)

_ = model.load_state_dict(checkpoint['state_dict'], strict=False)
model.eval()
_= model.cuda()
model.set_params_to_same_device_as(torch.Tensor(1).cuda())

print('Loading from epoch', checkpoint['epoch'])

In [None]:
plt.imshow(test_data[2], vmax=30)

In [None]:
def normalizer(x):
    return (x - mean_dict['input'].squeeze()) / std_dict['input'].squeeze()

In [None]:
inp_patch = test_data[2,1800:1928,1500:1628]
plt.imshow(inp_patch)

In [None]:
model.reset_for_different_output_size(inp_patch.shape[0])
model.mode_pred = True

In [None]:
inp = normalizer(inp_patch)
with torch.no_grad():
    out = model(torch.Tensor(inp[None,None]).cuda())
out[0].shape
plt.imshow(out[0][0,1].cpu().numpy(), vmax=30)

In [None]:
dset = EvaluationDloader(test_data[:1], normalizer, lambda x: x, config.data.image_size, config.data.image_size//2, GridAlignement.Center)

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm

def get_dset_predictions(model, dset, batch_size, mmse_count=1, num_workers=4):
    model.reset_for_different_output_size(dset[0].shape[0])
    
    dloader = DataLoader(dset, pin_memory=False, num_workers=num_workers, shuffle=False, batch_size=batch_size)
    predictions = []
    predictions_std = []
    with torch.no_grad():
        for inp in tqdm(dloader):
            inp = inp.cuda()
            recon_img_list = []
            for mmse_idx in range(mmse_count):
                imgs, _ = model(inp)
                recon_img_list.append(imgs.cpu()[None])

            samples = torch.cat(recon_img_list, dim=0)
            mmse_imgs = torch.mean(samples, dim=0)
            mmse_std = torch.std(samples, dim=0)
            predictions.append(mmse_imgs.cpu().numpy())
            predictions_std.append(mmse_std.cpu().numpy())
    return np.concatenate(predictions, axis=0), np.concatenate(predictions_std, axis=0)


In [None]:
pred_tiled, pred_std = get_dset_predictions(model, dset, batch_size, mmse_count=1, num_workers=4)

In [None]:
from disentangle.analysis.stitch_prediction import stitch_predictions
pred = stitch_predictions(pred_tiled,dset)


In [None]:
from disentangle.analysis.plot_utils import clean_ax

_,ax = plt.subplots(figsize=(16,8),ncols=4,nrows=2)
ax= ax.reshape(-1,)
sz = 800
for i in range(len(ax)//2):
    hs = np.random.randint(0, test_data.shape[1] - sz)
    ws = np.random.randint(0, test_data.shape[1] - sz)
    ax[2*i].imshow(test_data[0,hs:hs+sz,ws:ws+sz], vmax=30)
    ax[2*i+1].imshow(pred[0,hs:hs+sz,ws:ws+sz,1])
    ax[2*i].set_title('Input')
    ax[2*i+1].set_title('Puncta Removed')
clean_ax(ax)