## Purpose
This is to be used for loading the data loader and loading the model

In [None]:
# if image_size_for_grid_centers is None:
#     image_size_for_grid_centers = config.data.image_size


In [None]:
print('')

In [None]:
# from disentangle.data_loader.overlapping_dloader import get_overlapping_dset
from disentangle.data_loader.multi_channel_determ_tiff_dloader import MultiChDeterministicTiffDloader
from disentangle.data_loader.multiscale_mc_tiff_dloader import MultiScaleTiffDloader
from disentangle.core.data_split_type import DataSplitType
from disentangle.data_loader.single_channel.single_channel_dloader import SingleChannelDloader
from disentangle.data_loader.single_channel.single_channel_mc_dloader import SingleChannelMSDloader
from disentangle.data_loader.pavia2_3ch_dloader import Pavia2ThreeChannelDloader
from disentangle.data_loader.patch_index_manager import GridAlignement
from disentangle.data_loader.multi_dset_dloader import IBA1Ki67DataLoader

padding_kwargs = {
    'mode':config.data.get('padding_mode','constant'),
}

if padding_kwargs['mode'] == 'constant':
    padding_kwargs['constant_values'] = config.data.get('padding_value',0)

dloader_kwargs = {'overlapping_padding_kwargs':padding_kwargs}
if config.data.data_type == DataType.SemiSupBloodVesselsEMBL:
    if 'multiscale_lowres_count' in config.data and config.data.multiscale_lowres_count is not None:
        data_class = get_overlapping_dset(SingleChannelMSDloader)
        dloader_kwargs['num_scales'] = config.data.multiscale_lowres_count
        dloader_kwargs['padding_kwargs'] = padding_kwargs
    else:
        data_class = get_overlapping_dset(SingleChannelDloader)
elif config.data.data_type == DataType.Pavia2:
    data_class = get_overlapping_dset(Pavia2ThreeChannelDloader)

elif config.data.data_type == DataType.HTIba1Ki67 and config.model.model_type in [ModelType.LadderVaeMultiDataSet, 
                                    ModelType.LadderVaeMultiDatasetMultiBranch, ModelType.LadderVaeMultiDatasetMultiOptim]:
    data_class = IBA1Ki67DataLoader

elif 'multiscale_lowres_count' in config.data and config.data.multiscale_lowres_count is not None:
    data_class = MultiScaleTiffDloader
    dloader_kwargs['num_scales'] = config.data.multiscale_lowres_count
    dloader_kwargs['padding_kwargs'] = padding_kwargs

elif config.model.model_type==ModelType.AutoRegresiveLadderVAE:
    from disentangle.data_loader.autoregressive_dloader import AutoRegressiveDloader
    data_class = AutoRegressiveDloader
else:
    # data_class = get_overlapping_dset(MultiChDeterministicTiffDloader)
    data_class = MultiChDeterministicTiffDloader

if config.data.data_type in [DataType.CustomSinosoid, DataType.CustomSinosoidThreeCurve, 
                             DataType.AllenCellMito,DataType.SeparateTiffData,
                            DataType.SemiSupBloodVesselsEMBL, DataType.BSD68]:
    datapath = data_dir
elif config.data.data_type == DataType.OptiMEM100_014:
    datapath = os.path.join(data_dir, 'OptiMEM100x014.tif')
elif config.data.data_type == DataType.Prevedel_EMBL:
    datapath = os.path.join(data_dir, 'MS14__z0_8_sl4_fr10_p_10.1_lz510_z13_bin5_00001.tif')
# elif config.data.data_type == DataType.Convallaria:
#     datapath = os.path.join(data_dir, '20190520_tl_25um_50msec_05pc_488_130EM_Conv_withChannel.tif')
else:
    datapath = data_dir

normalized_input = config.data.normalized_input
use_one_mu_std = config.data.use_one_mu_std
train_aug_rotate = config.data.train_aug_rotate
enable_random_cropping = False #config.data.deterministic_grid is False
grid_alignment = GridAlignement.Center
print(data_class)

train_dset = data_class(
                config.data,
                datapath,
                datasplit_type=DataSplitType.Train,
                val_fraction=config.training.val_fraction,
                test_fraction=config.training.test_fraction,
                normalized_input=normalized_input,
                use_one_mu_std=use_one_mu_std,
                enable_rotation_aug=train_aug_rotate,
                enable_random_cropping=enable_random_cropping,
                grid_alignment=grid_alignment,
                **dloader_kwargs)
import gc
gc.collect()
max_val = train_dset.get_max_val()

val_dset = data_class(
                config.data,
                datapath,
                datasplit_type=eval_datasplit_type,
                val_fraction=config.training.val_fraction,
                test_fraction=config.training.test_fraction,
                normalized_input=normalized_input,
                use_one_mu_std=use_one_mu_std,
                enable_rotation_aug=False,  # No rotation aug on validation
                enable_random_cropping=False,
                # No random cropping on validation. Validation is evaluated on determistic grids
                grid_alignment=grid_alignment,
                max_val=max_val,
                **dloader_kwargs
                
            )

# For normalizing, we should be using the training data's mean and std.
mean_val, std_val = train_dset.compute_mean_std()
train_dset.set_mean_std(mean_val, std_val)
val_dset.set_mean_std(mean_val, std_val)


if evaluate_train:
    val_dset = train_dset
data_mean, data_std = train_dset.get_mean_std()


In [None]:
print('')

In [None]:
!ls /home/ashesh.ashesh/training/disentangle/2301/D3-M10-S0-L3/25

In [None]:
with config.unlocked():
    if config.data.data_type in [DataType.OptiMEM100_014,DataType.CustomSinosoid,
                                DataType.SeparateTiffData,
                                 DataType.CustomSinosoidThreeCurve, DataType.HTIba1Ki67] and old_image_size is not None:
        config.data.image_size = old_image_size

if config.data.target_separate_normalization is True:
    mean_fr_model, std_fr_model = train_dset.compute_individual_mean_std()
else:
    mean_fr_model, std_fr_model = train_dset.get_mean_std()

if config.model.model_type == ModelType.LadderVaeSemiSupervised:
    mean_fr_model = mean_fr_model[None]
    std_fr_model = std_fr_model[None]
    
model = create_model(config, mean_fr_model,std_fr_model)

ckpt_fpath = get_best_checkpoint(ckpt_dir)
checkpoint = torch.load(ckpt_fpath)

_ = model.load_state_dict(checkpoint['state_dict'])
model.eval()
_= model.cuda()

model.set_params_to_same_device_as(torch.Tensor(1).cuda())

print('Loading from epoch', checkpoint['epoch'])

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Model has {count_parameters(model)/1000_000:.3f}M parameters')