In [29]:
#imports to work with...
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
from torch.utils.data import DataLoader
import torch
import torchvision
from torchvision import transforms
from sklearn.metrics import precision_recall_curve, PrecisionRecallDisplay, average_precision_score
from sklearn.preprocessing import label_binarize

from cl_framework.continual_learning.metrics.metric_evaluator_incdec import MetricEvaluatorIncDec
from cl_framework.utilities.matrix_logger import IncDecLogger
from torchmetrics import Recall


In [30]:
results_path = [
                '../runs_trainings/no_freeze/multilabel/weighted',
                '../runs_trainings/no_freeze/joint_incremental_restored_multilabel/weighted/reset',
                '../runs_trainings/no_freeze/joint_incremental_multilabel/weighted',
                '../runs_trainings/no_freeze/incremental_decremental_multilabel/weighted',
                '../runs_trainings/no_freeze/decremental_multilabel/weighted',
                ]
save_exp_name = ['no_freeze_baseline',
                 'no_freeze_joint_reset',
                 'no_freeze_joint',
                 'no_freeze_incdec',
                 'no_freeze_dec',
                 ]
#seeds = [0,1,2]


In [31]:
def extract_data(task_dataframe):
    predictions_names = ['food', 'phone','smoking','fatigue','selfcare']
    targets_names = ['target_food','target_phone','target_smoking','target_fatigue','target_selfcare']
    probabilities = []
    targets = []
    for i in range(len(predictions_names)):
        class_targets = task_dataframe[targets_names[i]].tolist()
        targets.append(class_targets)
        class_predictions = task_dataframe[predictions_names[i]].tolist()
        probabilities.append(class_predictions)
    
    
    return probabilities, targets


def get_precision_recall_for_prcurve(probabilities, labels, num_classes):
     # precision recall curve
        Y = labels
        precision = dict()
        recall = dict()
        average_precision = dict()
        for i in range(num_classes):
            precision[i], recall[i], _ = precision_recall_curve(Y[:, i],
                                                                probabilities[:, i])


        return precision, recall


def plot_pr_curve(precision, recall, num_tasks, idx_class, class_name, output_path):
        """
        Returns a matplotlib figure containing the plotted confusion matrix.
        """

        figure, ax = plt.subplots(figsize=(8, 8))
        for i in range(num_tasks):
            if i == 0 or i==num_tasks-1:
                display = PrecisionRecallDisplay(
                    recall=recall[i][idx_class],
                    precision=precision[i][idx_class]
                )
                display.plot(ax=ax, name=f"Precision-recall for task_id {i} for class {class_name}")
        figure.savefig(output_path + '/' + class_name + '.png')
        plt.close(figure)


In [32]:
seeds = [0,1,2]
#this is a version where i put all the examples from all three seeds
output_path = '../statistics_to_save/pr_curves/'
for idx_exp in range (len(results_path)):  
    output_name_path = os.path.join(output_path,save_exp_name[idx_exp])
    if not os.path.exists(output_name_path):
            os.mkdir(output_name_path)

    precision = []
    recall = []
    precision_micro = []
    recall_micro = []
    for i in range(6):
        task_name = 'task_' + str(i) + '_test_error_analysis.csv'
        probabilities = []
        targets = []
        
        for idx_seed in seeds:  
            seed_path = os.path.join(results_path[idx_exp],'seed_' + str(idx_seed))
            for name_exp in os.listdir(seed_path):
                exp_path = os.path.join(seed_path,name_exp)
                ea_name = 'error_analysis'
                ea_path = os.path.join(exp_path,ea_name)
                
                task_ea_path = os.path.join(ea_path,task_name)
                task_dataframe = pd.read_csv(task_ea_path)
                tmp_probabilities, tmp_targets = extract_data(task_dataframe)
                probabilities.append(torch.Tensor(tmp_probabilities).permute(1,0))
                targets.append(torch.Tensor(tmp_targets).permute(1,0))

        probabilities = torch.cat(probabilities,0).numpy()
        targets = torch.cat(targets,0).numpy()
        
        tmp_precision, tmp_recall = get_precision_recall_for_prcurve(probabilities,targets,5)
        precision.append(tmp_precision)
        recall.append(tmp_recall)



    classes = ['food', 'phone','smoking','fatigue','selfcare']
    for idx_class in range(len(classes)):
        plot_pr_curve(precision, recall, 6, idx_class, classes[idx_class], output_name_path)

In [33]:
#this is a version where i do it just for 1 seed
seeds = [0]
output_path = '../statistics_to_save/pr_curves/one_seed/'
for idx_exp in range (len(results_path)):  
    output_name_path = os.path.join(output_path,save_exp_name[idx_exp])
    if not os.path.exists(output_name_path):
            os.mkdir(output_name_path)

    precision = []
    recall = []
    
    precision_micro = []
    recall_micro = []
    for idx_seed in seeds:  
        seed_path = os.path.join(results_path[idx_exp],'seed_' + str(idx_seed))
        for name_exp in os.listdir(seed_path):
            exp_path = os.path.join(seed_path,name_exp)
            ea_name = 'error_analysis'
            ea_path = os.path.join(exp_path,ea_name)
            for i in range(6):
                task_name = 'task_' + str(i) + '_test_error_analysis.csv'
                task_ea_path = os.path.join(ea_path,task_name)
                task_dataframe = pd.read_csv(task_ea_path)
                tmp_probabilities, tmp_targets = extract_data(task_dataframe)
                probabilities = torch.Tensor(tmp_probabilities).permute(1,0).numpy()
                targets = torch.Tensor(tmp_targets).permute(1,0).numpy()
                
                tmp_precision, tmp_recall = get_precision_recall_for_prcurve(probabilities,targets,5)
                precision.append(tmp_precision)
                recall.append(tmp_recall)

            classes = ['food', 'phone','smoking','fatigue','selfcare']
            for idx_class in range(len(classes)):
                plot_pr_curve(precision, recall, 6, idx_class, classes[idx_class], output_name_path)