## Generate data on how many examples each dataset contains for an antibiotic - species pair
### Harmonic mean used to rank the pairs based on balance and data volume

In [1]:
import os
import pandas as pd
import numpy as np

from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
import pycountry

from scipy.stats import hmean


atlas_path = "./Vivli Data Challenge 2025 - Nachtrab/ATLAS_Antibiotics/2025_03_11 atlas_antibiotics.xlsx"
atlas_df = pd.read_excel(atlas_path)

cabbage_path = "./CABBAGEdata/processed_database/dereplicated__all_clean.csv"
cabbage_df = pd.read_csv(cabbage_path)

# Only keep the liquid solution typing methods
cabbage_df.drop(cabbage_df[cabbage_df["measurement_unit"] != "mg/l"].index, inplace=True)

  cabbage_df = pd.read_csv(cabbage_path)


In [2]:
atlas_data = atlas_df.copy()
cabbage_data = cabbage_df.copy()

# Reviewed and corrected
cabbage_data["combined_value"] = cabbage_data["measurement_sign"].fillna("") + cabbage_data["measurement_value"].astype(str).fillna("")

cabbage_data.loc[cabbage_data['Antibiotic_name'] == 'Trimethoprim sulfamethoxazole', 'Antibiotic_name'] = 'Trimethoprim sulfa'

cabbage_data["host_sex"] = (
    cabbage_data["host_sex"]
    .replace({"male": 0.0, "female": 1.0})
    .infer_objects(copy=False)
    )


mask = cabbage_data["phenotype"].notna()
phen_enc = pd.Series(index=cabbage_data.index, dtype="float64")
phen_enc.loc[mask] = OrdinalEncoder(categories=[[
    "susceptible", "intermediate", "resistant",
    "decreased susceptibility", "susceptible dose dependent",
    "non-susceptible"
]]).fit_transform(cabbage_data.loc[mask, ["phenotype"]]).ravel()
cabbage_data["phenotype_encoded"] = phen_enc.astype("float32")


bins = [0, 2, 12, 18, 64, 84, np.inf]
labels = ["0 to 2 Years", "3 to 12 Years", "13 to 18 Years", 
          "19 to 64 Years", "65 to 84 Years", "85 and Over"]

cabbage_data["Age Group"] = pd.cut(cabbage_data["host_age"], bins=bins, labels=labels, right=True)

age_ordinal = OrdinalEncoder(categories=[labels], handle_unknown="use_encoded_value",
                             unknown_value=-1, dtype=float)
age_vals = age_ordinal.fit_transform(cabbage_data[["Age Group"]]).ravel()
age_vals = np.where(age_vals == -1, np.nan, age_vals)
cabbage_data["Age Group Encoded"] = age_vals.astype("float32")

# Create a new Species column for cabbage combining genus and species
cabbage_data["combined_species"] = cabbage_data["genus"].astype(str) + " " + cabbage_data["species"]

# Convert Year to numerical.
cabbage_data['collection_date'].astype('float32')

# defragment once after the few new columns
# cabbage_data = cabbage_data.copy()




# Bellow inverting of data with assistance from ChatGPT
import re

# Ensure a stable row_id and normalize column names (strip trailing spaces)
if "row_id" not in atlas_data.columns:
    atlas_data = atlas_data.reset_index(drop=True)
    atlas_data["row_id"] = np.arange(len(atlas_data), dtype=np.int32)

# Find antibio columns
I_cols = [c for c in atlas_data.columns if re.search(r"(?i)_I$", c)]

# Derive antibiotic base names from *_I and find MIC columns that are the bare base name
bases   = [re.sub(r"(?i)_I$", "", c) for c in I_cols]
MIC_cols = [b for b in bases if b in atlas_data.columns]   # MIC lives in the base column

# 3) Build abx subframe and map to a MultiIndex of (antibiotic_raw, measure)
abx_cols = I_cols + MIC_cols
meta = atlas_data.drop(columns=abx_cols)

abx = atlas_data.set_index("row_id")[abx_cols].copy()

pairs = []
for c in abx.columns:
    if re.search(r"(?i)_I$", c):
        drug = re.sub(r"(?i)_I$", "", c)
        measure = "I"
    else:
        drug = c
        measure = "MIC"
    pairs.append((drug, measure))

abx.columns = pd.MultiIndex.from_tuples(pairs, names=["antibiotic_raw", "measure"])

# Stack to long (one row per (row_id, antibiotic))
abx_long = abx.stack("antibiotic_raw", dropna=False).reset_index()

# Rename measure columns when present
rename_map = {}
if "I" in abx_long.columns:
    rename_map["I"] = "phenotype_encoded"
if "MIC" in abx_long.columns:
    rename_map["MIC"] = "mic"
abx_long = abx_long.rename(columns=rename_map)

measure_cols = [c for c in ["phenotype_encoded", "mic"] if c in abx_long.columns]
if measure_cols:
    abx_long = abx_long.dropna(subset=measure_cols, how="all")

atlas_data = abx_long.merge(meta, on="row_id", how="left", validate="many_to_one")


