Include source package

In [None]:
# switch to the project directory
%cd ..
# working directory should be ../FSE

In [None]:
import sys
import os
module_path = os.path.abspath('src')

if module_path not in sys.path:
    sys.path.append(module_path)

Experiment definition

In [None]:
import os
import pandas as pd
import numpy as np

from pdi.data.preparation import FeatureSetPreparation, MeanImputation, DeletePreparation, RegressionImputation, EnsemblePreparation
from pdi.constants import PARTICLES_DICT, TARGET_CODES, NUM_WORKERS, P_RANGE, P_RESOLUTION
from pdi.models import NeuralNet, NeuralNetEnsemble, AttentionModel
from pdi.data.types import Split, Additional
from pdi.evaluate import calculate_precision_recall, get_predictions_data_and_loss

EXPERIMENTS = {
    "Mean": {
        "model_class": NeuralNet,
        "data": {
            "all": MeanImputation,
            "complete_only": DeletePreparation,
        }
    },
    "Regression": {
        "model_class": NeuralNet,
        "data": {
            "all": RegressionImputation,
            "complete_only": DeletePreparation,
        }
    },
    "Ensemble": {
        "model_class": NeuralNetEnsemble,
        "data": {
            "all": EnsemblePreparation,
            "complete_only": lambda: EnsemblePreparation(complete_only=True),
        }
    },
    "Proposed": {
        "model_class": AttentionModel,
        "data": {
            "all": FeatureSetPreparation,
            "complete_only": lambda: FeatureSetPreparation(complete_only=True),
        }
    },
    "Delete": {
        "model_class": NeuralNet,
        "data": {
            "complete_only": DeletePreparation,
        }
    },
}

RANGES_95 = {
    211: [0.118081, 1.605047],
    2212: [0.257578, 2.662042],
    321: [0.175148, 2.520606],
    -211: [0.119654, 1.604230],
    -2212: [0.265045, 2.746735],
    -321: [0.185005, 2.538974],
}

RANGES_90 = {
    211: [0.118081, 1.204883],
    2212: [0.257578, 2.152858],
    321: [0.175148, 1.976095],
    -211: [0.119654, 1.203556],
    -2212: [0.265045, 2.204669],
    -321: [0.185005, 1.977073],
}

particle_names = [PARTICLES_DICT[i] for i in TARGET_CODES]
model_names = EXPERIMENTS.keys()
metrics = ["precision", "recall", "f1"]
data_types = ["all", "complete_only"]
pt_percentage = ["90%", "95%"]

Evaluate pretrained models on test datasets

In [None]:
import torch
import pickle
from os.path import isfile

device = torch.device("cuda")

SAMPLES = ["", "_0", "_1", "_2"]

for sample in SAMPLES:
    prediction_data = {}
    save_path = f"reports/test_results{sample}.pkl"
    if isfile(save_path):
        continue

    print(f"Testing: {sample}")

    for target_code in TARGET_CODES:
        prediction_data[target_code] = {}
        particle_name = PARTICLES_DICT[target_code]
        for experiment_name, exp_dict in EXPERIMENTS.items():
            load_path = f"models/{experiment_name}/{particle_name}{sample}.pt"
            saved_model = torch.load(load_path)
            model = exp_dict["model_class"](*saved_model["model_args"]).to(device)
            model.thres = saved_model["model_thres"]
            model.load_state_dict(saved_model["state_dict"])

            batch_size = 512

            prediction_data[target_code][experiment_name] = {}
            for data_type, data_prep in exp_dict["data"].items():
                test_loader, = data_prep().prepare_dataloaders(batch_size, NUM_WORKERS, [Split.TEST])

                predictions, targets, add_data, _ = get_predictions_data_and_loss(model, test_loader, device)

                selected = predictions > model.thres
                binary_targets = targets == target_code
                
                prediction_data[target_code][experiment_name][data_type] = {
                    "targets": binary_targets,
                    "predictions": predictions,
                    "momentum": add_data[Additional.fPt.name],
                    "threshold": model.thres
                }


            
    with open(save_path, "wb") as f:
        pickle.dump(prediction_data, f)

In [None]:
import pickle

prediction_data = []
SAMPLES = ["", "_0", "_1", "_2"]

for sample in SAMPLES:
    with open(f"reports/test_results{sample}.pkl", "rb") as f:
        prediction_data.append(pickle.load(f))

Plot precision (purity) and recall (efficiency) comparison

In [None]:
%matplotlib inline
import pdi.visualise as vis
from importlib import reload
reload(vis)

