## **Generative Classifier**

In [None]:
import torch
from torch import nn
from torchsummary import summary

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from datetime import datetime
import wandb

from models.GenClassifier import GenClassifier

from train import train

from utils.visualize import show_img, plot_loss_history, show_samples, plot_weights, plot_gradual_classification_loss, plot_conv_channels
from utils.data_loaders import get_mnist_data_loaders, get_cifar10_data_loaders, get_fashion_mnist_data_loaders
from utils.other import calc_accuracy

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32
# BATCH_SIZE = 16
%env "WANDB_NOTEBOOK_NAME" "main.ipynb"
wandb.login()

print(f"... Running on {DEVICE} ...")

In [None]:
# X_train_loader, X_test_loader, class_names = get_mnist_data_loaders(batch_size=BATCH_SIZE, root_path="data/", download=False)
# X_train_loader, X_test_loader, class_names = get_cifar10_data_loaders(batch_size=BATCH_SIZE, root_path="data/", download=False)
X_train_loader, X_test_loader, class_names = get_fashion_mnist_data_loaders(batch_size=BATCH_SIZE, root_path="data/", download=False)

In [None]:
config_dict = {
  "classifier_cnn_layers": (32,64,4),
  "classifier_cnn_input_dims": (28,28,1),
  "classifier_cnn_output_dim": 144,
  "classifier_head_layers": (64,128,10),
  "generator_cnn_block_in_layer_shapes": (144,240),
  "generator_prediction_in_layer_shapes": (10,72),
  "generator_in_combined_main_layer_shapes": (312,512,288),
  "generator_trans_cnn_input_dims":(2,12,12),
  "generator_cnn_trans_layer_shapes": (32,64,32,1),
  "classifier_lr": 0.0008,
  "classifier_weight_decay": 1e-5,
  "generator_reconstruction_loss_importance": 1.2,
  "generator_reconstruction_from_no_z_loss_importance": 1.2,
  "generator_classification_loss_importance": 0.6,
  "generator_z_similarity_loss_importance": 1.0,
  "generator_contrastive_loss_importance": 0.8,
  "generator_lr": 0.001,
  "generator_weight_decay": 1e-5,
  "eval_run_classifier_cnn_block_optimizer_lr": 0.008,
  "eval_run_classifier_cnn_block_optimizer_weight_decay": 1e-5,
  "eval_run_classifier_head_block_optimizer_lr": 0.015,
  "eval_run_classifier_head_block_optimizer_weight_decay": 1e-5,
  "device": DEVICE
}

model = GenClassifier(config_dict=config_dict)
model.load_pretrained_params("model_parameters/fashion_mnist_23-1-2022_11.27.tar", load_optimizers=True)

### Training

In [None]:
wandb_run = wandb.init(
    project="ReGAL", entity="johnny1188", config=config_dict,
    group="fashion-mnist",
    tags=["pretraining", "gen-four-step-loss"],
    notes=f""
)

In [None]:
loss_history, samples = train(
    model,
    epochs=15,
    X_train_loader=X_train_loader,
    batch_size=BATCH_SIZE,
    verbose=True,
    is_wandb_run=True,
    class_names=class_names
)

In [None]:
wandb_run.finish()

### Save pretrained model's parameters

In [None]:
now = datetime.now()

# model.save_model_params(f"model_parameters/cifar10_{now.day}-{now.month}-{now.year}_{now.hour}.{now.minute}.tar")
# model.save_model_params(f"model_parameters/mnist_{now.day}-{now.month}-{now.year}_{now.hour}.{now.minute}.tar")
model.save_model_params(f"model_parameters/fashion_mnist_{now.day}-{now.month}-{now.year}_{now.hour}.{now.minute}.tar")

### Analysis of training

In [None]:
show_samples(samples, 9, 0, 5, [*class_names,"_from_generator"])

In [None]:
# Check weights
plot_weights(model.generator["head_block"].dense_layers_stack_dict["in_combined_main_stack"][0].weight.detach().cpu().numpy())

print(torch.mean(torch.abs(model.generator["head_block"].dense_layers_stack_dict["in_combined_main_stack"][0].weight[:,:73])))
print(torch.mean(torch.abs(model.generator["head_block"].dense_layers_stack_dict["in_combined_main_stack"][0].weight[:,73:])))

