# Imports

In [237]:
import os
from pathlib import Path

import matplotlib.patches as patches
import matplotlib.path as mpath
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import ticker
from matplotlib.ticker import MaxNLocator
from matplotlib_venn import venn3
from scipy.spatial import ConvexHull

# Load Data

In [None]:
# Load the data
input_file_foras = list(Path(os.path.join("data")).glob("PTSS_Data_Foras.xlsx"))[0]
input_file_synergy = list(Path(os.path.join("data")).glob("PTSS_Data_Synergy.xlsx"))[0]
fulltext_foras = list(Path(os.path.join("data")).glob("PTSS_Data_Foras_Fulltext.xlsx"))[
    0
]
fulltext_foras_2nd = list(
    Path(os.path.join("data")).glob("PTSS_Data_Foras_Fulltext_2ndscreener.xlsx")
)[0]
fulltext_synergy = list(
    Path(os.path.join("data")).glob("PTSS_Data_Synergy_Fulltext.xlsx")
)[0]

# Print the file names
print(
    "Results based on file: ",
    input_file_foras,
    input_file_synergy,
    fulltext_foras,
    fulltext_foras_2nd,
    fulltext_synergy,
)

# Read the foras file and filter out the duplicates
foras_unfiltered = pd.read_excel(input_file_foras)
foras_filtered = foras_unfiltered[foras_unfiltered["filter_duplicate"] != 1]

# Read the other files
synergy = pd.read_excel(input_file_synergy)
fulltext_foras = pd.read_excel(fulltext_foras)
fulltext_foras_2nd = pd.read_excel(fulltext_foras_2nd)
fulltext_synergy = pd.read_excel(fulltext_synergy)

# Print the number of rows in each file
print("Number of records in original FORAS file: ", foras_unfiltered.shape[0])
print(
    "Number of records in FORAS after filtering duplicates: ", foras_filtered.shape[0]
)
print("Number of records in SYNERGY", synergy.shape[0])
print("Number of records in FORAS fulltext", fulltext_foras.shape[0])
print("Number of records in FORAS fulltext 2nd screener", fulltext_foras_2nd.shape[0])
print("Number of records in SYNERGY fulltext", fulltext_synergy.shape[0])

# Background variables

## Duplicates

In [None]:
# calculate the number of duplicates in the foras unfiltred file using duplicate_record_identifier
duplicates = foras_unfiltered[
    foras_unfiltered["filter_duplicate"].notnull()
].shape[0]
print(
    "Number of duplicates in FORAS file, which will be ignored in the remainder of the analyses: ",
    duplicates,
)

# number of records in the foras file after filtering out the duplicates
foras_filtered = foras_filtered[foras_filtered["filter_duplicate"] != 1]
print(
    "Number of records in FORAS after filtering duplicates: ", foras_filtered.shape[0]
)

## Number of PIDs, Titles and Abstracts

In [None]:
# Function to calculate missing data and plot for a given dataset
def analyze_missing_data(dataset, dataset_name):
    # Define columns to check for missing values
    columns_to_check = ["doi", "openalex_id", "title", "abstract"]
    missing_data_counts = {col: dataset[col].isnull().sum() for col in columns_to_check}
    missing_data_counts["Both Missing"] = (
        dataset["doi"].isnull().sum() & dataset["openalex_id"].isnull().sum()
    )

    # Print the number of records without certain values
    print(f"Missing Data Analysis for {dataset_name}:")
    for category, count in missing_data_counts.items():
        print(f"  - Number of records without {category}: {count}")

    # Calculate percentages for visualization
    total_records = dataset.shape[0]
    percentages = {
        key: (value / total_records) * 100 for key, value in missing_data_counts.items()
    }

    # Plotting
    plt.figure(figsize=(10, 6))
    categories = list(percentages.keys())
    values = list(percentages.values())
    bars = plt.bar(categories, values, color=plt.cm.tab10(range(len(categories))))

    # Add title and labels
    plt.title(f"Percentage of Missing Data in Each Category ({dataset_name})")
    plt.xlabel("Missing Data Category")
    plt.ylabel("Percentage of Total Records")

    # Annotate bars
    for bar, category in zip(bars, categories):
        height = bar.get_height()
        count = missing_data_counts[category]
        plt.text(
            bar.get_x() + bar.get_width() / 2.0,
            height,
            f"{count} ({height:.1f}%)",
            ha="center",
            va="bottom",
        )

    # Display the chart
    plt.show()

    # Return results for further programmatic use
    return missing_data_counts


# Call the function for both datasets
foras_results = analyze_missing_data(foras_filtered, "Foras")
synergy_results = analyze_missing_data(synergy, "Synergy")

# Frequencies

## Frequencies for search results

In [None]:
def generate_frequency_overview(dataset, variable_list):
    print("Frequency Overview:\n")
    for variable in variable_list:
        if variable not in dataset.columns:
            print(f"Variable '{variable}' not found in the dataset. Skipping...\n")
            continue

        print(f"Variable: {variable}")
        value_counts = dataset[variable].value_counts(dropna=False)
        total_records = value_counts.sum()
        percentages = (value_counts / total_records * 100).round(2)
        frequency_table = pd.DataFrame(
            {"Count": value_counts, "Percentage": percentages}
        )

        print(frequency_table)
        print("\n" + "-" * 50 + "\n")


# New list of variables to analyze
variable_list_additional = [
    "search_replication",
    "search_comprehensive",
    "search_snowballing",
    "search_fulltext",
    "search_openalex_inlusion_criteria",
    "search_openalex_inlusion_criteria_long",
    "search_openalex_logistic",
    "search_openalex_logistic_long",
    "search_openalex_all_abstracts",
    "search_openalex_all_abstracts_long",
    "batch",
]

generate_frequency_overview(foras_filtered, variable_list_additional)

## Frequencies for labeling decissions

In [None]:
# List of variables to analyze
variable_list = [
    "title_eligible_Bruno",
    "TI-AB_IC1_Bruno",
    "TI-AB_IC2_Bruno",
    "TI-AB_IC3_Bruno",
    "TI-AB_IC4_Bruno",
    "TI-AB_final_label_Bruno",
    "title_eligible_Rutger",
    "TI_final_label",
    "TI-AB_final_label_Rutger",
    "TI-AB_disagreement_human-human",
    "TI-AB_IC1_LLM",
    "TI-AB_IC2_LLM",
    "TI-AB_IC3_LLM",
    "TI-AB_IC4_LLM",
    "TI-AB_final_label_LLM",
    "LLM_re-assessed",
    "TI-AB_disagreement_human-LLM",
    "TI-AB_IC1_joint",
    "TI-AB_IC2_joint",
    "TI-AB_IC3_joint",
    "TI-AB_IC4_joint",
    "TI-AB_IC1_final",
    "TI-AB_IC2_final",
    "TI-AB_IC3_final",
    "TI-AB_IC4_final",
    "TI-AB_final_label",
    "full_text_available",
    "FT_IC1_Bruno",
    "FT_IC2_Bruno",
    "FT_IC3_Bruno",
    "FT_IC4_Bruno",
    "FT_inclusion_Bruno",
    "FT_inclusion_Rutger",
    "FT_disagreements_Bruno-Rutger",
    "FT_IC1_joint",
    "FT_IC2_joint",
    "FT_IC3_joint",
    "FT_IC4_joint",
    "FT_IC1_final",
    "FT_IC2_final",
    "FT_IC3_final",
    "FT_IC4_final",
    "FT_final_label",
    "label_included_TIAB",
    "label_included_FT",
]

# Call the function to generate the frequency overview
generate_frequency_overview(foras_filtered, variable_list)

## Labeling combinations

In [243]:
def analyze_inclusions(dataset, final_label, label_1, label_2):
    # Total count and percentages for inclusions
    print(f"Inclusion Analysis for {final_label}:\n")

    # Cross-tabulation
    print(f"Cross-tabulation of {label_1} vs {label_2}:\n")
    crosstab = pd.crosstab(
        dataset[label_1], dataset[label_2], margins=True, dropna=True
    )
    print(crosstab, "\n")

    print(f"Cross-tabulation of {final_label} vs {label_1}:\n")
    crosstab = pd.crosstab(
        dataset[final_label], dataset[label_1], margins=True, dropna=True
    )
    print(crosstab, "\n")

    print(f"Cross-tabulation of {final_label} vs {label_2}:\n")
    crosstab = pd.crosstab(
        dataset[final_label], dataset[label_2], margins=True, dropna=True
    )
    print(crosstab, "\n")

