## **Re**current **G**ener**A**tive C**L**assifier

In [None]:
import torch
from torch import nn
from torchsummary import summary

import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime
import wandb

from models.ReGAL import ReGALModel
from utils.visualize import imshow_mnist, imshow_cifar10
from utils.data_loaders import get_mnist_data_loaders, get_cifar10_data_loaders

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32
%env "WANDB_NOTEBOOK_NAME" "main.ipynb"
wandb.login()

print(f"... Running on {DEVICE} ...")

In [None]:
# X_train_mnist_loader, X_test_mnist_loader, classes_mnist = get_mnist_data_loaders(batch_size=BATCH_SIZE, root_path="data/", download=False)
X_train_cifar_loader, X_test_cifar_loader, classes_cifar = get_cifar10_data_loaders(batch_size=BATCH_SIZE, root_path="data/", download=False)

In [None]:
config_dict = {
  "classifier_cnn_layers": (16,64,32),
  "classifier_cnn_input_dims": (32,32,3),
  "classifier_cnn_output_dim": 512,
  "classifier_head_layers": (128,64,32,10),
  "generator_cnn_block_in_layer_shapes": (512, 458),
  "generator_prediction_in_layer_shapes": (10,64),
  "generator_in_combined_main_layer_shapes": (522,1024),
  "generator_cnn_trans_layer_shapes": (32,32,16,3),
  "generator_input_dims":(16,8,8),
  "classifier_lr": 0.001,
  "classifier_weight_decay": 1e-5,
  "generator_lr": 0.004,
  "generator_weight_decay": 1e-5,
  "eval_run_classifier_cnn_block_optimizer_lr": 0.008,
  "eval_run_classifier_cnn_block_optimizer_weight_decay": 1e-5,
  "eval_run_classifier_head_block_optimizer_lr": 0.009,
  "eval_run_classifier_head_block_optimizer_weight_decay": 1e-5,
  "device": DEVICE,
  "verbose": False
}

model = ReGALModel(config_dict=config_dict)
model.load_pretrained_params("model_parameters/model_parameters_dict_checkpoint_cifar10_20-12-2021_0.20.tar", load_optimizers=True)

### Pretrain

In [None]:
run = wandb.init(project="ReGAL", entity="johnny1188", tags=["pretraining"], config=config_dict)
# wandb.watch(models=(
#   model.classifier['cnn_block'], model.classifier['head_block'],
#   model.generator['head_block'],
#   model.generator['head_block'].dense_layers_stack_dict["in_classifier_prediction"],
#   model.generator['head_block'].dense_layers_stack_dict["in_combined_main_stack"],
#   model.generator['trans_cnn_block']), log="all", log_freq=500)

In [None]:
loss_history, samples = model.pretrain(
    epochs=10,
    X_train_loader=X_train_cifar_loader,
    batch_size=BATCH_SIZE,
    past_loss_history=None,
    verbose=True,
    wandb_run=False
)

In [None]:
run.finish()

### Save pretrained model's parameters

In [None]:
now = datetime.now()
model.save_model_params(f"model_parameters/model_parameters_dict_checkpoint_cifar10_{now.day}-{now.month}-{now.year}_{now.hour}.{now.minute}.tar")

# Analysis of pretraining

In [None]:
fig,ax = plt.subplots(nrows=1,ncols=2,figsize=(25,8))
ax[0].plot([t.cpu().detach() for t in loss_history["classifier"]])
ax[0].title.set_text("classifier loss")
ax[1].plot([t.cpu().detach() for t in loss_history["generator"]])
ax[1].title.set_text("generator loss")

In [None]:
epoch = 9
i = 0
imshow_cifar10(
    samples[epoch][0][i:i+5].cpu().detach(), 
    [f"Ground truth: {classes_cifar[ samples[epoch][2][i].cpu().detach() ]}",
    f"Ground truth: {classes_cifar[ samples[epoch][2][i+1].cpu().detach() ]}",
    f"Ground truth: {classes_cifar[ samples[epoch][2][i+2].cpu().detach() ]}",
    f"Ground truth: {classes_cifar[ samples[epoch][2][i+3].cpu().detach() ]}",
    f"Ground truth: {classes_cifar[ samples[epoch][2][i+4].cpu().detach() ]}",],
    w_color=True)
