## **Re**current **G**ener**A**tive C**L**assifier

In [None]:
import torch
from torch import nn
from torchsummary import summary

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from datetime import datetime
import wandb

from models.ReGAL import ReGALModel
from utils.visualize import imshow_mnist, imshow_cifar10, plot_loss_history, show_samples
from utils.data_loaders import get_mnist_data_loaders, get_cifar10_data_loaders

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32
%env "WANDB_NOTEBOOK_NAME" "main.ipynb"
wandb.login()

print(f"... Running on {DEVICE} ...")

In [None]:
# X_train_mnist_loader, X_test_mnist_loader, classes_mnist = get_mnist_data_loaders(batch_size=BATCH_SIZE, root_path="data/", download=False)
X_train_cifar_loader, X_test_cifar_loader, classes_cifar = get_cifar10_data_loaders(batch_size=BATCH_SIZE, root_path="data/", download=False)

In [None]:
config_dict = {
  "classifier_cnn_layers": (16,64,32),
  "classifier_cnn_input_dims": (32,32,3),
  "classifier_cnn_output_dim": 512,
  "classifier_head_layers": (128,64,32,10),
  "generator_cnn_block_in_layer_shapes": (512,312),
  "generator_prediction_in_layer_shapes": (10,312),
  "generator_in_combined_main_layer_shapes": (624,624,1024),
  "generator_cnn_trans_layer_shapes": (32,32,16,3),
  "generator_input_dims":(16,8,8),
#   "generator_input_dims":(4,16,16),
  "classifier_lr": 0.001,
  "classifier_weight_decay": 1e-5,
  "generator_alpha": 0.84,
  "generator_lr": 0.004,
  "generator_weight_decay": 1e-5,
  "eval_run_classifier_cnn_block_optimizer_lr": 0.008,
  "eval_run_classifier_cnn_block_optimizer_weight_decay": 1e-5,
  "eval_run_classifier_head_block_optimizer_lr": 0.016,
  "eval_run_classifier_head_block_optimizer_weight_decay": 1e-5,
  "device": DEVICE
}

model = ReGALModel(config_dict=config_dict)
# model.load_pretrained_params("model_parameters/model_parameters_dict_checkpoint_cifar10_26-12-2021_18.48.tar", load_optimizers=True)

### Pretrain

In [None]:
wandb_run = wandb.init(
    project="ReGAL", entity="johnny1188", config=config_dict,
    tags=["pretraining", "gen-two-step-loss"],
    notes=f"Generator's trans-cnn last layer using sigmoid"
)
# wandb.watch(models=(
#   model.classifier['cnn_block'], model.classifier['head_block'],
#   model.generator['head_block'],
#   model.generator['head_block'].dense_layers_stack_dict["in_classifier_prediction"],
#   model.generator['head_block'].dense_layers_stack_dict["in_combined_main_stack"],
#   model.generator['trans_cnn_block']), log="all", log_freq=500
# )

In [None]:
loss_history, samples = model.pretrain(
    epochs=30,
    X_train_loader=X_train_cifar_loader,
    batch_size=BATCH_SIZE,
    past_loss_history=None,
    verbose=True,
    is_wandb_run=True,
    class_names=classes_cifar
)

In [None]:
wandb_run.finish()

### Save pretrained model's parameters

In [None]:
now = datetime.now()
model.save_model_params(f"model_parameters/model_parameters_dict_checkpoint_cifar10_{now.day}-{now.month}-{now.year}_{now.hour}.{now.minute}.tar")

# Analysis of pretraining

In [None]:
plot_loss_history(loss_history)

In [None]:
show_samples(samples, 0, 0, 5, classes_cifar)

In [None]:
fig, ax = plt.subplots(figsize=(30,6))
sns.heatmap(
    model.generator["head_block"].dense_layers_stack_dict["in_classifier_prediction"][0].weight.detach().cpu().numpy().T,
    xticklabels=15,
    axes=ax
)
ax.set_ylabel("layer n-1 neurons")
ax.set_xlabel("layer n neurons")

In [None]:
print(torch.mean(torch.abs(model.generator["head_block"].dense_layers_stack_dict["in_combined_main_stack"][0].weight[:,:312])))
print(torch.mean(torch.abs(model.generator["head_block"].dense_layers_stack_dict["in_combined_main_stack"][0].weight[:,312:])))