In [None]:
# Call the function for title inclusions
analyze_inclusions(
    foras_filtered, "TI_final_label", "title_eligible_Rutger", "title_eligible_Bruno"
)

# Call the function for abstract inclusions
analyze_inclusions(
    foras_filtered,
    "TI-AB_final_label",
    "TI-AB_final_label_Rutger",
    "TI-AB_final_label_Bruno",
)

# Call the function for full text inclusions
analyze_inclusions(
    foras_filtered, "FT_final_label", "FT_inclusion_Rutger", "FT_inclusion_Bruno"
)

In [None]:
# Call the function for ti-ab inclusions criteria 1
analyze_inclusions(
    foras_filtered, "TI-AB_IC1_final", "TI-AB_IC1_Bruno", "TI-AB_IC1_joint"
)

# Call the function for ti-ab inclusions criteria 2
analyze_inclusions(
    foras_filtered, "TI-AB_IC2_final", "TI-AB_IC2_Bruno", "TI-AB_IC2_joint"
)

# Call the function for ti-ab inclusions criteria 3
analyze_inclusions(
    foras_filtered, "TI-AB_IC3_final", "TI-AB_IC3_Bruno", "TI-AB_IC3_joint"
)

# Call the function for ti-ab inclusions criteria 4
analyze_inclusions(
    foras_filtered, "TI-AB_IC4_final", "TI-AB_IC4_Bruno", "TI-AB_IC4_joint"
)

In [None]:
# Call the function for FT inclusions criteria 1
analyze_inclusions(foras_filtered, "FT_IC1_final", "FT_IC1_Bruno", "FT_IC1_joint")

# Call the function for FT inclusions criteria 2
analyze_inclusions(foras_filtered, "FT_IC2_final", "FT_IC2_Bruno", "FT_IC2_joint")

# Call the function for FT inclusions criteria 3
analyze_inclusions(foras_filtered, "FT_IC3_final", "FT_IC3_Bruno", "FT_IC3_joint")

# Call the function for FT inclusions criteria 4
analyze_inclusions(foras_filtered, "FT_IC4_final", "FT_IC4_Bruno", "FT_IC4_joint")

# Plots

## Search results

### Frequencies

In [None]:
# Search columns
search_columns = [
    "search_replication",
    "search_comprehensive",
    "search_snowballing",
    "search_fulltext",
    "search_openalex_inlusion_criteria",
    "search_openalex_inlusion_criteria_long",
    "search_openalex_logistic",
    "search_openalex_logistic_long",
    "search_openalex_all_abstracts",
    "search_openalex_all_abstracts_long",
]

# Shorter names for plotting
short_names = {
    "search_replication": "Replication",
    "search_comprehensive": "Comprehensive",
    "search_snowballing": "Snowballing",
    "search_fulltext": "Fulltext",
    "search_openalex_inlusion_criteria": "OpenAlex-Inclusion",
    "search_openalex_inlusion_criteria_long": "OpenAlex-Inclusion-Long",
    "search_openalex_logistic": "OpenAlex-Logistic",
    "search_openalex_logistic_long": "OpenAlex-Logistic-Long",
    "search_openalex_all_abstracts": "OpenAlex-All-Abstracts",
    "search_openalex_all_abstracts_long": "OpenAlex-All-Abstracts-Long",
}

# Define a color scheme
bar_color = "#009739"

# Calculate counts for each binary column in Foras
counts_foras = [foras_filtered[column].sum() for column in search_columns]

# Set up the figure and axis
plt.figure(figsize=(12, 6))  # Adjust width dynamically for readability
bars_foras = plt.bar(short_names.values(), counts_foras, color=bar_color)

# Add title and labels
plt.title("Number of Records Found via the Different Search Methods", fontsize=14)
plt.ylabel("Count", fontsize=12)
plt.xlabel("Search Method", fontsize=12)
plt.xticks(rotation=45, ha="right")

# Add gridlines for better readability
plt.grid(axis="y", linestyle="--", alpha=0.7)

# Annotate bars with their values
for bar in bars_foras:
    yval = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2,
        yval + 0.05 * max(counts_foras),
        f"{int(yval)}",
        ha="center",
        va="bottom",
        fontsize=10,
    )

# save plot as pdf and PNG in /output folder
plt.savefig("output/number_of_records_found_via_the_different_search_methods.pdf")
plt.savefig("output/number_of_records_found_via_the_different_search_methods.png")

# Display the plot
plt.tight_layout()
plt.show()

### Old versus New

In [None]:
# Create a copy of the DataFrame to avoid SettingWithCopyWarning
foras_filtered_copy = foras_filtered.copy()

# Create 'old-school' variable: 1 if any of the specified columns is 1, else 0
foras_filtered_copy["old-school"] = (
    foras_filtered_copy[
        ["search_replication", "search_comprehensive", "search_snowballing"]
    ]
    .any(axis=1)
    .astype(int)
)

# Create 'new-school' variable: 1 if any of the specified columns is 1, else 0
foras_filtered_copy["new-school"] = (
    foras_filtered_copy[
        [
            "search_openalex_inlusion_criteria_long",
            "search_openalex_logistic_long",
            "search_openalex_all_abstracts_long",
            "search_fulltext",
        ]
    ]
    .any(axis=1)
    .astype(int)
)

# Calculate counts for each category
categories = [
    "Old-school",
    "New-school",
    "Both old and new",
]
counts = [
    foras_filtered_copy["old-school"].sum(),
    foras_filtered_copy["new-school"].sum(),
    foras_filtered_copy[
        (foras_filtered_copy["old-school"] == 1)
        & (foras_filtered_copy["new-school"] == 1)
    ].shape[0],
]

# Print the counts for validation
print("Total records: ", foras_filtered_copy.shape[0])
print("Old-school: ", counts[0])
print("New-school: ", counts[1])
print("Both: ", counts[2])

# Plotting
plt.figure(figsize=(12, 6))
bar_color = "#009739" 
bars = plt.bar(categories, counts, color=bar_color)

# Add title and labels
plt.title("Number of Records Found via Different Search Methods in Foras", fontsize=14)
plt.ylabel("Count", fontsize=12)
plt.xlabel("Search Method", fontsize=12)
plt.xticks(rotation=45, ha="right") 

# Annotate each bar with its count
for bar in bars:
    yval = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2,
        yval + 0.05 * max(counts), 
        f"{int(yval)}",
        ha="center",
        va="bottom",
        fontsize=10,
    )

# Add gridlines for easier visualization
plt.grid(axis="y", linestyle="--", alpha=0.7)

# save plot as pdf and PNG in /output folder
plt.savefig("output/number_of_records_found_via_different_search_methods.pdf", bbox_inches="tight")
plt.savefig("output/number_of_records_found_via_different_search_methods.png", bbox_inches="tight")

# Show the plot
plt.tight_layout()
plt.show()


### Bar chart with Unique search results 

In [None]:
# Define search columns
search_columns = [
    "search_replication",
    "search_comprehensive",
    "search_snowballing",
    "search_fulltext",
    "search_openalex_inlusion_criteria_long",
    "search_openalex_logistic_long",
    "search_openalex_all_abstracts_long",
]

# Generate permutation column
foras_filtered = foras_filtered.copy()  # Avoid SettingWithCopyWarning
foras_filtered["permutation"] = foras_filtered[search_columns].apply(
    lambda row: "".join(row.values.astype(int).astype(str)), axis=1
)

# Value counts and aggregation
value_counts = foras_filtered["permutation"].value_counts().reset_index()
value_counts.columns = ["permutation", "count"]

sum_inclusions = (
    foras_filtered.groupby("permutation")[["TI-AB_final_label", "FT_final_label"]]
    .sum()
    .reset_index()
)

