In [None]:
import os
# set gpu number to 2
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [None]:
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import Dataset


class TemporalDataset(Dataset):
    def __init__(self, split, dataset='fashion_mnist', transform=None, img_to_timesteps_transforms=None):
        """
        Initializes the FashionMNISTNoisyDataset
        :param split: 'train' or 'test'
        :param dataset: name of the dataset
            must be one of the datasets in the huggingface datasets package
        :param transform: transforms to apply to the images
        :param img_to_timesteps_transforms: list of functions
            for every desired timestep, there should be a function in the list that converts
            the image to the desired format
        """
        self.split = split
        self.transform = transform
        self.dataset = load_dataset(dataset, split=split)
        input_col_name = 'img' if 'img' in self.dataset.column_names else 'image'  # because different datasets have different names
        self.data, self.targets = self.dataset[input_col_name], self.dataset['label']
        self.img_to_timesteps_transforms = img_to_timesteps_transforms
        
    def int_to_coordinate(self, index):
        if index == 0:
            return 0, 0
        elif index == 1:
            return 0, 28
        elif index == 2:
            return 28, 0
        elif index == 3:
            return 28, 28
        else:
            raise ValueError('index must be between 0 and 3')
        
    def adjust_contrast(self, img, contrast):
        mean = img.mean()
        img = (img - mean) * contrast #+ mean
        return img

    def __getitem__(self, index):
        img_timesteps = list()
        labels = list()
        
        # sample 4 ints between 0 and 20
        img_onsets = np.random.randint(0, 20, 3)
        img_onsets = np.append(img_onsets, 0)
        img_locations = np.random.choice(4, 4, replace=False)
        prev_img = torch.zeros((1, 28 * 2, 28 * 2)) + 0.5
        n_image = 0
        for i, trans_func in enumerate(self.img_to_timesteps_transforms):
            
            if i in img_onsets:
                # count number of times it's in the list
                count = img_onsets.tolist().count(i)
                cur_labels = list()
                cur_contrasts = list()
                for j in range(count):
                    # sample random image
                    idx = np.random.randint(0, len(self.data))
                    img, target = self.data[idx], int(self.targets[idx])
                    
                    if self.transform is not None:
                        img = self.transform(img)
                    mask = img > 1e-2
                    img = trans_func(img, index, target)
                    
                    rand_contrast = np.random.uniform(0.1, 1)
                    img = self.adjust_contrast(img, rand_contrast)
                                        
                    new_img = torch.zeros((1, 28 * 2, 28 * 2))
                    x, y = self.int_to_coordinate(img_locations[n_image])
                    new_img[:, x:x+28, y:y+28] = img * mask
                    n_image += 1
                    
                    cur_labels.append(target)
                    cur_contrasts.append(rand_contrast)
                    prev_img = prev_img + new_img
                labels.append(
                    cur_labels[np.array(cur_contrasts).argmax()]
                )
                
                img_timesteps.append(prev_img)
            else:
                img_timesteps.append(prev_img)
                labels.append(labels[-1])                    
            
        # Stack the augmented images along the timestep dimension
        img_timesteps = torch.stack(img_timesteps, dim=0)
        labels = torch.tensor(labels)
        return img_timesteps, labels

    def __len__(self):
        return len(self.data)


In [None]:
import torchvision.transforms as transforms

transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.5,), (0.5,))
    ])

In [None]:
from utils.transforms import Identity

eye = Identity()

def worker_init_fn(worker_id):
    os.sched_setaffinity(0, range(os.cpu_count())) 
    
timestep_transforms = [eye] * 20
# Create instances of the Fashion MNIST dataset
train_dataset = TemporalDataset('train', transform=transform,
                                     img_to_timesteps_transforms=timestep_transforms)

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=100, worker_init_fn=worker_init_fn)


In [None]:
class EvalDataWrapper(Dataset):
    """Simple Wrapper that adds contrast and repeated noise information to the dataset to power
    the evaluation metrics"""

    def __init__(self, dataset, contrast, rep_noise):
        self.dataset = dataset
        self.contrast = float(contrast)
        self.rep_noise = bool(rep_noise)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        if type(idx) == int:
            idx = torch.tensor([idx])
        contrast = torch.full_like(idx, self.contrast, dtype=torch.float)
        rep_noise = torch.full_like(idx, self.rep_noise, dtype=torch.bool)
        return x, y, contrast, rep_noise
test_loader = DataLoader(EvalDataWrapper(train_dataset, contrast=1, rep_noise=False), batch_size=64, shuffle=True, num_workers=100, worker_init_fn=worker_init_fn)

# Train Divisive Normalization

In [None]:
from modules.lateral_recurrence import LateralRecurrence
from modules.exponential_decay import ExponentialDecay
from modules.divisive_norm import DivisiveNorm
from modules.divisive_norm_group import DivisiveNormGroup
from modules.div_norm_channel import DivisiveNormChannel
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
import pytorch_lightning as pl
import json
from models.adaptation import Adaptation
from models.HookedRecursiveCNN import HookedRecursiveCNN

config_path = 'config.json'
with open(config_path, 'r') as f:
    config = json.load(f)

dataset = config["dataset"]
if config["dataset"] == 'fashion_mnist':
    layer_kwargs = config["layer_kwargs_fmnist"]
elif config["dataset"] == 'cifar10':
    layer_kwargs = config["layer_kwargs_cifar10"]

# Define transforms for data augmentation
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.5,), (0.5,))
])

logger = CSVLogger(config["log_dir"], name=config["log_name"])

if config["adaptation_module"] == 'LateralRecurrence':
    adaptation_module = LateralRecurrence
    adaptation_kwargs = config["adaptation_kwargs_lateral"]
elif config["adaptation_module"] == 'ExponentialDecay':
    adaptation_module = ExponentialDecay
    adaptation_kwargs = config["adaptation_kwargs_additive"]
elif config["adaptation_module"] == 'DivisiveNorm':
    adaptation_module = DivisiveNorm
    adaptation_kwargs = config["adaptation_kwargs_div_norm"]
elif config["adaptation_module"] == 'DivisiveNormGroup':
    adaptation_module = DivisiveNormGroup
    adaptation_kwargs = config["adaptation_kwargs_div_norm_group"]
elif config["adaptation_module"] == 'DivisiveNormChannel':
    adaptation_module = DivisiveNormChannel
    adaptation_kwargs = config["adaptation_kwargs_div_norm_channel"]
else:
    raise ValueError(f'Adaptation module {config["adaptation_module"]} not implemented')

t_steps = 20

num_epoch = 10

layer_kwargs = [{'in_channels': 1, 'out_channels': 32, 'kernel_size': 5},
 {'in_channels': 32, 'out_channels': 32, 'kernel_size': 5},
 {'in_channels': 32, 'out_channels': 32, 'kernel_size': 3},
 {'in_channels': 32, 'out_channels': 32, 'kernel_size': 3},
 {'in_features': 128, 'out_features': 1024}]