In [None]:
images = iter(X_train_cifar_loader)
X, y = [part_of_data.to(model.device) for part_of_data in next(images)]

z = model.classifier["cnn_block"](X)
z = z.reshape((BATCH_SIZE, model.classifier_cnn_output_dim_flattened))
y_hat = model.classifier["head_block"](z)

# Normal reconstruction as in the pretraining and evaluation phases
h = model.generator["head_block"](z.detach(), y.int().detach())
h_reshaped_for_cnn_block = torch.reshape(h, (BATCH_SIZE, *model.generator_cnn_input_dims))
X_hat = model.generator["trans_cnn_block"](h_reshaped_for_cnn_block)

# one_hot = torch.zeros(y_hat.shape)
# preds = torch.argmax(y_hat,dim=1)
# for i in preds:
#   one_hot[i,(preds[i] + 3) % len(one_hot[i])] = 1

# Permutation of the last ten values of the generator's head block (=permuted categories)
# h_2 = model.generator["head_block"](z.detach(), torch.randint_like(y, 9).to(model.device))
h_2 = model.generator["head_block"](z.detach(), torch.zeros_like(y).to(model.device))
# h_2 = model.generator["head_block"](z.detach(), torch.zeros(y_hat.shape).to(model.device))
h_reshaped_for_cnn_block_2 = torch.reshape(h_2, (BATCH_SIZE, *model.generator_cnn_input_dims))
X_hat_2 = model.generator["trans_cnn_block"](h_reshaped_for_cnn_block_2)

h_3 = model.generator["head_block"](torch.zeros(z.detach().shape).to(model.device), y.int().detach())
h_reshaped_for_cnn_block_3 = torch.reshape(h_3, (BATCH_SIZE, *model.generator_cnn_input_dims))
X_hat_3 = model.generator["trans_cnn_block"](h_reshaped_for_cnn_block_3)

imshow_cifar10(
    X[0:5].cpu().detach(), 
    f"Ground truth",
    w_color=True
)
imshow_cifar10(
    X_hat[0:5].cpu().detach(), 
    f"Gen",
    w_color=True
)
imshow_cifar10(
    X_hat_2[0:5].cpu().detach(), 
    f"Gen (permuted categories)",
    w_color=True
)
imshow_cifar10(
    X_hat_3[0:5].cpu().detach(), 
    f"Gen (zeroed-out cnn input)",
    w_color=True
)

# Evaluation

In [None]:
model.classifier["head_block"].dense_layers_stack[6].weight.register_hook(lambda grad_in: grad_in * 5)

In [None]:
def eval_model(model, X_test_loader, max_reconstruction_steps=10, max_batches=200):
    model.turn_model_to_mode(mode="eval")

    loss_func_classification = nn.CrossEntropyLoss()
    classification_loss_history = []

    for i,data in enumerate(X_test_loader):
        X, y = [part_of_data.to(DEVICE) for part_of_data in data]

        y_hat = model(X, max_reconstruction_steps=max_reconstruction_steps)

        classification_loss_history.append( loss_func_classification(y_hat, y).detach().cpu().item() )
        if i > max_batches: break

    return(classification_loss_history)

In [None]:
classification_loss_history_wout_reconstruction = eval_model(model, X_test_loader=X_test_cifar_loader, max_reconstruction_steps=0, max_batches=120)
classification_loss_history_w_reconstruction = eval_model(model, X_test_loader=X_test_cifar_loader, max_reconstruction_steps=20, max_batches=120)

print(f"""-----\nMean classification loss:
>>> with reconstruction: {round(sum(classification_loss_history_w_reconstruction)/len(classification_loss_history_w_reconstruction), 4)}
>>> without reconstruction: {round(sum(classification_loss_history_wout_reconstruction)/len(classification_loss_history_wout_reconstruction), 4)}\n-----\n"""
)

plt.figure(figsize=(12, 6))
plt.plot(classification_loss_history_w_reconstruction, label="with reconstruction")
plt.plot(classification_loss_history_wout_reconstruction, label="without reconstruction")
plt.legend()
plt.title("Classification loss history")
plt.show()