# Merge results
result = pd.merge(value_counts, sum_inclusions, on="permutation")

# Convert permutation to binary columns
for index, col in enumerate(search_columns):
    result[col] = result["permutation"].str[index].astype(int)

result.drop(columns="permutation", inplace=True)

# Reorder columns
columns_to_move = ["count", "TI-AB_final_label", "FT_final_label"]
new_order = [
    col for col in result.columns if col not in columns_to_move
] + columns_to_move
result = result[new_order]

# Filter data
df_filtered = result[(result["count"] != 0) & (result["TI-AB_final_label"] != 0)].copy()

# Combination column
df_filtered["combination"] = (
    df_filtered[search_columns].astype(str).agg("-".join, axis=1)
)

# Aggregate data
df_relevant_filtered_agg = (
    df_filtered.groupby("combination")
    .agg(
        relevant_count=("TI-AB_final_label", "sum"),
        **{col: (col, "sum") for col in search_columns},
    )
    .reset_index()
)

df_relevant_filtered_agg.sort_values(by="relevant_count", ascending=False, inplace=True)

# Define conditions with corrections
conditions = [
    # Uniquely via Snowballing
    (df_relevant_filtered_agg["search_snowballing"] == 1)
    & (
        df_relevant_filtered_agg[
            [
                "search_replication",
                "search_comprehensive",
                "search_fulltext",
                "search_openalex_inlusion_criteria_long",
                "search_openalex_logistic_long",
                "search_openalex_all_abstracts_long",
            ]
        ].sum(axis=1)
        == 0
    ),
    # Unique via Old-School (Replication or Comprehensive)
    (
        df_relevant_filtered_agg[["search_replication", "search_comprehensive"]].sum(
            axis=1
        )
        > 0
    )
    & (
        df_relevant_filtered_agg[
            [
                "search_snowballing",
                "search_fulltext",
                "search_openalex_inlusion_criteria_long",
                "search_openalex_logistic_long",
                "search_openalex_all_abstracts_long",
            ]
        ].sum(axis=1)
        == 0
    ),
    # Unique via OpenAlex (Inclusion Criteria, Logistic, All Abstracts)
    (
        df_relevant_filtered_agg[
            [
                "search_openalex_inlusion_criteria_long",
                "search_openalex_logistic_long",
                "search_openalex_all_abstracts_long",
            ]
        ].sum(axis=1)
        > 0
    )
    & (
        df_relevant_filtered_agg[
            [
                "search_replication",
                "search_comprehensive",
                "search_snowballing",
                "search_fulltext",
            ]
        ].sum(axis=1)
        == 0
    ),
    # Always Found (All methods combined, excluding Fulltext)
    (df_relevant_filtered_agg["search_fulltext"] == 0)
    & (
        df_relevant_filtered_agg[
            [
                "search_replication",
                "search_comprehensive",
                "search_snowballing",
                "search_openalex_inlusion_criteria_long",
                "search_openalex_logistic_long",
                "search_openalex_all_abstracts_long",
            ]
        ].sum(axis=1)
        == len(search_columns) - 1
    ),
    # Unique via Fulltext
    (df_relevant_filtered_agg["search_fulltext"] == 1)
    & (
        df_relevant_filtered_agg[
            [
                "search_replication",
                "search_comprehensive",
                "search_snowballing",
                "search_openalex_inlusion_criteria_long",
                "search_openalex_logistic_long",
                "search_openalex_all_abstracts_long",
            ]
        ].sum(axis=1)
        == 0
    ),
]

# Assign colors
colors = ["#ff7f0e", "#8B4513", "#4682B4", "#2ca02c", "#ffcc00", "#d3d3d3"]
df_relevant_filtered_agg["color"] = np.select(
    conditions, colors[:-1], default=colors[-1]
)

# Plot the data
plt.figure(figsize=(14, 10))
bars = plt.bar(
    df_relevant_filtered_agg["combination"],
    df_relevant_filtered_agg["relevant_count"],
    color=df_relevant_filtered_agg["color"],
)

plt.title(
    "Overview of Included Records per Combination of Search Strategies", fontsize=14
)
plt.xlabel("Combination", fontsize=12)
plt.ylabel("Relevant Count", fontsize=12)
plt.xticks(rotation=90)

# Dynamic legend with k= and italicized k
condition_counts = df_relevant_filtered_agg.groupby("color")["relevant_count"].sum()
legend_labels = [
    f"{label} $k={int(condition_counts.get(color, 0))}$"
    for label, color in zip(
        [
            "Uniquely via Snowballing",
            "Unique via Old-School",
            "Unique via OpenAlex",
            "Always Found",
            "Unique via Fulltext",
            "Other Cases",
        ],
        colors,
    )
]
handles = [
    plt.Line2D([0], [0], marker="o", color="w", markerfacecolor=color, markersize=10)
    for color in colors
]
plt.legend(
    handles,
    legend_labels,
    title="Conditions",
    bbox_to_anchor=(1.05, 1),
    loc="upper left",
)

plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))

# Add bullet list
datasets = [
    "x  .  .  .  .  .  .  = Replication",
    ".  x  .  .  .  .  .  = Comprehensive",
    ".  .  x  .  .  .  .  = Snowballing",
    ".  .  .  x  .  .  .  = Fulltext",
    ".  .  .  .  x  .  .  = OpenAlex: Inclusion Criteria",
    ".  .  .  .  .  x  .  = OpenAlex: Logistic",
    ".  .  .  .  .  .  x  = OpenAlex: All Abstracts",
]
plt.text(
    0.77,
    0.04,
    "\n".join(f"{dataset}" for dataset in datasets),
    horizontalalignment="left",
    verticalalignment="bottom",
    transform=plt.gcf().transFigure,
    rotation=90,
    bbox=dict(facecolor="none" , edgecolor="none"),
)

plt.tight_layout(rect=[0, 0, 1, 0.9]) 

# save plot as pdf and PNG in /output folder
plt.savefig("output/search_strategies.pdf", bbox_inches="tight")
plt.savefig("output/search_strategies.png", bbox_inches="tight")

# Adjust layout
plt.tight_layout()
plt.show()

In [None]:
# Calculate total relevant count
total_relevant_count = df_relevant_filtered_agg["relevant_count"].sum()

# Calculate total number of records uniquely identified with only one search strategy
unique_single_strategy_count = df_relevant_filtered_agg[
    (df_relevant_filtered_agg[search_columns].sum(axis=1) == 1)
]["relevant_count"].sum()
unique_single_strategy_percentage = (
    unique_single_strategy_count / total_relevant_count
) * 100
print(
    f"Total number of records uniquely identified with only one search strategy: {unique_single_strategy_count}"
)
print(
    f"Percentage of records uniquely identified with only one search strategy: {unique_single_strategy_percentage:.2f}%"
)

# Calculate number and percentage of other cases
other_cases_count = df_relevant_filtered_agg[
    df_relevant_filtered_agg["color"] == "#d3d3d3"
]["relevant_count"].sum()
other_cases_percentage = (other_cases_count / total_relevant_count) * 100
print(f"Number of other cases: {other_cases_count}")
print(f"Percentage of other cases: {other_cases_percentage:.2f}%")

# Calculate total number of records minus other cases
minus_other_cases_count = total_relevant_count - other_cases_count
minus_other_cases_percentage = (minus_other_cases_count / total_relevant_count) * 100
print(f"Total number of records minus other cases: {minus_other_cases_count}")
print(f"Percentage of records minus other cases: {minus_other_cases_percentage:.2f}%")

## Labels

### Pie chart with Ti-Ab inclusions for Foras and Synergy

In [None]:
# Helper function to format autopct values with decimals and commas
def autopct_format(values):
    def inner_autopct(pct):
        total = sum(values)
        val = round(pct * total / 100.0)
        if val > 999:
            val_str = f"{val:,.0f}"  # Format with commas for large numbers
        else:
            val_str = f"{val}"
        return f"{val_str} ({pct:.1f}%)"

    return inner_autopct


