In [1]:
import pandas as pd
import numpy as np
import malariagen_data
import allel
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import joblib # For saving/loading models
import os 
from concurrent.futures import ProcessPoolExecutor

In [2]:
ag3_init_params = {
    "sample_sets": "3.3",
    "debug": False,
    "show_progress": False
}

In [3]:
ag3_main_process = malariagen_data.Ag3(**ag3_init_params)
samples_metadata = ag3_main_process.sample_metadata()

In [4]:
gambiae_coluzzii_samples = samples_metadata.query("aim_species in ['gambiae', 'coluzzii']").dropna(subset=["sample_id", "aim_species"])

In [5]:
num_samples_for_training = 500
sub_samples = gambiae_coluzzii_samples.sample(num_samples_for_training, random_state=42)
sub_sample_ids = sub_samples['sample_id'].tolist()

In [6]:
y_labels_all_samples = sub_samples.set_index('sample_id')['aim_species']

In [7]:
def encode_diploid_allel(gt_array):
    """Encodes a diploid genotype array into a dosage matrix."""
    gt = allel.GenotypeArray(gt_array)
    allele_counts = gt.to_allele_counts()
    alt_allele_dosage = allele_counts[:, :, 1]
    dosage_matrix = np.where(gt.is_missing(), np.nan, alt_allele_dosage)
    return dosage_matrix

In [8]:
contigs = ag3_main_process.contigs
print(f"Detected contigs: {contigs}")

Detected contigs: ('2R', '2L', '3R', '3L', 'X')


In [9]:
partition_window_size = 1_000_000

In [10]:
all_partition_results = {}
overall_confusion_matrices = []

In [11]:
output_dir = 'trained_classifiers'
os.makedirs(output_dir, exist_ok=True)
print(f"Trained classifiers will be saved in: {output_dir}")

Trained classifiers will be saved in: trained_classifiers


