# Basics

## Imports

In [119]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd

## Load Data

In [None]:
# Load the data
input_file_foras = list(Path(os.path.join("data")).glob("PTSS_Data_Foras.xlsx"))[0]
input_file_synergy = list(Path(os.path.join("data")).glob("PTSS_Data_Synergy.xlsx"))[0]
fulltext_foras = list(Path(os.path.join("data")).glob("PTSS_Data_Foras_Fulltext.xlsx"))[
    0
]
fulltext_synergy = list(
    Path(os.path.join("data")).glob("PTSS_Data_Synergy_Fulltext.xlsx")
)[0]

# Print the file names
print(
    "Results based on file: ",
    input_file_foras,
    input_file_synergy,
    fulltext_foras,
    fulltext_synergy,
)

# Read the foras file and filter out the duplicates
foras_unfiltered = pd.read_excel(input_file_foras)
foras_filtered = foras_unfiltered[foras_unfiltered["filter_duplicate"] != 1]

# Read the other files
synergy = pd.read_excel(input_file_synergy)
fulltext_foras = pd.read_excel(fulltext_foras)
fulltext_synergy = pd.read_excel(fulltext_synergy)

# Print the number of rows in each file
print("Number of records in original FORAS file: ", foras_unfiltered.shape[0])
print(
    "Number of records in FORAS after filtering duplicates: ", foras_filtered.shape[0]
)
print("Number of records in SYNERGY", synergy.shape[0])
print("Number of records in FORAS fulltext", fulltext_foras.shape[0])
print("Number of records in SYNERGY fulltext", fulltext_synergy.shape[0])

## Create sets with variables

In [121]:
# Shorter names for plotting
short_names = {
    "search_replication": "Replication",
    "search_comprehensive": "Comprehensive",
    "search_snowballing": "Snowballing",
    "search_fulltext": "Fulltext",
    "search_openalex_inlusion_criteria": "OpenAlex-short: Inclusion Criteria",
    "search_openalex_inlusion_criteria_long": "OpenAlex: Inclusion Criteria",
    "search_openalex_logistic": "OpenAlex-short: Logistic",
    "search_openalex_logistic_long": "OpenAlex: Logistic",
    "search_openalex_all_abstracts": "OpenAlex-short: All Abstracts",
    "search_openalex_all_abstracts_long": "OpenAlex: All Abstracts",
}

## Tests

### Test if MID is unique

In [None]:
# Calculate the number of records without an 'MID' value for both datasets
num_records_without_mid_foras = foras_filtered["MID"].isnull().sum()
num_records_without_mid_synergy = synergy["MID"].isnull().sum()

# Calculate the number of duplicate IDs for both datasets
num_duplicate_ids_foras = len(foras_filtered["MID"]) - foras_filtered["MID"].nunique()
num_duplicate_ids_synergy = len(synergy["MID"]) - synergy["MID"].nunique()

# Test for Foras dataset
try:
    # Check if there are no records without an identifier in the 'MID' column for Foras
    assert (
        foras_filtered["MID"].notnull().all()
    ), f"Foras test failed: There are {num_records_without_mid_foras} records without an identifier in the 'MID' column."

    # Check if the identifiers in the 'MID' column are unique for Foras
    assert (
        foras_filtered["MID"].nunique() == len(foras_filtered["MID"])
    ), f"Foras test failed: There are {num_duplicate_ids_foras} duplicate identifiers in the 'MID' column."

    # If the test passes for Foras, print the following
    print(
        "Foras test passed: 'MID' column contains no records without an identifier and all identifiers are unique."
    )
except AssertionError as e:
    print(e)

# Test for Synergy dataset
try:
    # Check if there are no records without an identifier in the 'MID' column for Synergy
    assert (
        synergy["MID"].notnull().all()
    ), f"Synergy test failed: There are {num_records_without_mid_synergy} records without an identifier in the 'MID' column."

    # Check if the identifiers in the 'MID' column are unique for Synergy
    assert (
        synergy["MID"].nunique() == len(synergy["MID"])
    ), f"Synergy test failed: There are {num_duplicate_ids_synergy} duplicate identifiers in the 'MID' column."

    # If the test passes for Synergy, print the following
    print(
        "Synergy test passed: 'MID' column contains no records without an identifier and all identifiers are unique."
    )
except AssertionError as e:
    print(e)

### Test if PIDs, Titles and Abstracts are available

