## Objective
Here, we play with discriminators. The idea is to be able to discriminate between the two classes. 

In [None]:
from disentangle.configs.mnist_config import get_config
from disentangle.training import create_dataset
# from disentangle.loss.ssl_finetuning import finetune_two_forward_passes
from disentangle.loss.ssl_with_discriminator import finetune_with_D_two_forward_passes
from disentangle.nets.model_utils import create_model
import torch
torch.autograd.set_detect_anomaly(True)

In [None]:
k_moment_value = 2
best_t_estimate = 0.5
psnr_evaluation = False

max_step_count = 40_000
skip_pixels=0
validation_step_freq = 5000
k_augmentations = 1
enable_mixing_aug = True

lr = 1e-3
batch_size = 128
D_mode='-1_1'
D_realimg_key='pred_FP1'
D_fakeimg_key='pred_FP1_aug'
D_gp_lambda=0.0
external_real_data_probability = 1.0
train_G_on_both_real_and_fake = True

# how many discriminator steps per generator step
k_Dsteps_perG = 1/5
D_loss_scalar = 0.001
D_only_one_channel_idx = None
tv_weight = 0.1

In [None]:
config = get_config()
datadir = '/group/jug/ashesh/data/MNIST/'
train_dset, val_dset = create_dataset(config, datadir)

In [None]:
data_mean, data_std = train_dset.get_mean_std()
model = create_model(config, data_mean, data_std, val_idx_manager=None)
model = model.cuda()
model.set_params_to_same_device_as(torch.Tensor([1]).cuda())

In [None]:
val_dset.train_mode()
data = []
k_times = 5
for _ in range(k_times):
    for idx in range(len(val_dset)):
        _, tar = val_dset[idx]
        data.append(tar)

data = torch.stack(data, dim=0)
data_normalized = data/255.0

In [None]:
import matplotlib.pyplot as plt
_,ax = plt.subplots(figsize=(20,4),ncols=10,nrows=2)
for i in range(10):
    ax[0,i].imshow(data[i,0].cpu().numpy(), cmap='gray')
    ax[0,i].axis('off')
    ax[1,i].imshow(data[i,1].cpu().numpy(), cmap='gray')
    ax[1,i].axis('off')
# reduce the space between subplots
plt.subplots_adjust(wspace=0.05, hspace=0.05)

In [None]:
from disentangle.loss.ssl_finetuning import get_stats_loss_func
import numpy as np

# ch0 = val_dset._ch0_images
# ch1 = val_dset._ch1_images
# n = min(len(ch0), len(ch1))
# data = np.stack([ch0[:n], ch1[:n]], axis=1)
# print(data.shape)
stats_loss_func = get_stats_loss_func(data_normalized, k_moment_value)
stats_loss_func(torch.Tensor(data_normalized[:15]))

