# Tutorial on CAVs


### Extracting and saving activations from ResNet18/5% of CIFAR10

In [1]:
from torch.utils.data import DataLoader
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
import numpy as np
from pprint import pprint

from tcav.extract_activations import extract_activations

# Define your model
model = torchvision.models.resnet18(pretrained=True)

# Define your dataset and dataloader
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
dataset = datasets.CIFAR10(root='./example_data', train=False, download=True, transform=transform)

dataset = torch.utils.data.Subset(dataset, np.random.choice(len(dataset), int(len(dataset) * 0.1)))
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)

# Define the experiment name
experiment_name = 'resnet18_cifar10'

extract_activations(
    model=model,
    dataloader=dataloader,
    experiment_name=experiment_name,
    layers=None,  # None to extract from all layers
    device='cuda' if torch.cuda.is_available() else 'cpu',
)

print('Done!')



Files already downloaded and verified
Loaded activations from './activations/resnet18_cifar10.npz'
Done!


### Loading and interacting with activations

Treat loaded file as a dict. Each key is a layer name, and each value is a numpy array of shape (num_samples, num_channels, height, width) for conv layers or (num_samples, num_features) for dense layers.
The first entry will be `labels`, which is a numpy array of shape (num_samples, no_features) containing the labels of the samples and concept values.

In [2]:
activations = np.load('activations/resnet18_cifar10.npz')

In [3]:
activations['labels'].shape

(500, 1)

In [4]:
layer_names = activations.files
pprint(layer_names)