In [None]:
# Function to calculate missing data and plot for a given dataset
def analyze_missing_data(dataset, dataset_name):
    # Calculate the number of records with missing values for specified columns
    num_records_without_doi = dataset["doi"].isnull().sum()
    num_records_without_openalex_id = dataset["openalex_id"].isnull().sum()
    num_records_without_both = dataset[
        dataset["doi"].isnull() & dataset["openalex_id"].isnull()
    ].shape[0]
    num_records_without_title = dataset["title"].isnull().sum()
    num_records_without_abstract = dataset["abstract"].isnull().sum()

    # Print the number of records without certain values
    print(
        f"{dataset_name} - Number of records without a DOI: {num_records_without_doi}"
    )
    print(
        f"{dataset_name} - Number of records without an OpenAlex ID: {num_records_without_openalex_id}"
    )
    print(
        f"{dataset_name} - Number of records without a Title: {num_records_without_title}"
    )
    print(
        f"{dataset_name} - Number of records without an Abstract: {num_records_without_abstract}"
    )

    # Data for plotting
    categories = [
        "DOI Missing",
        "OpenAlex ID Missing",
        "Both Missing",
        "Title Missing",
        "Abstract Missing",
    ]
    values = [
        num_records_without_doi,
        num_records_without_openalex_id,
        num_records_without_both,
        num_records_without_title,
        num_records_without_abstract,
    ]
    total_records = dataset.shape[0]
    percentages = [(value / total_records) * 100 for value in values]

    # Creating the bar chart with percentages
    plt.figure(figsize=(10, 6))
    bars = plt.bar(
        categories, percentages, color=["blue", "orange", "green", "purple", "pink"]
    )

    # Adding title and labels for visualization
    plt.title(f"Percentage of Missing Data in Each Category ({dataset_name})")
    plt.xlabel("Missing Data Category")
    plt.ylabel("Percentage of Total Records")

    # Annotate each bar with its absolute value
    for bar, value in zip(bars, values):
        height = bar.get_height()
        plt.text(
            bar.get_x() + bar.get_width() / 2.0,
            height,
            f"{value}",
            ha="center",
            va="bottom",
        )

    # Display the chart
    plt.show()


# Call the function for both datasets
analyze_missing_data(foras_filtered, "Foras")
analyze_missing_data(synergy, "Synergy")

### Test if search columns are correct

In [None]:
# search columns
search_columns = [
    "search_replication",
    "search_comprehensive",
    "search_snowballing",
    "search_fulltext",
    "search_openalex_inlusion_criteria",
    "search_openalex_inlusion_criteria_long",
    "search_openalex_logistic",
    "search_openalex_logistic_long",
    "search_openalex_all_abstracts",
    "search_openalex_all_abstracts_long",
]

# Check if all records in search columns contain only 0 or 1
try:
    invalid_search_values = foras_unfiltered.loc[
        ~foras_unfiltered[search_columns].isin([0, 1]).all(axis=1), "MID"
    ]
    assert (
        invalid_search_values.empty
    ), f"Invalid values in search columns for MIDs: {invalid_search_values.tolist()}"

    print("All values in search columns are 0 or 1.")

except AssertionError as e:
    print(e)

# # Check if each record has at least one '1' in search columns
try:
    records_with_no_one = foras_unfiltered.loc[
        (foras_unfiltered[search_columns].sum(axis=1) == 0), "MID"
    ]
    assert (
        records_with_no_one.empty
    ), f"Records with no '1' in search columns for MIDs: {records_with_no_one.tolist()}"

    print("All records in the search columns have at least one '1'.")

except AssertionError as e:
    print(e)

### Check for valid values in eligibility variables

In [None]:
# Specify the column names to use for eligibility determination

eligibility_columns_YNQ = [
    # TI-AB level Bruno
    "title_eligible_Bruno",
    "TI-AB_IC1_Bruno",
    "TI-AB_IC2_Bruno",
    "TI-AB_IC3_Bruno",
    "TI-AB_IC4_Bruno",
    "TI-AB_final_label_Bruno",
    # TI-AB level Rutger
    "title_eligible_Rutger",
    "TI-AB_final_label_Rutger",
    # final TI-AB IC labels
    "TI-AB_IC1_final",
    "TI-AB_IC2_final",
    "TI-AB_IC3_final",
    "TI-AB_IC4_final",
]

for col in eligibility_columns_YNQ:
    foras_unfiltered[col] = foras_unfiltered[col].astype("category")

eligibility_columns_TU = [
    # LLM output in True/Unknown
    "TI-AB_IC1_LLM",
    "TI-AB_IC2_LLM",
    "TI-AB_IC3_LLM",
    "TI-AB_IC4_LLM",
    "TI-AB_final_label_LLM",
]

for col in eligibility_columns_TU:
    foras_unfiltered[col] = foras_unfiltered[col].astype("category")

eligibility_columns_YNQD = [
    # joint IC labels D=disagreement
    "TI-AB_IC1_joint",
    "TI-AB_IC2_joint",
    "TI-AB_IC3_joint",
    "TI-AB_IC4_joint",
    "FT_IC1_joint",
    "FT_IC2_joint",
    "FT_IC3_joint",
    "FT_IC4_joint",
]

for col in eligibility_columns_YNQD:
    foras_unfiltered[col] = foras_unfiltered[col].astype("category")

eligibility_columns_YNR = [
    # FT labels NR=not relevant
    "FT_IC1_Bruno",
    "FT_IC2_Bruno",
    "FT_IC3_Bruno",
    "FT_IC4_Bruno",
]

for col in eligibility_columns_YNR:
    foras_unfiltered[col] = foras_unfiltered[col].astype("category")

