In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import json
import pandas as pd
from datetime import datetime
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import lovely_tensors as lt

import csng
from csng.CNN_Decoder import CNN_Decoder
from csng.utils import plot_comparison, standardize, normalize, get_mean_and_std, count_parameters
from csng.losses import SSIMLoss, MSELossWithCrop

# from orig_data import prepare_spiking_data_loaders
from data import prepare_v1_dataloaders, SyntheticDataset, BatchPatchesDataLoader, MixedBatchLoader, PerSampleStoredDataset

lt.monkey_patch()

DATA_PATH = os.path.join(os.environ["DATA_PATH"], "cat_V1_spiking_model")
print(f"{DATA_PATH=}")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{device=}")

## Gradient maps

In [None]:
run_name = "2023-08-26_10-50-13"

### load ckpt
ckpt = torch.load(os.path.join(DATA_PATH, "models", run_name, "decoder.pt"))
config = ckpt["config"]
history = ckpt["history"]
best = ckpt["best"]
decoder = CNN_Decoder(**config["decoder"]["model"]).to(device)
decoder.load_state_dict(ckpt["decoder"])

if config["stim_crop_win"] is not None:
    crop_stim = lambda x: x[..., config["stim_crop_win"][0], config["stim_crop_win"][1]]
else:
    crop_stim = lambda x: x

np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
random.seed(config["seed"])

decoder.eval()
### turn off only batchnorm and dropout
# for m in decoder.modules():
#     if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.Dropout)):
#         m.eval()

In [None]:
### input
x = torch.zeros(1, decoder.layers[0].in_features, device=device)
# x = torch.ones(1, decoder.layers[0].in_features, device=device)
x.requires_grad = True

In [None]:
### output - intermediate
y = decoder.layers[0](x)
y = decoder.layers[4](y) # (B, C, H, W)

In [None]:
### output - final
y = decoder.layers(x)

In [None]:
### collect gradient maps for each channel for specific location
loc = (0, 0)
grad_maps = []
for i in range(y.shape[1]):
    y[0, i, loc[0], loc[1]].backward(retain_graph=True)
    grad_maps.append(x.grad.detach().cpu().numpy()[0])
    x.grad.zero_()
grad_maps = np.array(grad_maps)

In [None]:
(grad_maps == 0).all()

In [None]:
### plot - intermediate
fig, axs = plt.subplots(4, 4, figsize=(16, 16))
for i, ax in enumerate(axs.flatten()):
    grad_map_standardized = standardize(torch.tensor(grad_maps[i]).view(1, 1, 100, 100)).numpy()[0, 0]
    ax.imshow(-np.log(1 - grad_map_standardized + 1e-4), cmap="magma")
    ax.set_title(f"C={i} H={loc[0]} W={loc[1]}")
plt.tight_layout()
plt.show()

In [None]:
### plot - final
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111)
grad_map_standardized = standardize(torch.tensor(grad_maps[0]).view(1, 1, 100, 100)).numpy()[0, 0]
ax.imshow(-np.log(1 - grad_map_standardized + 1e-4), cmap="magma")
ax.set_title(f"H={loc[0]} W={loc[1]}")
plt.tight_layout()
plt.show()