# encoder = OrdinalEncoder(categories=[['Susceptible', 'Intermediate', 'Resistant']] * len(columns_of_interest), 
#                          handle_unknown="use_encoded_value", unknown_value=np.nan, 
#                          dtype=float)
# atlas_data[columns_of_interest] = encoder.fit_transform(atlas_data[columns_of_interest])

# Handle age and gender [using nans]
atlas_data['Age Group'] = atlas_data['Age Group'].replace('Unknown', np.nan)
age_ordinal = OrdinalEncoder(categories=[["0 to 2 Years", "3 to 12 Years", "13 to 18 Years", "19 to 64 Years", "65 to 84 Years", "85 and Over"]],
                            handle_unknown="use_encoded_value", unknown_value=np.nan, 
                            dtype=float)

atlas_data["Age Group"] = age_ordinal.fit_transform(atlas_data[["Age Group"]])

# Male 0.0, Female 1.0
atlas_data['Gender'] = atlas_data['Gender'].replace('Unknown', np.nan)
gender_ordinal = OrdinalEncoder(categories=[["Male", "Female"]],
                                handle_unknown="use_encoded_value", unknown_value=np.nan, 
                                dtype=float)
atlas_data["Gender"] = gender_ordinal.fit_transform(atlas_data[["Gender"]])


# Substitue the unknown ranks for nan
atlas_data['Speciality'] = atlas_data['Speciality'].replace("None Given",np.nan)
atlas_data['Speciality'] = atlas_data['Speciality'].replace("Other", np.nan)

atlas_data['In / Out Patient'] = atlas_data['In / Out Patient'].replace('None Given', np.nan)
atlas_data['In / Out Patient'] = atlas_data['In / Out Patient'].replace('Other', np.nan)

# Convert Year to numerical. Corresponds to Cabbage "collection_date"
atlas_data['Year'].astype('float32')

# Convert country to Alpha 3
atlas_data["country_alpha3"] = atlas_data["Country"].apply(
    lambda x: pycountry.countries.lookup(x).alpha_3 if isinstance(x, str) and pycountry.countries.get(name=x) else None
)

  .replace({"male": 0.0, "female": 1.0})
  abx_long = abx.stack("antibiotic_raw", dropna=False).reset_index()


In [None]:

def generate_pair_info(atlas_data, cabbage_data):
    """
    Slow, non vectorized
    Generate volume of data and harmonic mean values for each antibiotic - species pair
    """

    common_species = list(set(atlas_data["Species"].unique().tolist())
                          .intersection(cabbage_data["combined_species"].unique().tolist()))

    common_antibio = list(set(atlas_data["antibiotic_raw"].unique().tolist())
                          .intersection(cabbage_data["Antibiotic_name"].unique().tolist()))

    results = []
    for species in common_species:
        for antibiotic in common_antibio:
            examples_atl = atlas_data[((atlas_data["Species"] == species.lower()) & 
                                       (atlas_data["antibiotic_raw"] == antibiotic.lower()))]
            
            examples_cab = cabbage_data[((cabbage_data["combined_species"] == species.lower()) & 
                                       (cabbage_data["Antibiotic_name"] == antibiotic.lower()))]
            len_atl = len(examples_atl)
            len_cab = len(examples_cab)
            if len_atl > 50 & len_cab > 50:
                row = {"species": species,
                       "antibiotic": antibiotic,
                       "n_external": len_atl,
                       "n_internal": len_cab,
                       "score": hmean([len_atl, len_cab])}
                results.append(row)

    return results
        

def generate_pair_info_fast(atlas_data, cabbage_data, min_n=50):
    """
    Vectorized with help from ChatGPT
    Generate volume of data and harmonic mean values for each antibiotic - species pair
    """

    # Normalize to lowercase and drop rows missing either field
    a = atlas_data[["Species", "antibiotic_raw"]].dropna().copy()
    a["species"] = a["Species"].astype(str).str.lower()
    a["antibiotic"] = a["antibiotic_raw"].astype(str).str.lower()

    c = cabbage_data[["combined_species", "Antibiotic_name"]].dropna().copy()
    c["species"] = c["combined_species"].astype(str).str.lower()
    c["antibiotic"] = c["Antibiotic_name"].astype(str).str.lower()

    a_counts = (a.groupby(["species", "antibiotic"])
                  .size()
                  .rename("n_atlas")
                  .reset_index())

    c_counts = (c.groupby(["species", "antibiotic"])
                  .size()
                  .rename("n_cabbage")
                  .reset_index())

    pairs = a_counts.merge(c_counts, on=["species", "antibiotic"], how="inner")

    pairs = pairs[(pairs["n_atlas"] >= min_n) & (pairs["n_cabbage"] >= min_n)]

    pairs["score"] = 2.0 / (1.0 / pairs["n_atlas"] + 1.0 / pairs["n_cabbage"])
    pairs = pairs.sort_values("score", ascending=False, kind="mergesort").reset_index(drop=True)

    return pairs

results_df = pd.DataFrame(generate_pair_info_fast(atlas_data, cabbage_data))
sorted_df = results_df.sort_values(by="score", ascending=False)


In [9]:
sorted_df.to_csv('./generated_data/sorted_species_antibio_pairs.csv')