# Function to create a pie chart with left-aligned text for included values
def plot_pie_chart(data, title, subplot_position, labels):
    plt.subplot(1, 3, subplot_position)
    wedges, texts, autotexts = plt.pie(
        data,
        labels=labels,
        colors=brazilian_colors,
        autopct=autopct_format(data),
        startangle=90,
    )
    plt.title(f"{title}\n $k_{{total}}= {data.sum():,}$", loc="center")

    # Adjust alignment for included values
    for i, (autotext, wedge) in enumerate(zip(autotexts, wedges)):
        if labels[i] == "Included":
            autotext.set_horizontalalignment("right")
        else:
            autotext.set_horizontalalignment("center")


# Brazilian flag colors
brazilian_colors = [
    "#FEDD00",  # yellow
    "#009739",  # green
]

# Data preparation
foras_tiab_records = foras_filtered["label_included_TIAB"].value_counts()
synergy_tiab_records = synergy["TI-AB-corrected"].value_counts()
combined_tiab_records = foras_tiab_records.add(synergy_tiab_records, fill_value=0)

# Plot setup
plt.figure(figsize=(18, 6))
plt.subplots_adjust(right=0.85)

# Plotting
plot_pie_chart(foras_tiab_records, "FORAS TI-AB", 1, ["Excluded", "Included"])
plot_pie_chart(synergy_tiab_records, "Synergy TI-AB", 2, ["Excluded", "Included"])
plot_pie_chart(combined_tiab_records, "Total TI-AB", 3, ["Excluded", "Included"])

# Legend
plt.legend(
    ["Excluded", "Included"], loc="center left", bbox_to_anchor=(1, 0.5), frameon=False
)

# save plot as pdf and PNG in /output folder
plt.savefig("output/tiab_pie_chart.pdf")
plt.savefig("output/tiab_pie_chart.png")

# Display the plot
plt.show()

### Pie chart with FT inclusions for Foras and Synergy

In [None]:
# Data preparation for FT labels
foras_ft_records = foras_filtered["label_included_FT"].value_counts()
synergy_ft_records = synergy["FT-corrected"].value_counts()
combined_ft_records = foras_ft_records.add(synergy_ft_records, fill_value=0)

# Plot setup for FT labels
plt.figure(figsize=(18, 6))
plt.subplots_adjust(right=0.85)

# Plotting for FT labels
plot_pie_chart(foras_ft_records, "FORAS FT", 1, labels=["Excluded", "Included"])
plot_pie_chart(synergy_ft_records, "Synergy FT", 2, labels=["Excluded", "Included"])
plot_pie_chart(combined_ft_records, "Total FT", 3, labels=["Excluded", "Included"])

# Legend
plt.legend(
    ["Excluded", "Included"], loc="center left", bbox_to_anchor=(1, 0.5), frameon=False
)

# save plot as pdf and PNG in /output folder
plt.savefig("output/FT_labels_pie_chart.pdf")
plt.savefig("output/FT_labels_pie_chart.png")

plt.show()

### Overview all labels

In [None]:
# Count occurrences of each column value for FORAS
ti_final_label_counts = foras_filtered["TI_final_label"].value_counts().get(1, 0)
ti_ab_IC1_counts = foras_filtered["TI-AB_IC1_joint"].value_counts()
ti_ab_IC2_counts = foras_filtered["TI-AB_IC2_joint"].value_counts()
ti_ab_IC3_counts = foras_filtered["TI-AB_IC3_joint"].value_counts()
ti_ab_IC4_counts = foras_filtered["TI-AB_IC4_joint"].value_counts()
ti_ab_final_label_counts = foras_filtered["TI-AB_final_label"].value_counts()
ti_ab_disagreement_counts = foras_filtered[
    "TI-AB_disagreement_human-human"
].value_counts()
ti_ab_disagreement_llm_counts = foras_filtered[
    "TI-AB_disagreement_human-LLM"
].value_counts()
title_eligible_bruno_counts = foras_filtered["title_eligible_Bruno"].value_counts()
title_eligible_rutger_counts = foras_filtered["title_eligible_Rutger"].value_counts()
ft_IC1_counts = foras_filtered["FT_IC1_joint"].value_counts()
ft_IC2_counts = foras_filtered["FT_IC2_joint"].value_counts()
ft_IC3_counts = foras_filtered["FT_IC3_joint"].value_counts()
ft_IC4_counts = foras_filtered["FT_IC4_joint"].value_counts()
ft_inclusion_bruno_counts = foras_filtered["FT_inclusion_Bruno"].value_counts()
ft_inclusion_rutger_counts = foras_filtered["FT_inclusion_Rutger"].value_counts()
ti_ab_final_label_counts_foras = foras_filtered["TI-AB_final_label"].value_counts()
ft_inclusion_counts_foras = foras_filtered["FT_final_label"].value_counts()
ft_inclusion_foras_count = ft_inclusion_counts_foras.get(1, 0)
ft_not_available = foras_filtered["full_text_available"].value_counts().get(0, 0)

# Count occurrences of duplicate and data extraction for FORAS and Synergy
duplicate_foras = fulltext_foras["cohort_duplicate"].value_counts().get(1, 0)
duplicate_synergy = fulltext_synergy["cohort_duplicate"].value_counts().get(1, 0)
data_extracted_foras = (
    fulltext_foras["used_for_Paper2_prevalences"].value_counts().get(1, 0)
)
data_extracted_synergy = (
    fulltext_synergy["used_for_Paper2_prevalences"].value_counts().get(1, 0)
)
clinical_data_foras = (
    fulltext_foras["used_for_paper3_clinical"].value_counts().get(1, 0)
)
clinical_data_synergy = (
    fulltext_synergy["used_for_paper3_clinical"].value_counts().get(1, 0)
)

# Count occurrences for Synergy
ti_ab_final_label_counts_synergy = synergy["TI-AB-corrected"].value_counts()
ft_inclusion_counts_synergy = synergy["FT-corrected"].value_counts()
ti_ab_original_counts = synergy["TI-AB_original"].value_counts()
ft_original_counts = synergy["FT_original_34"].value_counts()

# Define total counts for FT inclusion
ft_inclusion_synergy_count = ft_inclusion_counts_synergy.get(1, 0)
ft_inclusion_foras_count = ft_inclusion_counts_foras.get(1, 0)


# Calculate the total number of records in both Synergy and Foras datasets
total_synergy_records = len(synergy)
total_foras_records = len(foras_filtered)


# Adjusted x positions for better spacing of TI-AB IC and FT IC variables
x = np.array([0, 0.4, 1.1, 1.6, 1.85, 2.1, 2.35, 2.65, 3.1, 3.35, 3.6, 3.85, 4.1])

# Adjusted bar widths for TI-AB IC and FT IC variables
ti_ab_ic_bar_width = 0.15
ft_ic_bar_width = 0.15
other_bar_width = 0.3

# Create the bar chart
fig, ax = plt.subplots(figsize=(16, 8))

# Extract counts for 'Y', 'Q', and 'D' for IC columns
ti_ab_y_counts = [
    ti_ab_IC1_counts.get("Y", 0),
    ti_ab_IC2_counts.get("Y", 0),
    ti_ab_IC3_counts.get("Y", 0),
    ti_ab_IC4_counts.get("Y", 0),
]
ti_ab_q_counts = [
    ti_ab_IC1_counts.get("Q", 0),
    ti_ab_IC2_counts.get("Q", 0),
    ti_ab_IC3_counts.get("Q", 0),
    ti_ab_IC4_counts.get("Q", 0),
]
ti_ab_d_counts = [
    ti_ab_IC1_counts.get("D", 0),
    ti_ab_IC2_counts.get("D", 0),
    ti_ab_IC3_counts.get("D", 0),
    ti_ab_IC4_counts.get("D", 0),
]
ti_ab_llm_counts = [
    ti_ab_IC1_counts.get("D_llm", 0),
    ti_ab_IC2_counts.get("D_llm", 0),
    ti_ab_IC3_counts.get("D_llm", 0),
    ti_ab_IC4_counts.get("D_llm", 0),
]

