## Purpose
This is to be used for loading the data loader and loading the model

In [None]:
if image_size_for_grid_centers is None:
    image_size_for_grid_centers = config.data.image_size


In [None]:
print('')

In [None]:
from disentangle.data_loader.overlapping_dloader import get_overlapping_dset
from disentangle.data_loader.multi_channel_determ_tiff_dloader import MultiChDeterministicTiffDloader
from disentangle.data_loader.multiscale_mc_tiff_dloader import MultiScaleTiffDloader
from disentangle.core.data_split_type import DataSplitType
from disentangle.data_loader.single_channel_dloader import SingleChannelDloader


padding_kwargs = {
    'mode':config.data.get('padding_mode','constant'),
}


if padding_kwargs['mode'] == 'constant':
    padding_kwargs['constant_values'] = config.data.get('padding_value',0)

dloader_kwargs = {'overlapping_padding_kwargs':padding_kwargs}


if 'multiscale_lowres_count' in config.data and config.data.multiscale_lowres_count is not None:
    data_class = get_overlapping_dset(MultiScaleTiffDloader)
    dloader_kwargs['num_scales'] = config.data.multiscale_lowres_count
    dloader_kwargs['padding_kwargs'] = padding_kwargs
elif config.data.data_type == DataType.SemiSupBloodVesselsEMBL:
    data_class = get_overlapping_dset(SingleChannelDloader)
else:
    data_class = get_overlapping_dset(MultiChDeterministicTiffDloader)
if config.data.data_type in [DataType.CustomSinosoid, DataType.CustomSinosoidThreeCurve, 
                             DataType.AllenCellMito,DataType.SeparateTiffData,
                            DataType.SemiSupBloodVesselsEMBL]:
    datapath = data_dir
elif config.data.data_type == DataType.OptiMEM100_014:
    datapath = os.path.join(data_dir, 'OptiMEM100x014.tif')
elif config.data.data_type == DataType.Prevedel_EMBL:
    datapath = os.path.join(data_dir, 'MS14__z0_8_sl4_fr10_p_10.1_lz510_z13_bin5_00001.tif')


normalized_input = config.data.normalized_input
use_one_mu_std = config.data.use_one_mu_std
train_aug_rotate = config.data.train_aug_rotate
enable_random_cropping = False #config.data.deterministic_grid is False

train_dset = data_class(
                config.data,
                datapath,
                datasplit_type=DataSplitType.Train,
                val_fraction=config.training.val_fraction,
                test_fraction=config.training.test_fraction,
                normalized_input=normalized_input,
                use_one_mu_std=use_one_mu_std,
                enable_rotation_aug=train_aug_rotate,
                enable_random_cropping=enable_random_cropping,
                image_size_for_grid_centers=image_size_for_grid_centers,
                **dloader_kwargs)
import gc
gc.collect()
max_val = train_dset.get_max_val()

val_dset = data_class(
                config.data,
                datapath,
                datasplit_type=eval_datasplit_type,
                val_fraction=config.training.val_fraction,
                test_fraction=config.training.test_fraction,
                normalized_input=normalized_input,
                use_one_mu_std=use_one_mu_std,
                enable_rotation_aug=False,  # No rotation aug on validation
                enable_random_cropping=False,
                # No random cropping on validation. Validation is evaluated on determistic grids
                image_size_for_grid_centers=image_size_for_grid_centers,
                max_val=max_val,
                **dloader_kwargs
                
            )

# For normalizing, we should be using the training data's mean and std.
mean_val, std_val = train_dset.compute_mean_std()
train_dset.set_mean_std(mean_val, std_val)
val_dset.set_mean_std(mean_val, std_val)


if evaluate_train:
    val_dset = train_dset
data_mean, data_std = train_dset.get_mean_std()


In [None]:
print('')

In [None]:
with config.unlocked():
    if config.data.data_type in [DataType.OptiMEM100_014,DataType.CustomSinosoid,
                                DataType.SeparateTiffData,
                                 DataType.CustomSinosoidThreeCurve] and old_image_size is not None:
        config.data.image_size = old_image_size

if config.data.target_separate_normalization is True:
    model = create_model(config, *train_dset.compute_individual_mean_std())
else:
    model = create_model(config, *train_dset.get_mean_std())


ckpt_fpath = get_best_checkpoint(ckpt_dir)
checkpoint = torch.load(ckpt_fpath)

_ = model.load_state_dict(checkpoint['state_dict'])
model.eval()
_= model.cuda()

model.data_mean = model.data_mean.cuda()
model.data_std = model.data_std.cuda()
print('Loading from epoch', checkpoint['epoch'])

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Model has {count_parameters(model)/1000_000:.3f}M parameters')