adaptation_kwargs = [
    {"epsilon":  1e-8, "K_init":  0.2, "train_K":  True, "alpha_init":  -2.0, "train_alpha": True, "sigma_init": 0.1, "train_sigma": True},
    {"epsilon":  1e-8, "K_init":  1.0, "train_K":  False, "alpha_init":  -2000000.0, "train_alpha": False, "sigma_init": 1.0, "train_sigma": False},
    {"epsilon":  1e-8, "K_init":  1.0, "train_K":  False, "alpha_init":  -2000000.0, "train_alpha": False, "sigma_init": 1.0, "train_sigma": False},
    {"epsilon":  1e-8, "K_init":  1.0, "train_K":  False, "alpha_init":  -2000000.0, "train_alpha": False, "sigma_init": 1.0, "train_sigma": False},
    {"epsilon":  1e-8, "K_init":  1.0, "train_K":  False, "alpha_init":  0.0, "train_alpha": False, "sigma_init": 1.0, "train_sigma": False}
  ]

hooked_model = HookedRecursiveCNN(t_steps=t_steps, layer_kwargs=layer_kwargs,
                                  adaptation_module=adaptation_module,
                                  adaptation_kwargs=adaptation_kwargs, decode_every_timestep=True)
model = Adaptation(hooked_model, lr=config["lr"], contrast_metrics=False)

contrast = 'random'
tb_logger = TensorBoardLogger("lightning_logs",
                              name=f'video_{config["adaptation_module"]}_001',
                              version=f'videon_{config["adaptation_module"]}_001')

# wandb.init(project='ai-thesis', config=config, entity='ai-thesis', name=f'{config["log_name"]}_{config["adaptation_module"]}_c_{contrast}_rep_{repeat_noise}_ep_{num_epoch}')
wandb_logger = pl.loggers.WandbLogger(project='ai-thesis', config=config,
                                      name=f'video_fmnist_002_{config["adaptation_module"]}_c_{contrast}_ep_{num_epoch}_{config["log_name"]}')

trainer = pl.Trainer(max_epochs=num_epoch, logger=wandb_logger)
# test_results = trainer.test(model, dataloaders=test_loader)
wandb_logger.watch(hooked_model, log='all', log_freq=1000)

trainer.fit(model, train_loader, test_loader)

# test
# test_results = trainer.test(model, dataloaders=train_loader)
# logger.log_metrics({'contrast': contrast, 'epoch': num_epoch, 'repeat_noise': 'n/a',
#                     'test_acc': test_results[0]["test_acc"]})
# logger.save()

trainer.save_checkpoint(
    f'learned_models/video_fmnist_{config["adaptation_module"]}_contrast_{contrast}_epoch_{num_epoch}.ckpt')

# Train Exponential Decay

In [None]:
# Create instances of the Fashion MNIST dataset
train_dataset = TemporalDataset('train', transform=transform,
                                img_to_timesteps_transforms=timestep_transforms)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=3)

test_loader = DataLoader(EvalDataWrapper(train_dataset, contrast=1, rep_noise=False), batch_size=64, shuffle=True,
                         num_workers=3)
config_path = 'config.json'
with open(config_path, 'r') as f:
    config = json.load(f)

dataset = config["dataset"]
if config["dataset"] == 'fashion_mnist':
    layer_kwargs = config["layer_kwargs_fmnist"]
elif config["dataset"] == 'cifar10':
    layer_kwargs = config["layer_kwargs_cifar10"]
    
config['adaptation_module'] = 'ExponentialDecay'

# Define transforms for data augmentation
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.5,), (0.5,))
])

logger = CSVLogger(config["log_dir"], name=config["log_name"])

if config["adaptation_module"] == 'LateralRecurrence':
    adaptation_module = LateralRecurrence
    adaptation_kwargs = config["adaptation_kwargs_lateral"]
elif config["adaptation_module"] == 'ExponentialDecay':
    adaptation_module = ExponentialDecay
    adaptation_kwargs = config["adaptation_kwargs_additive"]
elif config["adaptation_module"] == 'DivisiveNorm':
    adaptation_module = DivisiveNorm
    adaptation_kwargs = config["adaptation_kwargs_div_norm"]
elif config["adaptation_module"] == 'DivisiveNormGroup':
    adaptation_module = DivisiveNormGroup
    adaptation_kwargs = config["adaptation_kwargs_div_norm_group"]
elif config["adaptation_module"] == 'DivisiveNormChannel':
    adaptation_module = DivisiveNormChannel
    adaptation_kwargs = config["adaptation_kwargs_div_norm_channel"]
else:
    raise ValueError(f'Adaptation module {config["adaptation_module"]} not implemented')

t_steps = 20

num_epoch = 10

layer_kwargs = [{'in_channels': 1, 'out_channels': 32, 'kernel_size': 5},
                {'in_channels': 32, 'out_channels': 32, 'kernel_size': 5},
                {'in_channels': 32, 'out_channels': 32, 'kernel_size': 3},
                {'in_channels': 32, 'out_channels': 32, 'kernel_size': 3},
                {'in_features': 128, 'out_features': 1024}]
adaptation_kwargs = [
    {"alpha_init":  0.5, "train_alpha": True, "beta_init": 1, "train_beta": True},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False}
  ]

hooked_model = HookedRecursiveCNN(t_steps=t_steps, layer_kwargs=layer_kwargs,
                                  adaptation_module=adaptation_module,
                                  adaptation_kwargs=adaptation_kwargs, decode_every_timestep=True)
model = Adaptation(hooked_model, lr=config["lr"], contrast_metrics=False)

contrast = 'random'
tb_logger = TensorBoardLogger("lightning_logs",
                              name=f'video_{config["adaptation_module"]}_001',
                              version=f'videon_{config["adaptation_module"]}_001')

# wandb.init(project='ai-thesis', config=config, entity='ai-thesis', name=f'{config["log_name"]}_{config["adaptation_module"]}_c_{contrast}_rep_{repeat_noise}_ep_{num_epoch}')
wandb_logger = pl.loggers.WandbLogger(project='ai-thesis', config=config,
                                      name=f'video_fmnist_002_{config["adaptation_module"]}_c_{contrast}_ep_{num_epoch}_{config["log_name"]}')

trainer = pl.Trainer(max_epochs=num_epoch, logger=wandb_logger)
# test_results = trainer.test(model, dataloaders=test_loader)
wandb_logger.watch(hooked_model, log='all', log_freq=1000)

trainer.fit(model, train_loader, test_loader)

# test
# test_results = trainer.test(model, dataloaders=train_loader)
# logger.log_metrics({'contrast': contrast, 'epoch': num_epoch, 'repeat_noise': 'n/a',
#                     'test_acc': test_results[0]["test_acc"]})
# logger.save()

trainer.save_checkpoint(
    f'learned_models/video_fmnist_{config["adaptation_module"]}_contrast_{contrast}_epoch_{num_epoch}.ckpt')

# Train no-time baseline

