In [None]:
import os
import pandas as pd
import numpy as np
from lifelines.statistics import logrank_test
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines import LogLogisticAFTFitter, WeibullAFTFitter, LogNormalAFTFitter

import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
import pycountry

import ipywidgets as widgets
from IPython.display import display, clear_output


atlas_path = "./Vivli Data Challenge 2025 - Nachtrab/ATLAS_Antibiotics/2025_03_11 atlas_antibiotics.xlsx"
atlas_df = pd.read_excel(atlas_path)

cabbage_path = "./CABBAGEdata/processed_database/dereplicated__all_clean.csv"
cabbage_df = pd.read_csv(cabbage_path)

# Only keep the liquid solution typing methods
cabbage_df.drop(cabbage_df[cabbage_df["measurement_unit"] != "mg/l"].index, inplace=True)

antibio_pairs_path = "./generated_data/sorted_species_antibio_pairs.csv"
pairs_df = pd.read_csv(antibio_pairs_path)

In [None]:
# Vivli has a single isolate (row) tested on multiple antibio
# Cabbage has a single test on a row, with possibly multiple rows per specimen/id

atlas_data = atlas_df.copy()
cabbage_data = cabbage_df.copy()

# Process data to be ingestible for comparison
cabbage_data["combined_value"] = cabbage_data["measurement_sign"].fillna("") + cabbage_data["measurement_value"].astype(str).fillna("")

cabbage_antibiotics = cabbage_data["Antibiotic_name"].dropna().unique()

for abx in cabbage_antibiotics:
    cabbage_data[abx] = np.nan

for abx in cabbage_antibiotics:
    mask = cabbage_data["Antibiotic_name"] == abx
    cabbage_data.loc[mask, abx] = cabbage_data.loc[mask, "combined_value"]



# Now we handle making the values uniform with the vivli data
for antibiotic in cabbage_antibiotics:
    new_column_name = antibiotic.replace("/", " ")
    if new_column_name == "Trimethoprim sulfamethoxazole":
        new_column_name = "Trimethoprim sulfa"

    cabbage_data[new_column_name] = np.where(cabbage_data["Antibiotic_name"] == antibiotic, 
                                            (cabbage_data["measurement_sign"].fillna('') + cabbage_data["measurement_value"].astype(str)),
                                            np.nan)

# Create a new Species column for cabbage combining genus and species
cabbage_data["combined_species"] = cabbage_data["genus"].astype(str) + " " + cabbage_data["species"]

# Unify sex
cabbage_data['host_sex'] = cabbage_data['host_sex'].replace('male', 1.0)
cabbage_data['host_sex'] = cabbage_data['host_sex'].replace('female', 0.0)

# Encode phenotype. Cabbage has categories: ['susceptible' 'resistant' 'non-susceptible', 'susceptible dose dependent', 'decreased susceptibility']
# Ignore for the moment. Non-susceptible = resistant. Breakpoints change by country, over time, etc
encoder = OrdinalEncoder(categories=[['susceptible', 'intermediate', 'resistant', 'decreased susceptibility', 'susceptible dose dependent', 'non-susceptible']])
mask = cabbage_data['phenotype'].notna()
cabbage_data.loc[mask, 'phenotype_encoded'] = encoder.fit_transform(cabbage_data.loc[mask, ['phenotype']])
cabbage_data['phenotype_encoded'] = cabbage_data['phenotype_encoded'].astype('float')

# Bin ages. Atlas has defined age groups
bins = [0, 2, 12, 18, 64, 84, np.inf]
labels = ["0 to 2 Years", "3 to 12 Years", "13 to 18 Years", 
          "19 to 64 Years", "65 to 84 Years", "85 and Over"]
cabbage_data['Age Group'] = pd.cut(cabbage_data['host_age'], bins=bins, labels=labels, right=True)

# Encode. TODO: different than vivli - change vivli?
age_ordinal = OrdinalEncoder(categories=[labels], 
                            handle_unknown="use_encoded_value",
                            unknown_value=-1,
                            dtype=float)
