In [None]:
import os
import pandas as pd
import numpy as np
from lifelines.statistics import logrank_test
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines import LogLogisticAFTFitter, WeibullAFTFitter, LogNormalAFTFitter

from scipy.stats import hmean

import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
import pycountry

import ipywidgets as widgets
from IPython.display import display, clear_output


atlas_path = "./Vivli Data Challenge 2025 - Nachtrab/ATLAS_Antibiotics/2025_03_11 atlas_antibiotics.xlsx"
atlas_df = pd.read_excel(atlas_path)

cabbage_path = "./CABBAGEdata/processed_database/step1_merge_all_v18.csv/step1_merge_all_v18.csv"
cabbage_df = pd.read_csv(cabbage_path)

# Only keep the liquid solution typing methods
cabbage_df.drop(cabbage_df[cabbage_df["measurement_unit"] != "mg/l"].index, inplace=True)

In [None]:
# Vivli has a single isolate (row) tested on multiple antibio
# Cabbage has a single test on a row, with possibly multiple rows per specimen/id

atlas_data = atlas_df.copy()
cabbage_data = cabbage_df.copy()

# Process data to be ingestible for comparison
cabbage_data["combined_value"] = cabbage_data["measurement_sign"].fillna("") + cabbage_data["measurement_value"].astype(str).fillna("")

cabbage_antibiotics = cabbage_data["Antibiotic_name"].dropna().unique()

for abx in cabbage_antibiotics:
    cabbage_data[abx] = np.nan

for abx in cabbage_antibiotics:
    mask = cabbage_data["Antibiotic_name"] == abx
    cabbage_data.loc[mask, abx] = cabbage_data.loc[mask, "combined_value"]


# Now we handle making the values uniform with the vivli data
for antibiotic in cabbage_antibiotics:
    new_column_name = antibiotic.replace("/", " ")
    if new_column_name == "Trimethoprim sulfamethoxazole":
        new_column_name = "Trimethoprim sulfa"

    cabbage_data[new_column_name] = np.where(cabbage_data["Antibiotic_name"] == antibiotic, 
                                            (cabbage_data["measurement_sign"].fillna('') + cabbage_data["measurement_value"].astype(str)),
                                            np.nan)


# Create a new Species column for cabbage combining genus and species
cabbage_data["combined_species"] = cabbage_data["genus"].astype(str) + " " + cabbage_data["species"]

# Unify sex
cabbage_data['host_sex'] = cabbage_data['host_sex'].replace('male', 1.0)
cabbage_data['host_sex'] = cabbage_data['host_sex'].replace('female', 0.0)

# Encode phenotype. Cabbage has categories: ['susceptible' 'resistant' 'non-susceptible', 'susceptible dose dependent', 'decreased susceptibility']
# Ignore for the moment. Non-susceptible = resistant. Breakpoints change by country, over time, etc
encoder = OrdinalEncoder(categories=[['susceptible', 'intermediate', 'resistant', 'decreased susceptibility', 'susceptible dose dependent', 'non-susceptible']])
mask = cabbage_data['phenotype'].notna()
cabbage_data.loc[mask, 'phenotype_encoded'] = encoder.fit_transform(cabbage_data.loc[mask, ['phenotype']])
cabbage_data['phenotype_encoded'] = cabbage_data['phenotype_encoded'].astype('float')

# Bin ages. Atlas has defined age groups
bins = [0, 2, 12, 18, 64, 84, np.inf]
labels = ["0 to 2 Years", "3 to 12 Years", "13 to 18 Years", 
          "19 to 64 Years", "65 to 84 Years", "85 and Over"]
cabbage_data['Age Group'] = pd.cut(cabbage_data['host_age'], bins=bins, labels=labels, right=True)

# Encode. TODO: different than vivli - change vivli?
age_ordinal = OrdinalEncoder(categories=[labels], 
                            handle_unknown="use_encoded_value",
                            unknown_value=-1,
                            dtype=float)
cabbage_data['Age Group Encoded'] = age_ordinal.fit_transform(cabbage_data[['Age Group']])
cabbage_data['Age Group Encoded'].replace(-1, np.nan, inplace=True)

