In [None]:
from disentangle.configs.mnist_config import get_config
from disentangle.training import create_dataset
from disentangle.loss.ssl_finetuning import finetune_two_forward_passes
from disentangle.nets.model_utils import create_model
import torch

In [None]:
k_moment_value = 2
best_t_estimate = 0.5

In [None]:
config = get_config()
datadir = '/group/jug/ashesh/data/MNIST/'
train_dset, val_dset = create_dataset(config, datadir)

In [None]:
data_mean, data_std = train_dset.get_mean_std()
model = create_model(config, data_mean, data_std, val_idx_manager=None)
model = model.cuda()
model.set_params_to_same_device_as(torch.Tensor([1]).cuda())

In [None]:
from disentangle.loss.ssl_finetuning import get_stats_loss_func
import numpy as np

ch0 = val_dset._ch0_images
ch1 = val_dset._ch1_images
n = min(len(ch0), len(ch1))
data = np.stack([ch0[:n], ch1[:n]], axis=1)
print(data.shape)
stats_loss_func = get_stats_loss_func(data/255.0, k_moment_value)
stats_loss_func(torch.Tensor(data[:15]))

In [None]:
max_step_count = 10000
skip_pixels=0
validation_step_freq = 1000
k_augmentations = 1
enable_mixing_aug = False

lr = 1e-3
batch_size = 256

In [None]:
inp, tar = val_dset[0]

In [None]:
pred, _ = model(inp[None].cuda())

In [None]:
from disentangle.data_loader.mnist_dset import get_transform_obj
transform_all = get_transform_obj(config.data.ch1_transforms_params, config.data.ch2_transforms_params, device='cuda')

# define a learnable scalar and an offset 
factor1 = torch.nn.Parameter(torch.tensor(1.0).cuda())
offset1 = torch.nn.Parameter(torch.tensor(0.0).cuda())

factor2 = torch.nn.Parameter(torch.tensor(1.0).cuda())
offset2 = torch.nn.Parameter(torch.tensor(0.0).cuda())
mixing_ratio = torch.nn.Parameter(torch.tensor(best_t_estimate).cuda())

optimization_params = model.parameters()
finetuning_output_dict = finetune_two_forward_passes(model, val_dset, val_dset, transform_all, 
                                                    max_step_count=max_step_count, 
                                                    batch_size=batch_size, 
                                                    skip_pixels=skip_pixels,
                                                    validation_step_freq=validation_step_freq,
                                scalar_params_dict={'factor1':factor1, 'offset1':offset1, 'factor2':factor2, 'offset2':offset2, 'mixing_ratio':mixing_ratio},
                                optimization_params_dict={'lr':lr, 'parameters': optimization_params},
                                # lookback=lookback,
                                k_augmentations=k_augmentations,
                                stats_enforcing_loss_fn=lambda x : stats_loss_func(x),
                                sample_mixing_ratio=enable_mixing_aug,
                                )


In [None]:
from disentangle.analysis.ssl_plots import plot_finetuning_loss
plot_finetuning_loss(finetuning_output_dict)

In [None]:
finetuning_output_dict.keys()

In [None]:
import matplotlib.pyplot as plt
plt.plot(finetuning_output_dict['loss_pred'])

In [None]:
pred, _ = model(inp[None].cuda())

In [None]:
import matplotlib.pyplot as plt
_,ax = plt.subplots(figsize=(9,3),ncols=3)
ax[0].imshow(inp[0].cpu().detach().numpy(), cmap='gray')
ax[1].imshow(pred[0,0].cpu().detach().numpy(), cmap='gray')
ax[2].imshow(pred[0,1].cpu().detach().numpy(), cmap='gray')