cabbage_data['Age Group Encoded'] = age_ordinal.fit_transform(cabbage_data[['Age Group']])
cabbage_data['Age Group Encoded'].replace(-1, np.nan, inplace=True)

# Convert Year to numerical.
cabbage_data['collection_date'].astype('float32')





# Encode atlas vivli data into a numerical format

# Handle _I columns, [nan 'Susceptible' 'Intermediate' 'Resistant'] into nan, 0, 1, 2
columns_of_interest = [col for col in atlas_data.columns if col.endswith('_I')]
atlas_antibiotics = [col[:-2] for col in atlas_data.columns]
atlas_data[columns_of_interest] = atlas_data[columns_of_interest].astype("string")

encoder = OrdinalEncoder(categories=[['Susceptible', 'Intermediate', 'Resistant']] * len(columns_of_interest), 
                         handle_unknown="use_encoded_value", unknown_value=np.nan, 
                         dtype=float)
atlas_data[columns_of_interest] = encoder.fit_transform(atlas_data[columns_of_interest])


# Handle age and gender [using nans]
atlas_data['Age Group'] = atlas_data['Age Group'].replace('Unknown', np.nan)
age_ordinal = OrdinalEncoder(categories=[["0 to 2 Years", "3 to 12 Years", "13 to 18 Years", "19 to 64 Years", "65 to 84 Years", "85 and Over"]],
                            handle_unknown="use_encoded_value", unknown_value=np.nan, 
                            dtype=float)

atlas_data["Age Group"] = age_ordinal.fit_transform(atlas_data[["Age Group"]])


# Male 0.0, Female 1.0
atlas_data['Gender'] = atlas_data['Gender'].replace('Unknown', np.nan)
gender_ordinal = OrdinalEncoder(categories=[["Male", "Female"]],
                                handle_unknown="use_encoded_value", unknown_value=np.nan, 
                                dtype=float)
atlas_data["Gender"] = gender_ordinal.fit_transform(atlas_data[["Gender"]])

# Substitue the unknown ranks for nan
atlas_data['Speciality'] = atlas_data['Speciality'].replace("None Given",np.nan)
atlas_data['Speciality'] = atlas_data['Speciality'].replace("Other", np.nan)

atlas_data['In / Out Patient'] = atlas_data['In / Out Patient'].replace('None Given', np.nan)
atlas_data['In / Out Patient'] = atlas_data['In / Out Patient'].replace('Other', np.nan)

# Convert Year to numerical. Corresponds to Cabbage "collection_date"
atlas_data['Year'].astype('float32')

# Convert country to Alpha 3
atlas_data["country_alpha3"] = atlas_data["Country"].apply(
    lambda x: pycountry.countries.lookup(x).alpha_3 if isinstance(x, str) and pycountry.countries.get(name=x) else None
)




# Get common bacteria and antibiotics
overlapping_bacteria = [s.lower() for s in list(set(cabbage_data["combined_species"].unique()) & set(atlas_data["Species"].unique()))]

print(f"overlapping_bacteria length: {len(overlapping_bacteria)}, cabbage species: {len(cabbage_data["combined_species"].unique())}, atlas species: {len(atlas_data["Species"].unique())}")


overlapping_antibio = [s.lower() for s in list(set(cabbage_antibiotics) & set(atlas_antibiotics))]
print(f"overlapping_antibio length: {len(overlapping_antibio)}")


In [None]:
def print_yoy_data(analysis_df, species, antibiotic, location=None):
    """
    Print the year over year data present in an ATLAS dataframe
    """        
    antibiotic_column_name = next(col for col in analysis_df.columns if col.lower() == antibiotic.lower())

    external_subset = analysis_df[analysis_df["Species"].str.lower() == species.lower()]
    internal_subset = analysis_df[analysis_df["Species"].str.lower() == species.lower()]
    if location:
        internal_subset = internal_subset[internal_subset["country_alpha3"].str.lower() == location.lower()]
        external_subset = external_subset[external_subset["country_alpha3"].str.lower() == location.lower()]

    internal_subset = internal_subset.dropna(subset=[antibiotic_column_name])
    external_subset = external_subset.dropna(subset=[antibiotic_column_name])
    print(f"For country: {location}")
    for i in range(2000, 2026):
        print(f"Year: {i}, nb: {(internal_subset["Year"] == i).sum()}")