imshow_cifar10(
    samples[epoch][1][i:i+5].cpu().detach(), 
    [f"Reconstruction: {classes_cifar[ torch.argmax(samples[epoch][3][i].cpu().detach()) ]}",
    f"Reconstruction: {classes_cifar[ torch.argmax(samples[epoch][3][i+1].cpu().detach()) ]}",
    f"Reconstruction: {classes_cifar[ torch.argmax(samples[epoch][3][i+2].cpu().detach()) ]}",
    f"Reconstruction: {classes_cifar[ torch.argmax(samples[epoch][3][i+3].cpu().detach()) ]}",
    f"Reconstruction: {classes_cifar[ torch.argmax(samples[epoch][3][i+4].cpu().detach()) ]}"],
    w_color=True)

In [None]:
images = iter(X_train_cifar_loader)
X, y = next(images)

In [None]:
print( model.generator["head_block"].dense_layers_stack_dict["in_combined_main_stack"][0].weight.shape )
plt.plot( model.generator["head_block"].dense_layers_stack_dict["in_combined_main_stack"][0].weight[:,-30:-20].detach().cpu() )
# torch.mean(torch.abs( model.generator["head_block"].dense_layers_stack_dict["in_combined_main_stack"][0].weight[:,:10] ))

In [None]:
images = iter(X_train_cifar_loader)
X, y = [part_of_data.to(model.device) for part_of_data in next(images)]

z = model.classifier["cnn_block"](X)
z = z.reshape((BATCH_SIZE, model.classifier_cnn_output_dim_flattened))
y_hat = model.classifier["head_block"](z)

# Normal reconstruction as in the pretraining and evaluation phases
h = model.generator["head_block"](z.detach(), y_hat.detach()) # TODO: Try to pretrain the generator w/ true target labels
h_reshaped_for_cnn_block = torch.reshape(h, (BATCH_SIZE, *model.generator_cnn_input_dims))
X_hat = model.generator["trans_cnn_block"](h_reshaped_for_cnn_block)

one_hot = torch.zeros(y_hat.shape)
preds = torch.argmax(y_hat,dim=1)
for i in preds:
  one_hot[i,(preds[i] + 3) % len(one_hot[i])] = 1

# Permutation of the last ten values of the generator's head block (=permuted categories)
# h_2 = model.generator["head_block"](z.detach(), one_hot.to(model.device)) # TODO: Try to pretrain the generator w/ true target labels
h_2 = model.generator["head_block"](z.detach(), torch.zeros(y_hat.shape).to(model.device)) # TODO: Try to pretrain the generator w/ true target labels
h_reshaped_for_cnn_block_2 = torch.reshape(h_2, (BATCH_SIZE, *model.generator_cnn_input_dims))
X_hat_2 = model.generator["trans_cnn_block"](h_reshaped_for_cnn_block_2)

imshow_cifar10(
    X[0:5].cpu().detach(), 
    f"Ground truth",
    w_color=True
)
imshow_cifar10(
    X_hat[0:5].cpu().detach(), 
    f"Gen",
    w_color=True
)
imshow_cifar10(
    X_hat_2[0:5].cpu().detach(), 
    f"Gen (changed category)",
    w_color=True
)

# Evaluation

In [None]:
def eval(model, X_test_loader, max_reconstruction_steps=10, max_batches=200):
  model.turn_components_to_eval_mode()

  loss_func_classification = nn.CrossEntropyLoss()
  classification_loss_history = []

  for i,data in enumerate(X_test_loader):
    X, y = [part_of_data.to(DEVICE) for part_of_data in data]

    y_hat = model(X, max_reconstruction_steps=max_reconstruction_steps)

    classification_loss_history.append( loss_func_classification(y_hat, y).detach().cpu().item() )
    if i > max_batches: break

  return(classification_loss_history)

In [None]:
classification_loss_history_wout_reconstruction = eval(model, X_test_loader=X_test_cifar_loader, max_reconstruction_steps=0, max_batches=150)
classification_loss_history_w_reconstruction = eval(model, X_test_loader=X_test_cifar_loader, max_reconstruction_steps=15, max_batches=150)

print(
f"""-----\nMean classification loss:
>>> with reconstruction: {round(sum(classification_loss_history_w_reconstruction)/len(classification_loss_history_w_reconstruction), 4)}
>>> without reconstruction: {round(sum(classification_loss_history_wout_reconstruction)/len(classification_loss_history_wout_reconstruction), 4)}\n-----\n"""
)

plt.figure(figsize=(12, 6))
plt.plot(classification_loss_history_w_reconstruction, label="with reconstruction")
plt.plot(classification_loss_history_wout_reconstruction, label="without reconstruction")
plt.legend()
plt.title("Classification loss history")
plt.show()