In [None]:
from modules.lateral_recurrence import LateralRecurrence
from modules.exponential_decay import ExponentialDecay
from modules.divisive_norm import DivisiveNorm
from modules.divisive_norm_group import DivisiveNormGroup
from modules.div_norm_channel import DivisiveNormChannel
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
import pytorch_lightning as pl
import json
from models.adaptation import Adaptation
from models.HookedRecursiveCNN import HookedRecursiveCNN

In [None]:
# Create instances of the Fashion MNIST dataset
train_dataset = TemporalDataset('train', transform=transform,
                                img_to_timesteps_transforms=timestep_transforms)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=3)

test_loader = DataLoader(EvalDataWrapper(train_dataset, contrast=1, rep_noise=False), batch_size=64, shuffle=True,
                         num_workers=3)
config_path = 'config.json'
with open(config_path, 'r') as f:
    config = json.load(f)

dataset = config["dataset"]
if config["dataset"] == 'fashion_mnist':
    layer_kwargs = config["layer_kwargs_fmnist"]
elif config["dataset"] == 'cifar10':
    layer_kwargs = config["layer_kwargs_cifar10"]
    
config['adaptation_module'] = 'ExponentialDecay'

# Define transforms for data augmentation
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.5,), (0.5,))
])

logger = CSVLogger(config["log_dir"], name=config["log_name"])

if config["adaptation_module"] == 'LateralRecurrence':
    adaptation_module = LateralRecurrence
    adaptation_kwargs = config["adaptation_kwargs_lateral"]
elif config["adaptation_module"] == 'ExponentialDecay':
    adaptation_module = ExponentialDecay
    adaptation_kwargs = config["adaptation_kwargs_additive"]
elif config["adaptation_module"] == 'DivisiveNorm':
    adaptation_module = DivisiveNorm
    adaptation_kwargs = config["adaptation_kwargs_div_norm"]
elif config["adaptation_module"] == 'DivisiveNormGroup':
    adaptation_module = DivisiveNormGroup
    adaptation_kwargs = config["adaptation_kwargs_div_norm_group"]
elif config["adaptation_module"] == 'DivisiveNormChannel':
    adaptation_module = DivisiveNormChannel
    adaptation_kwargs = config["adaptation_kwargs_div_norm_channel"]
else:
    raise ValueError(f'Adaptation module {config["adaptation_module"]} not implemented')

t_steps = 20

num_epoch = 10

layer_kwargs = [{'in_channels': 1, 'out_channels': 32, 'kernel_size': 5},
                {'in_channels': 32, 'out_channels': 32, 'kernel_size': 5},
                {'in_channels': 32, 'out_channels': 32, 'kernel_size': 3},
                {'in_channels': 32, 'out_channels': 32, 'kernel_size': 3},
                {'in_features': 128, 'out_features': 1024}]
adaptation_kwargs = [
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False}
  ]

hooked_model = HookedRecursiveCNN(t_steps=t_steps, layer_kwargs=layer_kwargs,
                                  adaptation_module=adaptation_module,
                                  adaptation_kwargs=adaptation_kwargs, decode_every_timestep=True)
model = Adaptation(hooked_model, lr=config["lr"], contrast_metrics=False)

contrast = 'random'
tb_logger = TensorBoardLogger("lightning_logs",
                              name=f'video_{config["adaptation_module"]}_001',
                              version=f'baseline_video_{config["adaptation_module"]}_001')

# wandb.init(project='ai-thesis', config=config, entity='ai-thesis', name=f'{config["log_name"]}_{config["adaptation_module"]}_c_{contrast}_rep_{repeat_noise}_ep_{num_epoch}')
wandb_logger = pl.loggers.WandbLogger(project='ai-thesis', config=config,
                                      name=f'baseline_video_fmnist_002_{config["adaptation_module"]}_c_{contrast}_ep_{num_epoch}_{config["log_name"]}')

trainer = pl.Trainer(max_epochs=num_epoch, logger=wandb_logger)
# test_results = trainer.test(model, dataloaders=test_loader)
wandb_logger.watch(hooked_model, log='all', log_freq=1000)

trainer.fit(model, train_loader, test_loader)

# test
# test_results = trainer.test(model, dataloaders=train_loader)
# logger.log_metrics({'contrast': contrast, 'epoch': num_epoch, 'repeat_noise': 'n/a',
#                     'test_acc': test_results[0]["test_acc"]})
# logger.save()

trainer.save_checkpoint(
    f'learned_models/baseline_video_fmnist_{config["adaptation_module"]}_contrast_{contrast}_epoch_{num_epoch}.ckpt')

# Visualize data

In [None]:
x, y = next(iter(train_loader))
x.shape, y.shape

In [None]:
int_to_label = {
    0: 'T-shirt',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle boot'
}

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrow

idx = 7
sample = x[idx].permute((0, 2, 3, 1))
fig, axes = plt.subplots(1, 20, figsize=(20, 4))  # Adjusted figsize to make it wider
fig.subplots_adjust(wspace=0.02)

for i, ax in enumerate(axes):
    ax.imshow(sample[i], cmap='gray', vmin=0.4, vmax=1)  # Set colormap to greyscale
    ax.set_xlabel(int_to_label[int(y[idx, i])], rotation=25, labelpad=10, fontsize=14)  # Rotate and position labels
    ax.set_xticks([])  # Remove x ticks
    ax.set_yticks([])  # Remove y ticks

# Drawing a long arrow
arrow = FancyArrow(0.15, 0.2, 0.65, 0, width=0.01, color='black', transform=fig.transFigure, clip_on=False)
fig.add_artist(arrow)

plt.savefig("figures/attn_001.svg", format='svg')
plt.savefig("figures/attn_001.png")


# Load pretrained models

In [None]:
from modules.lateral_recurrence import LateralRecurrence
from modules.exponential_decay import ExponentialDecay
from modules.divisive_norm import DivisiveNorm
from modules.divisive_norm_group import DivisiveNormGroup
from modules.div_norm_channel import DivisiveNormChannel
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
import pytorch_lightning as pl
import json
from models.adaptation import Adaptation
from models.HookedRecursiveCNN import HookedRecursiveCNN

# HookedRecursiveCNN needs layer_kwargs and div_norm_kwargs to know how to setup the model but the concrete init values are unimportant as they will get overwritten with the pretrained values
layer_kwargs = [{'in_channels': 1, 'out_channels': 32, 'kernel_size': 5},
 {'in_channels': 32, 'out_channels': 32, 'kernel_size': 5},
 {'in_channels': 32, 'out_channels': 32, 'kernel_size': 3},
 {'in_channels': 32, 'out_channels': 32, 'kernel_size': 3},
 {'in_features': 128, 'out_features': 1024}]

div_norm_kwargs = [
    {"epsilon":  1e-8, "K_init":  0.2, "train_K":  True, "alpha_init":  -2.0, "train_alpha": True, "sigma_init": 0.1, "train_sigma": True},
    {"epsilon":  1e-8, "K_init":  1.0, "train_K":  False, "alpha_init":  -2000000.0, "train_alpha": False, "sigma_init": 1.0, "train_sigma": False},
    {"epsilon":  1e-8, "K_init":  1.0, "train_K":  False, "alpha_init":  -2000000.0, "train_alpha": False, "sigma_init": 1.0, "train_sigma": False},
    {"epsilon":  1e-8, "K_init":  1.0, "train_K":  False, "alpha_init":  -2000000.0, "train_alpha": False, "sigma_init": 1.0, "train_sigma": False},
    {"epsilon":  1e-8, "K_init":  1.0, "train_K":  False, "alpha_init":  0.0, "train_alpha": False, "sigma_init": 1.0, "train_sigma": False}
  ]