# Extract counts for 'Y' 'N' and 'D' for FT IC columns
ft_y_counts = [
    ft_IC1_counts.get("Y", 0),
    ft_IC2_counts.get("Y", 0),
    ft_IC3_counts.get("Y", 0),
    ft_IC4_counts.get("Y", 0),
]
ft_nr_counts = [
    ft_IC1_counts.get("D", 0),
    ft_IC2_counts.get("D", 0),
    ft_IC3_counts.get("D", 0),
    ft_IC4_counts.get("D", 0),
]

colors_synergy = [
    "#002776",
    "#667AB0",
    "#002776",
    "#667AB0",
]  # Synergy: TI-AB blue and light blue
colors_foras = [
    "#009739",
    "#009739",
    "#FFCC00",
    "#66C38D",
    "#A1F60F",
]  # FORAS: Title dark green, TI-AB Y standard green, Q standard yellow, D light green, D_llm lighter green
colors_ft_ic = ["#009739", "#FFCC00"]  # FT ICs: Y darker green, NR darker yellow
colors_combined = [
    "#A0B9E6",
    "#009739",
]  # Combined totals: Synergy darker blue, Foras standard green


### SYNERGY ###

## TI-AB ##
# Synergy: Plot TI-AB original and corrected bars
bars_synergy_tiab_original = ax.bar(
    x[0],
    ti_ab_original_counts.get(1, 0),
    color=colors_synergy[0],
    width=other_bar_width,
)
bars_synergy_tiab_corrected = ax.bar(
    x[0],
    ti_ab_final_label_counts_synergy.get(1, 0) - ti_ab_original_counts.get(1, 0),
    bottom=ti_ab_original_counts.get(1, 0),
    color=colors_synergy[1],
    width=other_bar_width,
)

## FT ##
# Synergy: Plot FT original and corrected bars
bars_synergy_ft_original = ax.bar(
    x[1], ft_original_counts.get(1, 0), color=colors_synergy[2], width=other_bar_width
)
bars_synergy_ft_corrected = ax.bar(
    x[1],
    ft_inclusion_counts_synergy.get(1, 0) - ft_original_counts.get(1, 0),
    bottom=ft_original_counts.get(1, 0),
    color=colors_synergy[3],
    width=other_bar_width,
)

# Synergy: Extract total counts for label 1
ti_ab_total_synergy = ti_ab_final_label_counts_synergy.get(1, 0)
ft_total_synergy = ft_inclusion_counts_synergy.get(1, 0)

# Place text only for the total corrected value
ax.text(
    x[0], ti_ab_total_synergy + 0.05, f"{ti_ab_total_synergy}", ha="center", va="bottom"
)
ax.text(x[1], ft_total_synergy + 0.05, f"{ft_total_synergy}", ha="center", va="bottom")


### FORAS ###
## TITLE ##

# FORAS: Title bar
bars_title = ax.bar(
    x[2], height=ti_final_label_counts, color=colors_foras[0], width=other_bar_width
)

# Display the total count for the Title eligible bar
for bar in bars_title:
    ax.text(
        bar.get_x() + bar.get_width() / 2,
        bar.get_height() + 0.05,
        f"{bar.get_height()}",
        ha="center",
        va="bottom",
    )


## TI-AB ##

# FORAS: TI-AB IC bars with closer spacing and smaller width
bars_ti_ab_y = ax.bar(
    x[3:7] - ti_ab_ic_bar_width / 2,
    ti_ab_y_counts,
    color=colors_foras[1],
    width=ti_ab_ic_bar_width,
)
bars_ti_ab_q = ax.bar(
    x[3:7] - ti_ab_ic_bar_width / 2,
    ti_ab_q_counts,
    bottom=ti_ab_y_counts,
    color=colors_foras[2],
    width=ti_ab_ic_bar_width,
)
bottom_ti_ab_yq = [y + q for y, q in zip(ti_ab_y_counts, ti_ab_q_counts)]
bars_ti_ab_d = ax.bar(
    x[3:7] - ti_ab_ic_bar_width / 2,
    ti_ab_d_counts,
    bottom=bottom_ti_ab_yq,
    color=colors_foras[3],
    width=ti_ab_ic_bar_width,
)
bottom_ti_ab_yqd = [
    y + q + d for y, q, d in zip(ti_ab_y_counts, ti_ab_q_counts, ti_ab_d_counts)
]
bars_llm = ax.bar(
    x[3:7] - ti_ab_ic_bar_width / 2,
    ti_ab_llm_counts,
    bottom=bottom_ti_ab_yqd,
    color=colors_foras[4],
    width=ti_ab_ic_bar_width,
)


# FORAS: Display totals for the TI-AB IC bars (Y + Q + D + LLM)
def display_totals_ti_ab_ic(bars_y, bars_q, bars_d, bars_llm):
    for bar_y, bar_q, bar_d, bar_llm in zip(bars_y, bars_q, bars_d, bars_llm):
        total_height = (
            bar_y.get_height()
            + bar_q.get_height()
            + bar_d.get_height()
            + bar_llm.get_height()
        )
        ax.text(
            bar_llm.get_x() + bar_llm.get_width() / 2,
            total_height + 0.05,
            f"{int(total_height)}",
            ha="center",
            va="bottom",
        )


# FORAS: Call the function to display totals on TI-AB IC columns
display_totals_ti_ab_ic(bars_ti_ab_y, bars_ti_ab_q, bars_ti_ab_d, bars_llm)


# FORAS: TI-AB final bar with standard width
ti_ab_total_count = ti_ab_final_label_counts.get(1, 0)
ti_ab_disagreement_count = ti_ab_disagreement_counts.get(1, 0)
ti_ab_disagreement_llm_count = ti_ab_disagreement_llm_counts.get(1, 0)

# Adjust the final Y count to account for the not available full texts
ti_ab_final_y_count = (
    ti_ab_total_count
    - ti_ab_disagreement_count
    - ft_not_available
    - ti_ab_disagreement_llm_count
)

# Plot the bars with the adjusted bottom parameter
bars_ti_ab_final_y = ax.bar(
    x[7],
    ti_ab_final_y_count,
    color=colors_foras[1],
    width=other_bar_width,
    bottom=ft_not_available,
)
bars_llm_final = ax.bar(
    x[7],
    ti_ab_disagreement_llm_count,
    color=colors_foras[4],
    width=other_bar_width,
    bottom=ti_ab_final_y_count + ft_not_available,
)
bars_ti_ab_final_d = ax.bar(
    x[7],
    ti_ab_disagreement_count,
    color=colors_foras[3],
    width=other_bar_width,
    bottom=ti_ab_final_y_count + ft_not_available + ti_ab_disagreement_llm_count,
)

# Add the FT not available bar at the bottom
bars_ft_not_available = ax.bar(
    x[7], ft_not_available, color=colors_ft_ic[1], width=other_bar_width
)

# Display the total count for the TI-AB final bar
total_ti_ab_height = (
    ti_ab_final_y_count
    + ti_ab_disagreement_count
    + ft_not_available
    + ti_ab_disagreement_llm_count
)
ax.text(
    bars_ti_ab_final_y[0].get_x() + bars_ti_ab_final_y[0].get_width() / 2,
    total_ti_ab_height + 0.1,
    f"{total_ti_ab_height}",
    ha="center",
    va="bottom",
)


## FT ##

# FORAS: FT IC bars with closer spacing and smaller width
bars_ft_ic_y = ax.bar(
    x[8:12] - ft_ic_bar_width / 2,
    ft_y_counts,
    color=colors_ft_ic[0],
    width=ft_ic_bar_width,
)  # 'Y' bars
bars_ft_ic_d = ax.bar(
    x[8:12] - ft_ic_bar_width / 2,
    ft_nr_counts,
    bottom=ft_y_counts,
    color=colors_foras[3],
    width=ft_ic_bar_width,
)  # 'NR' bars


# FORAS:  FT inclusion and combined bars with standard width
bars_ft_foras = ax.bar(
    x[12],
    ft_inclusion_counts_foras.get(1, 0),
    color=colors_combined[1],
    width=other_bar_width,
)


# Display totals for the FT IC bars (Y + NR only)
def display_totals_ft_ic(bars_y, bars_nr):
    for bar_y, bar_nr in zip(bars_y, bars_nr):
        total_height = bar_y.get_height() + bar_nr.get_height()
        ax.text(
            bar_nr.get_x() + bar_nr.get_width() / 2,
            total_height + 0.05,
            f"{int(total_height)}",
            ha="center",
            va="bottom",
        )


