# LRP Fine tuning
The objective of this notebook is to demonstrate that one can optimise a network post training. 
We demonstrate here that by incorporating attention information, we can improve the quality of models. 
This has several steps: 
- Download the Imagenette dataset (we use this rather than the full imagenet as it is considerably more lightweight)
- Download two pre-trained models on Imagenet (from the VGG family)
- Evaluate the quality of the models on the dataset. 
- Use the better performing model as the "teacher model" 
- Attempt to improve the learner model using the teacher's heatmaps. 
- Periodically evaluate performance
 

## Helper functions
The following functions contain the bulk of the specific helper functions necessary to evaluate the performance and explanation of VGG16 vs VGG19. 
Generic helper functions have been moved to the utils folder for use in other notebooks. 

In [None]:
import sys
sys.path.append('..')
import torch
import torchvision.models as models
import numpy as np
import pandas as pd  
from experiments import WrapperNet, perform_lrp_plain, evaluate_performance, evaluate_explanations
# evaluate_performance, process_dataset, evaluate_explanations
from internal_utils import get_data_imagenette, get_data_imagenette, get_vgg16, get_vgg19
from tqdm import tqdm
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TRUNCATE = 25
print(f"WARNING: TRUNCATING THE DATASET TO {TRUNCATE} --- THIS WILL ALSO BE TRUNCATED IN THE FUNCTIONS INCLUDED IN THE experiments/run_evaluation.py FILE")
print(f"TO RUN OVER THE ENTIRE DATASET, UNCOMMENT THE RELEVANT LINES IN THE FUNCTIONS (SEARCH FOR THE STRING 'TRUNCATE')")


def plot_comparative_figure(df, method_0, method_1, data_type="Train"):
    """
    Plot a comparative figure of the results between the two models.
    """
    figs_per_row = ["distance_noise_small", "distance_noise_large", "distance_blur_small", "distance_blur_large"]

    # Create a single row figure with two boxplots per column
    fig, axs = plt.subplots(1, len(figs_per_row), figsize=(20, 5), sharey=True)

    for j, fig_type in enumerate(figs_per_row):
        # Filter data for method_0
        if "small" in fig_type:
            df_method_0 = df[df[f"{method_0}_{fig_type}_class_change"] == False]
        else:
            df_method_0 = df[df[f"{method_0}_{fig_type}_class_change"] == True]
        
        # Filter data for method_1
        if "small" in fig_type:
            df_method_1 = df[df[f"{method_1}_{fig_type}_class_change"] == False]
        else:
            df_method_1 = df[df[f"{method_1}_{fig_type}_class_change"] == True]

        # Combine the data for boxplot using a hue for methods
        combined_df = pd.DataFrame({
            'Value': pd.concat([df_method_0[f"{method_0}_{fig_type}"], df_method_1[f"{method_1}_{fig_type}"]]),
            'Method': [method_0] * len(df_method_0) + [method_1] * len(df_method_1)
        })

        # Create boxplot
        sns.boxplot(x='Method', y='Value', hue= 'Method', data=combined_df, ax=axs[j])
        axs[j].set_title(f"{fig_type}".replace("_", " "))
    fig.suptitle(f"Comparative Analysis of {method_0} and {method_1} on {data_type} Data", fontsize=16)
    plt.tight_layout()
    plt.show()


## Pre-training evaluation
Here, we evaluate the performance of the models on the test and train data, and then evaluate the quality of explanations over the same datasets. 


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

train_data, test_data = get_data_imagenette()
vgg16_results = evaluate_performance(get_vgg16(), train_data, test_data)
vgg19_results = evaluate_performance(get_vgg19(), train_data, test_data)

# visualise results from the initial dataframes
vgg16_results['model'] = 'VGG16'
vgg19_results['model'] = 'VGG19'

# Concatenate dataframes
combined_df = pd.concat([vgg16_results, vgg19_results])

# Melt the dataframe to long format for seaborn
melted_df = combined_df.melt(id_vars=['model'], var_name='metric', value_name='value')

# Create the boxplot
plt.figure(figsize=(12, 8))
sns.boxplot(x='metric', y='value', hue='model', data=melted_df)
plt.title('Model Performance Comparison')
plt.xlabel('Metric')
plt.ylabel('Value')
plt.legend(title='Model')
plt.show()

# To evaluate the explanations, we need to pass in a list of "methods"
# each method is a tuple of the form (name, method, model)
# name is a string which identifies the method -- i.e. "VGG16"
# method is a function which generates a heatmap on a certain model --- i.e. perform_lrp_plain
# model is the model which the method is applied to --- it needs to be in the heatmap form (i.e.)
vgg16 = get_vgg16()
vgg19 = get_vgg19()
methods = [
        ("VGG16", perform_lrp_plain, WrapperNet(vgg16, hybrid_loss=True)),
        ("VGG19", perform_lrp_plain, WrapperNet(vgg19, hybrid_loss=True))
    ]
# now evaluate explanations
df_train, df_test = evaluate_explanations(train_data, test_data, methods, save_results = False)


In [None]:
plot_comparative_figure(df_test, "VGG16", "VGG19", "Test")
plot_comparative_figure(df_train, "VGG16", "VGG19", "Train")