exp_decay_kwargs = [
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False},
    {"alpha_init":  1.0, "train_alpha": False, "beta_init": 1, "train_beta": False}
  ]

div_norm_cfg = {
    't_steps': 20, 'layer_kwargs': layer_kwargs,
    'adaptation_module': DivisiveNorm,
    'adaptation_kwargs': div_norm_kwargs, 'decode_every_timestep': True
}
exp_decay_cfg = {
    't_steps': 20, 'layer_kwargs': layer_kwargs,
    'adaptation_module': ExponentialDecay,
    'adaptation_kwargs': exp_decay_kwargs, 'decode_every_timestep': True
}

In [None]:
div_norm_model = HookedRecursiveCNN.load_from_checkpoint('learned_models/video_fmnist_DivisiveNorm_contrast_random_epoch_10.ckpt', div_norm_cfg)
exp_decay_model = HookedRecursiveCNN.load_from_checkpoint('learned_models/video_fmnist_ExponentialDecay_contrast_random_epoch_10.ckpt', exp_decay_cfg)
baseline_model = HookedRecursiveCNN.load_from_checkpoint('learned_models/baseline_video_fmnist_ExponentialDecay_contrast_random_epoch_10.ckpt', exp_decay_cfg)

In [None]:
models = {
    'Divisive Norm.': div_norm_model,
    'Additive': exp_decay_model,
    'No Adaptation': baseline_model
}

# Accuracy per timestep

In [None]:
from tqdm import tqdm
from torchmetrics.functional import accuracy
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

df = pd.DataFrame({'Model': [], 'Accuracy': [], 'Timestep': []})
for name, model in models.items():
    j = 0
    for x, y, _, _ in tqdm(test_loader):
        x = x.cuda()
        y = y.cuda()
        logits = model(x.cuda())
        for i in range(20):
            l = logits[:, i, :]
            t = y[:, i]
            preds = torch.argmax(l, dim=1)
            acc = accuracy(preds, t, task='multiclass', num_classes=10)
            df.loc[len(df)] = [name, float(acc), i]
        j += 1
        if j > 100:
            break
df['Timestep'] += 1

In [None]:
plt.figure(figsize=(3, 3))
sns.set_style('white')
sns.lineplot(data=df, x='Timestep', y='Accuracy', hue='Model', palette='dark')
plt.ylim(0, 1)
sns.despine(offset=10)
plt.tight_layout()

plt.savefig('figures/attn_002.svg')
plt.savefig('figures/attn_002.png')

## Only one image the whole time

In [None]:
class OneImageTemporalDataset(Dataset):
    def __init__(self, split, dataset='fashion_mnist', transform=None, img_to_timesteps_transforms=None, contrast='random'):
        """
        Initializes the FashionMNISTNoisyDataset
        :param split: 'train' or 'test'
        :param dataset: name of the dataset
            must be one of the datasets in the huggingface datasets package
        :param transform: transforms to apply to the images
        :param img_to_timesteps_transforms: list of functions
            for every desired timestep, there should be a function in the list that converts
            the image to the desired format
        """
        self.split = split
        self.transform = transform
        self.dataset = load_dataset(dataset, split=split)
        input_col_name = 'img' if 'img' in self.dataset.column_names else 'image'  # because different datasets have different names
        self.data, self.targets = self.dataset[input_col_name], self.dataset['label']
        self.img_to_timesteps_transforms = img_to_timesteps_transforms
        self.contrast = contrast

    def int_to_coordinate(self, index):
        if index == 0:
            return 0, 0
        elif index == 1:
            return 0, 28
        elif index == 2:
            return 28, 0
        elif index == 3:
            return 28, 28
        else:
            raise ValueError('index must be between 0 and 3')

    def adjust_contrast(self, img, contrast):
        mean = img.mean()
        img = (img - mean) * contrast  #+ mean
        return img

    def __getitem__(self, index):
        img, target = self.data[index], int(self.targets[index])
        img_timesteps = list()
        labels = list()

        prev_img = torch.zeros((1, 28 * 2, 28 * 2)) + 0.5
        for i, trans_func in enumerate(self.img_to_timesteps_transforms):
            if i == 0:
                if self.transform is not None:
                    img = self.transform(img)
                mask = img > 1e-2
                img = trans_func(img, index, target)

                if self.contrast == 'random':
                    rand_contrast = np.random.uniform(0.1, 1)
                else:
                    rand_contrast = self.contrast
                img = self.adjust_contrast(img, rand_contrast)

                new_img = torch.zeros((1, 28 * 2, 28 * 2))
                x, y = 0, 0
                new_img[:, x:x + 28, y:y + 28] = img * mask

                prev_img = prev_img + new_img
                labels.append(
                    target
                )

                img_timesteps.append(prev_img)
            else:
                img_timesteps.append(prev_img)
                labels.append(labels[-1])

                # Stack the augmented images along the timestep dimension
        img_timesteps = torch.stack(img_timesteps, dim=0)
        labels = torch.tensor(labels)
        return img_timesteps, labels

    def __len__(self):
        return len(self.data)


import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.5,), (0.5,))
])
from utils.transforms import MeanFlat, RandomRepeatedNoise, Identity
from functools import partial

eye = Identity()

timestep_transforms = [eye] * 20
# Create instances of the Fashion MNIST dataset
one_image_dataset = OneImageTemporalDataset('train', transform=transform,
                                img_to_timesteps_transforms=timestep_transforms)
from torch.utils.data import DataLoader

loader = DataLoader(one_image_dataset, batch_size=64, shuffle=True, num_workers=3)

from utils.visualization import visualize_first_batch_with_timesteps

visualize_first_batch_with_timesteps(loader, 8)

In [None]:
from tqdm import tqdm
from torchmetrics.functional import accuracy
import pandas as pd

df = pd.DataFrame({'Model': [], 'Accuracy': [], 'Timestep': []})
for name, model in models.items():
    model.cpu()
    j = 0
    for x, y in tqdm(loader):
        logits = model(x)
        for i in range(20):
            l = logits[:, i, :]
            t = y[:, 0]
            preds = torch.argmax(l, dim=1)
            acc = accuracy(preds, t, task='multiclass', num_classes=10)
            df.loc[len(df)] = [name, float(acc), i]
        j += 1
        if j > 100:
            break

In [None]:
df['Timestep'] += 1

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns

plt.figure(figsize=(3, 3))
sns.set_style('white')
sns.lineplot(data=df, x='Timestep', y='Accuracy', hue='Model', palette='dark')
plt.ylim(0, 1)
sns.despine(offset=10)
plt.tight_layout()