eligibility_columns_YN = [
    "FT_IC1_final",
    "FT_IC2_final",
    "FT_IC3_final",
    "FT_IC4_final",
]

for col in eligibility_columns_YN:
    foras_unfiltered[col] = foras_unfiltered[col].astype("category")

eligibility_columns_binary = [
    "TI_final_label",
    "TI-AB_disagreement_human-human",
    "TI-AB_final_label",
    "LLM_re-assessed",
    "TI-AB_disagreement_human-LLM",
    "full_text_available",
    "FT_inclusion_Bruno",
    "FT_inclusion_Rutger",
    "FT_disagreements_Bruno-Rutger",
    "FT_final_label",
    "label_included",
]

for col in eligibility_columns_binary:
    foras_unfiltered[col] = pd.to_numeric(foras_unfiltered[col], errors="coerce")


eligibility_columns_no_missing = [
    "title_eligible_Bruno",
    "TI-AB_final_label_Bruno",
    "TI_final_label",
    "label_included",
]

# Test for valid values

try:
    # Identify invalid values per column in eligibility_columns_YNQ
    invalid_ynq_values_per_column = {
        col: foras_unfiltered.loc[
            ~foras_unfiltered[col].isin(["Y", "N", "Q"])
            & ~pd.isna(foras_unfiltered[col]),
            "MID",
        ].tolist()
        for col in eligibility_columns_YNQ
    }

    # Identify invalid values per column in eligibility_columns_TU
    invalid_TU_values_per_column = {
        col: foras_unfiltered.loc[
            ~foras_unfiltered[col].isin(["True", "Unknown"])
            & ~pd.isna(foras_unfiltered[col]),
            "MID",
        ].tolist()
        for col in eligibility_columns_TU
    }

    # Identify invalid values per column in eligibility_columns_YNQD
    invalid_YNQD_values_per_column = {
        col: foras_unfiltered.loc[
            ~foras_unfiltered[col].isin(["Y", "N", "Q", "D", "D_fi", "D_llm", "D_re"])
            & ~pd.isna(foras_unfiltered[col]),
            "MID",
        ].tolist()
        for col in eligibility_columns_YNQD
    }

    # Identify invalid values per column in eligibility_columns_YNR
    invalid_YNR_values_per_column = {
        col: foras_unfiltered.loc[
            ~foras_unfiltered[col].isin(["Y", "N", "NR"])
            & ~pd.isna(foras_unfiltered[col]),
            "MID",
        ].tolist()
        for col in eligibility_columns_YNR
    }

    # Identify invalid values per column in eligibility_columns_YN
    invalid_YN_values_per_column = {
        col: foras_unfiltered.loc[
            ~foras_unfiltered[col].isin(["Y", "N"]) & ~pd.isna(foras_unfiltered[col]),
            "MID",
        ].tolist()
        for col in eligibility_columns_YN
    }

    # Identify invalid binary values per column in eligibility_columns_binary
    invalid_binary_values_per_column = {
        col: foras_unfiltered.loc[
            ~foras_unfiltered[col].isin([0, 1]) & ~pd.isna(foras_unfiltered[col]), "MID"
        ].tolist()
        for col in eligibility_columns_binary
    }

    # Identify missing values in columns that should not have missing values in eligibility_columns_no_missing
    missing_values_per_column = {
        col: foras_unfiltered.loc[pd.isna(foras_unfiltered[col]), "MID"].tolist()
        for col in eligibility_columns_no_missing
    }

    # Check if there are any invalid values for TI-AB level columns with only YNQ values
    ynq_failures = {
        col: mids for col, mids in invalid_ynq_values_per_column.items() if mids
    }
    assert not ynq_failures, f"Invalid values in YNQ columns:\n{ynq_failures}"

    # Check if there are any invalid values for TI-AB_IC._LLM columns with only True/Unknown values
    TU_failures = {
        col: mids for col, mids in invalid_TU_values_per_column.items() if mids
    }
    assert not TU_failures, f"Invalid values in TI-AB_IC._LLM columns:\n{TU_failures}"

    # check if there are any invalid values for TI-AB_IC._joint columns with only YNQD values
    YNQD_failures = {
        col: mids for col, mids in invalid_YNQD_values_per_column.items() if mids
    }
    assert (
        not YNQD_failures
    ), f"Invalid values in TI-AB_IC._joint columns:\n{YNQD_failures}"

    # check if there are any invalid values for T_IC._Bruno columns with only YNR values
    YNR_failures = {
        col: mids for col, mids in invalid_YNR_values_per_column.items() if mids
    }
    assert not YNR_failures, f"Invalid values in FT_IC._Bruno columns:\n{YNR_failures}"

    # Check if there are any invalid values for FT_IC._final with only YN values
    yn_failures = {
        col: mids for col, mids in invalid_YN_values_per_column.items() if mids
    }
    assert not yn_failures, f"Invalid values in FT_IC._final columns:\n{yn_failures}"

    # Check if there are any invalid values for binary columns
    binary_failures = {
        col: mids for col, mids in invalid_binary_values_per_column.items() if mids
    }
    assert not binary_failures, f"Invalid values in binary columns:\n{binary_failures}"

    # Check if there are any missing values in columns that should not have missing values
    no_missing_failures = {
        col: mids for col, mids in missing_values_per_column.items() if mids
    }
    assert not no_missing_failures, f"Missing values in columns that should not have missing values:\n{no_missing_failures}"

    print(
        "Test passed: All records have valid values and no missing values where not allowed."
    )