display_totals_ft_ic(bars_ft_ic_y, bars_ft_ic_d)

# Display the total count for the FT inclusion bar for foras
for bar in bars_ft_foras:
    ax.text(
        bar.get_x() + bar.get_width() / 2,
        bar.get_height() + 0.05,
        f"{bar.get_height()}",
        ha="center",
        va="bottom",
    )


## Add totals next to bars


def add_totals_next_to_bars(
    bars, counts, bar_width, font_size=7, jitter=0.0, vertical_jitter=0.0, color="black"
):
    for bar, count in zip(bars, counts):
        text_x = bar.get_x() + bar.get_width() + 0.02 + jitter
        text_y = (
            bar.get_y() + bar.get_height() / 2 + (bar.get_height() * vertical_jitter)
        )

        # Add the text
        ax.text(
            text_x,
            text_y,
            f"{count}",
            ha="left",
            va="center",
            fontsize=font_size,
            color=color,
            rotation=90,
            weight="bold",
        )

        # Add a thin line if jitter is used
        if jitter != 0.0 or vertical_jitter != 0.0:
            ax.plot(
                [bar.get_x() + bar.get_width(), text_x],
                [bar.get_y() + bar.get_height() / 2, text_y],
                color=color,
                linewidth=0.5,
            )


# Add totals next to Synergy TI-AB bars
add_totals_next_to_bars(
    bars_synergy_tiab_original,
    [ti_ab_original_counts.get(1, 0)],
    other_bar_width,
    jitter=0.0,
    vertical_jitter=0,
    color=colors_synergy[0],
)
add_totals_next_to_bars(
    bars_synergy_tiab_corrected,
    [ti_ab_final_label_counts_synergy.get(1, 0) - ti_ab_original_counts.get(1, 0)],
    other_bar_width,
    jitter=0,
    vertical_jitter=7,
    color=colors_synergy[1],
)

# Add totals next to Synergy FT bars
add_totals_next_to_bars(
    bars_synergy_ft_original,
    [ft_original_counts.get(1, 0)],
    other_bar_width,
    jitter=0.0,
    vertical_jitter=2,
    color=colors_synergy[2],
)
add_totals_next_to_bars(
    bars_synergy_ft_corrected,
    [ft_inclusion_counts_synergy.get(1, 0) - ft_original_counts.get(1, 0)],
    other_bar_width,
    jitter=0,
    vertical_jitter=15,
    color=colors_synergy[3],
)

# Add totals next to FORAS TI-AB IC bars
add_totals_next_to_bars(
    [bars_ti_ab_y[0]],
    [ti_ab_y_counts[0]],
    ti_ab_ic_bar_width,
    jitter=0,
    vertical_jitter=0.0,
    color=colors_foras[1],
)
add_totals_next_to_bars(
    [bars_ti_ab_y[1]],
    [ti_ab_y_counts[1]],
    ti_ab_ic_bar_width,
    jitter=0,
    vertical_jitter=0.0,
    color=colors_foras[1],
)
add_totals_next_to_bars(
    [bars_ti_ab_y[2]],
    [ti_ab_y_counts[2]],
    ti_ab_ic_bar_width,
    jitter=0,
    vertical_jitter=0.0,
    color=colors_foras[1],
)
add_totals_next_to_bars(
    [bars_ti_ab_y[3]],
    [ti_ab_y_counts[3]],
    ti_ab_ic_bar_width,
    jitter=0,
    vertical_jitter=0.5,
    color=colors_foras[1],
)

add_totals_next_to_bars(
    [bars_ti_ab_q[0]],
    [ti_ab_q_counts[0]],
    ti_ab_ic_bar_width,
    jitter=0,
    vertical_jitter=0,
    color=colors_foras[2],
)
add_totals_next_to_bars(
    [bars_ti_ab_q[1]],
    [ti_ab_q_counts[1]],
    ti_ab_ic_bar_width,
    jitter=0,
    vertical_jitter=0.0,
    color=colors_foras[2],
)
add_totals_next_to_bars(
    [bars_ti_ab_q[2]],
    [ti_ab_q_counts[2]],
    ti_ab_ic_bar_width,
    jitter=0,
    vertical_jitter=0.0,
    color=colors_foras[2],
)
add_totals_next_to_bars(
    [bars_ti_ab_q[3]],
    [ti_ab_q_counts[3]],
    ti_ab_ic_bar_width,
    jitter=0,
    vertical_jitter=1.5,
    color=colors_foras[2],
)

add_totals_next_to_bars(
    [bars_ti_ab_d[0]],
    [ti_ab_d_counts[0]],
    ti_ab_ic_bar_width,
    jitter=0,
    vertical_jitter=0,
    color=colors_foras[3],
)
add_totals_next_to_bars(
    [bars_ti_ab_d[1]],
    [ti_ab_d_counts[1]],
    ti_ab_ic_bar_width,
    jitter=0,
    vertical_jitter=0.0,
    color=colors_foras[3],
)
add_totals_next_to_bars(
    [bars_ti_ab_d[2]],
    [ti_ab_d_counts[2]],
    ti_ab_ic_bar_width,
    jitter=0,
    vertical_jitter=0.0,
    color=colors_foras[3],
)
add_totals_next_to_bars(
    [bars_ti_ab_d[3]],
    [ti_ab_d_counts[3]],
    ti_ab_ic_bar_width,
    jitter=0,
    vertical_jitter=3,
    color=colors_foras[3],
)


add_totals_next_to_bars(
    [bars_llm[0]],
    [ti_ab_llm_counts[0]],
    ti_ab_ic_bar_width,
    jitter=0.02,
    vertical_jitter=1.5,
    color=colors_foras[4],
)
add_totals_next_to_bars(
    [bars_llm[1]],
    [ti_ab_llm_counts[1]],
    ti_ab_ic_bar_width,
    jitter=0.02,
    vertical_jitter=45,
    color=colors_foras[4],
)
add_totals_next_to_bars(
    [bars_llm[2]],
    [ti_ab_llm_counts[2]],
    ti_ab_ic_bar_width,
    jitter=0.02,
    vertical_jitter=1,
    color=colors_foras[4],
)
add_totals_next_to_bars(
    [bars_llm[3]],
    [ti_ab_llm_counts[3]],
    ti_ab_ic_bar_width,
    jitter=0,
    vertical_jitter=6,
    color=colors_foras[4],
)


# Add totals next to FORAS TI-AB final bars
add_totals_next_to_bars(
    bars_ti_ab_final_y,
    [ti_ab_final_y_count],
    other_bar_width,
    jitter=0.0,
    vertical_jitter=1,
    color=colors_foras[1],
)
add_totals_next_to_bars(
    bars_ti_ab_final_d,
    [ti_ab_disagreement_count],
    other_bar_width,
    jitter=0.0,
    vertical_jitter=2.5,
    color=colors_foras[3],
)
add_totals_next_to_bars(
    bars_ft_not_available,
    [ft_not_available],
    other_bar_width,
    jitter=0.0,
    vertical_jitter=3,
    color=colors_ft_ic[1],
)
add_totals_next_to_bars(
    bars_llm_final,
    [ti_ab_disagreement_llm_count],
    other_bar_width,
    jitter=0.0,
    vertical_jitter=90,
    color=colors_foras[4],
)


# Add totals next to FORAS FT IC bars
add_totals_next_to_bars(
    [bars_ft_ic_y[0]],
    [ft_y_counts[0]],
    ft_ic_bar_width,
    jitter=0,
    vertical_jitter=0.0,
    color=colors_ft_ic[0],
)
add_totals_next_to_bars(
    [bars_ft_ic_y[1]],
    [ft_y_counts[1]],
    ft_ic_bar_width,
    jitter=0,
    vertical_jitter=0.0,
    color=colors_ft_ic[0],
)
add_totals_next_to_bars(
    [bars_ft_ic_y[2]],
    [ft_y_counts[2]],
    ft_ic_bar_width,
    jitter=0,
    vertical_jitter=0.0,
    color=colors_ft_ic[0],
)
add_totals_next_to_bars(
    [bars_ft_ic_y[3]],
    [ft_y_counts[3]],
    ft_ic_bar_width,
    jitter=0,
    vertical_jitter=0.5,
    color=colors_ft_ic[0],
)