# Convert Year to numerical.
cabbage_data['collection_date'].astype('float32')



# Encode atlas vivli data into a numerical format

# Handle _I columns, [nan 'Susceptible' 'Intermediate' 'Resistant'] into nan, 0, 1, 2
columns_of_interest = [col for col in atlas_data.columns if col.endswith('_I')]
atlas_antibiotics = [col[:-2] for col in atlas_data.columns]
atlas_data[columns_of_interest] = atlas_data[columns_of_interest].astype("string")

encoder = OrdinalEncoder(categories=[['Susceptible', 'Intermediate', 'Resistant']] * len(columns_of_interest), 
                         handle_unknown="use_encoded_value", unknown_value=np.nan, 
                         dtype=float)
atlas_data[columns_of_interest] = encoder.fit_transform(atlas_data[columns_of_interest])


# Handle age and gender [using nans]
atlas_data['Age Group'] = atlas_data['Age Group'].replace('Unknown', np.nan)
age_ordinal = OrdinalEncoder(categories=[["0 to 2 Years", "3 to 12 Years", "13 to 18 Years", "19 to 64 Years", "65 to 84 Years", "85 and Over"]],
                            handle_unknown="use_encoded_value", unknown_value=np.nan, 
                            dtype=float)

atlas_data["Age Group"] = age_ordinal.fit_transform(atlas_data[["Age Group"]])


# Male 0.0, Female 1.0
atlas_data['Gender'] = atlas_data['Gender'].replace('Unknown', np.nan)
gender_ordinal = OrdinalEncoder(categories=[["Male", "Female"]],
                                handle_unknown="use_encoded_value", unknown_value=np.nan, 
                                dtype=float)
atlas_data["Gender"] = gender_ordinal.fit_transform(atlas_data[["Gender"]])

# Substitue the unknown ranks for nan
atlas_data['Speciality'] = atlas_data['Speciality'].replace("None Given",np.nan)
atlas_data['Speciality'] = atlas_data['Speciality'].replace("Other", np.nan)

atlas_data['In / Out Patient'] = atlas_data['In / Out Patient'].replace('None Given', np.nan)
atlas_data['In / Out Patient'] = atlas_data['In / Out Patient'].replace('Other', np.nan)

# Convert Year to numerical. Corresponds to Cabbage "collection_date"
atlas_data['Year'].astype('float32')

# Convert country to Alpha 3
atlas_data["country_alpha3"] = atlas_data["Country"].apply(
    lambda x: pycountry.countries.lookup(x).alpha_3 if isinstance(x, str) and pycountry.countries.get(name=x) else None
)




# Get common bacteria and antibiotics
overlapping_bacteria = [s.lower() for s in list(set(cabbage_data["combined_species"].unique()) & set(atlas_data["Species"].unique()))]

print(f"overlapping_bacteria length: {len(overlapping_bacteria)}, cabbage species: {len(cabbage_data["combined_species"].unique())}, atlas species: {len(atlas_data["Species"].unique())}")


overlapping_antibio = [s.lower() for s in list(set(cabbage_antibiotics) & set(atlas_antibiotics))]
print(f"overlapping_antibio length: {len(overlapping_antibio)}")


In [None]:
# Weibull estimation
# There are 29 species x 36 antibio = 1044 pairs to test

# Parse value for Left truncated (late entry) data
# https://lifelines.readthedocs.io/en/latest/Survival%20analysis%20with%20lifelines.html#left-truncated-late-entry-data
def parse_value(val):
    """
    Parse str '<=2', '>8', '2/32', '=0.5', '16' into (lower, upper) for censoring.
    Returns: lower bound, upper bound, observed
    """
    if pd.isna(val):
        return (np.nan, np.nan, np.nan)
    
    val = str(val).strip()
    
    # Intevals happen when there is ambiguity
    if val.startswith(">,<"):
        val = val[3:].split(",")
        return (float(val[0]), float(val[1]), True)
    
    # Right censoring
    if val.startswith(">="):
        return (float(val[2:]), np.inf, True)
    
    elif val.startswith(">"):
        return (float(val[1:]), np.inf, False)
    
    # <= and < are in practice observed the same way. So both false.
    elif val.startswith("<="):
        return (0.0, float(val[2:]), False)
    
    elif val.startswith("<"):
        return (0.0, float(val[1:]), False)
    
    elif val.startswith("="):
        v = float(val[1:])
        return (v, v, True) # Upper value inf? Or v?
    
    # Just straight up numbers
    else:
        try:
            v = float(val)
            return (v, v, True)
        except ValueError:
            return (np.nan, np.nan, np.nan)



