In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import argparse
import glob
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import json
from asreview import open_state
from asreviewcontrib.insights.plot import plot_recall
from asreviewcontrib.insights.plot import _recall_values
from asreviewcontrib.insights.utils import pad_simulation_labels
import asreviewcontrib.insights.metrics as met
from langdetect import detect
import pylab
import shutil
import os
import string

# Move state files from simulations into state_files folder

In [None]:
path = r'../simulations/state_files' 
if not os.path.exists(path):
    os.makedirs(path)

In [None]:
state_source = "../simulations/prior-knowledge-{}/output/simulation/Hamilton_{}/state_files/sim_Hamilton_{}_{}_{}_0.asreview"
state_target = "../simulations/state_files/pk{}_sim_Hamilton_{}_{}_{}_0.asreview"
prior_knowledge = ['1','2']
classifiers = ['logistic', 'svm']
datasets = ['Original', 'English', 'Multi_1', 'Multi_2']
models = ['tfidf', 'sbert', 'mbert', 'muse', 'mlongt5', 'labse', 'laser', 'mpnet', 'minilm', 'stsb']

In [None]:
for prior in prior_knowledge:
    for classifier in classifiers:
        for dataset in datasets:
            for model in models:
                source = state_source.format(prior, dataset, dataset, classifier, model)
                target = state_target.format(prior, dataset, classifier, model)
                if (os.path.exists(source)) & (not os.path.exists(target)):
                    shutil.copy(source, target)

# Extract statefiles into df

In [None]:
metrics_df = pd.DataFrame()
for prior in prior_knowledge:
    for classifier in classifiers:
        for dataset in datasets:
            for model in models:
                row = {'Dataset':dataset, 'Classifier':classifier, 'Model':model, 'Prior':prior}
                with open_state(state_target.format(prior, dataset, classifier, model)) as state:
                    metrics = met.get_metrics(state)
                    for item in metrics['data']['items']:
                        if item.get('title') == 'Time to discovery':
                            for paper_td in item.get('value'):
                                if paper_td[0] in [300, 567, 741, 878, 1112]:
                                    row['td_'+str(paper_td[0])] = paper_td[1]
                        elif type(item.get('value')) != list:
                            row[item.get('title')] = item.get('value')
                        else:
                            for value in item.get('value'):
                                row[str(item.get('title'))+'_'+str(value[0])] = value[1]
                metrics_df = metrics_df.append(row, ignore_index = True)
metrics_df = metrics_df.rename(columns={"Average time to discovery": "ATD", "Work Saved over Sampling_0.95": "WSS@95"})
metrics_df['ctd_targets'] = metrics_df['td_300'] + metrics_df['td_567'] + metrics_df['td_741'] + metrics_df['td_878'] + metrics_df['td_1112']
metrics_df

In [None]:
#metrics_df.to_csv("metrics.csv")

# Tables

In [None]:
#All simulations
sim = metrics_df
sim[['Model', 'Dataset', 'Classifier', 'Prior', 
        'Recall_0.1', 'WSS@95', 
        'ATD', 'td_1112', 'td_300', 'td_567', 'td_741', 'td_878']]

In [None]:
#Specific simulations
sim = metrics_df.loc[(metrics_df.Classifier == 'logistic') & (metrics_df.Prior == '1') & (metrics_df.Dataset == 'Multi_2')]
sim[['Model', 'Dataset', 'Classifier', 'Prior', 
        'Recall_0.1', 'WSS@95', 
        'ATD', 'td_300', 'td_567', 'td_741', 'td_878', 'td_1112', 'td_sum']]

In [None]:
#Average grouped by dataset
sim = metrics_df
sim.groupby(['Dataset'])['Recall_0.1', 'WSS@95', 'ATD', 'td_300', 'td_567', 'td_741', 'td_878', 'td_1112', 'td_sum'].mean()

In [None]:
#Average grouped by dataset and classifier
sim = metrics_df
sim.groupby(['Dataset', 'Classifier'])['Recall_0.1', 'WSS@95', 'ATD', 'td_300', 'td_567', 'td_741', 'td_878', 'td_1112', 'td_sum'].mean()

In [None]:
#Average grouped by dataset and prior knowledge
sim = metrics_df
sim.groupby(['Dataset', 'Prior'])['Recall_0.1', 'WSS@95', 'ATD', 'td_300', 'td_567', 'td_741', 'td_878', 'td_1112', 'td_sum'].mean()

In [None]:
#Average grouped by dataset and model
sim = metrics_df
sim.groupby(['Dataset', 'Model'])['Recall_0.1', 'WSS@95', 'ATD', 'td_300', 'td_567', 'td_741', 'td_878', 'td_1112', 'td_sum'].mean()

