In [None]:
sample_sets = "AG1000G-BF-A"
cohorts_analysis = "20230223"
min_cohort_size = 20
use_gcs_cache = False

In [None]:
import yaml
import pandas as pd
import malariagen_data
from pyprojroot import here
import numpy as np

In [None]:
extra_params = dict()
if use_gcs_cache:
    extra_params["url"] = "simplecache::gs://vo_agam_release"
    extra_params["simplecache"] = dict(cache_storage=(here() / "gcs_cache").as_posix())

ag3 = malariagen_data.Ag3(
    # pin the version of the cohorts analysis for reproducibility
    cohorts_analysis=cohorts_analysis,
    results_cache=(here() / "malariagen_data_cache").as_posix(),
    **extra_params,
)
ag3

In [None]:
df_samples = ag3.sample_metadata(sample_sets=sample_sets)

In [None]:
def month_to_quarter(row):
    return ((row.month - 1) // 3) + 1 if row.month > 0 else -1

In [None]:
# add a "quarter" column for convenience
df_samples["quarter"] = df_samples.apply(
    month_to_quarter,
    axis="columns"
)
df_samples

In [None]:
# check the quarter logic
df_samples.groupby("quarter").agg({'month': lambda v: set(v)})

In [None]:
cohorts_col = "cohort_admin2_quarter"

In [None]:
def make_cohort_label(row):
    # N.B., not all cohorts have a quarter defined, because samples were not provided
    # with collection month in the metadata. In this case we expect to fall back to
    # year.
    if row.quarter > 0:
        return f"{row.country} / {row.admin2_name} / {row.taxon} / {row.year} / Q{row.quarter}"
    else:
        return f"{row.country} / {row.admin2_name} / {row.taxon} / {row.year}"

In [None]:
df_cohorts_selected = (
    df_samples
    # N.B., only include females, otherwise data on X chromosome will be wonky
    .query("sex_call == 'F'")
    .groupby(cohorts_col).agg({
        'sample_id': 'count',
        'country': 'first',
        'admin1_iso': 'first',
        'admin1_name': 'first',
        'admin2_name': 'first',
        'taxon': 'first',
        'year': 'first',
        'quarter': 'first',
    })
    .reset_index()
    .rename(columns={
        'sample_id': 'cohort_size',
        cohorts_col: 'cohort_id',
    })
    .query(f'cohort_size >= {min_cohort_size}')
)
df_cohorts_selected['cohort_label'] = df_cohorts_selected.apply(
    make_cohort_label,
    axis="columns",
)
df_cohorts_selected['sample_query'] = df_cohorts_selected.apply(
    # N.B., only include females, otherwise data on X chromosome will be wonky
    lambda row: f"{cohorts_col} == '{row.cohort_id}' and sex_call == 'F'",
    axis="columns",
)
df_cohorts_selected

In [None]:
## Add average latitude and longitude for each cohort for plotting.
## May want to use different approach, but mean OK for very small scales (which our cohorts usually are)

for idx, row in df_cohorts_selected.iterrows():   
    df = df_samples.query(f"cohort_admin2_quarter == '{row['cohort_id']}'")

    df_cohorts_selected.loc[idx, 'latitude'] = df['latitude'].mean()
    df_cohorts_selected.loc[idx, 'longitude'] = df['longitude'].mean()

In [None]:
df_cohorts_selected.to_csv(here() / "build" / "cohorts.csv", index=False)