In [12]:
def process_single_partition(partition_tuple, sub_sample_ids, y_labels_all_samples, ag3_init_parameters, output_directory):
    p_contig, p_start, p_end = partition_tuple
    partition_id = f"{p_contig}_{p_start}-{p_end}"
    model_filename = os.path.join(output_directory, f"classifier_{partition_id}.joblib")

    # --- NEW: Check if classifier already exists for this partition ---
    if os.path.exists(model_filename):
        return partition_id, "SKIPPED" # Return a special string to indicate skipping

    try:
        # --- NEW: Initialize ag3 instance within this worker process ---
        ag3_instance = malariagen_data.Ag3(**ag3_init_parameters)

        # 3.1. Load SNP data for the current partition
        callset_partition = ag3_instance.snp_calls(
            region=f"{p_contig}:{p_start}-{p_end}", 
            sample_query=f"sample_id in {sub_sample_ids}"
        )
        
        # Check if any variants were loaded for this partition
        if callset_partition['call_genotype'].shape[0] == 0:
            return partition_id, None # Return None to indicate skipped partition

        # 3.2. Preprocessing for the current partition
        p_call_genotype = callset_partition['call_genotype'].values
        p_missing_genotype_mask = (p_call_genotype[:, :, 0] == -1) | (p_call_genotype[:, :, 1] == -1)
        p_variant_missingness = np.mean(p_missing_genotype_mask, axis=1)
        p_sample_missingness = np.mean(p_missing_genotype_mask, axis=0)

        # --- Feature Reduction: Step 1 - Stricter Missingness Filtering ---
        p_filtered_variants_idx_missingness = np.where(p_variant_missingness <= 0.01)[0]
        p_filtered_samples_idx = np.where(p_sample_missingness <= 0.05)[0]

        p_filtered_gt_initial = p_call_genotype[p_filtered_variants_idx_missingness, :, :][:, p_filtered_samples_idx, :]
        p_filtered_sample_ids = callset_partition['sample_id'].values[p_filtered_samples_idx]
        p_variant_positions_initial = callset_partition['variant_position'].values[p_filtered_variants_idx_missingness]

        if p_filtered_gt_initial.shape[0] == 0 or p_filtered_gt_initial.shape[1] == 0:
            return partition_id, None

        p_encoded_genotypes = encode_diploid_allel(p_filtered_gt_initial)
        temp_snp_df = pd.DataFrame(p_encoded_genotypes.T, index=p_filtered_sample_ids, columns=p_variant_positions_initial)
        
        # --- Feature Reduction: Step 2 - MAF Filtering ---
        temp_gt_allel = allel.GenotypeArray(p_filtered_gt_initial)
        temp_allele_counts = temp_gt_allel.count_alleles()
        temp_allele_frequencies = np.where(temp_allele_counts.sum(axis=1)[:, np.newaxis] == 0, 0.0, temp_allele_counts.to_frequencies())
        temp_mafs = np.min(temp_allele_frequencies, axis=1)
        maf_filtered_variants_idx = np.where(temp_mafs >= 0.01)[0]
        temp_snp_df_maf_filtered = temp_snp_df.iloc[:, maf_filtered_variants_idx]

        if temp_snp_df_maf_filtered.shape[1] == 0:
            return partition_id, None

        # --- Feature Reduction: Step 3 - Variance Filtering ---
        temp_snp_df_maf_filtered_no_nan_cols = temp_snp_df_maf_filtered.dropna(axis=1, how='all')
        snp_variances = temp_snp_df_maf_filtered_no_nan_cols.var(axis=0, skipna=True)
        variance_filtered_variants_idx = snp_variances[snp_variances >= 0.01].index
        X_partition = temp_snp_df_maf_filtered_no_nan_cols[variance_filtered_variants_idx]

        if X_partition.shape[1] == 0:
            return partition_id, None

        p_imputer = SimpleImputer(strategy='most_frequent')
        X_partition_imputed = pd.DataFrame(
            p_imputer.fit_transform(X_partition),
            index=X_partition.index,
            columns=X_partition.columns
        )
        
        y_partition = y_labels_all_samples.loc[X_partition_imputed.index]

        common_samples = X_partition_imputed.index.intersection(y_partition.index)
        X_partition_final = X_partition_imputed.loc[common_samples]
        y_partition_final = y_partition.loc[common_samples]

        if X_partition_final.shape[0] == 0 or len(y_partition_final.unique()) < 2:
            return partition_id, None
        
        # 3.3. K-Fold Cross-Validation for the current partition
        kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        
        partition_accuracy_scores = []
        partition_precision_scores = []
        partition_recall_scores = []
        partition_f1_scores = []
        partition_confusion_matrices = []

        for fold, (train_index, test_index) in enumerate(kf.split(X_partition_final, y_partition_final)):
            X_train, X_test = X_partition_final.iloc[train_index], X_partition_final.iloc[test_index]
            y_train, y_test = y_partition_final.iloc[train_index], y_partition_final.iloc[test_index]

            model = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=4) 
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)

            partition_accuracy_scores.append(accuracy_score(y_test, y_pred))
            partition_precision_scores.append(precision_score(y_test, y_pred, average='weighted', zero_division=0))
            partition_recall_scores.append(recall_score(y_test, y_pred, average='weighted', zero_division=0))
            partition_f1_scores.append(f1_score(y_test, y_pred, average='weighted', zero_division=0))
            partition_confusion_matrices.append(confusion_matrix(y_test, y_pred, labels=model.classes_))

        partition_results = {
            'accuracy_mean': np.mean(partition_accuracy_scores),
            'accuracy_std': np.std(partition_accuracy_scores),
            'precision_mean': np.mean(partition_precision_scores),
            'precision_std': np.std(partition_precision_scores),
            'recall_mean': np.mean(partition_recall_scores),
            'recall_std': np.std(partition_recall_scores),
            'f1_mean': np.mean(partition_f1_scores),
            'f1_std': np.std(partition_f1_scores),
            'classes': model.classes_.tolist(), 
            'aggregated_confusion_matrix': np.sum(partition_confusion_matrices, axis=0).tolist()
        }
        
        # 3.4. Save the trained model for this partition
        joblib.dump(model, model_filename)
        
        return partition_id, partition_results # Return results for this partition

    except Exception as e:
        return partition_id, None # Return None for failed partitions