## ATD Plots

In [None]:
# ATD's of target papers per dataset
# Define the mapping dictionary for dataset renaming
dataset_mapping = {
    'Original': 'Multiple',
    'Multi_1': 'Non-English 1',
    'Multi_2': 'Non-English 2'
}

# Group by the desired factor and calculate the mean of 'td' values
grouped_df = metrics_df.groupby('Dataset')['td_1112', 'td_300', 'td_567', 'td_741', 'td_878'].mean()

# Sort the grouped DataFrame by the average 'td' values in ascending order
grouped_df = grouped_df.reindex(grouped_df.mean(axis=1).sort_values().index)

# Set the factor you want to group by (e.g., 'Dataset', 'Prior Knowledge', 'Classifier')
group_by_factor = 'Dataset'

# Create the stacked bar chart
ax = grouped_df.plot.bar(stacked=True)

# Set the x-axis label
plt.xlabel(group_by_factor)

# Set the y-axis label
plt.ylabel('Average Time to Discovery')

# Modify the legend labels
legend_labels = ['Record 1112', 'Record 300', 'Record 567', 'Record 741', 'Record 878']
ax.legend(labels=legend_labels, bbox_to_anchor=(1, 1))

# Rename the datasets on the x-axis
new_labels = [dataset_mapping.get(label.get_text(), label.get_text()) for label in ax.get_xticklabels()]
ax.set_xticklabels(new_labels)

# Set the title for the plot
plt.title("Cumulative Average Time to Discovery (CATD) of Target Papers per Dataset")

# Add the sum of the entire bar on top
for i, value in enumerate(grouped_df.values):
    x = i
    y = value.sum()
    
    # Annotate the rounded sum value at the top of each bar
    ax.annotate(f'{int(round(y))}', (x, y), xytext=(0, 3), textcoords='offset points',
                ha='center', va='bottom', fontsize=8, color='black')

# Adjust the y-axis limits
ax.set_ylim(top=ax.get_ylim()[1] * 1.1)

# Show the plot
plt.show()

In [None]:
# Define the mapping dictionary for dataset renaming
dataset_mapping = {
    'Original': 'Multiple',
    'Multi_1': 'Non-English 1',
    'Multi_2': 'Non-English 2'
}

groups = ['Classifier', 'Prior']

for group in groups:
    # Group by the desired factor and calculate the mean of 'td' values
    grouped_df = metrics_df.groupby(['Dataset', group])['td_1112', 'td_300', 'td_567', 'td_741', 'td_878'].mean()

    # Set the factor you want to group by (e.g., 'Dataset', 'Prior Knowledge', 'Classifier')
    group_by_factor = 'Dataset'

    # Create the stacked bar chart
    ax = grouped_df.plot.bar(stacked=True)

    # Set the x-axis label
    plt.xlabel(group_by_factor + ' + ' + group)

    # Set the y-axis label
    plt.ylabel('Average Time to Discovery')

    # Modify the legend labels
    legend_labels = ['Record 1112', 'Record 300', 'Record 567', 'Record 741', 'Record 878']
    ax.legend(labels=legend_labels, bbox_to_anchor=(1, 1))

    # Update the dataset labels on the x-axis to include prior knowledge
    x_pos = range(len(grouped_df))
    xticklabels = [
        f"{dataset_mapping.get(dataset, dataset)}, {prior.split('_')[-1]}"
        for (dataset, prior), i in zip(grouped_df.index, x_pos)
    ]
    ax.set_xticks(x_pos)
    ax.set_xticklabels(xticklabels, rotation=45, ha='right', fontsize=8)

    # Set the title for the plot
    plt.title("Cumulative Average Time to Discovery (CATD) of Target Papers \n per Dataset and " + group)

    # Add the sum of the entire bar on top
    for i, value in enumerate(grouped_df.values):
        x = i
        y = value.sum()

        # Annotate the rounded sum value at the top of each bar
        ax.annotate(f'{int(round(y))}', (x, y), xytext=(0, 3), textcoords='offset points',
                    ha='center', va='bottom', fontsize=8, color='black')

    # Adjust the y-axis limits
    ax.set_ylim(top=ax.get_ylim()[1] * 1.1)

    # Show the plot
    plt.show()