add_totals_next_to_bars(
    [bars_ft_ic_d[0]],
    [ft_nr_counts[0]],
    ft_ic_bar_width,
    jitter=0,
    vertical_jitter=30,
    color=colors_foras[3],
)
add_totals_next_to_bars(
    [bars_ft_ic_d[1]],
    [ft_nr_counts[1]],
    ft_ic_bar_width,
    jitter=0,
    vertical_jitter=15,
    color=colors_foras[3],
)
add_totals_next_to_bars(
    [bars_ft_ic_d[2]],
    [ft_nr_counts[2]],
    ft_ic_bar_width,
    jitter=0,
    vertical_jitter=15,
    color=colors_foras[3],
)
add_totals_next_to_bars(
    [bars_ft_ic_d[3]],
    [ft_nr_counts[3]],
    ft_ic_bar_width,
    jitter=0,
    vertical_jitter=10,
    color=colors_foras[3],
)

# Set the labels and title
ax.set_ylabel("Count")
ax.set_title(
    rf"Overview of Included Records: $\mathit{{k}}_{{\text{{grand - total}}}}$ = {total_synergy_records + total_foras_records:,}",
    fontsize=14,
)

# Set x-axis labels
labels = [
    "TI-AB Eligible",
    "FT Eligible",
    "Title Eligible",
    "TI-AB IC1",
    "TI-AB IC2",
    "TI-AB IC3",
    "TI-AB IC4",
    "TI-AB Eligible",
    "FT IC1",
    "FT IC2",
    "FT IC3",
    "FT IC4",
    "FT Eligible",
]
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=45, ha="right")

# Add vertical lines for segmentation
ax.axvline(x=0.75, color="grey", linestyle="--")

# Add text boxes at the top of the figure
fig.text(
    0.12,
    0.8,
    rf" Synergy: $\mathit{{k}}_{{\text{{total}}}}$={total_synergy_records:,}",
    ha="center",
    va="top",
    bbox=dict(facecolor="white", alpha=0.5),
)
fig.text(
    0.37,
    0.8,
    rf"Foras: $\mathit{{k}}_{{\text{{total}}}}$={total_foras_records:,}",
    ha="center",
    va="top",
    bbox=dict(facecolor="white", alpha=0.5),
)

# Add legend
legend_labels = [
    "Synergy: - Initial",
    "              - Added",
    "FORAS:  - TI / AB / FT inclusion",
    "             - Questionable (IC 1...4) or no full text available (TI-AB Eligible)",
    "             - Human - Human Disagreement",
    "             - LLM - Human Disagreement",
    "             - Full text inclusion",
    "Data Extraction: - Synergy",
    "                           - FORAS",
]
handles = [
    plt.Rectangle((0, 0), 1, 1, color=color)
    for color in [
        colors_synergy[0],
        colors_synergy[1],
        colors_foras[1],
        colors_foras[2],
        colors_foras[3],
        colors_foras[4],
        colors_ft_ic[0],
        colors_combined[0],
        colors_combined[1],
    ]
]
ax.legend(
    handles, legend_labels, title="Legend", bbox_to_anchor=(1.05, 1), loc="upper left"
)

plt.tight_layout(rect=[0, 0, 1, 0.9]) 

# save plot as pdf and PNG in /output folder
plt.savefig("output/overview_included_records.pdf", bbox_inches="tight")
plt.savefig("output/overview_included_records.png", bbox_inches="tight")

# Display the chart
plt.tight_layout()
plt.show()

## Interaction Search x Labels

### Stacked bar chart

In [None]:
# Search columns
search_columns = [
    "search_replication",
    "search_comprehensive",
    "search_snowballing",
    "search_fulltext",
    "search_openalex_inlusion_criteria_long",
    "search_openalex_logistic_long",
    "search_openalex_all_abstracts_long",
]

# Shorter names for plotting
short_names = {
    "search_replication": "Replication",
    "search_comprehensive": "Comprehensive",
    "search_snowballing": "Snowballing",
    "search_fulltext": "Fulltext",
    "search_openalex_inlusion_criteria_long": "OpenAlex: Inclusion criteria",
    "search_openalex_logistic_long": "OpenAlex: Logistic",
    "search_openalex_all_abstracts_long": "OpenAlex: All abstracts",
}

# Calculate the number of additional records found in each search phase
search_counts = []
previous_records = set()

# Track unique TI-AB inclusions for each search phase
previous_ti_ab_records = set()
ti_ab_final_label_1 = foras_filtered[foras_filtered["TI-AB_final_label"] == 1]
ti_ab_inclusion_counts = []

for col in search_columns:
    current_records = set(foras_filtered[foras_filtered[col] == 1].index)
    new_records = current_records - previous_records

    # Get unique TI-AB inclusions for this search method
    current_ti_ab_records = set(
        ti_ab_final_label_1[ti_ab_final_label_1[col] == 1].index
    )
    new_ti_ab_records = current_ti_ab_records - previous_ti_ab_records

    # Update search counts and TI-AB inclusion counts
    search_counts.append(len(new_records))
    ti_ab_inclusion_counts.append(len(new_ti_ab_records))

    # Update previous records and TI-AB records
    previous_records.update(new_records)
    previous_ti_ab_records.update(new_ti_ab_records)

# Calculate the grand totals
grand_total = sum(search_counts)
grand_total_ti_ab = sum(ti_ab_inclusion_counts)

# Create the bar chart
fig, ax = plt.subplots(figsize=(12, 8))

# Plot the stacked bars
bottom = 0
for i, count in enumerate(search_counts):
    ax.bar("Records Found", count, bottom=bottom, color=colors_brazil[i], width=0.2)
    bottom += count

# Set the color for the dots (yellow)
dot_color = "#FFCC00"

# Add randomly distributed dots and display counts (k_total and k_TI-AB)
# Initialize the bottom values to track where each bar starts
bottoms = np.zeros(len(search_columns))

# Calculate bottom positions for each bar
for i, count in enumerate(search_counts):
    if i > 0:
        bottoms[i] = bottoms[i - 1] + search_counts[i - 1]

# Move the text closer to the bars, but still outside the stacked bars

buffer = 0.1  # Buffer to avoid placing dots on the boundaries
bottom = 0
for i, (count, ti_ab_count) in enumerate(zip(search_counts, ti_ab_inclusion_counts)):
    # Generate random y-positions for the dots
    if ti_ab_count > 0:
        # Adjust the low and high to add a buffer
        dot_y_positions = np.random.uniform(
            low=bottoms[i] + buffer * count,  # Start slightly above the bottom
            high=bottoms[i] + count - buffer * count,  # End slightly below the top
            size=ti_ab_count,
        )
        for y in dot_y_positions:
            # Add random horizontal jitter for better visual spread
            x_jitter = np.random.uniform(-0.09, 0.09)
            ax.plot(x_jitter, y, "o", color=dot_color, markersize=2)

    # Adjust the x position to place the text closer to the bar (within the chart)
    label = f"{short_names[search_columns[i]]} ($k_{{total}} = {count}$, $k_{{TIAB}} = {ti_ab_count}$)"
    ax.text(
        0.11,
        bottom + count / 2,
        label,
        ha="left",
        va="center",
        fontsize=14,
        color="black",
        weight="bold",
    )

    # Update the bottom for the next stack
    bottom += count

# Set the labels and title
ax.set_ylabel("Number of Records")

# Add the total counts (k_grand-total and k_TI-AB-total) in the header
ax.set_title(
    f"Additional number of records after adding yet another search strategy\n"
    f"($k_{{grand-total}} = {grand_total}$, $k_{{TIAB-total}} = {grand_total_ti_ab}$)",
    fontsize=14,
)

# Remove x-axis ticks and labels
ax.set_xticks([])
ax.set_xticklabels([])

