In [None]:
from tcav import TCAV
from train import train_model, eval_model
from datasets import get_datasets_and_data_loaders, create_copy_of_train_set
from utils import imshow, visualize_model, get_reduced_activation_space_points, scatter_plot_classes, get_sensitive_filenames

import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
import pickle

### (1) Train a CNN classification model on top of the provided images and labels, achieving high predictive performance on the train set 

**Set training parameters and get the dataloaders**

In [None]:
# Set training parameters
LEARNING_RATE = 0.00001
BATCH_SIZE = 4
NUM_EPOCHS = 10

# Set data directory
data_dir = './data'

# mean and std values the model's original training dataset
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# Get data loaders for the data
image_datasets, dataloaders, dataset_sizes, class_names, single_sample_train_dataloader = get_datasets_and_data_loaders(data_dir, mean, std, BATCH_SIZE)


In [None]:
print('Dataset Classes: ', class_names)

**Display some of the images and labels**

In [None]:
# Get a batch of the training data
inputs, classes = next(iter(dataloaders['train']))

# Display some images in the trainset
out = torchvision.utils.make_grid(inputs)
imshow(out, mean, std, title=[class_names[x] for x in classes])

**Load a pre-trained model and define the loss criteria and optimizer; finetune on our dataset**

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load a pretrained resnet152 model
classification_model = torchvision.models.resnet152(weights='IMAGENET1K_V1')

# Replace the last fully connected layer with a layer that will predict scores for our two classes.
num_ftrs = classification_model.fc.in_features
classification_model.fc = nn.Linear(num_ftrs, 2)

# Move the model to the gpu
classification_model = classification_model.to(device)

# Define binary cross entropy as our loss criteria
criterion = nn.CrossEntropyLoss()

# Define an SGD optimizer for model parameters
optimizer = optim.SGD(classification_model.parameters(), lr=LEARNING_RATE, momentum=0.9)


In [None]:
# Finetune the model on our dataset
classification_model = train_model(classification_model, criterion, optimizer, dataloaders, dataset_sizes, data_dir, num_epochs=NUM_EPOCHS, return_best_test_set_model=False)

**Visualize some of the predictions**

In [None]:
# Visualize some of the model predictions on the train set
visualize_model(classification_model, dataloaders, class_names, mean, std, num_images=5)

**Print the named internal modules of our classification-model**

In [None]:
# Print the modules of the resnet152 network
print('Classification Network Modules: ')
print('------------------------------------------')
print(classification_model.modules)

### (2) Visualise the model’s hidden activation space

**Extract and visualize activation-space output from intermediate layers of the network on a portion of the train set**

In [None]:
# Get 2D dimensionally reduced activation-space encoded vectors for named layers 1,2 and 3 of the resnet152 network for a portion of the train split
layers_out = ['layer1','layer2','layer3']
activations, activation_labels = get_reduced_activation_space_points(classification_model, layers_out, dataloaders['train'], num_samples = 600)

In [None]:
# Create a pca-dimensionally reduced scatter plot of train datapoints encoded to the level of layers 1, 2, and 3
for layer in layers_out:
    print(layer,'output')
    print('-----------------------------')
    scatter_plot_classes(activations[layer], activation_labels, [0, 1], {0:'r', 1:'g'}, ['classA', 'classB'])

**Load Images of the “concept”** 

In [None]:
# Load examples of the concept images
with open(data_dir+'/concept_imgs.pkl','rb') as f:
    concept_examples = pickle.load(f)

print('Number of concept examples: ', len(concept_examples))

# Display images
out = torchvision.utils.make_grid(concept_examples)
imshow(out, mean, std)

**Set TCAV parameters**

In [None]:
# Specify tcav parameters:
layer_out = 'layer2' # Use outputs from layer2 of the resnet model for our intermediate activation-space
class_index = 0      # We plan to test the sensitivity the model for class 0 (enemies) detection

# **NEGATIVE SAMPLES APPROACH:** 
Use negative examples that are uniformly randomly sampled from the train set (that may randomly include some Arnie samples)

In [None]:
# Create a new tcav object using uniform random negative samples (that may include some concept samples)
tcav = TCAV(classification_model, concept_examples, layer_out, class_index, random_dataset=image_datasets['train'])

### (4) Train TCAV to recognise the particular concept 

### (5) Visualise the CAV vector in the model’s hidden space.

In [None]:
# Compute and plot a cav vector (classifying between random samples and our concept images in the intermediate activation-space)
# Plotted at lower dimensionality
cav = tcav.train_and_get_CAV(plot=True)

print('Shape of the cav vector computed: ',cav.shape)

### (6) Use the trained TCAV model to find images in the train set for which the concept is high. 

In [None]:
# Iterate through every sample in the train set and find the filenames for images classified to be the particular concept by high tcav sensitivity (above the given threshold)
concept_images_filenames = get_sensitive_filenames(tcav, single_sample_train_dataloader, threshold_sensitivity=0.005)