except AssertionError as e:
    print(e)

## Test for logical combinations of values

#### Logical combination test for Title Eligibility

Conditions:
1. If `title_eligible_Bruno == 'Y'` OR `title_eligible_Rutger == 'Y'`, then `TI_final_label` must be `1`.
2. If `title_eligible_Bruno == 'Y'` AND `title_eligible_Rutger == 'N'` (or vice versa), then `TI_final_label` must be `1`.
3. If `title_eligible_Bruno == 'N'` AND `title_eligible_Rutger == 'N'`, then `TI_final_label` must be `0`.
4. If `title_eligible_Bruno == 'N'` AND `title_eligible_Rutger` is missing, then `TI_final_label` must be `0`.


In [None]:
try:

    def validate_ti_final_label(row):
        bruno = row["title_eligible_Bruno"]
        rutger = row["title_eligible_Rutger"]
        ti_final_label = row["TI_final_label"]

        # Track rule violations for the current record
        violations = []

        # Rule 1: If bruno == "Y" or rutger == "Y", ti_final_label must be 1
        if (bruno == "Y" or rutger == "Y") and ti_final_label != 1:
            violations.append(
                "Rule 1: TI_final_label != 1 when title_eligible_Bruno or title_eligible_Rutger is Y"
            )

        # Rule 2: If bruno == "N" and rutger == "N", ti_final_label must be 0
        if bruno == "N" and rutger == "N" and ti_final_label != 0:
            violations.append(
                "Rule 2: TI_final_label != 0 when both title_eligible_Bruno and title_eligible_Rutger are N"
            )

        # Rule 3: If bruno == "N" and rutger is missing, ti_final_label must be 0
        if bruno == "N" and pd.isna(rutger) and ti_final_label != 0:
            violations.append(
                "Rule 3: TI_final_label != 0 when title_eligible_Bruno is N and title_eligible_Rutger is missing"
            )

        # Rule 4: If bruno == "Y" and rutger == "N", ti_final_label must be 1
        if bruno == "Y" and rutger == "N" and ti_final_label != 1:
            violations.append(
                "Rule 4: TI_final_label != 1 when title_eligible_Bruno is Y and title_eligible_Rutger is N"
            )

        # Rule 5: If bruno == "N" and rutger == "Y", ti_final_label must be 1
        if bruno == "N" and rutger == "Y" and ti_final_label != 1:
            violations.append(
                "Rule 5: TI_final_label != 1 when title_eligible_Bruno is N and title_eligible_Rutger is Y"
            )

        # Return the list of violations (empty if none)
        return violations

    # Apply validation across rows
    failing_records = foras_unfiltered.apply(validate_ti_final_label, axis=1)

    # Collect MIDs with specific violations
    violations_by_mid = {
        foras_unfiltered.iloc[idx]["MID"]: fail
        for idx, fail in enumerate(failing_records)
        if fail
    }

    assert (
        not violations_by_mid
    ), f"Logical combination test failed at Title level:\n{violations_by_mid}"

    print(
        "Logical combination test passed for Title level: All records satisfy the conditions."
    )
except AssertionError as e:
    print(e)

#### Logical combination test for Inclusion Criteria and TI-AB for Human-Human agreement

Conditions:
1. **TI_final_label Validation**:
   - If `TI_final_label == 1`, all `TI-AB_IC._Bruno` columns (`IC1`, `IC2`, `IC3`, `IC4`) must have a value.
2. **TI-AB_final_label_Bruno Validation**:
   - If all `TI-AB_IC._Bruno` columns are `Y` or `Q`, then `TI-AB_final_label_Bruno` must be `1`. Otherwise, it must be `0`.
3. **Agreement Check**:
   - If `TI-AB_final_label_Bruno == TI-AB_final_label_Rutger`, then `TI-AB_disagreement_human-human = 0`.
   - If they differ, then `TI-AB_disagreement_human-human = 1`.
   - Ignore if TI-AB_final_label_Rutger is missing