plt.savefig('figures/attn_003.svg')
plt.savefig('figures/attn_003.png')

## per contrast

In [None]:
import pandas as pd

acc_dict = {'Contrast': [], 'Timestep': [], 'Accuracy': [], 'Model': []}
for contrast in [0.2, 0.4, 0.6, 0.8, 1.0]:
    one_image_dataset = OneImageTemporalDataset('test', transform=transform,
                                    img_to_timesteps_transforms=timestep_transforms, contrast=contrast)
    from torch.utils.data import DataLoader
    
    loader = DataLoader(one_image_dataset, batch_size=64, shuffle=False, num_workers=10, worker_init_fn=worker_init_fn)
    
    j = 0
    for x, y in tqdm(loader):
        x = x.cuda()
        y = y.cuda()
        for name, model in models.items():
            model.cuda()
            logits = model(x)
            for i in range(20):
                l = logits[:, i, :]
                t = y[:, 0]
                preds = torch.argmax(l, dim=1)
                acc = accuracy(preds, t, task='multiclass', num_classes=10)
                acc_dict['Contrast'].append(contrast)
                acc_dict['Timestep'].append(i)
                acc_dict['Accuracy'].append(float(acc.cpu()))
                acc_dict['Model'].append(name)
        j += 1
        if j > 100:
            break
df = pd.DataFrame(acc_dict)
df

In [None]:
df.Timestep += 1

In [None]:
sns.set_style('white')
sns.relplot(data=df[df.Model.isin(['Divisive Norm.', 'Additive'])], x='Timestep', y='Accuracy', hue='Contrast', row='Model', height=2.5, aspect=1.5, kind='line')
plt.ylim(0, 1)

sns.despine(offset=5)

plt.savefig('figures/attn_011.svg')
plt.savefig('figures/attn_011.png')

## Activation Scale

In [None]:
actv_dict = {'Timestep': [], 'Layer': [], 'Mean': [], 'num_active': [], 'mean_not_null': [], 'Norm': [], 'Model': []}
one_image_dataset = OneImageTemporalDataset('test', transform=transform,
                                img_to_timesteps_transforms=timestep_transforms, contrast='random')

loader = DataLoader(one_image_dataset, batch_size=64, shuffle=True, num_workers=3, pin_memory=True, pin_memory_device='cuda')
    
for x, y in tqdm(loader):
    x = x.cuda()
    y = y.cuda()
    for name, model in models.items():
        model.cuda()
        logits, cache = model.run_with_cache(x)
        for layer in range(4):
            for timestep in range(20):
                actv = cache[f'hks.adapt_{layer}_{timestep}']
                actv_dict['Timestep'].append(timestep)
                actv_dict['Layer'].append(layer)
                actv_dict['Mean'].append(float(actv.mean()))
                actv_dict['num_active'].append(float((actv > 1e-4).sum()))
                actv_dict['mean_not_null'].append(float(actv[actv > 1e-4].mean()))   
                actv_dict['Norm'].append(float(actv.norm()))
                actv_dict['Model'].append(name)

actv_df = pd.DataFrame(actv_dict)
actv_df

In [None]:
plt.figure(figsize=(3,3))
sns.lineplot(data=actv_df[actv_df.Layer==0], x='Timestep', y='Mean', hue='Model', palette='dark')
sns.despine(offset=5)
plt.tight_layout()
plt.savefig('figures/attn_006.svg')
plt.savefig('figures/attn_006.png')
#plt.ylim(0, 0.09)

In [None]:
plt.figure(figsize=(3,3))
sns.lineplot(data=actv_df[actv_df.Layer==0], x='Timestep', y='Normalized Activations', hue='Model', palette='dark')
sns.despine(offset=5)
plt.tight_layout()
plt.ylim(0, 1)

plt.savefig('figures/attn_006.svg')
plt.savefig('figures/attn_006.png')

In [None]:
actv_dict = {'Timestep': [], 'Layer': [], 'Mean': [], 'num_active': [], 'mean_not_null': [], 'Norm': [], 'Model': [], 'Contrast': []}

for contrast in [0.2, 0.4, 0.6, 0.8, 1.0]:
    one_image_dataset = OneImageTemporalDataset('test', transform=transform,
                                    img_to_timesteps_transforms=timestep_transforms, contrast=contrast)
    
    loader = DataLoader(one_image_dataset, batch_size=64, shuffle=True, num_workers=10, pin_memory=True, pin_memory_device='cuda', worker_init_fn=worker_init_fn)
        
    for x, y in tqdm(loader):
        x = x.cuda()
        y = y.cuda()
        for name, model in models.items():
            model.cuda()
            logits, cache = model.run_with_cache(x)
            for layer in range(4):
                for timestep in range(20):
                    actv = cache[f'hks.adapt_{layer}_{timestep}']
                    actv_dict['Timestep'].append(timestep + 1)
                    actv_dict['Layer'].append(layer)
                    actv_dict['Mean'].append(float(actv.mean()))
                    actv_dict['num_active'].append(float((actv > 1e-4).sum()))
                    actv_dict['mean_not_null'].append(float(actv[actv > 1e-4].mean()))   
                    actv_dict['Norm'].append(float(actv.norm()))
                    actv_dict['Model'].append(name)
                    actv_dict['Contrast'].append(contrast)

actv_df = pd.DataFrame(actv_dict)
actv_df

In [None]:
actv_df['Normalized Activations'] = actv_df.groupby(['Model', 'Layer'], sort=False).apply(
    lambda df: df['Mean'] / df.loc[(df.Timestep == 1) & (df.Contrast==1.0), 'Mean'].mean()).reset_index(level=['Model', 'Layer'], drop=True)
sns.set_style('white')
sns.relplot(data=actv_df[(actv_df.Layer==1) & (actv_df.Model.isin(['Divisive Norm.', 'Additive']))], 
            x='Timestep', y='Normalized Activations', hue='Contrast', kind='line', row='Model',
            height=2.5, aspect=1.5, legend=False)

sns.despine(offset=5)

plt.savefig('figures/attn_012.svg')
plt.savefig('figures/attn_012.png')

## Two images