def generate_analysis(anal_cabbage_df, anal_atlas_df, species, antibiotic, year, location, gender):

    internal_subset = anal_cabbage_df[anal_cabbage_df["combined_species"].str.lower() == species.lower()]
    external_subset = anal_atlas_df[anal_atlas_df["Species"].str.lower() == species.lower()]

    # Extra filters /!\ excludes NaNs!!!
    if year:
        internal_subset = internal_subset[internal_subset["collection_date"].str.lower() == year.lower()]
        external_subset = external_subset[external_subset["Year"].str.lower() == year.lower()]
    
    if location:
        internal_subset = internal_subset[internal_subset["isolation_country"].str.lower() == location.lower()]
        external_subset = external_subset[external_subset["country_alpha3"].str.lower() == location.lower()]

    if gender:
        internal_subset = internal_subset[internal_subset["host_sex"].str.lower() == gender.lower()]
        external_subset = external_subset[external_subset["Gender"].str.lower() == gender.lower()]
    


    # Check if too little data
    if len(internal_subset) < len(external_subset)/100 or len(external_subset) < len(internal_subset)/100:
        print(f"Imbalanced data, cab {len(internal_subset)}, atlas {len(external_subset)}")
        return None
    
    
    
    cabbage_col = next(col for col in anal_cabbage_df.columns if col.lower() == antibiotic.lower())
    atlas_col = next(col for col in anal_atlas_df.columns if col.lower() == antibiotic.lower())
    
    # Parse values from the given antibiotic column
    parsed_int = internal_subset[cabbage_col].dropna().apply(parse_value)
    parsed_ext = external_subset[atlas_col].dropna().apply(parse_value)

    if len(parsed_int) < 10 or len(parsed_ext) < 10:
        print(f"Too little parsed data, cab {len(parsed_int)}, atlas {len(parsed_ext)}")
        return None
    
    print(len(parsed_int))
    print(len(parsed_ext))
    
    T_int_lower = np.array([p[0] for p in parsed_int])
    T_int_upper = np.array([p[1] for p in parsed_int])
    T_int_obsrv = np.array([p[2] for p in parsed_int])
    
    T_ext_lower = np.array([p[0] for p in parsed_ext])
    T_ext_upper = np.array([p[1] for p in parsed_ext])
    T_ext_obsrv = np.array([p[2] for p in parsed_ext])

    df_int = pd.DataFrame({
        "T_lower": T_int_lower,
        "T_upper": T_int_upper
    })

    df_ext = pd.DataFrame({
        "T_lower": T_ext_lower,
        "T_upper": T_ext_upper
    })


    aft_int = WeibullAFTFitter()
    aft_int.fit_interval_censoring(df_int, lower_bound_col="T_lower", upper_bound_col="T_upper")

    aft_ext = WeibullAFTFitter()
    aft_ext.fit_interval_censoring(df_ext, lower_bound_col="T_lower", upper_bound_col="T_upper")

    return {
        "species": species,
        "antibiotic": antibiotic,
        "weibull_cabbage": aft_int,
        "weibull_atlas": aft_ext
    }


results = []
i = 0
j = 0
for species in overlapping_bacteria:
    print(f"Analysing all antibiotics for {species}")
    if i > 3:
        break

    for antibiotic in overlapping_antibio:
        if j > 10:
            break
        row = generate_analysis(cabbage_data, atlas_data, species, antibiotic, None, None, None)
        if row:
            results.append(row)
            j = j+1
    j = 0
    i = i+1
        

results_df = pd.DataFrame(results)

In [None]:
# There are 29 species x 36 antibio = 1044 pairs to test

