In [None]:
from disentangle.configs.mnist_config import get_config
from disentangle.training import create_dataset
from disentangle.loss.ssl_finetuning import finetune_two_forward_passes
from disentangle.nets.model_utils import create_model
import torch

In [None]:
k_moment_value = 2
best_t_estimate = 0.5
psnr_evaluation = True

In [None]:
config = get_config()
datadir = '/group/jug/ashesh/data/MNIST/'
train_dset, val_dset = create_dataset(config, datadir)

In [None]:
data_mean, data_std = train_dset.get_mean_std()
model = create_model(config, data_mean, data_std, val_idx_manager=None)
model = model.cuda()
model.set_params_to_same_device_as(torch.Tensor([1]).cuda())

In [None]:
from disentangle.loss.ssl_finetuning import get_stats_loss_func
import numpy as np

ch0 = val_dset._ch0_images
ch1 = val_dset._ch1_images
n = min(len(ch0), len(ch1))
data = np.stack([ch0[:n], ch1[:n]], axis=1)
print(data.shape)
stats_loss_func = get_stats_loss_func(data/255.0, k_moment_value)
stats_loss_func(torch.Tensor(data[:15]))

In [None]:
import matplotlib.pyplot as plt
nimgs = 5*2
imgsz = 2
_,ax = plt.subplots(figsize=(6*imgsz,imgsz*nimgs//2), ncols=6, nrows=nimgs//2)
ax[0,0].set_title('input')
ax[0,1].set_title('target C1')
ax[0,2].set_title('target C2')

ax[0,3].set_title('input')
ax[0,4].set_title('target C1')
ax[0,5].set_title('target C2')

for i in range(nimgs):
    row_idx = i//2
    col_idx = 3 * (i%2)
    idx = np.random.randint(len(train_dset))
    inp, tar = train_dset[idx]
    ax[row_idx,col_idx+0].imshow(inp[0], cmap='gray')
    ax[row_idx,col_idx+0].axis('off')
    ax[row_idx,col_idx+1].imshow(tar[0], cmap='gray')
    ax[row_idx,col_idx+1].axis('off')
    ax[row_idx,col_idx+2].imshow(tar[1], cmap='gray')
    ax[row_idx,col_idx+2].axis('off')
# remve space between subplots
plt.subplots_adjust(hspace=0.01, wspace=0.01)

In [None]:
max_step_count = 200
lambda_term = 0.1
lr = 1e-4
batch_size = 256
enable_gradient_penalty = True
discriminator_mode = '-1_1'
assert discriminator_mode in ['-1_1', 'wgan']

In [None]:
from disentangle.nets.discriminator import Discriminator
from disentangle.loss.discriminator_loss import update_gradients_with_discriminator_loss
from tqdm import tqdm
discriminator = Discriminator(channels=28*28, first_out_channel=128, dense=True).cuda()
opt = torch.optim.Adam(discriminator.parameters(), lr=lr, weight_decay=0)
dloader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=True, num_workers=4)
d_pred_real = []
d_pred_fake = []
d_loss_gradient_penalty = []
bar = tqdm(total=max_step_count)
for step in range(max_step_count):
    for i, (inp, tar) in enumerate(dloader):
        opt.zero_grad()
        if i >= 1:
            break
        # inp = inp.cuda()
        ch1 = tar[:,:1].cuda()
        ch2 = tar[:,1:2].cuda()
        # print(inp.shape)
        # print(tar.shape)
        # print(inp[0].shape)
        # print(tar[0].shape)
        # print(tar[1].shape)
        # print(tar[2].shape)
        # print(tar[3].shape)

        loss_dict = update_gradients_with_discriminator_loss(discriminator, ch1, ch2, lambda_term, enable_gradient_penalty=enable_gradient_penalty, mode=discriminator_mode)
        opt.step()
        # {'d_pred_real': d_pred_real.item(), 'd_pred_fake': d_pred_fake.item(), 'd_loss_gradient_penalty': d_loss_gradient_penalty.item()}
        d_pred_real.append(loss_dict['d_pred_real'])
        d_pred_fake.append(loss_dict['d_pred_fake'])
        d_loss_gradient_penalty.append(loss_dict['d_loss_gradient_penalty'])
        bar.update(1)
        bar.set_postfix(d_pred_real=np.mean(d_pred_real), d_pred_fake=np.mean(d_pred_fake), d_loss_gradient_penalty=np.mean(d_loss_gradient_penalty))
    
bar.close()

In [None]:
import pandas as pd
_,ax  = plt.subplots(figsize=(6,3),ncols=2)
pd.Series(d_pred_real).rolling(10).mean().plot(label='d_pred_real', ax=ax[0])
pd.Series(d_pred_fake).rolling(10).mean().plot(label='d_pred_fake', ax=ax[0])
pd.Series(d_loss_gradient_penalty).rolling(10).mean().plot(label='gradient_penalty', ax=ax[1], logy=True)
ax[0].legend()
ax[1].legend()

In [None]:
break here

In [None]:
from disentangle.data_loader.mnist_dset import get_transform_obj
transform_all = get_transform_obj(config.data.ch0_transforms_params, config.data.ch1_transforms_params, device='cuda')

# define a learnable scalar and an offset 
factor1 = torch.nn.Parameter(torch.tensor(1.0).cuda())
offset1 = torch.nn.Parameter(torch.tensor(0.0).cuda())

factor2 = torch.nn.Parameter(torch.tensor(1.0).cuda())
offset2 = torch.nn.Parameter(torch.tensor(0.0).cuda())
mixing_ratio = torch.nn.Parameter(torch.tensor(best_t_estimate).cuda())

optimization_params = model.parameters()
finetuning_output_dict = finetune_two_forward_passes(model, val_dset, val_dset, transform_all, 
                                                    max_step_count=max_step_count, 
                                                    batch_size=batch_size, 
                                                    skip_pixels=skip_pixels,
                                                    validation_step_freq=validation_step_freq,
                                scalar_params_dict={'factor1':factor1, 'offset1':offset1, 'factor2':factor2, 'offset2':offset2, 'mixing_ratio':mixing_ratio},
                                optimization_params_dict={'lr':lr, 'parameters': optimization_params},
                                # lookback=lookback,
                                k_augmentations=k_augmentations,
                                stats_enforcing_loss_fn=lambda x : stats_loss_func(x),
                                sample_mixing_ratio=enable_mixing_aug,
                                psnr_evaluation=psnr_evaluation,
                                )


In [None]:
from disentangle.analysis.ssl_plots import plot_finetuning_loss
plot_finetuning_loss(finetuning_output_dict)

In [None]:
finetuning_output_dict.keys()

In [None]:
import matplotlib.pyplot as plt
_,ax = plt.subplots(figsize=(12,6),nrows=2,ncols=4)

val_dset.train_mode()
inp, tar = val_dset[0]
pred, _ = model(torch.Tensor(inp[None]).cuda())

ax[0,0].imshow(inp[0].cpu().detach().numpy(), cmap='gray')
ax[0,1].imshow(pred[0,0].cpu().detach().numpy(), cmap='gray')
ax[0,2].imshow(pred[0,1].cpu().detach().numpy(), cmap='gray')
ax[0,3].imshow((pred[0,0] + pred[0,1]).cpu().detach().numpy(), cmap='gray')
ax[1,0].imshow(tar[0].cpu().detach().numpy(), cmap='gray')
ax[1,1].imshow(tar[1].cpu().detach().numpy(), cmap='gray')

ax[0,0].set_title('Input')
ax[0,1].set_title('Pred ch0')
ax[0,2].set_title('Pred ch1')
ax[0,3].set_title('Pred ch0 + ch1')
ax[1,0].set_title('Target ch0')
ax[1,1].set_title('Target ch1')