In [None]:
# ATD's per model per dataset
# Create a new figure
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Iterate over each dataset
for i, dataset in enumerate(datasets):
    # Select the data for the current dataset
    df = metrics_df.loc[metrics_df.Dataset == dataset]

    # Group by the desired factor and calculate the mean of 'td' values
    grouped_df = df.groupby('Model')['td_1112', 'td_300', 'td_567', 'td_741', 'td_878'].mean()

    # Sort the grouped DataFrame by the average 'td' values in ascending order (only for 'Model' group)
    grouped_df = grouped_df.reindex(grouped_df.mean(axis=1).sort_values().index)

    # Create the stacked bar chart for the current dataset
    ax = axes[i // 2, i % 2]
    grouped_df.plot.bar(stacked=True, ax=ax)

    # Set the x-axis label
    ax.set_xlabel('Model')

    # Set the y-axis label
    ax.set_ylabel('Average Time to Discovery')

    # Set the title for each plot with labels A, B, C, D
    dataset_title = dataset_mapping.get(dataset, dataset)
    ax.set_title(f"({string.ascii_uppercase[i]}) Dataset: {dataset_title}")

    # Add the sum of the entire bar on top
    for j, value in enumerate(grouped_df.values):
        x = j
        y = value.sum()

        # Annotate the rounded sum value at the top of each bar
        ax.annotate(f'{int(round(y))}', (x, y), xytext=(0, 3), textcoords='offset points',
                    ha='center', va='bottom', fontsize=8, color='black')

    # Adjust the y-axis limits
    ax.set_ylim(top=ax.get_ylim()[1] * 1.1)

# Set the title for the figure
fig.suptitle("Cumulative Average Time to Discovery (CATD) of Target Papers per Model and Dataset", fontsize=16)

# Remove legends from subplots
for ax in axes.flatten():
    ax.legend().remove()

# Modify the legend labels
legend_labels = ['Record 1112', 'Record 300', 'Record 567', 'Record 741', 'Record 878']

# Create a common legend underneath the subplots
fig.legend(labels=legend_labels, loc='lower center', title='Records', ncol=5, bbox_to_anchor=(0.5, -0.15))

# Adjust the spacing between subplots and legend
fig.tight_layout(rect=[0, -0.1, 1, 0.99])

# Show the plot
plt.show()

## Recall Plots

In [None]:
#Define language for each paper abstract, this could take up to 1 minute
csv_file = "../simulations/prior-knowledge-1/data/Hamilton_{}.csv"
language_dfs = {}
for dataset in datasets:
    languages = pd.read_csv(csv_file.format(dataset))
    languages['language'] = languages['abstract'].dropna().apply(detect)
    languages.columns.values[0] = 'record_id'
    language_dfs[dataset] = languages

In [None]:
#Create all recall plots
plotCount = 0
for prior in prior_knowledge:
    for classifier in classifiers:
        for dataset in datasets:
            plt.figure(figsize=(20, 30))
            plt.subplots_adjust(hspace=0.25)
            if dataset == 'Original':
                tit = 'Multiple'
            elif dataset == 'Multi_1':
                tit = 'Non-English 1'
            elif dataset == 'Multi_2':
                tit = 'Non-English 2'
            else:
                tit = 'English'
            plt.suptitle("Target records in recall plot - " + tit + ' - Prior Knowledge ' + prior, fontsize=18, y=0.90)
            n = 0
            for model in models:
                ax = plt.subplot(5, 2, n + 1)
                with open_state(state_target.format(prior, dataset, classifier, model)) as state:
                    states_df = state.get_dataset()

                    sim_labels = pad_simulation_labels(state)
                    x, y = _recall_values(sim_labels, x_absolute=False, y_absolute=False)

                    languages = language_dfs[dataset]
                    states_df = states_df.merge(languages[['record_id', 'language']], on='record_id')
                    non_english = states_df.loc[((states_df.language != 'en') & (states_df.language.isna() == False)
                                                & (states_df.label == 1))]
                    targets = states_df.loc[states_df.record_id.isin([300, 567, 741, 878, 1112])]

                    for language in targets.language.sort_values().unique():
                        temp = targets.loc[targets.language == language]
                        ax.scatter([x[i] for i in temp.index], [y[i] for i in temp.index], marker="o", s=50,
                                   alpha=0.8)
                        for index, row in temp.iterrows():
                            ax.text(x[index] - 0.05, y[index] + 0.03, row.record_id, fontsize=9)
                    ax.legend(non_english.language.sort_values().unique())

                    # draw the plot
                    plot_recall(ax, state)

                    ax.set_title(classifier + ' - ' + model + ' (' + string.ascii_uppercase[n] + ')')
                    n += 1
            plotCount += 1
            #plt.savefig('Recall plot ' + str(plotCount) +'.png')