In [None]:
import os
# set gpu number to 2
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
from attention_all_layers import TemporalAugmentedDataset, EvalDataWrapper
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.5,), (0.5,))
])
from utils.transforms import MeanFlat, RandomRepeatedNoise, Identity
from functools import partial

eye = Identity()

def worker_init_fn(worker_id):
    os.sched_setaffinity(0, range(os.cpu_count()))


timestep_transforms = [eye] * 20
# Create instances of the Fashion MNIST dataset
test_dataset = TemporalAugmentedDataset('test', transform=transform,
                                img_to_timesteps_transforms=timestep_transforms)
from torch.utils.data import DataLoader, Dataset

from utils.visualization import visualize_first_batch_with_timesteps

test_loader = DataLoader(EvalDataWrapper(test_dataset, contrast=1, rep_noise=False), batch_size=100, shuffle=False, num_workers=30, worker_init_fn=worker_init_fn)

In [None]:
from modules.exponential_decay import ExponentialDecay
from modules.divisive_norm import DivisiveNorm
from modules.div_norm_channel import DivisiveNormChannel
from models.HookedRecursiveCNN import HookedRecursiveCNN

# HookedRecursiveCNN needs layer_kwargs and div_norm_kwargs to know how to setup the model but the concrete init values are unimportant as they will get overwritten with the pretrained values
layer_kwargs = [{'in_channels': 1, 'out_channels': 32, 'kernel_size': 5},
 {'in_channels': 32, 'out_channels': 32, 'kernel_size': 5},
 {'in_channels': 32, 'out_channels': 32, 'kernel_size': 3},
 {'in_channels': 32, 'out_channels': 32, 'kernel_size': 3},
 {'in_features': 128, 'out_features': 1024}]

div_norm_kwargs = [
    {"epsilon":  1e-8, "K_init":  0.2, "train_K":  True, "alpha_init":  -2.0, "train_alpha": True, "sigma_init": 0.1, "train_sigma": True, 'sqrt': True},
    {"epsilon":  1e-8, "K_init":  1.0, "train_K":  False, "alpha_init":  -2000000.0, "train_alpha": False, "sigma_init": 1.0, "train_sigma": False, 'sqrt': True},
    {"epsilon":  1e-8, "K_init":  1.0, "train_K":  False, "alpha_init":  -2000000.0, "train_alpha": False, "sigma_init": 1.0, "train_sigma": False, 'sqrt': True},
    {"epsilon":  1e-8, "K_init":  1.0, "train_K":  False, "alpha_init":  -2000000.0, "train_alpha": False, "sigma_init": 1.0, "train_sigma": False},
    {"epsilon":  1e-8, "K_init":  1.0, "train_K":  False, "alpha_init":  0.0, "train_alpha": False, "sigma_init": 1.0, "train_sigma": False}
  ]
exp_decay_kwargs = [
    {"alpha_init":  1.0, "train_alpha": True, "beta_init": 1, "train_beta": True},
    {"alpha_init":  1.0, "train_alpha": True, "beta_init": 1, "train_beta": True},
    {"alpha_init":  1.0, "train_alpha": True, "beta_init": 1, "train_beta": True},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False}
  ]

div_norm_cfg = {
    't_steps': 20, 'layer_kwargs': layer_kwargs,
    'adaptation_module': DivisiveNorm,
    'adaptation_kwargs': div_norm_kwargs, 'decode_every_timestep': True
}
exp_decay_cfg = {
    't_steps': 20, 'layer_kwargs': layer_kwargs,
    'adaptation_module': ExponentialDecay,
    'adaptation_kwargs': exp_decay_kwargs, 'decode_every_timestep': True
}

In [None]:
from tqdm import tqdm
from torchmetrics.functional import accuracy
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from temporal_datasets.one_image_temporal_augmented_dataset import OneImageTemporalAugmentedDataset

In [None]:
import torchvision.transforms as transforms
from utils.transforms import Identity
from torch.utils.data import DataLoader
from utils.visualization import visualize_first_batch_with_timesteps

transform = transforms.Compose([
    transforms.ToTensor(),
])