In [None]:
import matplotlib.pyplot as plt
nimgs = 5*2
imgsz = 2
_,ax = plt.subplots(figsize=(6*imgsz,imgsz*nimgs//2), ncols=6, nrows=nimgs//2)
ax[0,0].set_title('input')
ax[0,1].set_title('target C1')
ax[0,2].set_title('target C2')

ax[0,3].set_title('input')
ax[0,4].set_title('target C1')
ax[0,5].set_title('target C2')

for i in range(nimgs):
    row_idx = i//2
    col_idx = 3 * (i%2)
    idx = np.random.randint(len(train_dset))
    inp, tar = train_dset[idx]
    ax[row_idx,col_idx+0].imshow(inp[0], cmap='gray')
    ax[row_idx,col_idx+0].axis('off')
    ax[row_idx,col_idx+1].imshow(tar[0], cmap='gray')
    ax[row_idx,col_idx+1].axis('off')
    ax[row_idx,col_idx+2].imshow(tar[1], cmap='gray')
    ax[row_idx,col_idx+2].axis('off')
# remve space between subplots
plt.subplots_adjust(hspace=0.01, wspace=0.01)

In [None]:
from disentangle.data_loader.mnist_dset import get_transform_obj
transform_all = get_transform_obj(config.data.ch0_transforms_params, config.data.ch1_transforms_params, device='cuda')
external_real_data = data_normalized

# define a learnable scalar and an offset 
factor1 = torch.nn.Parameter(torch.tensor(1.0).cuda())
offset1 = torch.nn.Parameter(torch.tensor(0.0).cuda())

factor2 = torch.nn.Parameter(torch.tensor(1.0).cuda())
offset2 = torch.nn.Parameter(torch.tensor(0.0).cuda())
mixing_ratio = torch.nn.Parameter(torch.tensor(best_t_estimate).cuda())

optimization_params = model.parameters()
finetuning_output_dict = finetune_with_D_two_forward_passes(model, 
                                                            val_dset, 
                                                            val_dset, 
                                                            transform_all, 
                                                    max_step_count=max_step_count, 
                                                    external_real_data=external_real_data,
                                                    external_real_data_probability=external_real_data_probability,
                                                    batch_size=batch_size, 
                                                    skip_pixels=skip_pixels,
                                                    validation_step_freq=validation_step_freq,
                                scalar_params_dict={'factor1':factor1, 'offset1':offset1, 'factor2':factor2, 'offset2':offset2, 'mixing_ratio':mixing_ratio},
                                optimization_params_dict={'lr':lr, 'parameters': optimization_params},
                                D_mode=D_mode,
                                D_realimg_key=D_realimg_key,
                                D_fakeimg_key=D_fakeimg_key,
                                D_gp_lambda=D_gp_lambda,
                                D_loss_scalar=D_loss_scalar,
                                k_Dsteps_perG=k_Dsteps_perG,
                                tv_weight=tv_weight,
                                k_augmentations=k_augmentations,
                                stats_enforcing_loss_fn=lambda x : stats_loss_func(x),
                                sample_mixing_ratio=enable_mixing_aug,
                                D_only_one_channel_idx=D_only_one_channel_idx,
                                D_train_G_on_both_real_and_fake = train_G_on_both_real_and_fake, 
                                # psnr_evaluation=psnr_evaluation,
                                )


In [None]:
plt.plot(finetuning_output_dict['loss_tv'])

In [None]:
from disentangle.analysis.ssl_plots import plot_finetuning_loss
plot_finetuning_loss(finetuning_output_dict)


In [None]:
import matplotlib.pyplot as plt
from disentangle.analysis.plot_utils import clean_ax
nimgs = 3
imgsz = 2
offset = imgsz*nimgs*0.3
_,ax = plt.subplots(figsize=(4*imgsz,2*imgsz*nimgs + offset),nrows=2*nimgs,ncols=4)

val_dset.train_mode()
for img_idx in range(nimgs):
    inp, tar = val_dset[0]
    pred, _ = model(torch.Tensor(inp[None]).cuda())


    ax[2*img_idx,0].imshow(inp[0].cpu().detach().numpy(), cmap='gray')
    ax[2*img_idx,1].imshow(pred[0,0].cpu().detach().numpy(), cmap='gray')
    ax[2*img_idx,2].imshow(pred[0,1].cpu().detach().numpy(), cmap='gray')
    ax[2*img_idx,3].imshow((pred[0,0] + pred[0,1]).cpu().detach().numpy(), cmap='gray')
    ax[2*img_idx +1,1].imshow(tar[0].cpu().detach().numpy(), cmap='gray')
    ax[2*img_idx +1,2].imshow(tar[1].cpu().detach().numpy(), cmap='gray')

    ax[2*img_idx,0].set_title('Input')
    ax[2*img_idx,1].set_title('Pred ch0')
    ax[2*img_idx,2].set_title('Pred ch1')
    ax[2*img_idx,3].set_title('Pred ch0 + ch1')
    ax[2*img_idx +1,1].set_title('Target ch0')
    ax[2*img_idx +1,2].set_title('Target ch1')
    ax[2*img_idx+1,0].axis('off')
    ax[2*img_idx+1,3].axis('off')

# remove space between subplots
plt.subplots_adjust(hspace=0.1, wspace=0.02)
clean_ax(ax)

In [None]:
pred[:,0].max(), pred[:,1].max()