In [None]:
import os
import random
from functools import partial

import torch
from pytorch_lightning import LightningDataModule

from src.datamodules.cifar100_datamodule import CIFAR100DataModule, cifar100_normalization
from src.models.nn.moe.resnet_block_moe import ResNetBlockMoE
from src.models.nn.moe.resnet_conv_moe import ResNetConvMoE
from src.models.nn.moe.routing import ConvGlobalAvgRoutingNetwork, GlobalAvgLinearRoutingNetwork
from src.evaluation.activation_visualization import (
    _load_state_dict,
    compute_sensitivity_to_attacks,
    compute_sensitivity_to_gates,
    compute_sensitivity_to_samples,
    plot_grad_grid,
)

In [None]:
# model = ResNetBlockMoE(
#    num_classes=100, num_channels=3, small_inputs=True, layers=18, routing_layer_type=GlobalAvgLinearRoutingNetwork,
#    num_experts=4, k=4
# )

In [None]:
model = ResNetConvMoE(
    num_classes=100,
    num_channels=3,
    small_inputs=True,
    layers=18,
    routing_layer_type=partial(ConvGlobalAvgRoutingNetwork, kernel_size=(3, 3)),
    num_experts=4,
    k=1,
)

In [None]:
# Init datamodule
datamodule: LightningDataModule = CIFAR100DataModule()
datamodule.prepare_data()
datamodule.setup()

In [None]:
run_name = "cifar100-resnet18-free-adv-train-conv-moe4-CGARN-1"
project = "robust-cifar100-resnet-moe"

In [None]:
import wandb

wandb.init("try-visualization", project=project)

In [None]:
state_dict = _load_state_dict(run_name=run_name, project=project)
model.load_state_dict(state_dict)

In [None]:
loader = datamodule.test_dataloader()
batch = next(loader.__iter__())
normalize = cifar100_normalization()

In [None]:
random.seed(5)
choices = random.sample(range(len(batch[0])), k=5)

gates = [
    model.model.layer4[0].conv1.gate,
    model.model.layer4[0].conv2.gate,
    model.model.layer4[1].conv1.gate,
    model.model.layer4[1].conv2.gate,
]
gate_names = [f"layer4.{b}.conv{c}" for b, c in ((0, 1), (0, 2), (1, 1), (1, 2))]

for idx in choices:
    sample, target = batch[0][idx], batch[1][idx]
    grad_grid = compute_sensitivity_to_gates(model, gates, normalize, sample)

    target_class = datamodule.class_map[target.item()]
    titles = [f"Class: {target_class}, Gate: {gate_name}" for gate_name in gate_names]

    fig = plot_grad_grid(len(gates) * [sample], grad_grid, titles)

    os.makedirs(f"expert_attention_{run_name}", exist_ok=True)
    fig.savefig(f"expert_attention_{run_name}/{idx}_{target_class}.png")

In [None]:
random.seed(5)
choices = random.sample(range(len(batch[0])), k=5)

gates = [
    model.model.layer4[0].conv1.gate,
    model.model.layer4[0].conv2.gate,
    model.model.layer4[1].conv1.gate,
    model.model.layer4[1].conv2.gate,
]
gate_names = [f"layer4.{b}.conv{c}" for b, c in ((0, 1), (0, 2), (1, 1), (1, 2))]

for gate, gate_name in zip(gates, gate_names):
    samples, targets = batch[0][choices], batch[1][choices]
    samples = torch.rand_like(samples) * samples.std() + samples.mean()
    grad_grid = compute_sensitivity_to_samples(model, gate, normalize, samples)

    target_classes = [datamodule.class_map[target.item()] for target in targets]
    titles = [f"Class: {target_class}, Gate: {gate_name}" for target_class in target_classes]

    fig = plot_grad_grid(samples, grad_grid, titles)

    os.makedirs(f"expert_attention_{run_name}", exist_ok=True)
    fig.savefig(f"expert_attention_{run_name}/random_{gate_name}.png")

In [None]:
% load_ext autoreload
% autoreload 2

random.seed(5)
choices = random.sample(range(len(batch[0])), k=5)

gate = model.model.layer4[0].conv2.gate
gate_name = "layer4.0.conv2"

steps_list = [0, 8, 16, 32, 64]
eps = 8
alpha = 2

for idx in choices:
    sample, target = batch[0][idx], batch[1][idx]
    grad_grid, all_inputs = compute_sensitivity_to_attacks(
        model, normalize, gate, sample, target, eps=eps, alpha=alpha, steps_list=steps_list
    )
    preds = [pred for _, _, pred in grad_grid]

    target_class = datamodule.class_map[target.item()]
    titles = [
        f"Class: {target_class},Pred: {datamodule.class_map[pred.argmax().item()]},PGD(steps={steps},eps={eps},alpha={alpha})"
        for steps, pred in zip(steps_list, preds)
    ]

    fig = plot_grad_grid(all_inputs, grad_grid, titles)

    os.makedirs(f"expert_attention_{run_name}", exist_ok=True)
    fig.savefig(f"expert_attention_{run_name}/{idx}_{target_class}.png")