In [None]:
try:

    def validate_inclusion_criteria(row):
        ti_final_label = row["title_eligible_Bruno"]
        ic_columns = [
            "TI-AB_IC1_Bruno",
            "TI-AB_IC2_Bruno",
            "TI-AB_IC3_Bruno",
            "TI-AB_IC4_Bruno",
        ]
        final_label_bruno = row["TI-AB_final_label_Bruno"]
        final_label_rutger = row["TI-AB_final_label_Rutger"]
        disagreement_human_human = row["TI-AB_disagreement_human-human"]

        # Track rule violations for the current record
        violations = []

        # Rule 1: If TI_final_label = 1, all TI-AB_IC._Bruno columns must have a value
        if ti_final_label == 1 and any(pd.isna(row[col]) for col in ic_columns):
            violations.append(
                "Rule 1: Missing value in TI-AB_IC._Bruno columns when TI_final_label = 1"
            )

        # Rule 2: All TI-AB_IC._Bruno are Y or Q -> TI-AB_final_label_Bruno = 1, otherwise 0
        if all(row[col] in ["Y", "Q"] for col in ic_columns):
            if final_label_bruno != "Y":
                violations.append(
                    "Rule 2: TI-AB_final_label_Bruno != 1 when all TI-AB_IC._Bruno are Y or Q"
                )
        else:
            if final_label_bruno != "N":
                violations.append(
                    "Rule 2: TI-AB_final_label_Bruno != 0 when TI-AB_IC._Bruno are not all Y or Q"
                )

        # Rule 3: Agreement/disagreement check
        if pd.notna(
            final_label_rutger
        ):  # Only check when Rutger's label is not missing
            if (
                final_label_bruno == final_label_rutger
                and disagreement_human_human != 0
            ):
                violations.append("Rule 3: Disagreement flag != 0 when labels agree")
            if (
                final_label_bruno != final_label_rutger
                and disagreement_human_human != 1
            ):
                violations.append("Rule 3: Disagreement flag != 1 when labels disagree")

        # Return the list of violations (empty if none)
        return violations

    # Apply validation across rows
    failing_records = foras_unfiltered.apply(validate_inclusion_criteria, axis=1)

    # Collect MIDs with specific violations
    violations_by_mid = {
        foras_unfiltered.iloc[idx]["MID"]: fail
        for idx, fail in enumerate(failing_records)
        if fail
    }

    assert not violations_by_mid, f"Logical combination test for Inclusion Criteria Eligibility failed:\n{violations_by_mid}"

    print(
        "Logical combination test for Inclusion Criteria Eligibility passed: All records satisfy the conditions."
    )
except AssertionError as e:
    print(e)

#### Logical combination check for LLM results

Conditions:

1. **Rule 1a**:
   - If any `TI-AB_IC._LLM` column is missing (`NaN`), then `TI-AB_final_label_LLM` must also be missing (`NaN`).
   - Violation: `TI-AB_final_label_LLM` is not missing when any `TI-AB_IC._LLM` column is missing.

2. **Rule 1b**:
   - If all `TI-AB_IC._LLM` columns are `"True"`, then `TI-AB_final_label_LLM` must be `"True"`.
   - Violation: `TI-AB_final_label_LLM` is not `"True"` when all `TI-AB_IC._LLM` columns are `"True"`.

3. **Rule 1c**:
   - If any `TI-AB_IC._LLM` column is `"Unknown"`, then `TI-AB_final_label_LLM` must also be `"Unknown"`.
   - Violation: `TI-AB_final_label_LLM` is not `"Unknown"` when any `TI-AB_IC._LLM` column is `"Unknown"`.

4. **Rule 2**:
   - If `TI-AB_final_label_LLM` is `"True"` but either `TI-AB_final_label_Bruno` or `TI-AB_final_label_Rutger` is `0`, then `TI-AB_disagreement_human-LLM` must be `1`.
   - Ignore missing values for `TI-AB_final_label_Bruno` and `TI-AB_final_label_Rutger`.
   - Violation: `TI-AB_disagreement_human-LLM` is not `1` when `TI-AB_final_label_LLM` is `"True"` and either Bruno or Rutger's final label is `0`.


In [None]:
try:

    def validate_llm_criteria(row):
        llm_columns = [
            "TI-AB_IC1_LLM",
            "TI-AB_IC2_LLM",
            "TI-AB_IC3_LLM",
            "TI-AB_IC4_LLM",
        ]
        final_label_llm = row["TI-AB_final_label_LLM"]
        final_label_bruno = row["TI-AB_final_label_Bruno"]
        final_label_rutger = row["TI-AB_final_label_Rutger"]
        disagreement_human_llm = row["TI-AB_disagreement_human-LLM"]

        # Track rule violations for the current record
        violations = []

        # Rule 1a: If any TI-AB_IC._LLM column is missing, TI-AB_final_label_LLM must be missing
        if any(pd.isna(row[col]) for col in llm_columns):
            if pd.notna(final_label_llm):
                violations.append(
                    "Rule 1a: TI-AB_final_label_LLM is not missing when any TI-AB_IC._LLM is missing"
                )

        # Rule 1b: If all TI-AB_IC._LLM columns are True, TI-AB_final_label_LLM must be True
        if all(row[col] == "True" for col in llm_columns):
            if final_label_llm != "True":
                violations.append(
                    "Rule 1b: TI-AB_final_label_LLM is not True when all TI-AB_IC._LLM are True"
                )

        # Rule 1c: If any TI-AB_IC._LLM column is Unknown, TI-AB_final_label_LLM must also be Unknown
        if any(row[col] == "Unknown" for col in llm_columns):
            if final_label_llm != "Unknown":
                violations.append(
                    "Rule 1c: TI-AB_final_label_LLM is not Unknown when any TI-AB_IC._LLM is Unknown"
                )

        # Rule 2: If TI-AB_final_label_LLM is True but either Bruno or Rutger's final label is 0, disagreement_human-LLM must be 1
        if final_label_llm == "True":
            if (
                final_label_bruno == 0 or final_label_rutger == 0
            ) and disagreement_human_llm != 1:
                violations.append(
                    "Rule 2: TI-AB_disagreement_human-LLM != 1 when TI-AB_final_label_LLM is True but Bruno or Rutger is 0"
                )

        # Return the list of violations (empty if none)
        return violations

    # Apply validation across rows
    failing_records = foras_unfiltered.apply(validate_llm_criteria, axis=1)

    # Collect MIDs with specific violations
    violations_by_mid = {
        foras_unfiltered.iloc[idx]["MID"]: fail
        for idx, fail in enumerate(failing_records)
        if fail
    }

    assert (
        not violations_by_mid
    ), f"Logical combination test for LLM Criteria failed:\n{violations_by_mid}"

    print(
        "Logical combination test for LLM Criteria passed: All records satisfy the conditions."
    )
