#### Imports and Environment Setup

In [None]:
import os
import sys
import json
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm
import torch

# Custom project modules
# Ensure your current directory or PYTHONPATH includes the 'src' folder
from src.data import GenotypeEncoder, ImputationDataset
from src.model import EvoFill
from src.tensor2vcf import make_imputed_vcfgz_from_prob

print(f"Using Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

#### Path Configuration

In [None]:
# 1. Define Workspace
WORK_DIR = Path("<path/to/output_directory>")
MASK_DIR = Path("<path/to/data/masked_vcf>")

# 2. Define Model and Metadata Files
MODEL_BIN = Path("<path/to/models/hg38_chr22_v1.0.bin>")
MODEL_META = Path("<path/to/models/model_meta.json>")
TRAIN_META = Path("<path/to/train/gt_enc_meta.json>")

# Initialize Directories
WORK_DIR.mkdir(parents=True, exist_ok=True)
(WORK_DIR / "impute_out").mkdir(exist_ok=True)

# Set working directory for the session
os.chdir(WORK_DIR)

#### Load Pre-trained Model

In [None]:
def load_model(meta_path, bin_path, device):
    """Initializes EvoFill and loads weights."""
    print(f"[INIT] Loading metadata: {meta_path.name}")
    with open(meta_path, 'r') as f:
        meta = json.load(f)
    
    model = EvoFill(
        n_alleles=int(meta["alleles"]),
        total_sites=int(meta["total_sites"]),
        chunk_size=int(meta["chunk_size"]),
        chunk_overlap=int(meta["overlap"]),
        d_model=int(meta["d_model"]),
        d_state=int(meta["d_state"]),
        headdim=int(meta["headdim"]),
        bimamba_layers=int(meta["bimamba_layers"]),
        stack_mamba_layers=int(meta["stack_mamba_layers"]),
    ).to(device)

    print(f"[INIT] Loading weights: {bin_path.name}")
    state_dict = torch.load(bin_path, map_location="cpu")
    model.load_state_dict(state_dict)
    model.eval()
    
    return model, meta

device = "cuda" if torch.cuda.is_available() else "cpu"
model, model_meta = load_model(MODEL_META, MODEL_BIN, device)
print(f"\n[OK] Model '{model_meta.get('model_name')}' is ready for inference.")

#### Define Inference Pipeline

In [None]:
def process_mask_vcf(model, mask_vcf_path, out_dir, train_meta_path, device, digits=4):
    mask_vcf_path = Path(mask_vcf_path)
    
    # --- Step 1: Encoding --- 
    # The impute_in folders will be generated under the WORK_DIR.
    gt_enc = GenotypeEncoder(
        phased=False, gts012=False, save2disk=True,
        save_dir=out_dir / "impute_in"
    )
    gt_enc = gt_enc.encode_ref(
        ref_meta_json=str(train_meta_path),
        vcf_path=str(mask_vcf_path)
    )

    # --- Step 2: Dataset Loading ---
    dataset = ImputationDataset(x_gts_sparse=gt_enc.X_gt, seq_depth=gt_enc.seq_depth)
    
    def collate_fn(batch):
        return torch.stack([b[0] for b in batch]), [b[1] for b in batch]

    loader = torch.utils.data.DataLoader(
        dataset, batch_size=1, shuffle=False, num_workers=4,
        pin_memory=True, collate_fn=collate_fn
    )

    # --- Step 3: Inference ---
    y_prob, y_mask = [], []
    with torch.no_grad():
        for x_onehot, _ in tqdm(loader, desc=f"Imputing {mask_vcf_path.name}", leave=False):
            x_onehot = x_onehot.to(device)
            _, prob, _ = model(x_onehot)
            miss_mask = x_onehot[..., -1].bool()
            y_prob.append(prob.cpu())
            y_mask.append(miss_mask.cpu())

    y_prob = torch.cat(y_prob).numpy()
    y_mask = torch.cat(y_mask).numpy()

    # Save intermediate results 
    # The impute_out folders will be generated under the WORK_DIR.

    prob_path = out_dir / "impute_out" / f"{mask_vcf_path.stem}_prob.npy"
    mask_path = out_dir / "impute_out" / f"{mask_vcf_path.stem}_mask.npy"
    np.save(prob_path, y_prob)
    np.save(mask_path, y_mask)

    # --- Step 4: VCF Reconstruction ---
    out_vcfgz = out_dir / f"{mask_vcf_path.stem}_imputed_Evofill.vcf.gz"
    make_imputed_vcfgz_from_prob(
        prob_npy=str(prob_path),
        mask_vcf_gz=str(mask_vcf_path),
        out_vcfgz=str(out_vcfgz),
        digits=digits
    )
    return out_vcfgz

#### Run Batch Processing

In [None]:
# Scan for all VCF files in the target directory
mask_files = sorted(MASK_DIR.glob("*.vcf.gz"))
print(f"[INFO] Found {len(mask_files)} files to process.\n")

results = []
for f in mask_files:
    print(f"--- Processing: {f.name} ---")
    try:
        final_vcf = process_mask_vcf(model, f, WORK_DIR, TRAIN_META, device)
        results.append(final_vcf)
        print(f"[SUCCESS] Saved to: {final_vcf.name}\n")
    except Exception as e:
        print(f"[ERROR] Failed to process {f.name}: {e}\n")


print("INFERENCE COMPLETE")
for r in results:
    print(f"- {r}")