In [1]:
!pip install cyvcf2



In [2]:
import random
import sys
from cyvcf2 import VCF, Writer
from pathlib import Path

sys.path.append(str(Path().resolve().parents[2] / "src" / "python"))
from paths import training_path, validation_path

In [3]:
def stratified_split(in_vcf, out_train, out_valid, frac, seed=None):
    if seed is not None:
        random.seed(seed)

    # First pass: count how many GOLDEN / non-GOLDEN
    vcf = VCF(in_vcf)
    total_golden = 0
    total_non = 0
    for rec in vcf:
        if rec.INFO.get("GOLDEN"):
            total_golden += 1
        else:
            total_non += 1
    vcf.close()

    n_valid_golden = int(total_golden * frac)
    n_valid_non    = int(total_non    * frac)

    # Reservoirs will store 1-based indices within each class
    golden_reservoir = []
    non_reservoir    = []

    # Second pass: build reservoirs
    vcf = VCF(in_vcf)
    ig, ing = 0, 0
    for rec in vcf:
        if rec.INFO.get("GOLDEN"):
            ig += 1
            if len(golden_reservoir) < n_valid_golden:
                golden_reservoir.append(ig)
            else:
                j = random.randrange(ig)
                if j < n_valid_golden:
                    golden_reservoir[j] = ig
        else:
            ing += 1
            if len(non_reservoir) < n_valid_non:
                non_reservoir.append(ing)
            else:
                j = random.randrange(ing)
                if j < n_valid_non:
                    non_reservoir[j] = ing
    vcf.close()

    # Third pass: write records to train vs. valid
    vcf = VCF(in_vcf)
    w_train = Writer(out_train, vcf)
    w_valid = Writer(out_valid, vcf)

    ig, ing = 0, 0
    for rec in vcf:
        if rec.INFO.get("GOLDEN"):
            ig += 1
            if ig in golden_reservoir:
                w_valid.write_record(rec)
            else:
                w_train.write_record(rec)
        else:
            ing += 1
            if ing in non_reservoir:
                w_valid.write_record(rec)
            else:
                w_train.write_record(rec)

    w_train.close()
    w_valid.close()
    vcf.close()

    print(f"Total GOLDEN variants       : {total_golden}")
    print(f"Total non-GOLDEN variants   : {total_non}")
    print(f"Validation GOLDEN assigned  : {len(golden_reservoir)}")
    print(f"Validation non-GOLDEN assigned: {len(non_reservoir)}")
    print("Done. Stratified split complete.")

In [4]:
input_vcf = training_path / "annotated_hg38.vcf"
output_train_vcf = training_path / "training_hg38.vcf"
output_validation_vcf = validation_path / "validation_hg38.vcf"
frac = 0.2
seed = 42

In [5]:
stratified_split(input_vcf, output_train_vcf, output_validation_vcf, frac, seed)

Total GOLDEN variants       : 11068
Total non-GOLDEN variants   : 268035
Validation GOLDEN assigned  : 2213
Validation non-GOLDEN assigned: 53607
Done. Stratified split complete.