except AssertionError as e:
    print(e)

#### Test for TI-AB_IC._joint Variables

Conditions:

1. **Rule 1**:
   - If `TI-AB_disagreement_human-human = 1`, at least one `TI-AB_IC._joint` must contain `"D"`.
   - **Violation**: No `TI-AB_IC._joint` contains `"D"` when `TI-AB_disagreement_human-human = 1`.

2. **Rule 2**:
   - If `TI-AB_disagreement_human-LLM = 1`, at least one `TI-AB_IC._joint` must contain `"D_fi"`, corresponding to the IC variable where `TI-AB_IC._Bruno` disagrees with `TI-AB_IC._LLM`.
   - **Violation**: No `TI-AB_IC._joint` contains `"D_fi"` where `TI-AB_IC._Bruno` disagrees with `TI-AB_IC._LLM`.

3. **Rule 3**:
   - If `TI-AB_IC._Bruno` contains `"N"` but `TI-AB_IC._LLM` is `"True"`, the corresponding `TI-AB_IC._joint` must be `"D_llm"` if not already `"D"`.
   - **Violation**: The `TI-AB_IC._joint` does not contain `"D_llm"` when `TI-AB_IC._Bruno = "N"` and `TI-AB_IC._LLM = "True"`.

4. **Rule 4**:
   - If `LLM_re-assessed = 1`, only `TI-AB_IC4_joint` can have the value `"D_re"`.
   - **Violation**: `TI-AB_IC4_joint` does not contain `"D_re"` when `LLM_re-assessed = 1`, or other `TI-AB_IC._joint` variables contain `"D_re"`.

5. **Rule 5**:
   - For all other cases, `TI-AB_IC._Bruno` must have the same value (`Y`, `N`, `Q`, or missing) as the corresponding `TI-AB_IC._joint` variable.
   - **Violation**: `TI-AB_IC._joint` does not match the value of `TI-AB_IC._Bruno`.