In [None]:
print(pairs_df.info())
pd.set_option('display.max_rows', None)  
pairs_df.head(50)

In [None]:
# Report graphs
import seaborn as sns
import scipy

# Parse value for Left truncated (late entry) data
# https://lifelines.readthedocs.io/en/latest/Survival%20analysis%20with%20lifelines.html#left-truncated-late-entry-data
def parse_value(val):
    """
    Parse str '<=2', '>8', '2/32', '=0.5', '16' into (lower, upper) for censoring.
    Returns: lower bound, upper bound, observed
    """
    if pd.isna(val):
        return (np.nan, np.nan, np.nan)
    
    val = str(val).strip()
    
    # To check -> removed
    # Intevals happen when there is ambiguity
    if val.startswith(">,<"):
        val = val[3:].split(",")
        return (float(val[0]), float(val[1]), True)
    
    # Has been normalized too
    # Right censoring
    if val.startswith(">="):
        return (float(val[2:]), float(val[2:]), True)
    
    elif val.startswith(">"):
        return (float(val[1:]), float(val[1:])*2, False)
    
    # <= and < are in practice observed the same way. So both false.
    elif val.startswith("<="):
        return (0.0, float(val[2:]), True)
    
    elif val.startswith("<"):
        return (0.0, float(val[1:]), False)
    
    elif val.startswith("="):
        v = float(val[1:])
        return (v, v, True) # Upper value inf? Or v?
    
    # Just straight up numbers
    else:
        try:
            v = float(val)
            return (v, v, True)
        except ValueError:
            return (np.nan, np.nan, np.nan)


