# LRP Fine tuning
The objective of this notebook is to demonstrate that one can optimise a network post training. 
We demonstrate here that by incorporating attention information, we can improve the quality of models. 
This has several steps: 
- Download the Imagenette dataset (we use this rather than the full imagenet as it is considerably more lightweight)
- Download two pre-trained models on Imagenet (from the VGG family)
- Evaluate the quality of the models on the dataset. 
- Use the better performing model as the "teacher model" 
- Attempt to improve the learner model using the teacher's heatmaps. 
- Periodically evaluate performance
 

In [None]:
import sys
sys.path.append('..')


In [None]:
import torch
import torchvision.models as models
import numpy as np
import pandas as pd  
from experiments import process_batch, WrapperNet, perform_lrp_plain
from internal_utils import get_data_imagenette, preprocess_images
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TRUNCATE = 3

def imagenette_to_imagenet_label_mapping(imagenette_labels):
    mapping = {
        0: 0,    # tench
        1: 217,  # English Springer
        2: 482,  # Cassette Player
        3: 491,  # Chain Saw
        4: 497,  # Church
        5: 566,  # French Horn
        6: 569,  # Garbage Truck
        7: 571,  # Gas Pump
        8: 574,  # Golf Ball
        9: 701   # Parachute
    }
    
    # Assuming imagenette_labels is a list
    output = [mapping[label.item()] for label in imagenette_labels]
    return output

def imagenette_to_imagenet_label_mapping_fast(imagenette_labels):
    # a vectorised version of the mapping above
    mapping = torch.tensor([0, 217, 482, 491, 497, 566, 569, 571, 574, 701])
    
    # Use the imagenette_labels as indices to map to the corresponding imagenet labels
    output = mapping[imagenette_labels]
    return output

def get_vgg16():
    """Download a pretrained vgg16 on Imagenet"""
    vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
    vgg16.eval()
    return vgg16
    
def get_vgg19():
    """ Download a pretrained vgg19 on Imagenet"""
    vgg19 = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)
    vgg19.eval() 
    return vgg19
    
    
def evaluate_performance(model, train_data, test_data):
    """Evaluate the performance of the model on the data"""
    train_loss, train_accuracy = [], []
    test_loss, test_accuracy = [], []
    model.to(device)
    # for i, (input, label) in enumerate(train_data):
    for i in tqdm(range(0, TRUNCATE)):
        input, label = next(iter(train_data))
        input = preprocess_images(input).to(device)
        label = imagenette_to_imagenet_label_mapping_fast(label).to(device)
        with torch.no_grad():
            output = model(input)
        loss = torch.nn.functional.cross_entropy(output, label)
        accuracy = (output.argmax(dim=1) == label).float().mean() * 100
        train_loss.append(loss)
        train_accuracy.append(accuracy)
    # for i, (input, label) in enumerate(test_data):
    for i in tqdm(range(0, TRUNCATE)):
        input, label = next(iter(test_data))
        input = preprocess_images(input).to(device)
        label = imagenette_to_imagenet_label_mapping_fast(label).to(device)
        with torch.no_grad():
            output = model(input)
        loss = torch.nn.functional.cross_entropy(output, label)
        accuracy = (output.argmax(dim=1) == label).float().mean() * 100
        test_loss.append(loss)
        test_accuracy.append(accuracy)
    # convert results to dictionary
    results = {
        "train_loss": np.array(train_loss),
        "train_accuracy": np.array(train_accuracy),
        "test_loss": np.array(test_loss),
        "test_accuracy": np.array(test_accuracy)
    }
    return pd.DataFrame.from_dict(results)   

def process_dataset(data_loader, methods, kernel_size_min, kernel_size_max, noise_level_min, noise_level_max):
    """Generate results over a single dataset

    Args:
        data_loader (DataLoader object): Dataloader object without heavy preprocessing steps
        methods (List of Tuples): list of methods(functions) to be used on each datapoint of form (name, method, model)
        kernel_size_min (int): For Gaussian Blur, minimum kernel size
        kernel_size_max (int): For Gaussian Blur, maximum kernel size
        noise_level_min (float): For adding noise, minimum noise level
        noise_level_max (float): For adding noise, maximum noise level

    Returns:
        dict: dictionary of results with keys as method names and values as torch tensors of results
    """
    table = {}
    # for i, (input_batch, input_labels) in enumerate(data_loader):
    for i in range(0, TRUNCATE):
        input_batch, input_labels = next(iter(data_loader))
        input_batch.to(device)
        input_labels = imagenette_to_imagenet_label_mapping_fast(input_labels).to(device)    
        results = process_batch(
            input_batch, 
            input_labels, 
            methods, 
            kernel_size_min, 
            kernel_size_max, 
            noise_level_min, 
            noise_level_max
        )
        for key, value in results.items():
            # deal with empty tensors
            if value == None:
                    value = torch.empty(0)
            # add to table
            if key not in table.keys():
                table[key] = value.detach()
            else:
                table[key] = torch.cat([table[key], value.detach()], dim = 0)
        print(f"Processed batch {i+1}/{len(data_loader)}")
    return table    

def evaluate_explanations(train_data, test_data, save_results = False):
    """Evaluate the explanations of the model on the data"""
    
    # define params
    kernel_size_min = 3
    kernel_size_max = 5
    noise_level_min = 0.1
    noise_level_max = 0.2
    # get the data
    
    # get the model
    vgg16 = get_vgg16()
    vgg19 = get_vgg19()
    # define the methods
    methods = [
        ("VGG16", perform_lrp_plain, WrapperNet(vgg16, hybrid_loss=True)),
        ("VGG19", perform_lrp_plain, WrapperNet(vgg19, hybrid_loss=True))
    ]
    train_table = process_dataset(train_data, methods, kernel_size_min, kernel_size_max, noise_level_min, noise_level_max)
    test_table = process_dataset(test_data, methods, kernel_size_min, kernel_size_max, noise_level_min, noise_level_max)
    # convert to pandas dataframe
    df_train = pd.DataFrame(train_table)
    df_test = pd.DataFrame(test_table)
    
    # save results
    if save_results:
        df_train.to_csv("train_results.csv")
        df_test.to_csv("test_results.csv")
    return df_train, df_test



In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

train_data, test_data = get_data_imagenette()
vgg16_results = evaluate_performance(get_vgg16(), train_data, test_data)
vgg19_results = evaluate_performance(get_vgg19(), train_data, test_data)

# visualise results from the initial dataframes
vgg16_results['model'] = 'VGG16'
vgg19_results['model'] = 'VGG19'

# Concatenate dataframes
combined_df = pd.concat([vgg16_results, vgg19_results])

# Melt the dataframe to long format for seaborn
melted_df = combined_df.melt(id_vars=['model'], var_name='metric', value_name='value')

# Create the boxplot
plt.figure(figsize=(12, 8))
sns.boxplot(x='metric', y='value', hue='model', data=melted_df)
plt.title('Model Performance Comparison')
plt.xlabel('Metric')
plt.ylabel('Value')
plt.legend(title='Model')
plt.show()

# now evaluate explanations
df_train, df_test = evaluate_explanations(train_data, test_data, save_results = False)