In [None]:
try:

    def validate_joint_criteria(row):
        ic_columns_bruno = [
            "TI-AB_IC1_Bruno",
            "TI-AB_IC2_Bruno",
            "TI-AB_IC3_Bruno",
            "TI-AB_IC4_Bruno",
        ]
        ic_columns_llm = [
            "TI-AB_IC1_LLM",
            "TI-AB_IC2_LLM",
            "TI-AB_IC3_LLM",
            "TI-AB_IC4_LLM",
        ]
        ic_columns_joint = [
            "TI-AB_IC1_joint",
            "TI-AB_IC2_joint",
            "TI-AB_IC3_joint",
            "TI-AB_IC4_joint",
        ]
        disagreement_human_human = row["TI-AB_disagreement_human-human"]
        disagreement_human_llm = row["TI-AB_disagreement_human-LLM"]
        llm_re_assessed = row["LLM_re-assessed"]

        # Track rule violations for the current record
        violations = []

        # Rule 1: If TI-AB_disagreement_human-human = 1, one TI-AB_IC._joint must contain `D`
        if disagreement_human_human == 1:
            if not any(row[col] == "D" for col in ic_columns_joint):
                violations.append(
                    "Rule 1: TI-AB_disagreement_human-human = 1 but no TI-AB_IC._joint contains `D`"
                )

        # Rule 2: If TI-AB_disagreement_human-LLM = 1, one TI-AB_IC._joint must contain `D_fi` corresponding to Bruno's disagreement
        if disagreement_human_llm == 1:
            if not any(
                row[joint] == "D_fi" and row[bruno] != row[llm]
                for joint, bruno, llm in zip(
                    ic_columns_joint, ic_columns_bruno, ic_columns_llm
                )
            ):
                violations.append(
                    "Rule 2: TI-AB_disagreement_human-LLM = 1 but no TI-AB_IC._joint contains `D_fi` where Bruno disagrees with LLM"
                )

        # Rule 3: If TI-AB_IC._Bruno contains `N` but TI-AB_IC._LLM is `True`, TI-AB_IC._joint must be `D_llm` for that IC, but only if there is no human disagreement
        for bruno, llm, joint in zip(
            ic_columns_bruno, ic_columns_llm, ic_columns_joint
        ):
            if (
                disagreement_human_human == 0
                and disagreement_human_llm != 1
                and row[bruno] == "N"
                and row[llm] == "True"
                and row[joint] != "D_llm"
            ):
                violations.append(
                    f"Rule 3: {joint} is not `D_llm` where Bruno = `N` and LLM = `True`, and there is no human disagreement (disagreement_human-human = 0)"
                )

        # Rule 4: If LLM_re-assessed = 1, only TI-AB_IC4_joint should have `D_re`
        if llm_re_assessed == 1:
            for joint in ic_columns_joint:
                if joint == "TI-AB_IC4_joint" and row[joint] != "D_re":
                    violations.append(
                        "Rule 4: TI-AB_IC4_joint is not `D_re` when LLM_re-assessed = 1"
                    )
                elif joint != "TI-AB_IC4_joint" and row[joint] == "D_re":
                    violations.append(
                        "Rule 4: Other TI-AB_IC._joint contains `D_re` when LLM_re-assessed = 1"
                    )

        # Rule 5: For all other cases, TI-AB_IC._Bruno must match TI-AB_IC._joint (Y, N, Q, or missing)
        for bruno, joint in zip(ic_columns_bruno, ic_columns_joint):
            if (
                pd.notna(row[bruno])
                and row[bruno] != row[joint]
                and row[joint] != "D"
                and row[joint] != "D_llm"
                and row[joint] != "D_re"
                and row[joint] != "D_fi"
            ):
                violations.append(
                    f"Rule 5: {joint} does not match Bruno's value {bruno}"
                )

        # Return the list of violations (empty if none)
        return violations

    # Apply validation across rows
    failing_records = foras_unfiltered.apply(validate_joint_criteria, axis=1)

    # Collect MIDs with specific violations
    violations_by_mid = {
        foras_unfiltered.iloc[idx]["MID"]: fail
        for idx, fail in enumerate(failing_records)
        if fail
    }

    assert not violations_by_mid, f"Logical combination test for TI-AB_IC._joint Criteria failed:\n{violations_by_mid}"

    print(
        "Logical combination test for TI-AB_IC._joint Criteria passed: All records satisfy the conditions."
    )
except AssertionError as e:
    print(e)

### Logical Combination Test: TI-AB_IC._final Variables

Conditions:

1. **Rule 1**:
   - If `TI_final_label = 1`, all `TI-AB_IC._final` columns must not be missing.
   - For disagreement labels (`D`, `D_fi`, `D_re`, `D_llm`) in `TI-AB_IC._joint`, the corresponding `TI-AB_IC._final` column must have a value of `"Y"`, `"N"`, or `"Q"`.
   - **Violation**:
     - `TI-AB_IC._final` is missing when `TI_final_label = 1`.
     - `TI-AB_IC._final` has an invalid value (not `"Y"`, `"N"`, or `"Q"`) for a disagreement label.

2. **Rule 2**:
   - For non-disagreement labels, the values in `TI-AB_IC._final` columns must match `TI-AB_IC._joint`.
   - **Violation**:
     - `TI-AB_IC._final` does not match the corresponding values in `TI-AB_IC._joint` for non-disagreement labels.


In [None]:
try:

    def validate_final_criteria(row):
        ic_columns_joint = [
            "TI-AB_IC1_joint",
            "TI-AB_IC2_joint",
            "TI-AB_IC3_joint",
            "TI-AB_IC4_joint",
        ]
        ic_columns_final = [
            "TI-AB_IC1_final",
            "TI-AB_IC2_final",
            "TI-AB_IC3_final",
            "TI-AB_IC4_final",
        ]
        ti_final_label = row["TI_final_label"]
        llm_re_assessed = row["LLM_re-assessed"]

        # Track rule violations for the current record
        violations = []

        # Rule 1: All D_ labels in final columns should have Y, N, or Q values and must not be missing if TI_final_label = 1
        for final, joint in zip(ic_columns_final, ic_columns_joint):
            if ti_final_label == 1 and pd.isna(row[final]):
                violations.append(f"Rule 1: {final} is missing when TI_final_label = 1")
            if row[joint] in ["D", "D_fi", "D_re", "D_llm"] and row[final] not in [
                "Y",
                "N",
                "Q",
            ]:
                if llm_re_assessed != 1:
                    violations.append(
                        f"Rule 1: {final} has an invalid value (not Y, N, or Q) for a disagreement label {joint}"
                    )

        # Rule 2: Values in final should match labels in joint, except for disagreement labels
        for final, joint in zip(ic_columns_final, ic_columns_joint):
            if row[joint] not in ["D", "D_fi", "D_re", "D_llm"]:  # Ignore disagreements
                if ti_final_label == 1 and row[final] != row[joint]:
                    violations.append(
                        f"Rule 2: {final} does not match {joint} for non-disagreement values"
                    )

        # Return the list of violations (empty if none)
        return violations

    # Apply validation across rows
    failing_records = foras_unfiltered.apply(validate_final_criteria, axis=1)

    # Collect MIDs with specific violations
    violations_by_mid = {
        foras_unfiltered.iloc[idx]["MID"]: fail
        for idx, fail in enumerate(failing_records)
        if fail
    }

    assert not violations_by_mid, f"Logical combination test for TI-AB_IC._final Criteria failed:\n{violations_by_mid}"

    print(
        "Logical combination test for TI-AB_IC._final Criteria passed: All records satisfy the conditions."
    )