In [None]:
class TwoImageTemporalDataset(Dataset):
    def __init__(self, split, onset_img_2, dataset='fashion_mnist', transform=None, img_to_timesteps_transforms=None, contrast_1='random', contrast_2='random'):
        """
        Initializes the FashionMNISTNoisyDataset
        :param split: 'train' or 'test'
        :param dataset: name of the dataset
            must be one of the datasets in the huggingface datasets package
        :param transform: transforms to apply to the images
        :param img_to_timesteps_transforms: list of functions
            for every desired timestep, there should be a function in the list that converts
            the image to the desired format
        """
        self.split = split
        self.transform = transform
        self.dataset = load_dataset(dataset, split=split)
        input_col_name = 'img' if 'img' in self.dataset.column_names else 'image'  # because different datasets have different names
        self.data, self.targets = self.dataset[input_col_name], self.dataset['label']
        self.img_to_timesteps_transforms = img_to_timesteps_transforms
        self.contrast_1 = contrast_1
        self.contrast_2 = contrast_2
        self.onset_img_2 = onset_img_2

    def int_to_coordinate(self, index):
        if index == 0:
            return 0, 0
        elif index == 1:
            return 0, 28
        elif index == 2:
            return 28, 0
        elif index == 3:
            return 28, 28
        else:
            raise ValueError('index must be between 0 and 3')

    def adjust_contrast(self, img, contrast):
        mean = img.mean()
        img = (img - mean) * contrast  #+ mean
        return img

    def __getitem__(self, index):
        img, target = self.data[index], int(self.targets[index])
        img_timesteps = list()
        labels = list()

        prev_img = torch.zeros((1, 28 * 2, 28 * 2)) + 0.5
        for i, trans_func in enumerate(self.img_to_timesteps_transforms):
            if i == 0:
                if self.transform is not None:
                    img = self.transform(img)
                mask = img > 1e-2
                img = trans_func(img, index, target)

                if self.contrast_1 == 'random':
                    rand_contrast = np.random.uniform(0.1, 1)
                else:
                    rand_contrast = self.contrast_1
                img = self.adjust_contrast(img, rand_contrast)

                new_img = torch.zeros((1, 28 * 2, 28 * 2))
                x, y = 0, 0
                new_img[:, x:x + 28, y:y + 28] = img * mask

                prev_img = prev_img + new_img
                labels.append(
                    target
                )

                img_timesteps.append(prev_img)
            elif i == self.onset_img_2:
                index = int(np.random.choice(10000, 1))
                img, target = self.data[index], int(self.targets[index])
                if self.transform is not None:
                    img = self.transform(img)
                mask = img > 1e-2
                img = trans_func(img, index, target)

                if self.contrast_2 == 'random':
                    rand_contrast = np.random.uniform(0.1, 1)
                else:
                    rand_contrast = self.contrast_2
                img = self.adjust_contrast(img, rand_contrast)

                new_img = torch.zeros((1, 28 * 2, 28 * 2))
                x, y = 28, 28
                new_img[:, x:x + 28, y:y + 28] = img * mask

                prev_img = prev_img + new_img
                labels.append(
                    target
                )

                img_timesteps.append(prev_img)
            else:
                img_timesteps.append(prev_img)
                labels.append(labels[-1])

                # Stack the augmented images along the timestep dimension
        img_timesteps = torch.stack(img_timesteps, dim=0)
        labels = torch.tensor(labels)
        return img_timesteps, labels

    def __len__(self):
        return len(self.data)


import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.5,), (0.5,))
])
from models.noisy_dataloader import NoisyTemporalDataset
from utils.transforms import MeanFlat, RandomRepeatedNoise, Identity
from functools import partial

eye = Identity()

timestep_transforms = [eye] * 20
# Create instances of the Fashion MNIST dataset
two_image_dataset = TwoImageTemporalDataset('train', 5, transform=transform,
                                img_to_timesteps_transforms=timestep_transforms)
from torch.utils.data import DataLoader

loader = DataLoader(two_image_dataset, batch_size=64, shuffle=True, num_workers=3)

from utils.visualization import visualize_first_batch_with_timesteps

visualize_first_batch_with_timesteps(loader, 8)

In [None]:
acc_dict = {'Model': [], 'Timestep': [], 'Accuracy': [], 'Onset': []}
for onset in [2, 5, 15]:
    one_image_dataset = TwoImageTemporalDataset('test', onset_img_2=onset, transform=transform,
                                    img_to_timesteps_transforms=timestep_transforms, contrast_1=0.6, contrast_2='random')        
    loader = DataLoader(one_image_dataset, batch_size=64, shuffle=True, num_workers=30)
    
    for name, model in models.items():
        for x, y in tqdm(loader):
            logits = model(x)
            for i in range(20):
                l = logits[:, i, :]
                t = y[:, i]
                preds = torch.argmax(l, dim=1)
                acc = accuracy(preds, t, task='multiclass', num_classes=10)
                acc_dict['Timestep'].append(i)
                acc_dict['Accuracy'].append(float(acc))
                acc_dict['Onset'].append(onset)
                acc_dict['Model'].append(name)
df = pd.DataFrame(acc_dict)
df

In [None]:
one_image_dataset = TwoImageTemporalDataset('test', onset_img_2=4, transform=transform,
                                img_to_timesteps_transforms=timestep_transforms, contrast_1=.8, contrast_2=.8)        
loader = DataLoader(one_image_dataset, batch_size=10, shuffle=True, num_workers=0)

x, y = next(iter(loader))
x = x.cuda()
y = y.cuda()
_, cache = div_norm_model.run_with_cache(x)

In [None]:
idx = 3
map = 2
sample = x[idx].permute((0, 2, 3, 1)).cpu()
fig, axes = plt.subplots(ncols=8, nrows=2, figsize=(10, 3))
fig.subplots_adjust(wspace=0.03, hspace=0.01)

for i, ax in enumerate(axes[0]):
    ax.imshow(sample[i], cmap='gray')  # Set colormap to greyscale
    ax.set_xticks([]) 
    ax.set_yticks([]) 
    
for i, ax in enumerate(axes[1]):
    ax.imshow(cache[f'hks.adapt_0_{i}'][idx, map].cpu(), vmin=0, vmax=1)
    ax.set_xticks([]) 
    ax.set_yticks([]) 
    
axes[0, 0].set_ylabel('Input', fontsize=14)
axes[1, 0].set_ylabel('Feature Map', fontsize=14)
    
# Drawing a long arrow
arrow = FancyArrow(0.15, 0.1, 0.65, 0, width=0.01, color='black', transform=fig.transFigure, clip_on=False)
fig.add_artist(arrow)

plt.savefig("figures/attn_005.svg", format='svg')

In [None]:
one_image_dataset = TwoImageTemporalDataset('test', onset_img_2=2, transform=transform,
                                            img_to_timesteps_transforms=timestep_transforms, contrast_1=.8,
                                            contrast_2=.8)
loader = DataLoader(one_image_dataset, batch_size=10, shuffle=True, num_workers=0)

x, y = next(iter(loader))
x = x.cuda()
y = y.cuda()
_, cache = div_norm_model.run_with_cache(x)

sample = cache[f'hks.adapt_{0}_{2}'].cpu()
fig, ax = plt.subplots(figsize=(2, 2))
ax.imshow(sample[0, 9], cmap='gray')
ax.add_patch(plt.Rectangle((0, 0), 26, 26, fill=False, edgecolor='red', linestyle='--', linewidth=2))
ax.add_patch(plt.Rectangle((26, 26), 25, 25, fill=False, edgecolor='green', linestyle='--', linewidth=2))
ax.set_xticks([])
ax.set_yticks([])

plt.savefig('figures/attn_007.svg')
plt.savefig('figures/attn_007.png')

In [None]:
palette = sns.color_palette("mako", 3)
g = sns.relplot(data=df, x='Timestep', y='Accuracy', hue='Onset', col='Model', style='Onset', kind='line', height=2.5, palette=palette)
plt.subplots_adjust(wspace=0.2)
sns.despine(offset=5)

