In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np
from scipy.stats import rankdata
from itertools import product

from src.datasets import get_cifar_test
from src.paths import get_local_data_dir
from src.utils import ActivationSwitch, LossSwitch, DatasetSwitch

dataset = DatasetSwitch.CIFAR10

bias = "--nobias"
port = ""  # "--port 5678"
block_main = ""  # "--block_main"
batch_size = 64
lr = 1e-4
epochs = 50
img_size = 32
ckpt_mod = 10
add_inverse = False
losses = [
    LossSwitch.CE,
    # LossSwitch.MSE
]
activations = [
    ActivationSwitch.RELU,
    ActivationSwitch.LEAKY_RELU,
    ActivationSwitch.SOFTPLUS_B1,
    ActivationSwitch.SOFTPLUS_B10,
    # ActivationSwitch.SOFTPLUS_B100,
    # ActivationSwitch.SOFTPLUS_B1000,
    # ActivationSwitch.SOFTPLUS_B10000,
    # ActivationSwitch.SOFTPLUS_B100000,
]

COMPUTE_DATA_DIR = get_local_data_dir(dataset)

In [None]:
test_dataloader = get_cifar_test(
    root_path=COMPUTE_DATA_DIR,
    add_inverse=add_inverse,
)

In [None]:
x,y = next(iter(test_dataloader))
x = x.squeeze(0)

In [15]:
checkpoint_dir = 'checkpoints'
device = torch.device("cpu")

In [None]:
image_idx = 1

plt.figure(figsize=(10, 10))
counter = 2
for i, (activation, loss, epoch, bias) in enumerate(
    product(
        activations,
        losses,
        [epochs - 1],
        [1],
    )
):
    if bias == 1:
        path = "_False"
    elif bias == 0:
        path = "_True"
    else:
        raise ValueError(f"Invalid bias value: {bias}")

    value = True if bias == 2 else False
    conv_bias, fc_bias = (True, True) if bias == 0 else (False, value)

    checkpoint_filename = f"{activation}_{path}.pth"

    raw_image, label = training_data[image_idx]
    image = raw_image.unsqueeze(0).to(device)
    image.requires_grad = True

    # Load the model from the checkpoint
    activation_fn = convert_str_to_activation_fn(activation)
    model = NeuralNetwork(activation_fn, conv_bias, True).to(device)
    checkpoint_path = f"{checkpoint_dir}/{checkpoint_filename}"
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint)
    model.eval()  # Set the model to evaluation mode

    output = model(image)
    output.max().backward()
    input_gradient = np.abs(image.grad.detach().cpu().numpy())
    input_gradient = input_gradient / np.max(input_gradient)
    plt.subplot(3, 3, counter)
    counter += 2 if counter % 3 == 0 else 1
    plt.imshow(input_gradient.squeeze())
    plt.title(f"Input Gradient {activation} {loss} {epoch} {bias}")

plt.subplot(1, 3, 1)
plt.imshow(raw_image.squeeze())
plt.title(f"input image {activation} {loss} {epoch}")

In [None]:
image_idx = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

plt.figure(figsize=(20, 24))
counter = 2
for i,(activation,loss,epoch,bias,j,k) in enumerate(product(activations,[LossSwitch.CE],[epochs-1],[2],range(2),[0])):

    if bias == 2:
        path = "_conv_False"
    elif bias == 1:
        path = "_False"
    elif bias == 0:
        path = ""

    value = True if bias == 2 else False
    conv_bias,fc_bias = (True,True) if bias == 0 else (False,value)

    checkpoint_filename = f'{activation}_{loss}_{epoch}{path}.pth'

    raw_image, label = training_data[image_idx]
    image = raw_image.unsqueeze(0).to(device)
    image.requires_grad = True

    # Load the model from the checkpoint
    activation_fn = convert_str_to_activation_fn(activation)
    model = NeuralNetwork(activation_fn,conv_bias,fc_bias).to(device)
    checkpoint_path = f'{checkpoint_dir}/{checkpoint_filename}'
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint)
    model.eval()  # Set the model to evaluation mode

    output = model(image)
    output = output.max() if k == 0 else output.min()
    output.backward()

    input_gradient = image.grad.detach().cpu().numpy()
    input_gradient = np.abs(input_gradient) if j == 0 else input_gradient
    input_gradient = input_gradient/np.max(input_gradient)

    plt.subplot(6, 5, counter)
    counter += 2 if counter%5 == 0 else 1

    minmax_str = 'max' if k == 0 else 'min'
    abs_str = 'abs' if j == 0 else ''
    cmap = 'Blues' if j == 0 else 'bwr'

    plt.imshow(input_gradient.squeeze(),cmap=cmap)
    plt.title(f'VG {activation} {abs_str} {minmax_str}')


    expected_shape = input_gradient.squeeze().shape
    U = rankdata(input_gradient.flatten()) / (input_gradient.squeeze().shape[0] * input_gradient.squeeze().shape[1])
    U = U.reshape(expected_shape)
    plt.subplot(6, 5, counter)
    counter += 2 if counter%5 == 0 else 1
    # U = np.stack([image.detach().cpu().numpy().squeeze(),U,np.zeros_like(U)],axis=2)
    plt.imshow(U,cmap=cmap)

    plt.title(f'IT {activation} {abs_str} {bias}')

plt.subplot(1, 5, 1)
plt.imshow(raw_image.squeeze())
plt.title(f'input image {activation} {loss} {epoch}')

In [None]:
image_idx = 1


plt.figure(figsize=(8, 8))


for activation,epoch,bias in product(activations,[epochs-1],[2]):
    if bias == 2:
        path = "_conv_False"
    elif bias == 1:
        path = "_False"
    elif bias == 0:
        path = ""

    value = True if bias == 2 else False
    conv_bias,fc_bias = (True,True) if bias == 0 else (False,value)

    checkpoint_filename = f'{activation}_{loss}_{epoch}{path}.pth'
    raw_image, label = training_data[image_idx]
    image = raw_image.unsqueeze(0).to(device)
    image.requires_grad = True

    activation_fn = convert_str_to_activation_fn(activation)
    model = NeuralNetwork(activation_fn, conv_bias=conv_bias, fc_bias=fc_bias).to(device)
    checkpoint_path = f'{checkpoint_dir}/{checkpoint_filename}'
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint)
    model.eval()

    output = model(image)
    output.max().backward()
    input_gradient = np.abs(image.grad.detach().cpu().numpy())
    input_gradient = input_gradient/np.max(input_gradient)

    input_gradient = input_gradient.squeeze()
    power_spectrum = compute_1d_power_spectrum(input_gradient)
    # power_spectrum = power_spectrum[1:]
    power_spectrum = power_spectrum/power_spectrum.sum()
    plt.plot(power_spectrum[1:],alpha=0.8,label=checkpoint_filename) # remove constant component for better visualization

plt.xscale('log')
plt.yscale('log')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Spectral Density')
plt.title('Spectral Density of Input Gradient')
plt.legend()
plt.show()