except AssertionError as e:
    print(e)

### Logical Combination Test: TI-AB_final_label

Conditions:

1. **Rule 1**:
   - If `TI_final_label = 1`, `TI-AB_final_label` must be `0` or `1`.
   - If `TI_final_label = 0`, `TI-AB_final_label` must be empty (`NaN`).
   - **Violation**:
     - `TI-AB_final_label` has a value other than `0` or `1` when `TI_final_label = 1`.
     - `TI-AB_final_label` is not empty when `TI_final_label = 0`.

2. **Rule 2**:
   - If `TI-AB_disagreement_human-human != 1`, `TI-AB_disagreement_human-LLM != 1`, and `LLM_re-assessed != 1`, `TI-AB_final_label` must match `TI-AB_final_label_Bruno`:
     - `1` if `TI-AB_final_label_Bruno = Y`.
     - `0` if `TI-AB_final_label_Bruno = N`.
   - **Violation**:
     - `TI-AB_final_label` does not match Bruno's final label when no disagreements or re-assessment exist.

3. **Rule 3**:
   - If any `TI-AB_IC._final` contains `"N"`, `TI-AB_final_label` must be `0`.
   - **Violation**:
     - `TI-AB_final_label` is not `0` when any `TI-AB_IC._final` contains `"N"`.




In [None]:
try:

    def validate_final_label(row):
        ic_columns_final = [
            "TI-AB_IC1_final",
            "TI-AB_IC2_final",
            "TI-AB_IC3_final",
            "TI-AB_IC4_final",
        ]
        ti_final_label = row["TI_final_label"]
        final_label = row["TI-AB_final_label"]
        final_label_bruno = row["TI-AB_final_label_Bruno"]
        disagreement_human_human = row["TI-AB_disagreement_human-human"]
        disagreement_human_llm = row["TI-AB_disagreement_human-LLM"]
        llm_re_assessed = row["LLM_re-assessed"]

        # Track rule violations for the current record
        violations = []

        # Rule 1: If TI_final_label = 1, TI-AB_final_label must be 0 or 1; if TI_final_label = 0, it must be empty
        if ti_final_label == 1 and final_label not in [0, 1]:
            violations.append(
                "Rule 1: TI-AB_final_label must be 0 or 1 when TI_final_label = 1"
            )
        if ti_final_label == 0 and pd.notna(final_label):
            violations.append(
                "Rule 1: TI-AB_final_label must be empty when TI_final_label = 0"
            )

        # Rule 2: If no disagreements and not re-assessed, TI-AB_final_label must match Bruno's final label
        if (
            disagreement_human_human != 1
            and disagreement_human_llm != 1
            and llm_re_assessed != 1
            and ti_final_label == 1
        ):
            if final_label_bruno == "Y" and final_label != 1:
                violations.append(
                    "Rule 2: TI-AB_final_label must be 1 when TI-AB_final_label_Bruno = Y and no disagreements or re-assessment"
                )
            if final_label_bruno == "N" and final_label != 0:
                violations.append(
                    "Rule 2: TI-AB_final_label must be 0 when TI-AB_final_label_Bruno = N and no disagreements or re-assessment"
                )

        # Rule 3: TI-AB_final_label must be 0 if any TI-AB_IC._final contains an N
        if any(row[col] == "N" for col in ic_columns_final) and final_label != 0:
            violations.append(
                "Rule 3: TI-AB_final_label must be 0 when any TI-AB_IC._final contains N"
            )

        # Return the list of violations (empty if none)
        return violations

    # Apply validation across rows
    failing_records = foras_unfiltered.apply(validate_final_label, axis=1)

    # Collect MIDs with specific violations
    violations_by_mid = {
        foras_unfiltered.iloc[idx]["MID"]: fail
        for idx, fail in enumerate(failing_records)
        if fail
    }

    assert (
        not violations_by_mid
    ), f"Logical combination test for TI-AB_final_label failed:\n{violations_by_mid}"

    print(
        "Logical combination test for TI-AB_final_label passed: All records satisfy the conditions."
    )
except AssertionError as e:
    print(e)