# Adding vertical dashed lines
unique_onsets = df['Onset'].unique()

# Loop through each subplot and add vertical lines
for model, ax in g.axes_dict.items():
    for onset, color in zip(unique_onsets, palette):
        ax.axvline(x=onset, color=color, linestyle='--', alpha=0.7)
plt.savefig('figures/attn_004.svg')
plt.savefig('figures/attn_004.png')

In [None]:
import pandas as pd

acc_dict = {'Contrast first image': [], 'Contrast second image': [], 'Timestep': [], 'Accuracy': [], 'Onset': [], 'Logits novel image': [], 'Model': [], 'Logits first image': []}
for onset in [2, 5, 15]:
    for contrast_1 in [0.2, 0.4, 0.6, 0.8, 1.0]:
        for contrast_2 in [0.2, 0.4, 0.6, 0.8, 1.0]:
            one_image_dataset = TwoImageTemporalDataset('train', onset_img_2=onset, transform=transform,
                                            img_to_timesteps_transforms=timestep_transforms, contrast_1=contrast_1, contrast_2=contrast_2)        
            loader = DataLoader(one_image_dataset, batch_size=64, shuffle=True, num_workers=10, worker_init_fn=worker_init_fn)
            
            j = 0
            for x, y in tqdm(loader):
                x = x.cuda()
                y = y.cuda()
                for name, model in models.items():
                    model.cuda()
                    logits = model(x)
                    for i in range(20):
                        l = logits[:, i, :]
                        t = y[:, i]
                        preds = torch.argmax(l, dim=1)
                        acc = accuracy(preds, t, task='multiclass', num_classes=10)
                        acc_dict['Contrast first image'].append(contrast_1)
                        acc_dict['Contrast second image'].append(contrast_2)
                        acc_dict['Timestep'].append(i)
                        acc_dict['Accuracy'].append(float(acc))
                        acc_dict['Onset'].append(onset)
                        # how much higher are logits of the novel image compared to the rest
                        acc_dict['Logits novel image'].append(float(l[:, y[:, onset]].mean() - l[:, ~y[:, onset]].mean()))
                        acc_dict['Logits first image'].append(float(l[:, y[:, 0]].mean() - l[:, ~y[:, 0]].mean()))
                        acc_dict['Model'].append(name)             
                        
                j += 1
                if j > 200:
                    break
df = pd.DataFrame(acc_dict)
df

# Receptive Fields