def generate_internal_comparison(anal_atlas_df, species, antibiotic, years, location=None, gender=None):
    """
    Compare a dataset amongst itself, stratified on parameters
    """

    internal_subset = anal_atlas_df[anal_atlas_df["Species"].str.lower() == species.lower()]
   
    if location:
        internal_subset = internal_subset[internal_subset["country_alpha3"].str.lower() == location.lower()]

    internal_subsets = {}
    for year in years:
        internal_subsets[year] = internal_subset[internal_subset["Year"] == year]

    atlas_col = next(col for col in anal_atlas_df.columns if col.lower() == antibiotic.lower())
    

    min_val = 0
    max_val = 0
    percentage_resistant = {}
    unique_x_vals = set()
    # Drop NaNs 
    for year_key in internal_subsets:
        internal_subsets[year_key] = internal_subsets[year_key].dropna(subset=[atlas_col])

        # Apply parse_value and expand into new columns
        internal_subsets[year_key] = internal_subsets[year_key][["T_lower", "T_upper", "T_observed"]] = internal_subsets[year_key] = internal_subsets[year_key][atlas_col].apply(
            lambda x: pd.Series(parse_value(x))
        )

        unique_x_vals.update(internal_subsets[year_key]["T_upper"].unique())

        internal_high = internal_subsets[year_key]["T_upper"].max()
        if internal_high > max_val:
            max_val = internal_high

        percentage_resistant[year_key] = (((internal_subsets[year_key]["T_upper"] >= 8.0).sum())
                                          /len(internal_subsets[year_key]["T_upper"])
                                          )*100
        print(f"Year: {year_key}, %: {percentage_resistant[year_key]}")

    
    fig, ax = plt.subplots(figsize=(9, 5))

    ax.set_xlim(min(unique_x_vals), max(unique_x_vals))
    ax.set_ylim(-0.05, 1.05)
    ax.set_xscale('log', base=2)
    print(unique_x_vals)
    ax.set_xticks(np.sort(np.array(list(unique_x_vals))))

    for year_key in internal_subsets:
        sub = internal_subsets[year_key]

        kmf_int = KaplanMeierFitter()

        kmf_int.fit(sub["T_upper"], event_observed=sub["T_observed"], label=f"{year_key}, n={len(sub)}")
        kmf_int.plot_survival_function(ax=ax)

    plt.axvline(x = 8.0, ymin = 0, ymax = 1.0, linestyle ="--", color ='red', label="EUCAST Breakpoint, 2021")
    plt.ylabel("Proportion of samples")
    plt.xlabel("Log2 Concentration (μg/mL)")




    # ---- Cox PH model across years ----
    ref_year = years[0]

    # Combine all years into one modeling dataframe
    model_df = []
    for year_key, sub in internal_subsets.items():
        if len(sub):
            m = sub[["T_upper","T_observed"]].copy()
            m["Year"] = year_key
            model_df.append(m)
    if not model_df:
        raise ValueError("No rows to fit the Cox model after filtering.")
    
    model_df = pd.concat(model_df, ignore_index=True)


    # Encode Year as categorical with a reference
    model_df["Year"] = pd.Categorical(model_df["Year"], categories=years, ordered=False)
    X = pd.get_dummies(model_df["Year"], prefix="Year", drop_first=True)

    # If you want a specific reference year, reorder categories so it is first:
    if ref_year in years and years[0] != ref_year:
        years_ordered = [ref_year] + [y for y in years if y != ref_year]
        model_df["Year"] = pd.Categorical(model_df["Year"], categories=years_ordered, ordered=False)
        X = pd.get_dummies(model_df["Year"], prefix="Year", drop_first=True)

    # Assemble dataset for lifelines
    cox_df = pd.concat([model_df[["T_upper","T_observed"]].reset_index(drop=True), X.reset_index(drop=True)], axis=1)

    # Fit Cox PH
    cph = CoxPHFitter()
    cph.fit(cox_df, duration_col="T_upper", event_col="T_observed")
    print(cph.summary)

    # Global test: do years matter overall?
    # lifelines exposes the log-likelihood ratio test on the fitted model:
    # (attribute name includes a trailing underscore in many versions)
    try:
        llrt = cph.log_likelihood_ratio_test_
        print(f"\nGlobal LLR test for Year terms: chi2={llrt.test_statistic:.3f}, df={llrt.degrees_freedom}, p={llrt.p_value:.4g}")
    except Exception:
        pass

    # Optional: check proportional hazards assumption (prints diagnostics)
    cph.check_assumptions(cox_df, p_value_threshold=0.05, show_plots=True)



    return None


sns.set_style("darkgrid")  # Or "darkgrid", "ticks", etc.
# print_yoy_data(atlas_data, "escherichia coli", "ampicillin", location="FRA")
generate_internal_comparison(atlas_data, "escherichia coli", "ampicillin", years=[2013,2017,2021], location="DEU")


In [None]:

def generate_atlas_cabbage_comparison(atlas_dataset, cabbage_dataset, species, antibiotic, year_from_to, locations=None, gender=None):
    """
    Compare two datasets
    """

    atlas_sub = atlas_dataset[atlas_dataset["Species"].str.lower() == species.lower()]
    cabbage_sub = cabbage_dataset[cabbage_dataset["combined_species"].str.lower() == species.lower()]

    atlas_sub = atlas_sub[atlas_sub["country_alpha3"].isin(locations)]
    cabbage_sub = cabbage_sub[cabbage_sub["isolation_country"].isin(locations)]

    y_from, y_to = year_from_to
    atlas_sub = atlas_sub[(atlas_sub["Year"] >= y_from) & (atlas_sub["Year"] <= y_to)]
    cabbage_sub = cabbage_sub[(cabbage_sub["collection_date"] >= y_from) & (cabbage_sub["collection_date"] <= y_to)]

    atlas_col = next(col for col in atlas_dataset.columns if col.lower() == antibiotic.lower())
    cabbage_col = next(col for col in cabbage_dataset.columns if col.lower() == antibiotic.lower())    

    print(len(atlas_sub))
    print(len(cabbage_sub))

    # print((cabbage_sub[["combined_species",cabbage_col]]).unique())
    print((cabbage_sub[cabbage_col]).unique())

    min_val = 0
    max_val = 0
    percentage_resistant = {}
    # unique_x_vals = set()
    
    
    # Drop NaNs 
    atlas_sub = atlas_sub.dropna(subset=[atlas_col])
    cabbage_sub = cabbage_sub.dropna(subset=[cabbage_col])

    # Apply parse_value and expand into new columns
    atlas_sub[["T_lower", "T_upper", "T_observed"]] = atlas_sub[atlas_col].apply(
        lambda x: pd.Series(parse_value(x))
    )

    cabbage_sub[["T_lower", "T_upper", "T_observed"]] = cabbage_sub[cabbage_col].apply(
        lambda x: pd.Series(parse_value(x))
    )


    unique_x_atlas_vals = (atlas_sub["T_upper"].unique())
    unique_x_cabbage_vals = (cabbage_sub["T_upper"].unique())


    internal_high = max(unique_x_atlas_vals.max(), unique_x_cabbage_vals.max())
    internal_low = min(unique_x_atlas_vals.min(), unique_x_cabbage_vals.min())


    # percentage_resistant[year_key] = (((internal_subsets[year_key]["T_upper"] >= 8.0).sum())
    #                                     /len(internal_subsets[year_key]["T_upper"])
    #                                     )*100
    # print(f"Year: {year_key}, %: {percentage_resistant[year_key]}")

    
    fig, ax = plt.subplots(figsize=(9, 5))

    # ax.set_xlim(internal_low, internal_high)
    ax.set_xlim(internal_low, 64.1)
    ax.set_ylim(-0.05, 1.05)
    ax.set_xscale('log', base=2)

    print(f"Unique Atl vals: {unique_x_atlas_vals}")
    print(f"Unique Cab vals: {unique_x_cabbage_vals}")


    ax.set_xticks(np.sort(np.array(list(unique_x_atlas_vals))))


    kmf_int = KaplanMeierFitter()
    kmf_ext = KaplanMeierFitter()

    kmf_int.fit(atlas_sub["T_upper"], event_observed=atlas_sub["T_observed"], label=f"ATLAS Dataset, nb isolates: {len(atlas_sub)}")
    kmf_int.plot_survival_function(ax=ax)

    kmf_ext.fit(cabbage_sub["T_upper"], event_observed=cabbage_sub["T_observed"], label=f"CABBAGE Dataset, nb isolates: {len(cabbage_sub)}")
    kmf_ext.plot_survival_function(ax=ax)

    
    plt.axvline(x = 2.0, ymin = 0, ymax = 1.0, linestyle ="--", color ='red', label="EUCAST Breakpoint, 2021")
    plt.ylabel("Proportion of samples")
    plt.xlabel("Log2 Concentration (μg/mL)")
    

    atlas_count_above_8 = (atlas_sub["T_upper"] > 2.0).sum()
    cabbage_count_above_8 = (cabbage_sub["T_upper"] > 2.0).sum()
    return (atlas_count_above_8/len(atlas_sub["T_upper"]), cabbage_count_above_8/len(cabbage_sub["T_upper"]))


sns.set_style("darkgrid")  # Or "darkgrid", "ticks", etc.

# Meropenem, gentamicin
ret = generate_atlas_cabbage_comparison(atlas_data, cabbage_data, "escherichia coli", "gentamicin", year_from_to=(2013,2021), locations=("AUT", "BEL", "BGR", "HRV", "CYP", "CZE", "DNK", "EST", "FIN", "FRA",
                                                                                                                   "DEU", "GRC", "HUN", "IRL", "ITA", "LVA", "LTU", "LUX", "MLT", "NLD",
                                                                                                                     "POL", "PRT", "ROU", "SVK", "SVN", "ESP", "SWE"))
print(ret)