# Tutorial on extracting activations

### Extracting and saving activations from ResNet18/5% of CIFAR10

In [None]:
from torch.utils.data import DataLoader
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
import numpy as np
from pprint import pprint

from tcav.extract_activations import extract_activations

# Define your model
model = torchvision.models.resnet18(pretrained=True)

# Define your dataset and dataloader
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
dataset = datasets.CIFAR10(root='./example_data', train=False, download=True, transform=transform)

dataset = torch.utils.data.Subset(dataset, np.random.choice(len(dataset), int(len(dataset) * 0.05)))
dataloader = DataLoader(dataset, batch_size=12, shuffle=False, num_workers=4)

# Define the experiment name
experiment_name = 'resnet18_cifar10'

extract_activations(
    model=model,
    dataloader=dataloader,
    experiment_name=experiment_name,
    layers=None,  # None to extract from all layers
    device='cuda' if torch.cuda.is_available() else 'cpu',
)

### Loading and interacting with activations

In [2]:
activations = np.load('activations/resnet18_cifar10.npz')
layer_names = activations.files
pprint(layer_names)
for item in layer_names:
    pprint(activations[item].shape)

['conv1',
 'bn1',
 'relu',
 'maxpool',
 'layer1.0.conv1',
 'layer1.0.bn1',
 'layer1.0.relu',
 'layer1.0.conv2',
 'layer1.0.bn2',
 'layer1.0',
 'layer1.1.conv1',
 'layer1.1.bn1',
 'layer1.1.relu',
 'layer1.1.conv2',
 'layer1.1.bn2',
 'layer1.1',
 'layer1',
 'layer2.0.conv1',
 'layer2.0.bn1',
 'layer2.0.relu',
 'layer2.0.conv2',
 'layer2.0.bn2',
 'layer2.0.downsample.0',
 'layer2.0.downsample.1',
 'layer2.0.downsample',
 'layer2.0',
 'layer2.1.conv1',
 'layer2.1.bn1',
 'layer2.1.relu',
 'layer2.1.conv2',
 'layer2.1.bn2',
 'layer2.1',
 'layer2',
 'layer3.0.conv1',
 'layer3.0.bn1',
 'layer3.0.relu',
 'layer3.0.conv2',
 'layer3.0.bn2',
 'layer3.0.downsample.0',
 'layer3.0.downsample.1',
 'layer3.0.downsample',
 'layer3.0',
 'layer3.1.conv1',
 'layer3.1.bn1',
 'layer3.1.relu',
 'layer3.1.conv2',
 'layer3.1.bn2',
 'layer3.1',
 'layer3',
 'layer4.0.conv1',
 'layer4.0.bn1',
 'layer4.0.relu',
 'layer4.0.conv2',
 'layer4.0.bn2',
 'layer4.0.downsample.0',
 'layer4.0.downsample.1',
 'layer4.0.downs