In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import os
DEBUG=False

In [None]:
%run ./nb_core/root_dirs.ipynb
setup_syspath_disentangle(DEBUG)
%run ./nb_core/disentangle_imports.ipynb

In [None]:
from disentangle.configs.sox2golgi_config import get_config
from disentangle.core.model_type import ModelType
from disentangle.core.data_split_type import DataSplitType
from disentangle.data_loader.multifile_dset import MultiFileDset

config = get_config()
datapath = '/group/jug/ashesh/data/TavernaSox2Golgi/'

normalized_input = config.data.normalized_input
use_one_mu_std = config.data.use_one_mu_std
train_aug_rotate = config.data.train_aug_rotate
enable_random_cropping = config.data.deterministic_grid is False
lowres_supervision = config.model.model_type == ModelType.LadderVAEMultiTarget

train_data_kwargs = {}
val_data_kwargs = {}
train_data_kwargs['enable_random_cropping'] = enable_random_cropping
val_data_kwargs['enable_random_cropping'] = False
padding_kwargs = None
if 'multiscale_lowres_count' in config.data and config.data.multiscale_lowres_count is not None:
    padding_kwargs = {'mode': config.data.padding_mode}
if 'padding_value' in config.data and config.data.padding_value is not None:
    padding_kwargs['constant_values'] = config.data.padding_value

train_data = MultiFileDset(config.data,
                            datapath,
                            datasplit_type=DataSplitType.Train,
                            val_fraction=config.training.val_fraction,
                            test_fraction=config.training.test_fraction,
                            normalized_input=normalized_input,
                            use_one_mu_std=use_one_mu_std,
                            enable_rotation_aug=train_aug_rotate,
                            padding_kwargs=padding_kwargs,
                            **train_data_kwargs)

max_val = train_data.get_max_val()
val_data = MultiFileDset(
    config.data,
    datapath,
    datasplit_type=DataSplitType.Val,
    val_fraction=config.training.val_fraction,
    test_fraction=config.training.test_fraction,
    normalized_input=normalized_input,
    use_one_mu_std=use_one_mu_std,
    enable_rotation_aug=False,  # No rotation aug on validation
    padding_kwargs=padding_kwargs,
    max_val=max_val,
    **val_data_kwargs,
)

mean_val, std_val = train_data.compute_mean_std()
train_data.set_mean_std(mean_val, std_val)
val_data.set_mean_std(mean_val, std_val)


In [None]:
inp, tar = val_data[0]

In [None]:
_,ax = plt.subplots(figsize=(9,3),ncols=3)
ax[0].imshow(inp[0])
ax[1].imshow(tar[0])
ax[2].imshow(tar[1])

In [None]:
inp.shape

In [None]:
inp_arr = []
for i in range(len(val_data)):
    inp, tar = val_data[i]
    inp_arr.append(inp)

In [None]:
inpdata= np.concatenate(inp_arr,axis=0)

In [None]:
import matplotlib.pyplot as plt
_ = plt.hist(inpdata.flatten(),bins=100)

In [None]:
# import seaborn as sns
# sns.histplot(inpdata.flatten(),bins=100)

In [None]:
np.quantile(inpdata,[0.0, 0.01, 0.1, 0.5, 0.99,1])

In [None]:
# config.data


In [None]:
from disentangle.data_loader.sox2golgi_rawdata_loader import (get_train_val_data, get_one_channel_files, get_two_channel_files, SubDsetType)
datadir = '/group/jug/ashesh/data/TavernaSox2Golgi/'

config.data.subdset_type = SubDsetType.TwoChannel
data2ch = get_train_val_data(datadir,
                       config.data,
                       DataSplitType.Test,
                       val_fraction=0.1,
                       test_fraction=0.1)

config.data.subdset_type = SubDsetType.OneChannel
data1ch = get_train_val_data(datadir,
                       config.data,
                       DataSplitType.Test,
                       val_fraction=0.1,
                       test_fraction=0.1)

In [None]:
len(data1ch), len(data2ch)

In [None]:
import numpy as np
input1ch = []
input2ch = []
for idx in range(len(data1ch)):
    input1ch.append(np.mean(data1ch[idx][0],axis=2, keepdims=True))

for idx in range(len(data2ch)):
    input2ch.append(np.mean(data2ch[idx][0],axis=2, keepdims=True))

input1ch = np.concatenate(input1ch,axis=-1)
input2ch = np.concatenate(input2ch,axis=-1)

In [None]:
import seaborn as sns
_,ax = plt.subplots()
sns.histplot(input1ch.flatten()/2,bins=100, color='red', label='1ch', stat='density')
sns.histplot(input2ch.flatten(),bins=100, color='blue', label='2ch', stat='density')
ax.legend()

In [None]:
print('input 1ch', np.quantile(input1ch/2,[0.0, 0.01, 0.1, 0.5, 0.9, 0.99,1]).astype(np.int32))
print('input 2ch', np.quantile(input2ch,[0.0, 0.01, 0.1, 0.5, 0.9, 0.99,1]).astype(np.int32))

In [None]:
ch1 = []
ch2 = []
for idx in range(len(data2ch)):
    tmpd = data2ch[idx][0]
    ch1.append(tmpd[:,:,:1])
    ch2.append(tmpd[:,:,1:])

ch1 = np.concatenate(ch1,axis=-1)
ch2 = np.concatenate(ch2,axis=-1)

In [None]:
_,ax = plt.subplots()
sns.histplot(ch1.flatten(),bins=100, color='red', label='channel 1st', stat='density')
sns.histplot(ch2.flatten(),bins=100, color='blue', label='channel 2nd', stat='density')
ax.legend()

In [None]:
print('channel 1', np.quantile(ch1,[0.0, 0.01, 0.1, 0.5, 0.99,1]).astype(np.int32))
print('channel 2', np.quantile(ch2,[0.0, 0.01, 0.1, 0.5, 0.99,1]).astype(np.int32))