In [None]:
import malariagen_data
import numpy as np
import pandas as pd
from xgboost import XGBClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.metrics import classification_report

In [None]:
ag3 = malariagen_data.Ag3()
df = ag3.sample_metadata().dropna(subset=["sample_id", "aim_species"])
species_list = df['aim_species'].unique().tolist()

In [None]:
N = 100

In [4]:
sample_rows = []
for sp in species_list:
    sp_rows = df[df["aim_species"] == sp]
    sample_rows.append(sp_rows.sample(N, random_state=42) if len(sp_rows) >= N else sp_rows)
samples_df = pd.concat(sample_rows).set_index("sample_id")
sample_ids = samples_df.index.tolist()
labels = samples_df['aim_species'].values

In [5]:
le = LabelEncoder()
y = le.fit_transform(labels)
species_names = le.classes_

In [6]:
regions = ag3.contigs

In [7]:
def is_biallelic_site(genos):
    alleles = set(genos.flatten())
    alleles.discard(-1)
    return alleles.issubset({0, 1})

In [8]:
def encode_diploid(gt_slice):
    g0, g1 = gt_slice[:, 0], gt_slice[:, 1]
    encoded = np.full(len(gt_slice), np.nan, dtype=np.float32)
    encoded[(g0 == 0) & (g1 == 0)] = 0
    encoded[((g0 == 0) & (g1 == 1)) | ((g0 == 1) & (g1 == 0))] = 1
    encoded[(g0 == 1) & (g1 == 1)] = 2
    encoded[(g0 < 0) | (g1 < 0)] = np.nan
    return encoded

In [9]:
def encode_diploid(gt_slice):
    g0, g1 = gt_slice[:, 0], gt_slice[:, 1]
    encoded = np.full(len(gt_slice), np.nan)
    mask_hom_ref = (g0 == 0) & (g1 == 0)
    mask_het = ((g0 == 0) & (g1 == 1)) | ((g0 == 1) & (g1 == 0))
    mask_hom_alt = (g0 == 1) & (g1 == 1)
    mask_missing = (g0 < 0) | (g1 < 0)
    encoded[mask_hom_ref] = 0
    encoded[mask_het] = 1
    encoded[mask_hom_alt] = 2
    encoded[mask_missing] = np.nan
    return encoded

In [None]:
for region in regions:
    print(f"Processing region: {region}")
    ds = ag3.snp_calls(region=region, sample_query=f"sample_id in {sample_ids}")
    variant_pos = ds['variant_position'].values
    call_genotype = ds['call_genotype'].values
    del ds  # Free dask/xarray

    # Biallelic filtering
    biallelic_mask = np.array([
        is_biallelic_site(call_genotype[i, :, :])
        for i in range(call_genotype.shape[0])
    ])
    call_genotype_biallelic = call_genotype[biallelic_mask, :, :]
    variant_pos_biallelic = variant_pos[biallelic_mask]
    del call_genotype, variant_pos, biallelic_mask

    # Encode genotypes (samples x SNPs for this region)
    X = np.array([encode_diploid(call_genotype_biallelic[:, s, :])
                  for s in range(call_genotype_biallelic.shape[1])], dtype=np.float32)
    del call_genotype_biallelic

    # XGBoost classifier (chromosome-wise)
    clf = XGBClassifier(use_label_encoder=False, eval_metric='mlogloss', tree_method='auto', n_jobs=-1)
    clf.fit(X, y)

    # Feature importance: top million SNPs (or all if < 1M)
    top_n = min(1_000_000, X.shape[1])
    importances = clf.feature_importances_
    important_snps_idx = np.argsort(importances)[::-1][:top_n]
    variant_positions_top = variant_pos_biallelic[important_snps_idx]
    importances_top = importances[important_snps_idx]
    df_top_snps = pd.DataFrame({
        "variant_position": variant_positions_top,
        "feature_importance": importances_top
    })
    df_top_snps.to_csv(f"top_million_snps_{region}.csv", index=False)
    print(f"Saved {len(df_top_snps)} top SNPs for {region} to top_million_snps_{region}.csv")

    # Optionally print region summary
    cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
    cv_scores = cross_val_score(clf, X, y, cv=cv)
    print(f"  Mean CV accuracy for {region}: {np.mean(cv_scores):.2f}")

    del X, importances, important_snps_idx, variant_pos_biallelic, df_top_snps

Processing region: 2R
                                 