# save plot as pdf and PNG in /output folder
plt.savefig("output/records_found_per_search_strategy.pdf", bbox_inches="tight")
plt.savefig("output/records_found_per_search_strategy.png", bbox_inches="tight")

# Display the chart
plt.show()

### Venn diagram

In [None]:
# Filter records with FT_inclusion == 1
filtered_m_copy = foras_filtered[foras_filtered["FT_final_label"] == 1].copy()

# Update 'old-school' variable: 1 if either search_replication or search_comprehensive is 1, else 0
filtered_m_copy["old-school"] = filtered_m_copy[
    ["search_replication", "search_comprehensive"]
].apply(lambda row: int(row.sum() > 0), axis=1)

# Update 'new-school' variable: 1 if any of search_openalex inlusion_criteria, logistic, or all abstracts is 1, else 0
filtered_m_copy["new-school"] = filtered_m_copy[
    [
        "search_openalex_inlusion_criteria_long",
        "search_openalex_all_abstracts_long",
        "search_openalex_logistic_long",
    ]
].apply(lambda row: int(row.sum() > 0), axis=1)

# Calculate counts for each Venn region
only_old_school = filtered_m_copy[
    (filtered_m_copy["old-school"] == 1)
    & (filtered_m_copy["new-school"] == 0)
    & (filtered_m_copy["search_snowballing"] == 0)
].shape[0]

only_new_school = filtered_m_copy[
    (filtered_m_copy["old-school"] == 0)
    & (filtered_m_copy["new-school"] == 1)
    & (filtered_m_copy["search_snowballing"] == 0)
].shape[0]

only_snowballing = filtered_m_copy[
    (filtered_m_copy["old-school"] == 0)
    & (filtered_m_copy["new-school"] == 0)
    & (filtered_m_copy["search_snowballing"] == 1)
].shape[0]

old_and_new = filtered_m_copy[
    (filtered_m_copy["old-school"] == 1)
    & (filtered_m_copy["new-school"] == 1)
    & (filtered_m_copy["search_snowballing"] == 0)
].shape[0]

old_and_snowballing = filtered_m_copy[
    (filtered_m_copy["old-school"] == 1)
    & (filtered_m_copy["new-school"] == 0)
    & (filtered_m_copy["search_snowballing"] == 1)
].shape[0]

new_and_snowballing = filtered_m_copy[
    (filtered_m_copy["old-school"] == 0)
    & (filtered_m_copy["new-school"] == 1)
    & (filtered_m_copy["search_snowballing"] == 1)
].shape[0]

all_three = filtered_m_copy[
    (filtered_m_copy["old-school"] == 1)
    & (filtered_m_copy["new-school"] == 1)
    & (filtered_m_copy["search_snowballing"] == 1)
].shape[0]

# Recalculate the sizes of the three main categories
size_old_school = filtered_m_copy["old-school"].sum()
size_new_school = filtered_m_copy["new-school"].sum()
size_snowballing = filtered_m_copy["search_snowballing"].sum()

# Print results
print("Only old-school:", only_old_school)
print("Only new-school:", only_new_school)
print("Only snowballing:", only_snowballing)
print("Old-school and new-school:", old_and_new)
print("Old-school and snowballing:", old_and_snowballing)
print("New-school and snowballing:", new_and_snowballing)
print("All three:", all_three)
print("Size of old-school:", size_old_school)
print("Size of new-school:", size_new_school)
print("Size of snowballing:", size_snowballing)


# Calculate the sizes of the three main categories
size_old_school = filtered_m_copy["old-school"].sum()
size_new_school = filtered_m_copy["new-school"].sum()
size_snowballing = filtered_m_copy["search_snowballing"].sum()


# Function to add dots efficiently using direct path checking
def add_dots_efficiently(ax, patch, num_dots, color, markersize, margin=0.2):
    """
    Adds dots within a given Venn diagram patch.
    """
    if patch is None or num_dots == 0:
        return

    # Get the vertices of the patch
    vertices = patch.get_path().vertices
    codes = patch.get_path().codes
    path = mpath.Path(vertices, codes)

    # Calculate the bounding box of the patch
    min_x, min_y = np.min(vertices, axis=0)
    max_x, max_y = np.max(vertices, axis=0)

    # Apply margin to the bounding box
    x_range = max_x - min_x
    y_range = max_y - min_y
    min_x += margin * x_range
    max_x -= margin * x_range
    min_y += margin * y_range
    max_y -= margin * y_range

    # Generate points within the adjusted bounding box and check if they are inside the path
    dots_added = 0
    attempts = 0
    while dots_added < num_dots and attempts < num_dots * 10:
        rand_x = np.random.uniform(min_x, max_x)
        rand_y = np.random.uniform(min_y, max_y)
        rand_point = np.array([rand_x, rand_y])

        if path.contains_point(rand_point):
            ax.plot(rand_x, rand_y, "o", color=color, markersize=markersize)
            dots_added += 1
        attempts += 1


# Create the Venn diagram with correctly sized circles and custom lighter colors
venn = venn3(
    subsets=(
        size_old_school,
        size_new_school,
        old_and_new,
        size_snowballing,
        old_and_snowballing,
        new_and_snowballing,
        all_three,
    ),
    set_labels=(
        f"Old-school ($k_{{total}}={size_old_school}$)",
        f"New-school ($k_{{total}}={size_new_school}$)",
        f"Snowballing ($k_{{total}}={int(size_snowballing)}$)",
    ),
    set_colors=("#009739", "#FFCC00", "#002776"),
    subset_label_formatter=None,
    alpha=1.0,
)

# Manually set the labels for each region
venn.get_label_by_id("100").set_text(f"{only_old_school}")
venn.get_label_by_id("010").set_text(f"{only_new_school}")
venn.get_label_by_id("001").set_text(f"{only_snowballing}")
venn.get_label_by_id("110").set_text(f"{old_and_new}")
venn.get_label_by_id("101").set_text(f"{old_and_snowballing}")
venn.get_label_by_id("011").set_text(f"{new_and_snowballing}")
venn.get_label_by_id("111").set_text(f"{all_three}")

# Get the axes object
ax = plt.gca()

# Add dots to specific regions if they exist
add_dots_efficiently(
    ax, venn.get_patch_by_id("100"), only_old_school - 1, "#FEDD00", markersize=1
)  # Add one fewer yellow dot
add_dots_efficiently(
    ax, venn.get_patch_by_id("010"), only_new_school, "#FFFFFF", markersize=4
)
add_dots_efficiently(
    ax, venn.get_patch_by_id("001"), only_snowballing, "#FFFFFF", markersize=4
)
add_dots_efficiently(
    ax, venn.get_patch_by_id("110"), old_and_new, "#FEDD00", markersize=1
)
add_dots_efficiently(
    ax, venn.get_patch_by_id("101"), old_and_snowballing, "#FEDD00", markersize=1
)
add_dots_efficiently(
    ax, venn.get_patch_by_id("011"), new_and_snowballing, "#FFFFFF", markersize=4
)
add_dots_efficiently(
    ax, venn.get_patch_by_id("111"), all_three, "#FEDD00", markersize=1
)

# Manually add one white dot in the only_old_school region
# Choose specific coordinates within the patch for simplicity
ax.plot(-0.2, 0.35, "o", color="#FFFFFF", markersize=4)

# Define the position and size of the new fulltext circle
fulltext_circle = patches.Circle(
    (-0.4, -0.4), 0.1, color="#4CAF50", alpha=0.5
)  # Use a different green for fulltext
ax.add_patch(fulltext_circle)

# Add a label for the fulltext circle
ax.text(
    -0.4,
    -0.55,
    f'Fulltext ($k_{{total}}={filtered_m_copy["search_fulltext"].sum()}$)',
    ha="center",
    va="center",
    fontsize=12,
    color="black",
)

# Add a single dot within the fulltext circle
ax.plot(
    -0.38, -0.42, "o", color="#FFFFFF", markersize=4
)  # Place the dot at the center of the fulltext circle


# save plot as pdf and PNG in /output folder
plt.savefig("output/venn_diagram.pdf")
plt.savefig("output/venn_diagram.png")

# Display the Venn diagram with dots
plt.show()