p_min, p_max = P_RANGE
p_range = np.linspace(p_min, p_max, P_RESOLUTION)
intervals = list(zip(p_range[:-1], p_range[1:]))

for target_code in TARGET_CODES:
    particle_name = PARTICLES_DICT[target_code]
    for data_type in data_types:
        data = {}
        for exp_name, exp_dict in prediction_data[0][target_code].items():
            if exp_name == "Proposed":
                exp_name = "FSE + attention"
            if data_type in exp_dict:
                data[exp_name] = exp_dict[data_type]

        save_dir = f"reports/figures/comparison_{data_type}/{particle_name}"   
        os.makedirs(save_dir, exist_ok=True)
        # plot_purity_comparison(particle_name, data, intervals, save_dir)
        # plot_efficiency_comparison(particle_name, data, intervals, save_dir)
        
        vis.plot_precision_recall_comparison(particle_name, data, RANGES_90[target_code], 90, save_dir)
        vis.plot_precision_recall_comparison(particle_name, data, RANGES_95[target_code], 95, save_dir)

In [None]:
metric_results = pd.DataFrame(
    index=pd.MultiIndex.from_product(
        [pt_percentage, particle_names, model_names], names=["pt_range", "particle", "model"]
        ),
    columns=pd.MultiIndex.from_product(
        [list(range(len(SAMPLES))), data_types, metrics], names=["test_case", "data", "metric"]
        ),
    )

for i in range(len(SAMPLES)):
    for target_code in TARGET_CODES:
        for pt_percent, pt_range in zip(pt_percentage, [RANGES_90[target_code], RANGES_95[target_code]]):
            particle_name = PARTICLES_DICT[target_code]
            for data_type in data_types:
                data = {}
                for exp_name, exp_dict in prediction_data[i][target_code].items():
                    if data_type in exp_dict:
                        data[exp_name] = exp_dict[data_type]

                for method_name, results in data.items():
                    targets = results["targets"]
                    preds = results["predictions"]
                    momentum = results["momentum"]
                    thres = results["threshold"]

                    mask = (momentum >= pt_range[0]) & (momentum <= pt_range[1])
                    targets = targets[mask]
                    preds = preds[mask]

                    selected = preds > thres
        
                    true_positives = int(np.sum(selected & targets))
                    selected_positives = int(np.sum(selected))
                    positives = int(np.sum(targets))

                    precision, recall, _, _ = calculate_precision_recall(true_positives, selected_positives, positives)
                    f1 = 2 * precision * recall / (precision + recall + np.finfo(float).eps)

                    metric_results.loc[(pt_percent, particle_name, method_name), (i, data_type)] = precision, recall, f1
                
os.makedirs("reports/tables", exist_ok=True)
metric_results.to_csv(f"reports/tables/comparison_metrics.csv")

Create LaTeX table from metrics file

In [None]:
df = pd.read_csv(f"reports/tables/comparison_metrics.csv",
                 index_col=[0, 1, 2],
                 header=[0, 1, 2])
subsets = [((part, slice(None)), column) for column in df.columns
           for part in particle_names]

df = df.rename(columns={"f1": "$F_1$"})

def bold_max(x):
    arr = np.array(list(x))[:, 0]
    return np.where(arr == np.nanmax(arr), f"textbf:--rwrap", None)


for percent in pt_percentage:
    for dt in data_types:
        save_dir = f"reports/tables/{percent[:-1]}/comparison_{dt}"
        os.makedirs(save_dir, exist_ok=True)
        for particle in particle_names:
            results = df.xs((percent, particle)).xs(dt, axis='columns', level=1).groupby(level="metric", axis="columns")

            results_mean = results.mean() * 100
            results_std = results.std(numeric_only=True) * 100

            combined = pd.concat([results_mean, results_std], keys=['mean', 'std'], axis=1)
            combined = combined.reorder_levels([1, 0], axis=1).stack().groupby("model").agg(list)
            combined = combined[["precision", "recall", "$F_1$"]]

            columns = [(slice(None), (column)) for column in metrics]
            
            style = combined.style

            for column in columns:
                if column[1] == "f1":
                    column = (slice(None), "$F_1$")
                style.apply(bold_max, subset=column)
                # style = style.highlight_max(column, props='textbf:--rwrap')

            style.format('{0[0]:.2f} $\pm$ {0[1]:.2f}\%')


            style.to_latex(f"{save_dir}/{particle}_results.tex",
                        hrules=True,
                        clines="all;data")