eye = Identity()

timestep_transforms = [eye] * 20
split = 'test'
batch_size = 50
num_workers = 20
one_image_dataset = TemporalAugmentedDataset(split, transform=transform,
                                img_to_timesteps_transforms=timestep_transforms)

shuffle = True if split=='train' else False
loader = DataLoader(one_image_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, worker_init_fn=worker_init_fn)

x, y = next(iter(loader))
int_to_label = {
    0: 'T-shirt',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle boot'
}

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrow

idx = 7
sample = x[idx].permute((0, 2, 3, 1))
fig, axes = plt.subplots(1, 20, figsize=(20, 4))  # Adjusted figsize to make it wider
fig.subplots_adjust(wspace=0.02)

for i, ax in enumerate(axes):
    ax.imshow(sample[i], cmap='gray', vmin=0, vmax=1)  # Set colormap to greyscale
    ax.set_xlabel(int_to_label[int(y[idx, i])], rotation=25, labelpad=10, fontsize=14)  # Rotate and position labels
    ax.set_xticks([])  # Remove x ticks
    ax.set_yticks([])  # Remove y ticks

# Drawing a long arrow
arrow = FancyArrow(0.15, 0.2, 0.65, 0, width=0.01, color='black', transform=fig.transFigure, clip_on=False)
fig.add_artist(arrow)

plt.savefig("figures/attnl_001.svg", format='svg')
plt.savefig("figures/attnl_001.png")

In [None]:
from tqdm import tqdm
from torchmetrics.functional import accuracy
import pandas as pd
import torch

df = pd.DataFrame({'Model': [], 'Accuracy': [], 'Timestep': []})

j=0
for x, y in tqdm(loader):
    x = x.cuda()
    y = y.cuda()
    for name, model in models.items():
        model.cpu()

        model.cuda()
        
        logits = model(x)
        for i in range(20):
            l = logits[:, i, :]
            t = y[:, i]
            preds = torch.argmax(l, dim=1)
            acc = accuracy(preds, t, task='multiclass', num_classes=10)
            df.loc[len(df)] = [name, float(acc.cpu()), i]
    j += 1
    if j > 50:
        break
df['Timestep'] += 1
from matplotlib import pyplot as plt
import seaborn as sns

plt.figure(figsize=(3, 3))
sns.set_style('white')
sns.lineplot(data=df, x='Timestep', y='Accuracy', hue='Model', palette='dark')
plt.ylim(0, 1)
sns.despine(offset=10)
plt.tight_layout()

plt.savefig("figures/attnl_002.svg", format='svg')
plt.savefig("figures/attnl_002.png")

In [None]:
x, y = next(iter(loader))
x = x.cuda()
y = y.cuda()
_, cache = div_norm_model.run_with_cache(x)

In [None]:
idx = 0
map = [0, 7, 0]
sample = x[idx].permute((0, 2, 3, 1)).cpu()
fig, axes = plt.subplots(ncols=8, nrows=4, figsize=(10, 5))
fig.subplots_adjust(wspace=0.03, hspace=0.01)

for i, ax in enumerate(axes[0]):
    ax.imshow(sample[i], cmap='gray')  # Set colormap to greyscale
    ax.set_xticks([]) 
    ax.set_yticks([]) 

for layer in range(3):
    for i, ax in enumerate(axes[layer+1]):
        ax.imshow(cache[f'hks.adapt_{layer}_{i}'][idx, map[layer]].cpu(), vmin=0, vmax=1)
        ax.set_xticks([]) 
        ax.set_yticks([])
    axes[layer + 1, 0].set_ylabel(f'Layer {layer + 1}', fontsize=12)
axes[0, 0].set_ylabel('Input', fontsize=14)

    
# Drawing a long arrow
arrow = FancyArrow(0.15, 0.05, 0.65, 0, width=0.01, color='black', transform=fig.transFigure, clip_on=False)
fig.add_artist(arrow)

plt.savefig("figures/attnl_005.svg", format='svg')
plt.savefig("figures/attnl_005.png", format='svg')