# Parse value for Left truncated (late entry) data
# https://lifelines.readthedocs.io/en/latest/Survival%20analysis%20with%20lifelines.html#left-truncated-late-entry-data
def parse_value(val):
    """
    Parse str '<=2', '>8', '2/32', '=0.5', '16' into (lower, upper) for censoring.
    Returns: lower bound, upper bound, observed
    """
    if pd.isna(val):
        return (np.nan, np.nan, np.nan)
    
    val = str(val).strip()
    
    # Intevals happen when there is ambiguity
    if val.startswith(">,<"):
        val = val[3:].split(",")
        return (float(val[0]), float(val[1]), True)
    
    # Right censoring
    if val.startswith(">="):
        return (float(val[2:]), np.inf, True)
    
    elif val.startswith(">"):
        return (float(val[1:]), np.inf, False)
    
    # <= and < are in practice observed the same way. So both false.
    elif val.startswith("<="):
        return (0.0, float(val[2:]), False)
    
    elif val.startswith("<"):
        return (0.0, float(val[1:]), False)
    
    elif val.startswith("="):
        v = float(val[1:])
        return (v, v, True) # Upper value inf? Or v?
    
    # Just straight up numbers
    else:
        try:
            v = float(val)
            return (v, v, True)
        except ValueError:
            return (np.nan, np.nan, np.nan)