In [13]:
all_partitions_list = []
for contig in contigs:
    try:
        # Use ag3_main_process to get contig info
        contig_callset_info = ag3_main_process.snp_calls(region=contig, sample_query=f"sample_id in {sub_sample_ids}")
        if 'variant_position' not in contig_callset_info or contig_callset_info['variant_position'].shape[0] == 0:
            print(f"  No variants found for contig {contig}. Skipping.")
            continue
        max_pos = contig_callset_info['variant_position'].values.max()
    except Exception as e:
        print(f"  Could not get max position for contig {contig}: {e}. Skipping.")
        continue

    partitions_for_contig = []
    start_pos = 1
    while start_pos <= max_pos:
        end_pos = min(start_pos + partition_window_size - 1, max_pos)
        partitions_for_contig.append((contig, start_pos, end_pos))
        start_pos = end_pos + 1
    
    if not partitions_for_contig:
        print(f"  No partitions generated for contig {contig}. Skipping.")
        continue
    print(f"  Generated {len(partitions_for_contig)} partitions for {contig}.")
    all_partitions_list.extend(partitions_for_contig)


  Generated 62 partitions for 2R.
  Generated 50 partitions for 2L.
  Generated 54 partitions for 3R.
  Generated 42 partitions for 3L.
  Generated 25 partitions for X.


In [None]:
if not all_partitions_list:
    print("No partitions generated for parallel processing.")
else:
    # Set the number of parallel processes
    num_parallel_processes = 4 # You can adjust this based on your CPU cores
    print(f"\nStarting parallel processing with {num_parallel_processes} workers...")

    # Use ProcessPoolExecutor to run process_single_partition in parallel
    with ProcessPoolExecutor(max_workers=num_parallel_processes) as executor:
        # executor.map returns an iterator that yields results as they are completed
        results_iterator = executor.map(
            process_single_partition,
            all_partitions_list,
            [sub_sample_ids] * len(all_partitions_list),
            [y_labels_all_samples] * len(all_partitions_list),
            [ag3_init_params] * len(all_partitions_list), # Pass the ag3 initialization parameters
            [output_dir] * len(all_partitions_list)
        )

        # Collect results and aggregate
        all_partition_results = {} # Re-initialize here to collect results from parallel processes
        overall_confusion_matrices = [] # Re-initialize here

        for partition_id, result in results_iterator:
            if result == "SKIPPED":
                print(f"  Classifier for partition {partition_id} already exists. Skipping training.")
            elif result is not None:
                all_partition_results[partition_id] = result
                overall_confusion_matrices.append(np.array(result['aggregated_confusion_matrix']))
                print(f"  Successfully processed partition: {partition_id} (Avg Accuracy: {result['accuracy_mean']:.4f})")
            else:
                print(f"  Skipped or failed to process partition: {partition_id}")


Starting parallel processing with 4 workers...
  Classifier for partition 2R_1-1000000 already exists. Skipping training.
  Classifier for partition 2R_1000001-2000000 already exists. Skipping training.
  Classifier for partition 2R_2000001-3000000 already exists. Skipping training.
  Classifier for partition 2R_3000001-4000000 already exists. Skipping training.
  Classifier for partition 2R_4000001-5000000 already exists. Skipping training.
  Classifier for partition 2R_5000001-6000000 already exists. Skipping training.
  Classifier for partition 2R_6000001-7000000 already exists. Skipping training.
  Classifier for partition 2R_7000001-8000000 already exists. Skipping training.
  Classifier for partition 2R_8000001-9000000 already exists. Skipping training.