In [None]:
def get_loader(split='train', num_workers=10, batch_size=64):
    import torchvision.transforms as transforms
    from utils.transforms import Identity
    from torch.utils.data import DataLoader
    from utils.visualization import visualize_first_batch_with_timesteps

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    eye = Identity()

    timestep_transforms = [eye] * 20
    one_image_dataset = OneImageTemporalAugmentedDataset(split, transform=transform,
                                    img_to_timesteps_transforms=timestep_transforms)

    shuffle = True if split=='train' else False
    loader = DataLoader(one_image_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, worker_init_fn=worker_init_fn)

    # visualize_first_batch_with_timesteps(loader, 8)
    return loader

actv_dict = {'Timestep': [], 'Layer': [], 'Mean': [], 'num_active': [], 'mean_not_null': [], 'Norm': [], 'Model': [], 'State': [], 'Map': []}
one_image_dataset = OneImageTemporalAugmentedDataset('test', transform=transform,
                                img_to_timesteps_transforms=timestep_transforms)

loader = DataLoader(one_image_dataset, batch_size=50, shuffle=False, num_workers=25, pin_memory=True, pin_memory_device='cuda', worker_init_fn=worker_init_fn, persistent_workers=True)

In [None]:
from tqdm import tqdm
from torchmetrics.functional import accuracy
import pandas as pd

df = pd.DataFrame({'Model': [], 'Accuracy': [], 'Timestep': []})

j=0
for x, y in tqdm(loader):
    x = x.cuda()
    y = y.cuda()
    for name, model in models.items():
        model.cpu()

        model.cuda()
        
        logits = model(x)
        for i in range(20):
            l = logits[:, i, :]
            t = y[:, i]
            preds = torch.argmax(l, dim=1)
            acc = accuracy(preds, t, task='multiclass', num_classes=10)
            df.loc[len(df)] = [name, float(acc.cpu()), i]
    j += 1
    if j > 50:
        break
df['Timestep'] += 1
from matplotlib import pyplot as plt
import seaborn as sns

plt.figure(figsize=(3, 3))
sns.set_style('white')
sns.lineplot(data=df, x='Timestep', y='Accuracy', hue='Model', palette='dark', legend=False)
plt.ylim(0, 1)
sns.despine(offset=10)
plt.tight_layout()

plt.savefig("figures/attnl_003.svg", format='svg')
plt.savefig("figures/attnl_003.png")

In [None]:
for arg in div_norm_kwargs:
    arg['n_channels'] = 32
div_norm_channel_cfg = {
    't_steps': 20, 'layer_kwargs': layer_kwargs,
    'adaptation_module': DivisiveNormChannel,
    'adaptation_kwargs': div_norm_kwargs, 'decode_every_timestep': True
}
div_norm_model = HookedRecursiveCNN.load_from_checkpoint(
    'learned_models/new_augmented_attn_all_layers_DivisiveNormChannel_baseline=False_contrast_random_epoch_50.ckpt', div_norm_channel_cfg)
models['Divisive Norm. Channel'] = div_norm_model

In [None]:
j = 0
for x, y in tqdm(loader):
    x = x.cuda()
    y = y.cuda()
    for name, model in models.items():
        model.cuda()
        logits, cache = model.run_with_cache(x)
        for layer in range(4):
            for timestep in range(20):
                actv = cache[f'hks.adapt_{layer}_{timestep}']
                for map in range(32):
                    actv_dict['Map'].append(map)
                    actv_dict['Timestep'].append(timestep)
                    actv_dict['Layer'].append(layer)
                    actv_dict['Mean'].append(float(actv[:, map].mean()))
                    actv_dict['num_active'].append(float((actv[:, map] > 1e-4).sum()))
                    actv_dict['mean_not_null'].append(float(actv[:, map][actv[:, map] > 1e-4].mean()))   
                    actv_dict['Norm'].append(float(actv[:, map].norm()))
                    actv_dict['Model'].append(name)
                    actv_dict['State'].append(float(cache[f'hks.state_{layer}_{timestep}'][:, map].mean()))
    j += 1
    if j > 50:
        break

actv_df = pd.DataFrame(actv_dict)
actv_df

