In [None]:
import shutup; shutup.please() # disable warnings
import numpy as np
import torch
import matplotlib.pyplot as plt

from cifar10_utils import load_data, im_show, Net, train_model, load_saved_model, test_model

model = None

# check device
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using CUDA: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Metal Performance Shaders (MPS)")
else:
    device = torch.device("cpu")
    print("Using CPU")

## Collect data for the Visualizations using a CNN trained on Cifar10

### Load and check out the dataset

In [None]:
# load data
train_loader, valid_loader, test_loader = load_data(batch_size=20)

# specify the image classes
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
# obtain one batch of training images
dataiter = iter(train_loader)
images, labels = next(dataiter)
images = images.numpy() # convert images to numpy for display

# plot the images in the batch, along with the corresponding labels
fig = plt.figure(figsize=(25, 4))
# display 20 images
for idx in np.arange(20):
    ax = fig.add_subplot(2, int(20/2), idx+1, xticks=[], yticks=[])
    im_show(images[idx])
    ax.set_title(classes[labels[idx]])

### Train a CNN on the Cifar10 dataset

In [None]:
# initialize and train model
model = Net()
train_losses = train_model(model, train_loader, valid_loader, device, n_epochs=120, lr=0.01)

### Evaluate the trained model on the test set

In [None]:
# test model
if not model: model = load_saved_model(device=device)
test_model(model, test_loader, device)

### Evaluate the trained model on the test set while collecting activations

We enable save_act and get a list containing the layer-wise activations per batch. <br>
Maybe not all of these activations will be needed. <br>
Look at self.conv_layer and self.fc_layer of Net in cifar10_utils.py

In [None]:
# test model and collect activations
if not model: model = load_saved_model(device=device)
activations = test_model(model, test_loader, device, True)

In [None]:
activations[0].keys()