## Objective
In this notebook, we will create multiple OOD datasets for the HTLIF24 data and HTT24 dataset. 

In [None]:
ckpt_dir = '/group/jug/ashesh/training/disentangle/2504/D21-M3-S0-L0/12'

In [None]:
import os
from copy import deepcopy

import numpy as np
import torch
from tqdm import tqdm

from disentangle.analysis.checkpoint_utils import get_best_checkpoint
from disentangle.config_utils import load_config
from disentangle.core.data_split_type import DataSplitType
from disentangle.core.data_type import DataType
from disentangle.data_loader.patch_index_manager import TilingMode
from disentangle.training import create_dataset, create_model


In [None]:
config = load_config(ckpt_dir)
data_dir = "/group/jug/ashesh/data/TavernaSox2Golgi/acquisition2/"

In [None]:
padding_kwargs = {
        "mode": config.data.get("padding_mode", "constant"),
    }
if padding_kwargs["mode"] == "constant":
    padding_kwargs["constant_values"] = config.data.get("padding_value", 0)

dloader_kwargs = {
        "overlapping_padding_kwargs": padding_kwargs,
        "tiling_mode": TilingMode.ShiftBoundary,
    }

train_dset, val_dset = create_dataset(
    config,
    data_dir,
    eval_datasplit_type=DataSplitType.Test,
    kwargs_dict=dloader_kwargs,
)


In [None]:
from disentangle.core.data_type import DataType
dtype = DataType.name(config.data.data_type)
dtype

In [None]:
outputdir = '/group/jug/ashesh/EnsDeLyon/OOD_data'
outputdir = os.path.join(outputdir, dtype)
if not os.path.exists(outputdir):
    os.makedirs(outputdir)
outputdir

In [None]:
import matplotlib.pyplot as plt
plt.imshow(val_dset.dsets[0]._data[0,...,1])

In [None]:
import numpy as np
ch0 = np.concatenate([val_dset.dsets[i]._data[...,0] for i in range(len(val_dset.dsets))], axis=0)
ch1 = np.concatenate([val_dset.dsets[i]._data[...,1] for i in range(len(val_dset.dsets))], axis=0)
print(ch0.shape, ch1.shape)

### Mixing OOD

In [None]:
from disentangle.core.tiff_reader import save_tiff
for w in np.arange(start=0., stop=1.01, step=0.1):
    inp = w* ch0 + (1-w)*ch1
    inp = inp[...,None]
    fpath = os.path.join(outputdir, f"{dtype}_Test_W{w:.1f}.tif")
    print(f"Saving {inp.shape} {fpath}")
    save_tiff(fpath, inp)

## Real Input

In [None]:
from disentangle.data_loader.sox2golgi_v2_rawdata_loader import Sox2GolgiV2ChannelList
import ml_collections
config = ml_collections.ConfigDict(config)
with config.unlocked():
    config.data.input_idx = 2
    config.data.channel_idx_list = [Sox2GolgiV2ChannelList.GT_Cy5, Sox2GolgiV2ChannelList.GT_TRITC, Sox2GolgiV2ChannelList.GT_555_647]
    config.data.num_channels = len(config.data.channel_idx_list)

_, val_dset = create_dataset(
    config,
    data_dir,
    eval_datasplit_type=DataSplitType.Test,
    kwargs_dict=dloader_kwargs,
)


In [None]:
plt.imshow(val_dset.dsets[0]._data[0,...,2])

In [None]:
real_inp = np.concatenate([val_dset.dsets[i]._data[...,2] for i in range(len(val_dset.dsets))], axis=0)
print(real_inp.shape)

In [None]:
fpath = os.path.join(outputdir, f"{dtype}_Test_RealInput.tif")
print(f"Saving {real_inp.shape} {fpath}")
save_tiff(fpath, real_inp[...,None])