In [None]:
model.turn_model_to_mode("eval")

images = iter(X_test_loader)
X, y = [part_of_data.to(model.config_dict["device"]) for part_of_data in next(images)]

z = model.classifier["cnn_block"](X)
z = z.reshape((BATCH_SIZE, model.config_dict["classifier_cnn_output_dim"]))
y_hat = model.classifier["head_block"](z)

# Normal reconstruction as in the pretraining and evaluation phases
y_onehot = nn.functional.one_hot(y,10).float().to(model.config_dict["device"])
# h = model.generator["head_block"](z.detach(), y_onehot.detach())
h = model.generator["head_block"](z.detach(), nn.functional.softmax(y_hat.detach(),dim=1))
h_reshaped_for_cnn_block = torch.reshape(h, (BATCH_SIZE, *model.config_dict["generator_trans_cnn_input_dims"]))
X_hat = model.generator["trans_cnn_block"](h_reshaped_for_cnn_block)

# Permutation of the last ten values of the generator's head block (=permuted categories)
y_one_hot_cloned = np.random.permutation( y_onehot.clone().detach().cpu().numpy() )
y_onehot_permuted_classes = torch.tensor(y_one_hot_cloned).to(model.config_dict["device"])
# h_2 = model.generator["head_block"](z.detach(), torch.randint_like(y, 9).to(model.config_dict["device"]))
h_2 = model.generator["head_block"](z.detach(), y_onehot_permuted_classes)
h_reshaped_for_cnn_block_2 = torch.reshape(h_2, (BATCH_SIZE, *model.config_dict["generator_trans_cnn_input_dims"]))
X_hat_2 = model.generator["trans_cnn_block"](h_reshaped_for_cnn_block_2)

# Zeroing-out cnn input and generating from the true target labels (categories)
h_3 = model.generator["head_block"](torch.zeros(z.detach().shape).to(model.config_dict["device"]), y_onehot.detach())
h_reshaped_for_cnn_block_3 = torch.reshape(h_3, (BATCH_SIZE, *model.config_dict["generator_trans_cnn_input_dims"]))
X_hat_3 = model.generator["trans_cnn_block"](h_reshaped_for_cnn_block_3)

# Zeroing-out cnn input and generating from the predicted categories
h_4 = model.generator["head_block"](torch.zeros(z.detach().shape).to(model.config_dict["device"]), nn.functional.softmax(y_hat.detach(),dim=1))
h_reshaped_for_cnn_block_4 = torch.reshape(h_4, (BATCH_SIZE, *model.config_dict["generator_trans_cnn_input_dims"]))
X_hat_4 = model.generator["trans_cnn_block"](h_reshaped_for_cnn_block_4)

# indexes of imgs to show
img_i_start = 0
img_i_end = 5

print("Input images:")
show_img(
    X[img_i_start:img_i_end].cpu().detach(), 
    [class_names[c_i] for c_i in y_onehot[img_i_start:img_i_end].detach().cpu().argmax(axis=1)]
)
print("Generated - w/ predicted categories:")
show_img(
    X_hat[img_i_start:img_i_end].cpu().detach(),
    [class_names[c_i] for c_i in y_hat[img_i_start:img_i_end].detach().cpu().argmax(axis=1)]
)
print("Generated - w/ permuted true target categories:")
show_img(
    X_hat_2[img_i_start:img_i_end].cpu().detach(),
    [class_names[c_i] for c_i in y_one_hot_cloned[img_i_start:img_i_end].argmax(axis=1)]
)
print("Generated - w/ zeroed-out cnn input & true target categories:")
show_img(
    X_hat_3[img_i_start:img_i_end].cpu().detach(), 
    [class_names[c_i] for c_i in y_onehot[img_i_start:img_i_end].detach().cpu().argmax(axis=1)]
)
print("Generated - w/ zeroed-out cnn input & predicted categories:")
show_img(
    X_hat_4[img_i_start:img_i_end].cpu().detach(), 
    [class_names[c_i] for c_i in y_hat[img_i_start:img_i_end].detach().cpu().argmax(axis=1)]
)