def generate_analysis(anal_cabbage_df, anal_atlas_df, species, antibiotic, year, location, gender):

    internal_subset = anal_cabbage_df[anal_cabbage_df["combined_species"].str.lower() == species.lower()]
    external_subset = anal_atlas_df[anal_atlas_df["Species"].str.lower() == species.lower()]

    # Extra filters /!\ excludes NaNs!!!
    if year:
        internal_subset = internal_subset[internal_subset["collection_date"].str.lower() == year.lower()]
        external_subset = external_subset[external_subset["Year"].str.lower() == year.lower()]
    
    if location:
        internal_subset = internal_subset[internal_subset["isolation_country"].str.lower() == location.lower()]
        external_subset = external_subset[external_subset["country_alpha3"].str.lower() == location.lower()]

    if gender:
        internal_subset = internal_subset[internal_subset["host_sex"].str.lower() == gender.lower()]
        external_subset = external_subset[external_subset["Gender"].str.lower() == gender.lower()]
    


    # Check if too little data
    if len(internal_subset) < len(external_subset)/100 or len(external_subset) < len(internal_subset)/100:
        print(f"Imbalanced data, cab {len(internal_subset)}, atlas {len(external_subset)}")
        return None
    
    
    
    # print(f"Enough data: cab {len(internal_subset)}, atlas {len(external_subset)}")

    
    cabbage_col = next(col for col in anal_cabbage_df.columns if col.lower() == antibiotic.lower())
    atlas_col = next(col for col in anal_atlas_df.columns if col.lower() == antibiotic.lower())
    
    # Parse values from the given antibiotic column
    parsed_int = internal_subset[cabbage_col].dropna().apply(parse_value)
    parsed_ext = external_subset[atlas_col].dropna().apply(parse_value)

    if len(parsed_int) < 10 or len(parsed_ext) < 10:
        print(f"Too little parsed data, cab {len(parsed_int)}, atlas {len(parsed_ext)}")
        return None
    
    print(len(parsed_int))
    print(len(parsed_ext))
    
    T_int_lower = np.array([p[0] for p in parsed_int])
    T_int_upper = np.array([p[1] for p in parsed_int])
    T_int_obsrv = np.array([p[2] for p in parsed_int])
    
    T_ext_lower = np.array([p[0] for p in parsed_ext])
    T_ext_upper = np.array([p[1] for p in parsed_ext])
    T_ext_obsrv = np.array([p[2] for p in parsed_ext])

    df_int = pd.DataFrame({
        "T_lower": T_int_lower,
        "T_upper": T_int_upper
    })

    df_ext = pd.DataFrame({
        "T_lower": T_ext_lower,
        "T_upper": T_ext_upper
    })


    # aft_int = WeibullAFTFitter()
    # aft_int.fit_interval_censoring(df_int, lower_bound_col="T_lower", upper_bound_col="T_upper")

    # aft_ext = WeibullAFTFitter()
    # aft_ext.fit_interval_censoring(df_ext, lower_bound_col="T_lower", upper_bound_col="T_upper")


    # mask_int = ~np.isnan(T_int) & ~np.isnan(T_int_upper)
    # mask_ext = ~np.isnan(T_ext) & ~np.isnan(T_ext_upper)


    # kmf_cabbage = KaplanMeierFitter()
    # kmf_atlas = KaplanMeierFitter()

    # if mask_int.sum() > 0:
    #     kmf_cabbage.fit_interval_censoring(T_int[mask_int], T_int_upper[mask_int], label="Cabbage")

    # if mask_ext.sum() > 0:
    #     kmf_atlas.fit_interval_censoring(T_ext[mask_ext], T_ext_upper[mask_ext], label="Atlas")


    # # CABBAGE exact or right-censored only
    # durations_A = []
    # event_A = []
    # for t, t_up in zip(T_int[mask_int], T_int_upper[mask_int]):
    #     if np.isinf(t_up):  # right-censored
    #         durations_A.append(t)
    #         event_A.append(0)
    #     elif t == t_up:  # exact
    #         durations_A.append(t)
    #         event_A.append(1)
    #     # else:
    #     #     print(f"t: {t}, t_up: {t_up}")


    # # ATLAS
    # durations_B = []
    # event_B = []
    # for t, t_up in zip(T_ext[mask_ext], T_ext_upper[mask_ext]):
    #     if np.isinf(t_up):  # right-censored
    #         durations_B.append(t)
    #         event_B.append(0)
    #     elif t == t_up:  # exact
    #         durations_B.append(t)
    #         event_B.append(1)
    #     # else:
    #     #     print(f"t: {t}, t_up: {t_up}")


    #  # Again check if too short/small
    # if len(durations_A) < len(durations_B)/100 or len(durations_B) < len(durations_A)/100:
    #     print(f"cox too sort/small, durA {len(durations_A)}, durB {len(durations_B)}")
    #     return None


    # # Create combined dataframe
    # dfA = pd.DataFrame({'T': durations_A, 'E': event_A, 'group': 1})
    # dfB = pd.DataFrame({'T': durations_B, 'E': event_B, 'group': 0})
    # df = pd.concat([dfA, dfB])

    # # Fit Cox model
    # try:
    #     cph = CoxPHFitter().fit(df, duration_col='T', event_col='E')
    #     summary = cph.summary
    #     p = summary.loc["group", "p"]
    #     if p>0.001:
    #         print("higher than 0.001")
    # except Exception as e:
    #     print("Couldn't fit cox")
    #     return None
    
    return {
        "species": species,
        "antibiotic": antibiotic,
        # "weibull_cabbage": aft_int,
        # "weibull_atlas": aft_ext,
        # "p_value": p,
        "n_internal": len(parsed_int),
        "n_external": len(parsed_ext),
        # "dur_a":durations_A,
        # "dur_b":durations_B,
        # "len_internal": len(internal_subset),
        # "len_external": len(external_subset),
        # "kmf_cabbage": kmf_cabbage,
        # "kmf_atlas": kmf_atlas
    }


results = []
i = 0
j = 0
for species in overlapping_bacteria:
    print(f"Analysing all antibiotics for {species}")
    if i > 100:
        break

    for antibiotic in overlapping_antibio:
        if j > 100:
            break
        row = generate_analysis(cabbage_data, atlas_data, species, antibiotic, None, None, None)
        if row:
            results.append(row)
            j = j+1
    j = 0
    i = i+1
        

results_df = pd.DataFrame(results)

results_df["balance_score"] = results_df.apply(
    lambda row: hmean([row["n_internal"], row["n_external"]]) if row["n_internal"] > 0 and row["n_external"] > 0 else 0,
    axis=1
)
sorted_df = results_df.sort_values(by="balance_score", ascending=False)
sorted_df.to_csv('./generated_data/sorted_species_antibio_pairs.csv')