In [None]:
actv_df = pd.DataFrame(actv_dict)
actv_df['Timestep'] += 1

In [None]:
actv_df['Normalized Activations'] = actv_df.groupby(['Model', 'Layer'], sort=False).apply(lambda df: df['Mean'] / df.loc[df.Timestep==1, 'Mean'].mean()).reset_index(level=['Model', 'Layer'], drop=True)

In [None]:
sns.relplot(data=actv_df[actv_df.Model.isin(['Divisive Norm.', 'Additive']) & (actv_df.Layer < 3)], x='Timestep', y='Normalized Activations', hue='Layer', col='Model', kind='line', height=3)
sns.despine(offset=5)
plt.savefig("figures/attnl_004.svg", format='svg')
plt.savefig("figures/attnl_004.png")

# Causal experiments

In [None]:
def hook_fn( actv, hook, target_actv):
    return target_actv

In [None]:
actv_dict = {'Timestep': [], 'Layer': [], 'Mean': [], 'num_active': [], 'mean_not_null': [], 'Norm': [], 'Model': [], 'State': [], 'Map': [], 'Intervention Layer': []}
j = 0
for x, y in tqdm(loader):
    x = x.cuda()
    y = y.cuda()
    for intervention_layer in range(3):
        for name, model in models.items():
            model.cuda()
            _, c = model.run_with_cache(x)
            target_actv = c[f'hks.adapt_{intervention_layer}_0']
            hook = partial(hook_fn, target_actv=target_actv)
            hooks = [(f'hks.adapt_{intervention_layer}_{i}', hook) for i in range(20)]
            with model.hooks(hooks):
                _, cache = model.run_with_cache(x)
            for layer in range(4):
                for timestep in range(20):
                    actv = cache[f'hks.adapt_{layer}_{timestep}']
                    for map in range(32):
                        actv_dict['Map'].append(map)
                        actv_dict['Timestep'].append(timestep)
                        actv_dict['Layer'].append(layer)
                        actv_dict['Mean'].append(float(actv[:, map].mean()))
                        actv_dict['num_active'].append(float((actv[:, map] > 1e-4).sum()))
                        actv_dict['mean_not_null'].append(float(actv[:, map][actv[:, map] > 1e-4].mean()))   
                        actv_dict['Norm'].append(float(actv[:, map].norm()))
                        actv_dict['Model'].append(name)
                        actv_dict['State'].append(float(cache[f'hks.state_{layer}_{timestep}'][:, map].mean()))
                        actv_dict['Intervention Layer'].append(intervention_layer)
    j += 1
    if j > 50:
        break

actv_df = pd.DataFrame(actv_dict)
actv_df

In [None]:
actv_df['Timestep'] += 1
actv_df['Normalized Activations'] = actv_df.groupby(['Model', 'Layer', 'Intervention Layer'], sort=False).apply(
    lambda df: df['Mean'] / df.loc[df.Timestep == 1, 'Mean'].mean()).reset_index(level=['Model', 'Layer', 'Intervention Layer'], drop=True)

In [None]:
palette = sns.color_palette("rocket", n_colors=4)[::-1]

In [None]:
fig, ax = plt.subplots(figsize=(6, 3), ncols=2)
sns.set_style('white')
for layer in range(3):
    sns.lineplot(data=actv_df[(actv_df.Model.isin(['Divisive Norm.'])) & (actv_df.Layer==layer + 1) & (actv_df['Intervention Layer'] == layer)], x='Timestep', y='Normalized Activations', ax=ax[0], color=palette[layer + 1])
    sns.lineplot(data=actv_df[(actv_df.Model.isin(['Additive'])) & (actv_df.Layer==layer + 1) & (actv_df['Intervention Layer'] == layer)], x='Timestep', y='Normalized Activations', ax=ax[1], label=f'Layer {layer + 1}', color=palette[layer + 1])
# disable gridlines
ax[0].grid(False)
ax[1].grid(False)
sns.despine(offset=5)

# save
plt.tight_layout()
plt.savefig("figures/attnl_006.svg", format='svg')
plt.savefig("figures/attnl_006.png")