In [None]:
actv_df = pd.DataFrame({'Model': [], 'Timestep': [], 'Layer': [], 'Mean': [], 'Norm': [], 'num_active': [], 'mean_not_null': [], 'Onset': [], 'Region': [], 'Contrast': []})
for onset in [4]:  # [2, 5, 15]:
    for contrast in [0.2, 0.4, 0.6, 0.8, 1.0]:
        two_image_dataset = TwoImageTemporalDataset('test', onset_img_2=onset, transform=transform,
                                        img_to_timesteps_transforms=timestep_transforms, contrast_1=0.6, contrast_2=contrast)        
        loader = DataLoader(two_image_dataset, batch_size=100, shuffle=False, num_workers=10, worker_init_fn=worker_init_fn)
        j = 0
        for x, y in tqdm(loader):
            x = x.cuda()
            y = y.cuda()
            for name, model in models.items():
                model.cuda()
                logits, cache = model.run_with_cache(x)
                for layer in range(4):
                    for timestep in range(20):
                        actv = cache[f'hks.adapt_{layer}_{timestep}'] # batch channel width height
                        width = actv.shape[-1]
                        upper_left = actv[..., :width // 2, :width // 2]
                        lower_right = actv[..., width // 2:, width // 2:]
                        for region in ['ul', 'lr', 'both']:
                            if region == 'ul':
                                cur_actv = upper_left
                            elif region == 'lr':
                                cur_actv = lower_right
                            elif region == 'both':
                                cur_actv = actv
                            else:
                                raise ValueError('region can only be ul, lr, or both')
                                
                            actv_df.loc[len(actv_df)] = [
                                name, timestep, layer, float(cur_actv.mean()), float(cur_actv.norm()),
                                float((cur_actv > 1e-4).sum()), float(cur_actv[cur_actv > 1e-4].mean()),
                                onset, region, contrast
                            ]
            j += 1
            if j > 40:
                break
actv_df

In [None]:
palette = ['red', 'green']
sns.relplot(data=actv_df[(actv_df.Layer==1) & (actv_df.Region.isin(['ul', 'lr']))], x='Timestep', y='Mean', hue='Region', style='Onset', col='Model', kind='line', height=2.5, facet_kws={'sharey': False}, palette=palette)
sns.despine(offset=5)

plt.savefig('figures/attn_008.svg')
plt.savefig('figures/attn_008.png')

In [None]:
actv_dict = {'contrast': [], 'timestep': [], 'layer': [], 'mean': [], 'num_active': [], 'mean_not_null': [], 'onset': [], 'region': [], 'model': []}
for onset in [4]:
    for contrast in [0.2, 0.4, 0.6, 0.8, 1.0]:
        two_image_dataset = TwoImageTemporalDataset('test', onset_img_2=onset, transform=transform,
                                        img_to_timesteps_transforms=timestep_transforms, contrast_1=0.6, contrast_2=contrast)        
        loader = DataLoader(two_image_dataset, batch_size=50, shuffle=False, num_workers=10, worker_init_fn=worker_init_fn)
        
        j = 0
        for x, y in tqdm(loader):
            x = x.cuda()
            y = y.cuda()
            for name, model in models.items():
                model.cuda()
                logits, cache = model.run_with_cache(x)
                for layer in range(4):
                    for timestep in range(20):
                        actv = cache[f'hks.adapt_{layer}_{timestep}'] # batch channel width height
                        width = actv.shape[-1]
                        upper_left = actv[..., :width // 2, :width // 2]
                        lower_right = actv[..., width // 2:, width // 2:]
                        for region in ['ul', 'lr', 'both']:
                            actv_dict['contrast'].append(contrast)
                            actv_dict['timestep'].append(timestep)
                            actv_dict['layer'].append(layer)
                            if region == 'ul':
                                cur_actv = upper_left
                            elif region == 'lr':
                                cur_actv = lower_right
                            elif region == 'both':
                                cur_actv = actv
                            else:
                                raise ValueError('region can only be ul, lr, or both')
                                
                            actv_dict['mean'].append(float(cur_actv.mean().cpu()))
                            actv_dict['num_active'].append(float((cur_actv > 1e-4).sum().cpu()))
                            actv_dict['mean_not_null'].append(float(cur_actv[cur_actv > 1e-4].mean().cpu()))      
                            actv_dict['onset'].append(onset)
                            actv_dict['region'].append(region)
                            actv_dict['model'].append(name)
                    
                    
            j += 1
            if j > 100:
                break
actv_df = pd.DataFrame(actv_dict)
actv_df


In [None]:
actv_df['mean'] = [float(val) for val in actv_df['mean']]
actv_df['num_active'] = [float(val) for val in actv_df['num_active']]
actv_df['mean_not_null'] = [float(val) for val in actv_df['mean_not_null']]

In [None]:
sns.relplot(data=actv_df[actv_df.region=='both'], x='timestep', y='mean', hue='contrast', kind='line', col='layer', row='onset')

In [None]:
actv_df = actv_df.rename(columns={'timestep': 'Timestep', 'mean': 'Mean', 'num_active': 'Number of Active Units', 'mean_not_null': 'Mean of Active Units', 'onset': 'Onset', 'region': 'Region', 'contrast': 'Contrast', 'model': 'Model', 'layer': 'Layer'})
actv_df['Timestep'] += 1
actv_df['Normalized Activations'] = actv_df.groupby(['Model', 'Layer'], sort=False).apply(
    lambda df: df['Mean'] / df.loc[(df.Timestep == 1) & (df.Region=='ul'), 'Mean'].mean()).reset_index(level=['Model', 'Layer'], drop=True)

In [None]:
sns.set_style('white')
sns.relplot(data=actv_df[actv_df.Region.isin(['ul', 'lr']) & (actv_df.Layer==1) & (actv_df.Model.isin(['Divisive Norm.', 'Additive']))], 
            x='Timestep', y='Normalized Activations', hue='Region', palette=['red', 'green'], kind='line', row='Model', style='Contrast', facet_kws={'sharey': False}, height=2.5, aspect=1.5)
sns.despine(offset=5)


# move legend up a little bit
plt.legend(bbox_to_anchor=(0.5, 1.1), loc='upper center', ncol=3, bbox_transform=plt.gcf().transFigure)


plt.savefig('figures/attn_013.svg')
plt.savefig('figures/attn_013.png')

In [None]:
os.mkdir('trained_models')
actv_df.to_csv('trained_models/video_activation_data.csv')

## Causal Experiments

In [None]:
from tqdm import tqdm
# actv_dict = {'contrast': [], 'timestep': [], 'layer': [], 'mean': [], 'num_active': [], 'mean_not_null': [], 'onset': [], 'region': []}
res = {'Intervention': [], 'Accuracy': [], 'Image': [], 'Mean': [], 'Model': []}
def hook_fn(actv, hook):
    # batch map width height
    # ratio = actv[..., 26:, 26:].mean() / actv[..., :26, :26].mean()  # big / small
    mid = actv.shape[-1] // 2
    ratio = actv[..., mid:, mid:].mean(dim=(-1, -2), keepdim=True)/ (actv[..., :mid, :mid].mean(dim=(-1, -2), keepdim=True) + 1e-5)
    actv[..., :mid, :mid] *= ratio
    actv[..., mid:, mid:] /= ratio + 1e-5
    return actv

def acc(logits, target, timestep):
    logits = logits[:, timestep, :]
    preds = torch.argmax(logits, dim=1)
    acc = accuracy(preds, target, task='multiclass', num_classes=10)
    return acc
        
    
onset = 2
# for contrast in [0.2, 0.4, 0.6, 0.8, 1.0]:
contrast = 0.6
two_image_dataset = TwoImageTemporalDataset('train', onset_img_2=onset, transform=transform,
                                img_to_timesteps_transforms=timestep_transforms, contrast_1=0.6, contrast_2=contrast)        
loader = DataLoader(two_image_dataset, batch_size=5, shuffle=True, num_workers=0)

j = 0
for x, y in tqdm(loader):
    x = x.cuda()
    y = y.cuda()
    for name, model in models.items():
        model.cuda()
        logits_clean, cache_clean = model.run_with_cache(x)
        res['Intervention'].append(False)
        res['Accuracy'].append(float(acc(logits_clean, y[:, onset], onset)))
        res['Image'].append('second')
        res['Mean'].append(float(cache_clean[f'hks.adapt_0_{onset}'][..., 26:, 26:].mean()))
        
        res['Intervention'].append(False)
        res['Accuracy'].append(float(acc(logits_clean, y[:, 0], onset)))
        res['Image'].append('first')
        res['Mean'].append(float(cache_clean[f'hks.adapt_0_{onset}'][..., :26, :26].mean()))
        
        with model.hooks([(f'hks.adapt_0_{onset}', hook_fn)]):
            logits_swapped, cache_swapped = model.run_with_cache(x)
            res['Intervention'].append(True)
            res['Accuracy'].append(float(acc(logits_swapped, y[:, onset], onset)))
            res['Image'].append('second')
            res['Mean'].append(float(cache_swapped[f'hks.adapt_0_{onset}'][..., 26:, 26:].mean()))
            
            res['Intervention'].append(True)
            res['Accuracy'].append(float(acc(logits_swapped, y[:, 0], onset)))
            res['Image'].append('first')          
            res['Mean'].append(float(cache_swapped[f'hks.adapt_0_{onset}'][..., :26, :26].mean()))
            
        res['Model'].append(name)
        res['Model'].append(name)
        res['Model'].append(name)
        res['Model'].append(name)
                
    j += 1
    if j >= 200:
        break
causal_df = pd.DataFrame(res)
causal_df


In [None]:
onset = 5
two_image_dataset = TwoImageTemporalDataset('train', onset_img_2=onset, transform=transform,
                                img_to_timesteps_transforms=timestep_transforms, contrast_1=0.6, contrast_2=0.6)        
loader = DataLoader(two_image_dataset, batch_size=5, shuffle=True, num_workers=0)

j = 0
x, y = next(iter(loader))
x = x.cuda()
y = y.cuda()
logits_clean, cache_clean = div_norm_model.run_with_cache(x)
        
with div_norm_model.hooks([(f'hks.adapt_0_{onset}', hook_fn)]):
    logits_swapped, cache_swapped = div_norm_model.run_with_cache(x)

In [None]:
idx = 2
map = 14
onset = 5
clean = cache_clean[f'hks.adapt_0_{onset}'][idx, map].cpu()
corr = cache_swapped[f'hks.adapt_0_{onset}'][idx, map].cpu()

fig, ax = plt.subplots(ncols=2)

ax[0].imshow(clean, vmin=0, vmax=.7)
ax[0].set_xticks([])
ax[0].set_yticks([])

ax[1].imshow(corr, vmin=0, vmax=.7)
ax[1].set_xticks([])
ax[1].set_yticks([])

ax[0].annotate('', xy=(1.215, 0.5), xycoords='axes fraction', xytext=(1, 0.5),
               arrowprops=dict(arrowstyle="->", lw=3.5))

plt.savefig('figures/attn_009.svg')
plt.savefig('figures/attn_009.png')

In [None]:
sns.catplot(data=causal_df, x='Image', y='Accuracy', hue='Intervention', order=['first', 'second'], kind='bar', col='Model', col_order=['Divisive Norm.', 'Additive'], height=3, palette=['grey', 'red'])
#sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
#plt.tight_layout()
sns.despine(offset=10)

plt.savefig('figures/attn_10.svg')
plt.savefig('figures/attn_10.png')

In [None]:
sns.barplot(data=causal_df, x='Image', y='Mean', hue='Intervention')