['labels',
 'conv1',
 'bn1',
 'relu',
 'maxpool',
 'layer1.0.conv1',
 'layer1.0.bn1',
 'layer1.0.relu',
 'layer1.0.conv2',
 'layer1.0.bn2',
 'layer1.0',
 'layer1.1.conv1',
 'layer1.1.bn1',
 'layer1.1.relu',
 'layer1.1.conv2',
 'layer1.1.bn2',
 'layer1.1',
 'layer1',
 'layer2.0.conv1',
 'layer2.0.bn1',
 'layer2.0.relu',
 'layer2.0.conv2',
 'layer2.0.bn2',
 'layer2.0.downsample.0',
 'layer2.0.downsample.1',
 'layer2.0.downsample',
 'layer2.0',
 'layer2.1.conv1',
 'layer2.1.bn1',
 'layer2.1.relu',
 'layer2.1.conv2',
 'layer2.1.bn2',
 'layer2.1',
 'layer2',
 'layer3.0.conv1',
 'layer3.0.bn1',
 'layer3.0.relu',
 'layer3.0.conv2',
 'layer3.0.bn2',
 'layer3.0.downsample.0',
 'layer3.0.downsample.1',
 'layer3.0.downsample',
 'layer3.0',
 'layer3.1.conv1',
 'layer3.1.bn1',
 'layer3.1.relu',
 'layer3.1.conv2',
 'layer3.1.bn2',
 'layer3.1',
 'layer3',
 'layer4.0.conv1',
 'layer4.0.bn1',
 'layer4.0.relu',
 'layer4.0.conv2',
 'layer4.0.bn2',
 'layer4.0.downsample.0',
 'layer4.0.downsample.1',
 'lay

In [5]:
for item in layer_names:
    pprint(activations[item].shape)

(500, 1)
(500, 64, 112, 112)
(500, 64, 112, 112)
(500, 64, 112, 112)
(500, 64, 56, 56)
(500, 64, 56, 56)
(500, 64, 56, 56)
(1000, 64, 56, 56)
(500, 64, 56, 56)
(500, 64, 56, 56)
(500, 64, 56, 56)
(500, 64, 56, 56)
(500, 64, 56, 56)
(1000, 64, 56, 56)
(500, 64, 56, 56)
(500, 64, 56, 56)
(500, 64, 56, 56)
(500, 64, 56, 56)
(500, 128, 28, 28)
(500, 128, 28, 28)
(1000, 128, 28, 28)
(500, 128, 28, 28)
(500, 128, 28, 28)
(500, 128, 28, 28)
(500, 128, 28, 28)
(500, 128, 28, 28)
(500, 128, 28, 28)
(500, 128, 28, 28)
(500, 128, 28, 28)
(1000, 128, 28, 28)
(500, 128, 28, 28)
(500, 128, 28, 28)
(500, 128, 28, 28)
(500, 128, 28, 28)
(500, 256, 14, 14)
(500, 256, 14, 14)
(1000, 256, 14, 14)
(500, 256, 14, 14)
(500, 256, 14, 14)
(500, 256, 14, 14)
(500, 256, 14, 14)
(500, 256, 14, 14)
(500, 256, 14, 14)
(500, 256, 14, 14)
(500, 256, 14, 14)
(1000, 256, 14, 14)
(500, 256, 14, 14)
(500, 256, 14, 14)
(500, 256, 14, 14)
(500, 256, 14, 14)
(500, 512, 7, 7)
(500, 512, 7, 7)
(1000, 512, 7, 7)
(500, 512, 7,

## Primer on computing CAVs

In [6]:
from tcav.cav import compute_cav

In [7]:
labels_array = activations['labels'].flatten() # for simplicity, we flatten the labels as we will work on cifar labels and try to get CAVs concept of "5"
labels_array = np.where(labels_array == 5, 1, 0)
layer41_array = activations['layer4.1'].reshape(activations['layer4.1'].shape[0], -1) # we will use layer4.1 activations to compute CAVs
print(f"labels_array.shape: {labels_array.shape}, layer41_array.shape: {layer41_array.shape}")

labels_array.shape: (500,), layer41_array.shape: (500, 25088)


In [8]:
cav = compute_cav(layer41_array, labels_array, cav_type='ridge').reshape(-1)
print(cav.shape)

  dual_coef = linalg.solve(K, y, assume_a="pos", overwrite_a=False)
  dual_coef = linalg.solve(K, y, assume_a="pos", overwrite_a=False)
  dual_coef = linalg.solve(K, y, assume_a="pos", overwrite_a=False)
  dual_coef = linalg.solve(K, y, assume_a="pos", overwrite_a=False)
  dual_coef = linalg.solve(K, y, assume_a="pos", overwrite_a=False)
  dual_coef = linalg.solve(K, y, assume_a="pos", overwrite_a=False)
  dual_coef = linalg.solve(K, y, assume_a="pos", overwrite_a=False)
  dual_coef = linalg.solve(K, y, assume_a="pos", overwrite_a=False)


Best alpha: 1000
CAV type:  ridge
largest CAV values: torch.return_types.topk(
values=tensor([0.0391, 0.0383, 0.0362, 0.0350, 0.0346, 0.0340, 0.0330, 0.0325, 0.0324,
        0.0320]),
indices=tensor([20357, 20358, 17533, 22284, 20023, 20024,  7374, 20350, 14185, 20017]))
torch.Size([25088])


In [9]:
!pip install torchviz

Collecting torchviz
  Using cached torchviz-0.0.2-py3-none-any.whl
Collecting graphviz (from torchviz)
  Using cached graphviz-0.20.3-py3-none-any.whl.metadata (12 kB)
Using cached graphviz-0.20.3-py3-none-any.whl (47 kB)
Installing collected packages: graphviz, torchviz
Successfully installed graphviz-0.20.3 torchviz-0.0.2


In [10]:
from tcav.tcav import get_tcav_scores
tcav_scores = [get_tcav_scores(model, cav, dataloader, "layer4.1", cls, device='cpu') for cls in [0, 1, 5]]

Computing TCAV Scores:  10%|█         | 100/1000 [00:43<06:27,  2.32it/s]
Computing TCAV Scores:  10%|█         | 100/1000 [00:40<06:07,  2.45it/s]
Computing TCAV Scores:  10%|█         | 100/1000 [00:41<06:14,  2.40it/s]


In [11]:
tcav_scores

[(0.2, 20, 100), (0.05, 5, 100), (0.02, 2, 100)]