# VCF statistics

- Modality specific statistics
    - DeepSomatic
    - Mutect2
    - Strelka
- Consensus statistics
- Rescue statistics

```bash
micromamba install -n rnadnavar -c conda-forge -c bioconda cyvcf2=0.31.1 pysam=0.22.1 bcftools=1.21 htslib=1.21 pandas polars seaborn plotly ipykernel jupyterlab_widgets ipywidgets anywidget nbformat
```

In [1]:
import os
import sys
from pathlib import Path
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
import warnings

warnings.filterwarnings("ignore")

# Data processing
import numpy as np
import pandas as pd

# VCF and BAM handling
from cyvcf2 import VCF
import pysam

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (12, 6)

print("✓ All libraries imported successfully")

✓ All libraries imported successfully


## Configuration and File Discovery

Define paths and discover all VCF and alignment files in the dataset.

In [2]:
REFERENCE_FASTA = "/t9k/mnt/joey/bio_db/references/Homo_sapiens/GATK/GRCh38/Sequence/WholeGenomeFasta/Homo_sapiens_assembly38.fasta"

# Configuration
BASE_DIR = Path("/t9k/mnt/hdd/work/Vax/sequencing/aim_exp/rdv_test/COO8801.subset")

# Define file structure
TOOLS = ["deepsomatic", "mutect2", "strelka"]
MODALITIES = ["DNA_TUMOR_vs_DNA_NORMAL", "RNA_TUMOR_vs_DNA_NORMAL"]

## Variant Classification Functions

Define functions to classify variants from different callers (Strelka, DeepSomatic, Mutect2) into biological categories: Somatic, Germline, Reference, and Artifact.

In [3]:
def classify_strelka_variant(filter_val, nt_val, normal_dp):
    """
    Classifies a Strelka variant into Somatic, Germline, Reference, or Artifact.

    Args:
        filter_val (str): The value from the VCF FILTER column (e.g., 'PASS', 'LowDepth').
                          Can be None if the VCF parser returns None for PASS.
        nt_val (str): The value of the INFO/NT field (e.g., 'ref', 'het', 'hom').
        normal_dp (int): The read depth of the Normal sample (FORMAT/DP).

    Returns:
        str: One of ['Somatic', 'Germline', 'Reference', 'Artifact']
    """

    # 1. Somatic: Strictly relies on PASS filter
    # Some parsers return None for PASS, Strelka writes "PASS" string.
    if filter_val == "PASS" or filter_val is None:
        return "Somatic"

    # 2. Germline: Filter failed, but NT indicates variant presence in Normal
    if nt_val in ["het", "hom"]:
        # Rescue if Normal Depth is sufficient (>= 2 per Strelka specs)
        if normal_dp >= 2:
            return "Germline"
        return "Artifact"

    # 3. Reference: Filter failed, but NT indicates Normal is Reference
    if nt_val == "ref":
        # Rescue if Normal Depth is sufficient
        if normal_dp >= 2:
            return "Reference"
        return "Artifact"

    # 4. Artifact: Catch-all for everything else (e.g. NT='conflict', LowDepth with dp<2)
    return "Artifact"


def classify_deepsomatic_variant(filter_val):
    """
    Classifies a DeepSomatic variant into Somatic, Germline, Reference, or Artifact.

    DeepSomatic already provides explicit FILTER labels:
    - PASS: Somatic variant
    - GERMLINE: Germline variant
    - RefCall: Reference (no variant)
    - Other filters: Artifact

    Args:
        filter_val (str): The value from the VCF FILTER column.
                          Can be None (treated as PASS by some parsers).

    Returns:
        str: One of ['Somatic', 'Germline', 'Reference', 'Artifact']
    """

    # Normalize filter value
    if filter_val is None or filter_val == "." or filter_val == "PASS":
        return "Somatic"

    # Check explicit DeepSomatic filter labels
    if "GERMLINE" in filter_val.upper():
        return "Germline"

    if "REFCALL" in filter_val.upper() or "REF_CALL" in filter_val.upper():
        return "Reference"

    # Everything else is an artifact (failed filters)
    return "Artifact"


def classify_mutect2_variant(filter_val):
    """
    Classifies a Mutect2 variant into Somatic, Germline, Reference, or Artifact.

    Note: Mutect2 typically only outputs PASS variants in somatic calling mode.
    This function handles edge cases where other filters might appear.

    Args:
        filter_val (str): The value from the VCF FILTER column.
                          Can be None (treated as PASS by some parsers).

    Returns:
        str: One of ['Somatic', 'Germline', 'Reference', 'Artifact']
    """

    # Mutect2 in somatic mode primarily outputs PASS variants
    if filter_val is None or filter_val == "." or filter_val == "PASS":
        return "Somatic"

    # Check for common Mutect2 germline-related filters (if present)
    filter_upper = filter_val.upper()
    if "GERMLINE" in filter_upper or "NORMAL_ARTIFACT" in filter_upper:
        return "Germline"

    # Other failed filters are artifacts
    return "Artifact"


def classify_variant(variant, caller_name, sample_indices=None):
    """
    Universal variant classifier that dispatches to caller-specific functions.

    Args:
        variant: cyvcf2.Variant object
        caller_name (str): Name of the variant caller ('strelka', 'deepsomatic', 'mutect2')
        sample_indices (dict): Optional dict mapping 'tumor' and 'normal' to sample indices

    Returns:
        str: One of ['Somatic', 'Germline', 'Reference', 'Artifact']
    """

    filter_val = variant.FILTER
    caller_lower = caller_name.lower()

    if caller_lower == "strelka":
        # Strelka requires NT field and normal depth
        try:
            nt_val = variant.INFO.get("NT", "")

            # Get normal sample depth
            normal_dp = 0
            if sample_indices and "normal" in sample_indices:
                normal_idx = sample_indices["normal"]
                dp_array = variant.format("DP")
                if dp_array is not None and len(dp_array) > normal_idx:
                    normal_dp = (
                        dp_array[normal_idx][0]
                        if dp_array[normal_idx][0] is not None
                        else 0
                    )

            return classify_strelka_variant(filter_val, nt_val, normal_dp)
        except Exception as e:
            print(
                f"Warning: Strelka classification failed for variant at {variant.CHROM}:{variant.POS}: {e}"
            )
            return "Artifact"

    elif caller_lower == "deepsomatic":
        return classify_deepsomatic_variant(filter_val)

    elif caller_lower == "mutect2":
        return classify_mutect2_variant(filter_val)

    else:
        # Unknown caller, use generic PASS/FAIL logic
        if filter_val is None or filter_val == "." or filter_val == "PASS":
            return "Somatic"
        return "Artifact"


def get_sample_indices(vcf_obj, caller_name):
    """
    Determine tumor and normal sample indices from VCF samples.

    Args:
        vcf_obj: cyvcf2.VCF object
        caller_name (str): Name of the variant caller

    Returns:
        dict: {'tumor': int, 'normal': int} or None
    """
    samples = vcf_obj.samples

    if len(samples) < 2:
        return None

    indices = {}

    # Try to identify tumor and normal samples
    for i, sample in enumerate(samples):
        sample_lower = sample.lower()
        if "tumor" in sample_lower or "tumour" in sample_lower:
            indices["tumor"] = i
        elif "normal" in sample_lower:
            indices["normal"] = i

    # Fallback: assume first is tumor, second is normal
    if "tumor" not in indices and len(samples) >= 1:
        indices["tumor"] = 0
    if "normal" not in indices and len(samples) >= 2:
        indices["normal"] = 0
        indices["tumor"] = 1

    return indices


print("✓ Variant classification functions defined")

✓ Variant classification functions defined


In [None]:
class VCFFileDiscovery:
    """Discover and organize all VCF files in the pipeline output"""

    def __init__(self, base_dir: Path):
        self.base_dir = Path(base_dir)
        self.vcf_files = {
            "variant_calling": {},  # Raw tool outputs
            "normalized": {},  # Normalized VCFs
            "annotated": {},  # VEP annotated
            "consensus": {},  # Consensus VCFs
            "rescue": {},  # Rescue VCFs
            "filtered": {},  # Filtered VCFs
        }
        self.alignment_files = {}

    def discover_vcfs(self):
        """Discover all VCF files"""

        # 1. Per-tool variant calling outputs
        for tool in TOOLS:
            for modality in MODALITIES:
                vcf_dir = self.base_dir / "variant_calling" / tool / modality
                if vcf_dir.exists():
                    vcf_files = list(vcf_dir.glob("*.vcf.gz"))
                    # Filter out gVCF files
                    vcf_files = [f for f in vcf_files if ".g.vcf.gz" not in str(f)]
                    if vcf_files:
                        key = f"{tool}_{modality}"
                        self.vcf_files["variant_calling"][key] = vcf_files[0]

        # 2. Normalized VCFs
        for tool in TOOLS:
            for modality in MODALITIES:
                vcf_dir = self.base_dir / "normalized" / tool / modality
                if vcf_dir.exists():
                    vcf_files = list(vcf_dir.glob("*.norm.vcf.gz"))
                    if vcf_files:
                        key = f"{tool}_{modality}"
                        self.vcf_files["normalized"][key] = vcf_files[0]

        # 3. Consensus VCFs
        consensus_dir = self.base_dir / "consensus" / "vcf"
        if consensus_dir.exists():
            for modality in MODALITIES:
                vcf_dir = consensus_dir / modality
                if vcf_dir.exists():
                    vcf_files = list(vcf_dir.glob("*.consensus.vcf.gz"))
                    if vcf_files:
                        self.vcf_files["consensus"][modality] = vcf_files[0]

        # 4. Rescue VCFs
        rescue_dir = self.base_dir / "rescue"
        if rescue_dir.exists():
            for subdir in rescue_dir.iterdir():
                if subdir.is_dir():
                    vcf_files = list(subdir.glob("*.rescued.vcf.gz"))
                    if vcf_files:
                        self.vcf_files["rescue"][subdir.name] = vcf_files[0]

        return self.vcf_files

    def discover_alignments(self):
        """Discover alignment files (CRAM/BAM)"""
        recal_dir = self.base_dir / "preprocessing" / "recalibrated"

        if recal_dir.exists():
            for sample_dir in recal_dir.iterdir():
                if sample_dir.is_dir():
                    cram_files = list(sample_dir.glob("*.cram"))
                    bam_files = list(sample_dir.glob("*.bam"))

                    if cram_files:
                        self.alignment_files[sample_dir.name] = cram_files[0]
                    elif bam_files:
                        self.alignment_files[sample_dir.name] = bam_files[0]

        return self.alignment_files

    def print_summary(self):
        """Print discovery summary"""
        print("=" * 80)
        print("VCF FILE DISCOVERY SUMMARY")
        print("=" * 80)

        for category, files in self.vcf_files.items():
            if files:
                print(f"\n{category.upper()}:")
                for name, path in files.items():
                    print(f"  ✓ {name}: {path.name}")

        if self.alignment_files:
            print(f"\nALIGNMENT FILES:")
            for name, path in self.alignment_files.items():
                print(f"  ✓ {name}: {path.name}")

        print("\n" + "=" * 80)


class VCFStatisticsExtractor:
    """Extract comprehensive statistics from VCF files"""

    def __init__(self, vcf_path: Path, caller_name: str = None):
        self.vcf_path = vcf_path
        self.vcf = None
        self.stats = {}
        self.caller_name = caller_name

        # Auto-detect caller from filename if not provided
        if self.caller_name is None:
            filename = str(vcf_path.name).lower()
            if "deepsomatic" in filename:
                self.caller_name = "deepsomatic"
            elif "mutect2" in filename:
                self.caller_name = "mutect2"
            elif "strelka" in filename:
                self.caller_name = "strelka"
            else:
                self.caller_name = "unknown"

    def _is_consensus_or_rescue(self) -> bool:
        """Determine if the current VCF is a consensus or rescue VCF."""
        p = str(self.vcf_path).lower()
        return (
            ".consensus.vcf.gz" in p
            or ".rescued.vcf.gz" in p
            or "/consensus/" in p
            or "/rescue/" in p
        )

    def extract_basic_stats(self):
        """Extract basic variant statistics with biological classification or FILTER category"""
        try:
            self.vcf = VCF(str(self.vcf_path))

            # Get sample indices for classification
            sample_indices = get_sample_indices(self.vcf, self.caller_name)

            stats = {
                "total_variants": 0,
                "snps": 0,
                "indels": 0,
                "mnps": 0,
                "complex": 0,
                "passed": 0,
                "filtered": 0,
                "chromosomes": set(),
                "qualities": [],
                "variant_types": defaultdict(int),
                # Classification or category (based on FILTER for consensus/rescue)
                "classification": {},
            }

            use_filter_as_category = self._is_consensus_or_rescue()

            for variant in self.vcf:
                stats["total_variants"] += 1
                stats["chromosomes"].add(variant.CHROM)

                # Quality scores
                if variant.QUAL is not None and variant.QUAL > 0:
                    stats["qualities"].append(variant.QUAL)

                # Filter status
                if (
                    variant.FILTER is None
                    or variant.FILTER == "PASS"
                    or variant.FILTER == "."
                ):
                    stats["passed"] += 1
                else:
                    stats["filtered"] += 1

                # Variant type
                if variant.is_snp:
                    stats["snps"] += 1
                    stats["variant_types"]["SNP"] += 1
                elif variant.is_indel:
                    stats["indels"] += 1
                    if variant.is_deletion:
                        stats["variant_types"]["DEL"] += 1
                    else:
                        stats["variant_types"]["INS"] += 1
                else:
                    stats["complex"] += 1
                    stats["variant_types"]["COMPLEX"] += 1

                # Classification or FILTER-based category
                try:
                    if use_filter_as_category:
                        # Normalize FILTER into unified categories
                        raw_filter = variant.FILTER if variant.FILTER else "PASS"
                        cat = "Artifact" if raw_filter == "NoConsensus" else raw_filter
                        stats["classification"][cat] = (
                            stats["classification"].get(cat, 0) + 1
                        )
                    else:
                        classification = classify_variant(
                            variant, self.caller_name, sample_indices
                        )
                        stats["classification"][classification] = (
                            stats["classification"].get(classification, 0) + 1
                        )
                except Exception:
                    # Fallback
                    fallback_filter = variant.FILTER if variant.FILTER else "Artifact"
                    fallback_cat = (
                        "Artifact"
                        if (use_filter_as_category and fallback_filter == "NoConsensus")
                        else fallback_filter
                    )
                    stats["classification"][fallback_cat] = (
                        stats["classification"].get(fallback_cat, 0) + 1
                    )

            stats["chromosomes"] = sorted(list(stats["chromosomes"]))

            self.stats["basic"] = stats
            return stats

        except Exception as e:
            print(f"Error processing {self.vcf_path}: {e}")
            return None

    def extract_info_fields(self):
        """Extract INFO field statistics"""
        try:
            # Always reopen VCF to reset the iterator
            self.vcf = VCF(str(self.vcf_path))

            # Get available INFO fields from header
            info_fields = {}
            print(f"  [DEBUG] Starting header parsing...")
            try:
                for key in self.vcf.header_iter():
                    try:
                        # HREC objects: 'in' operator raises KeyError! Must use try-except
                        header_type = key["HeaderType"]
                        if header_type == "INFO":
                            field_id = key["ID"]
                            try:
                                field_type = key["Type"]
                            except KeyError:
                                field_type = "unknown"

                            if field_id:
                                info_fields[field_id] = {
                                    "type": field_type,
                                    "values": [],
                                }
                    except (KeyError, AttributeError, TypeError) as e:
                        # Skip this header entry if we can't parse it
                        continue
            except Exception as header_err:
                print(
                    f"  Error parsing header: {type(header_err).__name__}: {header_err}"
                )
                raise

            print(f"  [DEBUG] Found {len(info_fields)} INFO fields in header")

            # Collect values
            variant_count = 0
            for variant in self.vcf:
                variant_count += 1
                for info_id in info_fields.keys():
                    try:
                        val = variant.INFO.get(info_id)
                        if val is not None:
                            info_fields[info_id]["values"].append(val)
                    except:
                        pass

                # Limit to first 10000 variants for efficiency
                if variant_count > 10000:
                    break

            print(
                f"  [DEBUG] Processed {variant_count} variants, calculating statistics..."
            )

            # Calculate statistics for numeric fields
            info_stats = {}
            for info_id, data in info_fields.items():
                if data["values"]:
                    try:
                        # Try to convert to numeric
                        numeric_vals = []
                        for v in data["values"]:
                            if isinstance(v, (list, tuple)):
                                numeric_vals.extend(
                                    [float(x) for x in v if x is not None]
                                )
                            else:
                                numeric_vals.append(float(v))

                        if numeric_vals:
                            info_stats[info_id] = {
                                "count": len(numeric_vals),
                                "mean": np.mean(numeric_vals),
                                "median": np.median(numeric_vals),
                                "std": np.std(numeric_vals),
                                "min": np.min(numeric_vals),
                                "max": np.max(numeric_vals),
                                "q25": np.percentile(numeric_vals, 25),
                                "q75": np.percentile(numeric_vals, 75),
                            }
                    except (ValueError, TypeError):
                        # Non-numeric field
                        info_stats[info_id] = {
                            "count": len(data["values"]),
                            "type": "categorical",
                        }

            print(f"  [DEBUG] Calculated statistics for {len(info_stats)} INFO fields")
            self.stats["info"] = info_stats
            return info_stats

        except Exception as e:
            import traceback

            print(f"Error extracting INFO fields from {self.vcf_path}:")
            print(f"  {type(e).__name__}: {str(e)}")
            traceback.print_exc()  # ALWAYS print traceback
            return {}

    def extract_format_fields(self):
        """Extract FORMAT field statistics (sample-level)"""
        try:
            # Always reopen VCF to reset the iterator
            self.vcf = VCF(str(self.vcf_path))

            samples = self.vcf.samples
            format_stats = {sample: {} for sample in samples}

            # Common FORMAT fields to extract
            format_fields = ["DP", "AD", "AF", "GQ"]

            for sample in samples:
                for field in format_fields:
                    format_stats[sample][field] = []

            variant_count = 0
            for variant in self.vcf:
                variant_count += 1

                for i, sample in enumerate(samples):
                    # Depth
                    try:
                        dp = variant.format("DP")[i]
                        if dp is not None and dp[0] > 0:
                            format_stats[sample]["DP"].append(dp[0])
                    except:
                        pass

                    # Allelic depth
                    try:
                        ad = variant.format("AD")[i]
                        if ad is not None:
                            format_stats[sample]["AD"].append(ad)
                    except:
                        pass

                    # Allele frequency
                    try:
                        af = variant.format("AF")[i]
                        if af is not None and af[0] is not None:
                            format_stats[sample]["AF"].append(af[0])
                    except:
                        pass

                    # Genotype quality
                    try:
                        gq = variant.format("GQ")[i]
                        if gq is not None and gq[0] is not None:
                            format_stats[sample]["GQ"].append(gq[0])
                    except:
                        pass

                # Limit for efficiency
                if variant_count > 10000:
                    break

            # Calculate statistics
            format_summary = {}
            for sample, fields in format_stats.items():
                format_summary[sample] = {}
                for field, values in fields.items():
                    if values and field != "AD":
                        format_summary[sample][field] = {
                            "count": len(values),
                            "mean": np.mean(values),
                            "median": np.median(values),
                            "min": np.min(values),
                            "max": np.max(values),
                            "q25": np.percentile(values, 25),
                            "q75": np.percentile(values, 75),
                        }

            self.stats["format"] = format_summary
            return format_summary

        except Exception as e:
            print(f"Error extracting FORMAT fields from {self.vcf_path}: {e}")
            return {}

    def extract_all_stats(self, verbose: bool = True):
        """Extract all statistics"""
        if verbose:
            print(f"\nProcessing: {self.vcf_path.name}")

        basic = self.extract_basic_stats()
        info = self.extract_info_fields()
        format_stats = self.extract_format_fields()

        if verbose and basic:
            print(f"  ✓ Total variants: {basic['total_variants']}")
            print(f"  ✓ SNPs: {basic['snps']}, INDELs: {basic['indels']}")
            print(
                f"  ✓ Passed filters: {basic['passed']}, Filtered: {basic['filtered']}"
            )

        return self.stats


def process_all_vcfs(vcf_files_dict):
    """Process all VCF files and collect statistics"""
    all_stats = {}

    for category, files in vcf_files_dict.items():
        if not files:
            continue

        print(f"\n{'=' * 80}")
        print(f"PROCESSING: {category.upper()}")
        print(f"{'=' * 80}")

        all_stats[category] = {}

        for name, vcf_path in files.items():
            try:
                extractor = VCFStatisticsExtractor(vcf_path)
                stats = extractor.extract_all_stats()
                all_stats[category][name] = {"path": vcf_path, "stats": stats}
            except Exception as e:
                print(f"  ✗ Failed to process {name}: {e}")

    return all_stats


def analyze_rescue_vcf(all_vcf_stats, show_plot: bool = True):
    """
    Analyze and visualize rescue VCF statistics with FILTER category tracking and transition analysis.

    This function:
    1. Shows FILTER category breakdown for DNA consensus, RNA consensus, and rescued variants
    2. Tracks FILTER label transitions across stages (DNA → RNA → Rescue)
    3. Analyzes variant type composition (SNP vs INDEL) within each FILTER category

    Args:
        all_vcf_stats: Dictionary containing statistics for all VCF categories
        show_plot: Whether to display visualization plots
    """
    if "rescue" not in all_vcf_stats or not all_vcf_stats["rescue"]:
        print("No rescue VCFs found")
        return

    print("\n" + "=" * 80)
    print("RESCUE VCF ANALYSIS WITH FILTER TRACKING")
    print("=" * 80)

    # Define unified color scheme
    CATEGORY_COLORS = {
        "Somatic": "#636EFA",
        "Germline": "#00CC96",
        "Reference": "#FFA15A",
        "Artifact": "#EF553B",
    }
    CATEGORY_ORDER = ["Somatic", "Germline", "Reference", "Artifact"]

    # Collect classification data for each stage
    dna_classification = {}
    rna_classification = {}
    rescue_classification = {}

    dna_total = 0
    rna_total = 0
    rescue_total = 0

    # Get DNA consensus classification
    if "consensus" in all_vcf_stats:
        for name, data in all_vcf_stats["consensus"].items():
            if "DNA_TUMOR" in name:
                basic = data.get("stats", {}).get("basic", {})
                dna_classification = basic.get("classification", {})
                dna_total = basic.get("total_variants", 0)

    # Get RNA consensus classification
    if "consensus" in all_vcf_stats:
        for name, data in all_vcf_stats["consensus"].items():
            if "RNA_TUMOR" in name:
                basic = data.get("stats", {}).get("basic", {})
                rna_classification = basic.get("classification", {})
                rna_total = basic.get("total_variants", 0)

    # Get rescued classification
    if "rescue" in all_vcf_stats:
        for name, data in all_vcf_stats["rescue"].items():
            basic = data.get("stats", {}).get("basic", {})
            rescue_classification = basic.get("classification", {})
            rescue_total = basic.get("total_variants", 0)

    # Print summary statistics
    print(
        f"\n{'Category':<15} {'DNA Consensus':<15} {'RNA Consensus':<15} {'Rescued':<15}"
    )
    print("-" * 60)
    for filter_cat in CATEGORY_ORDER:
        dna_count = dna_classification.get(filter_cat, 0)
        rna_count = rna_classification.get(filter_cat, 0)
        rescue_count = rescue_classification.get(filter_cat, 0)
        print(f"{filter_cat:<15} {dna_count:<15,} {rna_count:<15,} {rescue_count:<15,}")

    print("-" * 60)
    print(f"{'TOTAL':<15} {dna_total:<15,} {rna_total:<15,} {rescue_total:<15,}")

    if rescue_total > 0 and dna_total > 0:
        rescue_rate = rescue_total / dna_total * 100
        print(f"\nOverall Rescue Rate: {rescue_rate:.2f}% of DNA consensus variants")

    # Calculate FILTER transition statistics
    print("\n" + "=" * 80)
    print("FILTER CATEGORY TRANSITION ANALYSIS")
    print("=" * 80)
    print("\nThis analysis shows how FILTER labels change across stages:")
    print("  DNA Consensus → RNA Consensus → Rescued")
    print("\nNote: Exact variant-level tracking requires VCF file parsing.")
    print("Below shows category-level changes based on total counts:\n")

    # Show changes for each category
    for filter_cat in CATEGORY_ORDER:
        dna_count = dna_classification.get(filter_cat, 0)
        rna_count = rna_classification.get(filter_cat, 0)
        rescue_count = rescue_classification.get(filter_cat, 0)

        if dna_count > 0 or rna_count > 0 or rescue_count > 0:
            print(f"\n{filter_cat} Category:")
            print(f"  DNA Consensus: {dna_count:,}")

            if dna_count > 0:
                rna_change = rna_count - dna_count
                rna_pct = (rna_count / dna_count * 100) if dna_count > 0 else 0
                print(
                    f"  RNA Consensus: {rna_count:,} ({rna_change:+,}, {rna_pct:.1f}% of DNA)"
                )

                rescue_change = rescue_count - dna_count
                rescue_pct = (rescue_count / dna_count * 100) if dna_count > 0 else 0
                print(
                    f"  Rescued: {rescue_count:,} ({rescue_change:+,}, {rescue_pct:.1f}% of DNA)"
                )

    # Create comprehensive visualization
    if show_plot:
        fig = make_subplots(
            rows=1,
            cols=3,
            subplot_titles=("DNA Consensus", "RNA Consensus", "Rescued"),
            horizontal_spacing=0.10,
            shared_yaxes=True,
        )

        stages = [
            ("DNA Consensus", dna_classification, 1),
            ("RNA Consensus", rna_classification, 2),
            ("Rescued", rescue_classification, 3),
        ]

        for stage_name, classification, col_idx in stages:
            for filter_cat in CATEGORY_ORDER:
                count = classification.get(filter_cat, 0)
                if count > 0:
                    fig.add_trace(
                        go.Bar(
                            name=filter_cat,
                            x=[stage_name],
                            y=[count],
                            marker_color=CATEGORY_COLORS[filter_cat],
                            text=[count],
                            textposition="inside",
                            showlegend=(col_idx == 1),
                            legendgroup=filter_cat,
                            hovertemplate=f"<b>{stage_name}</b><br>{filter_cat}: %{{y:,}}<extra></extra>",
                        ),
                        row=1,
                        col=col_idx,
                    )

            fig.update_xaxes(title_text="Stage", row=1, col=col_idx)
            if col_idx == 1:
                fig.update_yaxes(title_text="Number of Variants", row=1, col=col_idx)

        # Calculate max y value for consistent scaling
        max_y = max(dna_total, rna_total, rescue_total)
        if max_y > 0:
            fig.update_yaxes(range=[0, max_y * 1.1], row=1, col=1)

        fig.update_layout(
            title_text="Rescue VCF Analysis: FILTER Categories Across Stages",
            template="plotly_white",
            height=600,
            barmode="stack",
            showlegend=True,
        )

        fig.show()

    # Additional variant type analysis
    print("\n" + "=" * 80)
    print("VARIANT TYPE COMPOSITION")
    print("=" * 80)

    for stage_key, stage_name in [
        ("consensus_dna", "DNA Consensus"),
        ("consensus_rna", "RNA Consensus"),
        ("rescue", "Rescued"),
    ]:
        if stage_key == "consensus_dna" and "consensus" in all_vcf_stats:
            for name, data in all_vcf_stats["consensus"].items():
                if "DNA_TUMOR" in name:
                    basic = data.get("stats", {}).get("basic", {})
                    var_types = basic.get("variant_types", {})
                    snps = var_types.get("SNP", 0)
                    dels = var_types.get("DEL", 0)
                    ins = var_types.get("INS", 0)
                    indels = dels + ins
                    total = basic.get("total_variants", 0)
                    if total > 0:
                        print(f"\n{stage_name}:")
                        print(f"  SNPs: {snps:,} ({snps / total * 100:.1f}%)")
                        print(
                            f"  INDELs: {indels:,} ({indels / total * 100:.1f}%) [DEL: {dels:,}, INS: {ins:,}]"
                        )

        elif stage_key == "consensus_rna" and "consensus" in all_vcf_stats:
            for name, data in all_vcf_stats["consensus"].items():
                if "RNA_TUMOR" in name:
                    basic = data.get("stats", {}).get("basic", {})
                    var_types = basic.get("variant_types", {})
                    snps = var_types.get("SNP", 0)
                    dels = var_types.get("DEL", 0)
                    ins = var_types.get("INS", 0)
                    indels = dels + ins
                    total = basic.get("total_variants", 0)
                    if total > 0:
                        print(f"\n{stage_name}:")
                        print(f"  SNPs: {snps:,} ({snps / total * 100:.1f}%)")
                        print(
                            f"  INDELs: {indels:,} ({indels / total * 100:.1f}%) [DEL: {dels:,}, INS: {ins:,}]"
                        )

        elif stage_key == "rescue" and "rescue" in all_vcf_stats:
            for name, data in all_vcf_stats["rescue"].items():
                basic = data.get("stats", {}).get("basic", {})
                var_types = basic.get("variant_types", {})
                snps = var_types.get("SNP", 0)
                dels = var_types.get("DEL", 0)
                ins = var_types.get("INS", 0)
                indels = dels + ins
                total = basic.get("total_variants", 0)
                if total > 0:
                    print(f"\n{stage_name}:")
                    print(f"  SNPs: {snps:,} ({snps / total * 100:.1f}%)")
                    print(
                        f"  INDELs: {indels:,} ({indels / total * 100:.1f}%) [DEL: {dels:,}, INS: {ins:,}]"
                    )

    print("\n" + "=" * 80)


def generate_summary_report(
    all_vcf_stats,
    vcf_files,
    variant_summary,
    quality_summary,
    tool_comparison,
    consensus_comparison,
    output_dir: Path,
):
    """Generate comprehensive summary report"""

    report = []
    report.append("=" * 80)
    report.append("VCF STATISTICS - COMPREHENSIVE SUMMARY REPORT")
    report.append("=" * 80)
    report.append("")

    # 1. Overview
    report.append("## 1. OVERVIEW")
    report.append("")
    total_vcfs = sum(len(files) for files in vcf_files.values() if files)
    report.append(f"Total VCF files analyzed: {total_vcfs}")
    report.append(
        f"Categories: {', '.join([cat for cat, files in vcf_files.items() if files])}"
    )
    report.append(f"Tools: {', '.join(TOOLS)}")
    report.append(f"Modalities: DNA, RNA")
    report.append("")

    # 2. Variant Calling Tools Comparison
    report.append("## 2. VARIANT CALLING TOOLS COMPARISON")
    report.append("")

    if not tool_comparison.empty:
        report.append("### DNA Modality:")
        dna_tools = tool_comparison[
            tool_comparison["Modality"].str.contains("DNA_TUMOR")
        ]
        for _, row in dna_tools.iterrows():
            report.append(
                f"  {row['Tool']:12} - {row['Total_Variants']:6} variants "
                f"(SNPs: {row['SNPs']:5}, INDELs: {row['INDELs']:4})"
            )

        report.append("")
        report.append("### RNA Modality:")
        rna_tools = tool_comparison[
            tool_comparison["Modality"].str.contains("RNA_TUMOR")
        ]
        for _, row in rna_tools.iterrows():
            report.append(
                f"  {row['Tool']:12} - {row['Total_Variants']:6} variants "
                f"(SNPs: {row['SNPs']:5}, INDELs: {row['INDELs']:4})"
            )
    report.append("")

    # 3. Consensus Analysis
    report.append("## 3. CONSENSUS ANALYSIS")
    report.append("")

    if not consensus_comparison.empty:
        for modality in ["DNA_TUMOR_vs_DNA_NORMAL", "RNA_TUMOR_vs_DNA_NORMAL"]:
            mod_name = "DNA" if "DNA_TUMOR" in modality else "RNA"
            mod_data = consensus_comparison[
                consensus_comparison["Modality"].str.contains(mod_name)
            ]

            if not mod_data.empty:
                consensus_count = mod_data["Consensus_Variants"].iloc[0]
                report.append(f"### {mod_name} Consensus: {consensus_count} variants")
                report.append("")
                report.append("  Tool contributions:")
                for _, row in mod_data.iterrows():
                    retention = row["Retention_Rate"] * 100
                    report.append(
                        f"    {row['Tool']:12}: {row['Tool_Variants']:5} variants "
                        f"→ {retention:5.1f}% retained in consensus"
                    )
                report.append("")

    # 4. Rescue Statistics
    report.append("## 4. RESCUE (CROSS-MODALITY) ANALYSIS")
    report.append("")

    if "rescue" in all_vcf_stats and all_vcf_stats["rescue"]:
        for name, data in all_vcf_stats["rescue"].items():
            if "stats" in data and "basic" in data["stats"]:
                basic = data["stats"]["basic"]
                rescue_total = basic.get("total_variants", 0)

                # Compare with DNA consensus
                if "DNA_TUMOR_vs_DNA_NORMAL" in all_vcf_stats.get("consensus", {}):
                    dna_consensus = all_vcf_stats["consensus"][
                        "DNA_TUMOR_vs_DNA_NORMAL"
                    ]
                    if "stats" in dna_consensus and "basic" in dna_consensus["stats"]:
                        dna_total = dna_consensus["stats"]["basic"].get(
                            "total_variants", 0
                        )
                        added = rescue_total - dna_total
                        pct_increase = (added / dna_total * 100) if dna_total > 0 else 0

                        report.append(f"DNA Consensus: {dna_total} variants")
                        report.append(f"After RNA rescue: {rescue_total} variants")
                        report.append(f"Variants added: {added} (+{pct_increase:.1f}%)")
                        report.append(
                            f"SNPs: {basic.get('snps', 0)}, INDELs: {basic.get('indels', 0)}"
                        )
    else:
        report.append("No rescue VCFs found")
    report.append("")

    # 5. Quality Metrics
    report.append("## 5. QUALITY METRICS")
    report.append("")

    if not quality_summary.empty:
        report.append("Average quality scores by tool:")
        for _, row in quality_summary.iterrows():
            if row["Category"] == "variant_calling":
                report.append(
                    f"  {row['Tool']:12} ({row['Modality'][:3]}): "
                    f"Mean={row['Mean_QUAL']:7.2f}, Median={row['Median_QUAL']:7.2f}"
                )
    report.append("")

    # 6. Filter Status
    report.append("## 6. FILTER STATUS SUMMARY")
    report.append("")

    total_passed = variant_summary["Passed"].sum()
    total_filtered = variant_summary["Filtered"].sum()
    total_all = total_passed + total_filtered
    pass_rate = (total_passed / total_all * 100) if total_all > 0 else 0

    report.append(f"Total variants across all VCFs: {total_all}")
    report.append(f"  Passed filters: {total_passed} ({pass_rate:.1f}%)")
    report.append(f"  Filtered out: {total_filtered} ({100 - pass_rate:.1f}%)")
    report.append("")

    # 7. Recommendations
    report.append("## 7. KEY INSIGHTS")
    report.append("")

    if not tool_comparison.empty:
        # Find most/least sensitive tool
        dna_tools = tool_comparison[
            tool_comparison["Modality"].str.contains("DNA_TUMOR")
        ]
        if not dna_tools.empty:
            most_sensitive = dna_tools.loc[dna_tools["Total_Variants"].idxmax()]
            least_sensitive = dna_tools.loc[dna_tools["Total_Variants"].idxmin()]

            report.append(
                f"• Most sensitive tool (DNA): {most_sensitive['Tool']} "
                f"({most_sensitive['Total_Variants']} variants)"
            )
            report.append(
                f"• Most conservative tool (DNA): {least_sensitive['Tool']} "
                f"({least_sensitive['Total_Variants']} variants)"
            )
            report.append("")

    if "rescue" in all_vcf_stats and all_vcf_stats["rescue"]:
        report.append(
            "• Cross-modality rescue successfully recovered additional variants from RNA data"
        )
        report.append(
            "• RNA sequencing provides complementary variant detection to DNA"
        )

    report.append("")
    report.append("=" * 80)
    report.append("END OF REPORT")
    report.append("=" * 80)

    # Print report
    report_text = "\n".join(report)
    print(report_text)

    # Save report
    with open(output_dir / "summary_report.txt", "w") as f:
        f.write(report_text)

    print(f"\n✓ Report saved to {output_dir / 'summary_report.txt'}")

    return report_text


def export_results(
    variant_summary,
    quality_summary,
    tool_comparison,
    consensus_comparison,
    validation_df=None,
    output_dir: Path = Path("vcf_statistics_output"),
):
    """Export summary statistics to CSV files"""
    output_dir.mkdir(exist_ok=True)

    print("Exporting results...")

    # 1. Variant count summary
    variant_summary.to_csv(output_dir / "variant_count_summary.csv", index=False)
    print(f"✓ Exported: {output_dir / 'variant_count_summary.csv'}")

    # 2. Quality summary
    quality_summary.to_csv(output_dir / "quality_summary.csv", index=False)
    print(f"✓ Exported: {output_dir / 'quality_summary.csv'}")

    # 3. Tool comparison
    tool_comparison.to_csv(output_dir / "tool_comparison.csv", index=False)
    print(f"✓ Exported: {output_dir / 'tool_comparison.csv'}")

    # 4. Consensus comparison
    consensus_comparison.to_csv(output_dir / "consensus_comparison.csv", index=False)
    print(f"✓ Exported: {output_dir / 'consensus_comparison.csv'}")

    # 5. Validation results (if available)
    if validation_df is not None and not validation_df.empty:
        validation_df.to_csv(output_dir / "bam_validation_results.csv", index=False)
        print(f"✓ Exported: {output_dir / 'bam_validation_results.csv'}")

    print(f"\n✓ All results exported to {output_dir}/")


class BAMValidator:
    """Validate variants using BAM/CRAM alignment files"""

    def __init__(self, reference_fasta: Optional[str] = None):
        self.reference_fasta = reference_fasta

    def validate_variants(
        self, vcf_path: Path, bam_paths: Dict[str, Path], max_variants: int = 100
    ):
        """
        Validate variants by checking read support in BAM files

        Args:
            vcf_path: Path to VCF file
            bam_paths: Dictionary mapping sample names to BAM/CRAM paths
            max_variants: Maximum number of variants to validate
        """
        validation_results = []

        try:
            vcf = VCF(str(vcf_path))

            # Open BAM files
            bam_files = {}
            for sample, bam_path in bam_paths.items():
                try:
                    if self.reference_fasta and str(bam_path).endswith(".cram"):
                        bam_files[sample] = pysam.AlignmentFile(
                            str(bam_path), "rc", reference_filename=self.reference_fasta
                        )
                    else:
                        bam_files[sample] = pysam.AlignmentFile(str(bam_path))
                except Exception as e:
                    print(f"Warning: Could not open {sample} BAM file: {e}")

            if not bam_files:
                print("No BAM files could be opened for validation")
                return []

            # Validate variants
            variant_count = 0
            for variant in vcf:
                if variant_count >= max_variants:
                    break

                chrom = variant.CHROM
                pos = variant.POS
                ref = variant.REF
                alts = variant.ALT

                variant_result = {
                    "chrom": chrom,
                    "pos": pos,
                    "ref": ref,
                    "alt": ",".join(alts) if alts else "",
                    "qual": variant.QUAL,
                    "filter": variant.FILTER if variant.FILTER else "PASS",
                }

                # Check each sample
                for sample_name, bam_file in bam_files.items():
                    try:
                        # Fetch reads covering this position
                        pileup_count = 0
                        ref_count = 0
                        alt_counts = {alt: 0 for alt in alts if alt}
                        total_depth = 0

                        for pileupcolumn in bam_file.pileup(
                            chrom,
                            pos - 1,
                            pos,
                            truncate=True,
                            min_base_quality=20,
                            max_depth=10000,
                        ):
                            if pileupcolumn.pos == pos - 1:  # 0-based
                                total_depth = pileupcolumn.n

                                for pileupread in pileupcolumn.pileups:
                                    if (
                                        not pileupread.is_del
                                        and not pileupread.is_refskip
                                    ):
                                        base = pileupread.alignment.query_sequence[
                                            pileupread.query_position
                                        ]

                                        if base == ref:
                                            ref_count += 1
                                        elif base in alt_counts:
                                            alt_counts[base] += 1

                                        pileup_count += 1

                        variant_result[f"{sample_name}_total_depth"] = total_depth
                        variant_result[f"{sample_name}_ref_count"] = ref_count
                        for alt, count in alt_counts.items():
                            variant_result[f"{sample_name}_alt_{alt}_count"] = count

                        # Calculate VAF
                        if pileup_count > 0:
                            total_alt = sum(alt_counts.values())
                            vaf = total_alt / pileup_count if pileup_count > 0 else 0
                            variant_result[f"{sample_name}_vaf"] = vaf
                        else:
                            variant_result[f"{sample_name}_vaf"] = 0

                    except Exception as e:
                        variant_result[f"{sample_name}_error"] = str(e)

                validation_results.append(variant_result)
                variant_count += 1

            # Close BAM files
            for bam_file in bam_files.values():
                bam_file.close()

            return validation_results

        except Exception as e:
            print(f"Error during validation: {e}")
            return []

    def summarize_validation(self, validation_results: List[Dict]) -> pd.DataFrame:
        """Convert validation results to DataFrame"""
        if not validation_results:
            return pd.DataFrame()

        df = pd.DataFrame(validation_results)
        return df


class StatisticsAggregator:
    """Aggregate and summarize VCF statistics"""

    def __init__(self, all_stats: Dict):
        self.all_stats = all_stats

    def create_variant_count_summary(self) -> pd.DataFrame:
        """Create summary table of variant counts across all VCFs"""
        rows = []

        for category, files in self.all_stats.items():
            for name, data in files.items():
                if "stats" in data and "basic" in data["stats"]:
                    basic = data["stats"]["basic"]

                    # Parse tool and modality from name
                    parts = name.split("_")
                    if len(parts) >= 2:
                        tool = parts[0]
                        modality = "_".join(parts[1:])
                    else:
                        tool = category
                        modality = name

                    row = {
                        "Category": category,
                        "Tool": tool,
                        "Modality": modality,
                        "Total_Variants": basic.get("total_variants", 0),
                        "SNPs": basic.get("snps", 0),
                        "INDELs": basic.get("indels", 0),
                        "Passed": basic.get("passed", 0),
                        "Filtered": basic.get("filtered", 0),
                        "Pass_Rate": basic.get("passed", 0)
                        / basic.get("total_variants", 1)
                        if basic.get("total_variants", 0) > 0
                        else 0,
                    }

                    # Add classification counts if available
                    if "classification" in basic:
                        for class_name, count in basic["classification"].items():
                            row[class_name] = count

                    rows.append(row)

        df = pd.DataFrame(rows)
        return df.sort_values(["Category", "Tool", "Modality"])

    def create_classification_summary(self) -> pd.DataFrame:
        """Create summary table of variant biological classifications"""
        rows = []

        for category, files in self.all_stats.items():
            for name, data in files.items():
                if "stats" in data and "basic" in data["stats"]:
                    basic = data["stats"]["basic"]

                    # Only include if classification data is present
                    if "classification" not in basic:
                        continue

                    # Parse tool and modality from name
                    parts = name.split("_")
                    if len(parts) >= 2:
                        tool = parts[0]
                        modality = "_".join(parts[1:])
                    else:
                        tool = category
                        modality = name

                    classification = basic["classification"]
                    total = sum(classification.values())

                    rows.append(
                        {
                            "Category": category,
                            "Tool": tool,
                            "Modality": modality,
                            "Total_Variants": total,
                            "Somatic": classification.get("Somatic", 0),
                            "Germline": classification.get("Germline", 0),
                            "Reference": classification.get("Reference", 0),
                            "Artifact": classification.get("Artifact", 0),
                            "Somatic_Rate": classification.get("Somatic", 0) / total
                            if total > 0
                            else 0,
                            "Germline_Rate": classification.get("Germline", 0) / total
                            if total > 0
                            else 0,
                            "Reference_Rate": classification.get("Reference", 0) / total
                            if total > 0
                            else 0,
                            "Artifact_Rate": classification.get("Artifact", 0) / total
                            if total > 0
                            else 0,
                        }
                    )

        df = pd.DataFrame(rows)
        return df.sort_values(["Category", "Tool", "Modality"])

    def create_quality_summary(self) -> pd.DataFrame:
        """Create summary of quality score distributions"""
        rows = []

        for category, files in self.all_stats.items():
            for name, data in files.items():
                if "stats" in data and "basic" in data["stats"]:
                    basic = data["stats"]["basic"]
                    qualities = basic.get("qualities", [])

                    if qualities:
                        parts = name.split("_")
                        tool = parts[0] if parts else category
                        modality = "_".join(parts[1:]) if len(parts) > 1 else name

                        rows.append(
                            {
                                "Category": category,
                                "Tool": tool,
                                "Modality": modality,
                                "Mean_QUAL": np.mean(qualities),
                                "Median_QUAL": np.median(qualities),
                                "Min_QUAL": np.min(qualities),
                                "Max_QUAL": np.max(qualities),
                                "Q25": np.percentile(qualities, 25),
                                "Q75": np.percentile(qualities, 75),
                            }
                        )

        df = pd.DataFrame(rows)
        return df.sort_values(["Category", "Tool", "Modality"])

    def create_info_field_summary(self, info_field: str) -> pd.DataFrame:
        """Create summary for specific INFO field across all VCFs"""
        rows = []

        for category, files in self.all_stats.items():
            for name, data in files.items():
                if "stats" in data and "info" in data["stats"]:
                    info_stats = data["stats"]["info"]

                    if info_field in info_stats:
                        field_data = info_stats[info_field]

                        if isinstance(field_data, dict) and "mean" in field_data:
                            parts = name.split("_")
                            tool = parts[0] if parts else category
                            modality = "_".join(parts[1:]) if len(parts) > 1 else name

                            row = {
                                "Category": category,
                                "Tool": tool,
                                "Modality": modality,
                                "Field": info_field,
                            }
                            row.update(field_data)
                            rows.append(row)

        df = pd.DataFrame(rows)
        return df.sort_values(["Category", "Tool", "Modality"])

    def compare_tools_by_modality(self) -> pd.DataFrame:
        """Compare variant calling tools within each modality"""
        rows = []

        # Focus on variant_calling category
        if "variant_calling" in self.all_stats:
            for name, data in self.all_stats["variant_calling"].items():
                if "stats" in data and "basic" in data["stats"]:
                    basic = data["stats"]["basic"]
                    parts = name.split("_")

                    if len(parts) >= 2:
                        tool = parts[0]
                        modality = "_".join(parts[1:])

                        rows.append(
                            {
                                "Tool": tool,
                                "Modality": modality,
                                "Total_Variants": basic.get("total_variants", 0),
                                "SNPs": basic.get("snps", 0),
                                "INDELs": basic.get("indels", 0),
                                "SNP_Ratio": basic.get("snps", 0)
                                / basic.get("total_variants", 1)
                                if basic.get("total_variants", 0) > 0
                                else 0,
                                "INDEL_Ratio": basic.get("indels", 0)
                                / basic.get("total_variants", 1)
                                if basic.get("total_variants", 0) > 0
                                else 0,
                            }
                        )

        df = pd.DataFrame(rows)
        return df.sort_values(["Modality", "Tool"])

    def compare_consensus_to_individual(self) -> pd.DataFrame:
        """Compare consensus VCFs to individual tool outputs"""
        rows = []

        # Get consensus counts
        consensus_counts = {}
        if "consensus" in self.all_stats:
            for modality, data in self.all_stats["consensus"].items():
                if "stats" in data and "basic" in data["stats"]:
                    consensus_counts[modality] = data["stats"]["basic"].get(
                        "total_variants", 0
                    )

        # Get individual tool counts
        if "variant_calling" in self.all_stats:
            for name, data in self.all_stats["variant_calling"].items():
                if "stats" in data and "basic" in data["stats"]:
                    basic = data["stats"]["basic"]
                    parts = name.split("_")

                    if len(parts) >= 2:
                        tool = parts[0]
                        modality = "_".join(parts[1:])
                        tool_count = basic.get("total_variants", 0)
                        consensus_count = consensus_counts.get(modality, 0)

                        rows.append(
                            {
                                "Tool": tool,
                                "Modality": modality,
                                "Tool_Variants": tool_count,
                                "Consensus_Variants": consensus_count,
                                "Difference": tool_count - consensus_count,
                                "Retention_Rate": consensus_count / tool_count
                                if tool_count > 0
                                else 0,
                            }
                        )

        df = pd.DataFrame(rows)
        return df.sort_values(["Modality", "Tool"])


class VCFVisualizer:
    """Create visualizations for VCF statistics"""

    def __init__(self, all_stats: Dict):
        self.all_stats = all_stats

    def plot_variant_counts_by_tool(self):
        """Bar plot comparing variant counts with FILTER categories across tools and modalities"""

        # Define unified color scheme
        CATEGORY_COLORS = {
            "Somatic": "#636EFA",
            "Germline": "#00CC96",
            "Reference": "#FFA15A",
            "Artifact": "#EF553B",
        }
        CATEGORY_ORDER = ["Somatic", "Germline", "Reference", "Artifact"]

        data = []

        if "variant_calling" in self.all_stats:
            for name, vcf_data in self.all_stats["variant_calling"].items():
                if "stats" in vcf_data and "basic" in vcf_data["stats"]:
                    basic = vcf_data["stats"]["basic"]
                    classification = basic.get("classification", {})
                    parts = name.split("_")
                    tool = parts[0] if parts else name
                    modality = "DNA" if "DNA_TUMOR" in name else "RNA"

                    # Add data for each FILTER category
                    for filter_cat in CATEGORY_ORDER:
                        count = classification.get(filter_cat, 0)
                        if count > 0:
                            data.append(
                                {
                                    "Tool": tool,
                                    "Modality": modality,
                                    "Category": filter_cat,
                                    "Count": count,
                                }
                            )

        if not data:
            print("No data available for plotting")
            return

        df = pd.DataFrame(data)

        # Create subplots for DNA and RNA
        fig = make_subplots(
            rows=1,
            cols=2,
            subplot_titles=("DNA Modality", "RNA Modality"),
            horizontal_spacing=0.12,
        )

        # Plot DNA modality
        df_dna = df[df["Modality"] == "DNA"]
        if not df_dna.empty:
            tools = sorted(df_dna["Tool"].unique())
            for filter_cat in CATEGORY_ORDER:
                df_cat = df_dna[df_dna["Category"] == filter_cat]
                counts = [
                    df_cat[df_cat["Tool"] == t]["Count"].sum()
                    if not df_cat[df_cat["Tool"] == t].empty
                    else 0
                    for t in tools
                ]
                if sum(counts) > 0:
                    fig.add_trace(
                        go.Bar(
                            name=filter_cat,
                            x=tools,
                            y=counts,
                            marker_color=CATEGORY_COLORS[filter_cat],
                            text=counts,
                            textposition="inside",
                            showlegend=True,
                            legendgroup=filter_cat,
                        ),
                        row=1,
                        col=1,
                    )

        # Plot RNA modality
        df_rna = df[df["Modality"] == "RNA"]
        if not df_rna.empty:
            tools = sorted(df_rna["Tool"].unique())
            for filter_cat in CATEGORY_ORDER:
                df_cat = df_rna[df_rna["Category"] == filter_cat]
                counts = [
                    df_cat[df_cat["Tool"] == t]["Count"].sum()
                    if not df_cat[df_cat["Tool"] == t].empty
                    else 0
                    for t in tools
                ]
                if sum(counts) > 0:
                    fig.add_trace(
                        go.Bar(
                            name=filter_cat,
                            x=tools,
                            y=counts,
                            marker_color=CATEGORY_COLORS[filter_cat],
                            text=counts,
                            textposition="inside",
                            showlegend=False,
                            legendgroup=filter_cat,
                        ),
                        row=1,
                        col=2,
                    )

        fig.update_xaxes(title_text="Tool", row=1, col=1)
        fig.update_xaxes(title_text="Tool", row=1, col=2)
        fig.update_yaxes(title_text="Number of Variants", row=1, col=1)
        fig.update_yaxes(title_text="Number of Variants", row=1, col=2)

        fig.update_layout(
            title="Variant Counts by Tool and FILTER Category",
            template="plotly_white",
            barmode="stack",
            height=500,
            showlegend=True,
        )

        fig.show()

        fig.show()

    def plot_quality_distributions(self):
        """Box plot of quality score distributions"""
        data = []

        for category, files in self.all_stats.items():
            for name, vcf_data in files.items():
                if "stats" in vcf_data and "basic" in vcf_data["stats"]:
                    qualities = vcf_data["stats"]["basic"].get("qualities", [])

                    if qualities:
                        parts = name.split("_")
                        tool = parts[0] if parts else category
                        modality = "DNA" if "DNA_TUMOR" in name else "RNA"

                        for qual in qualities[:1000]:  # Limit for performance
                            data.append(
                                {
                                    "Category": category,
                                    "Tool": tool,
                                    "Modality": modality,
                                    "Quality": qual,
                                }
                            )

        if not data:
            print("No quality data available")
            return

        df = pd.DataFrame(data)

        fig = px.box(
            df,
            x="Tool",
            y="Quality",
            color="Modality",
            facet_col="Category",
            title="Quality Score Distributions",
            template="plotly_white",
            height=500,
        )

        fig.update_yaxes(title_text="QUAL Score")
        fig.show()

    def plot_variant_type_distribution(self):
        """Stacked bar charts showing SNP vs INDEL distribution with FILTER categories"""

        # Define colors
        CATEGORY_COLORS = {
            "Somatic": "#636EFA",
            "Germline": "#00CC96",
            "Reference": "#FFA15A",
            "Artifact": "#EF553B",
        }
        CATEGORY_ORDER = ["Somatic", "Germline", "Reference", "Artifact"]

        data = []

        # Collect data from consensus VCFs for cleaner view
        for vcf_type in ["consensus", "rescue"]:
            if vcf_type in self.all_stats:
                for name, vcf_data in self.all_stats[vcf_type].items():
                    if "stats" in vcf_data and "basic" in vcf_data["stats"]:
                        basic = vcf_data["stats"]["basic"]
                        classification = basic.get("classification", {})

                        if vcf_type == "consensus":
                            if "DNA_TUMOR" in name:
                                modality = "DNA Consensus"
                            elif "RNA_TUMOR" in name:
                                modality = "RNA Consensus"
                            else:
                                continue
                        else:
                            modality = "Rescued"

                        # Get variant types
                        variant_types = basic.get("variant_types", {})
                        snps = variant_types.get("SNP", 0)
                        indels = variant_types.get("DEL", 0) + variant_types.get(
                            "INS", 0
                        )

                        # For each FILTER category, calculate proportional SNP/INDEL split
                        total_vars = basic.get("total_variants", 1)
                        for filter_cat in CATEGORY_ORDER:
                            count = classification.get(filter_cat, 0)
                            if count > 0:
                                # Proportionally split into SNPs and INDELs
                                snp_count = int(count * (snps / total_vars))
                                indel_count = count - snp_count

                                if snp_count > 0:
                                    data.append(
                                        {
                                            "Modality": modality,
                                            "Type": "SNP",
                                            "Category": filter_cat,
                                            "Count": snp_count,
                                        }
                                    )
                                if indel_count > 0:
                                    data.append(
                                        {
                                            "Modality": modality,
                                            "Type": "INDEL",
                                            "Category": filter_cat,
                                            "Count": indel_count,
                                        }
                                    )

        if not data:
            print("No data available for plotting")
            return

        df = pd.DataFrame(data)

        # Create subplots
        modalities = ["DNA Consensus", "RNA Consensus", "Rescued"]
        available_mods = [m for m in modalities if m in df["Modality"].values]
        n_mods = len(available_mods)

        fig = make_subplots(
            rows=1,
            cols=n_mods,
            subplot_titles=available_mods,
            horizontal_spacing=0.10,
        )

        for i, modality in enumerate(available_mods, 1):
            df_mod = df[df["Modality"] == modality]

            # Group by Type (SNP/INDEL) and stack by Category
            for var_type in ["SNP", "INDEL"]:
                df_type = df_mod[df_mod["Type"] == var_type]
                if not df_type.empty:
                    for filter_cat in CATEGORY_ORDER:
                        df_cat = df_type[df_type["Category"] == filter_cat]
                        if not df_cat.empty:
                            count = df_cat["Count"].sum()
                            fig.add_trace(
                                go.Bar(
                                    name=f"{filter_cat}",
                                    x=[var_type],
                                    y=[count],
                                    marker_color=CATEGORY_COLORS[filter_cat],
                                    text=[count],
                                    textposition="inside",
                                    showlegend=(i == 1),
                                    legendgroup=filter_cat,
                                    hovertemplate=f"<b>{var_type}</b><br>{filter_cat}: %{{y}}<extra></extra>",
                                ),
                                row=1,
                                col=i,
                            )

            fig.update_xaxes(title_text="Variant Type", row=1, col=i)
            if i == 1:
                fig.update_yaxes(title_text="Number of Variants", row=1, col=i)

        fig.update_layout(
            title_text="Variant Type Distribution (SNP vs INDEL) by FILTER Category",
            height=500,
            barmode="stack",
            template="plotly_white",
            showlegend=True,
        )

        fig.show()

        fig.show()

    def plot_consensus_comparison(self):
        """Compare consensus variants to individual tools with FILTER category breakdown"""

        # Define unified color scheme
        CATEGORY_COLORS = {
            "Somatic": "#636EFA",
            "Germline": "#00CC96",
            "Reference": "#FFA15A",
            "Artifact": "#EF553B",
        }
        CATEGORY_ORDER = ["Somatic", "Germline", "Reference", "Artifact"]

        data = []

        # Get tool-level classification data
        if "variant_calling" in self.all_stats:
            for name, vcf_data in self.all_stats["variant_calling"].items():
                if "stats" in vcf_data and "basic" in vcf_data["stats"]:
                    basic = vcf_data["stats"]["basic"]
                    classification = basic.get("classification", {})
                    parts = name.split("_")
                    tool = parts[0] if parts else name
                    modality = "DNA" if "DNA_TUMOR" in name else "RNA"

                    for filter_cat in CATEGORY_ORDER:
                        count = classification.get(filter_cat, 0)
                        if count > 0:
                            data.append(
                                {
                                    "Category": tool,
                                    "Modality": modality,
                                    "FilterCat": filter_cat,
                                    "Count": count,
                                }
                            )

        # Get consensus classification data
        if "consensus" in self.all_stats:
            for modality_key, vcf_data in self.all_stats["consensus"].items():
                if "stats" in vcf_data and "basic" in vcf_data["stats"]:
                    basic = vcf_data["stats"]["basic"]
                    classification = basic.get("classification", {})
                    modality = "DNA" if "DNA_TUMOR" in modality_key else "RNA"

                    for filter_cat in CATEGORY_ORDER:
                        count = classification.get(filter_cat, 0)
                        if count > 0:
                            data.append(
                                {
                                    "Category": "consensus",
                                    "Modality": modality,
                                    "FilterCat": filter_cat,
                                    "Count": count,
                                }
                            )

        if not data:
            print("No comparison data available")
            return

        df = pd.DataFrame(data)

        # Create subplots for DNA and RNA modalities
        fig = make_subplots(
            rows=1,
            cols=2,
            subplot_titles=("DNA Modality", "RNA Modality"),
            horizontal_spacing=0.12,
            shared_yaxes=True,
        )

        # Get all unique categories and sort them (tools alphabetically, consensus last)
        categories = sorted([c for c in df["Category"].unique() if c != "consensus"])
        if "consensus" in df["Category"].unique():
            categories.append("consensus")

        # DNA modality
        df_dna = df[df["Modality"] == "DNA"]
        if not df_dna.empty:
            for filter_cat in CATEGORY_ORDER:
                df_filter = df_dna[df_dna["FilterCat"] == filter_cat]
                if not df_filter.empty:
                    counts = [
                        df_filter[df_filter["Category"] == cat]["Count"].sum()
                        if not df_filter[df_filter["Category"] == cat].empty
                        else 0
                        for cat in categories
                    ]
                    fig.add_trace(
                        go.Bar(
                            name=filter_cat,
                            x=categories,
                            y=counts,
                            marker_color=CATEGORY_COLORS[filter_cat],
                            text=counts,
                            textposition="inside",
                            showlegend=True,
                            legendgroup=filter_cat,
                        ),
                        row=1,
                        col=1,
                    )

        # RNA modality
        df_rna = df[df["Modality"] == "RNA"]
        if not df_rna.empty:
            for filter_cat in CATEGORY_ORDER:
                df_filter = df_rna[df_rna["FilterCat"] == filter_cat]
                if not df_filter.empty:
                    counts = [
                        df_filter[df_filter["Category"] == cat]["Count"].sum()
                        if not df_filter[df_filter["Category"] == cat].empty
                        else 0
                        for cat in categories
                    ]
                    fig.add_trace(
                        go.Bar(
                            name=filter_cat,
                            x=categories,
                            y=counts,
                            marker_color=CATEGORY_COLORS[filter_cat],
                            text=counts,
                            textposition="inside",
                            showlegend=False,
                            legendgroup=filter_cat,
                        ),
                        row=1,
                        col=2,
                    )

        fig.update_xaxes(title_text="Category", row=1, col=1)
        fig.update_xaxes(title_text="Category", row=1, col=2)
        fig.update_yaxes(title_text="Number of Variants", row=1, col=1)

        fig.update_layout(
            title_text="Variant Counts by Category with FILTER Classification",
            template="plotly_white",
            height=500,
            barmode="stack",
        )

        fig.show()

        fig.show()

    def plot_filter_status(self):
        """Stacked bar chart showing unified FILTER categories (Germline, Somatic, Reference, Artifact)"""

        # Define unified color scheme for the 4 FILTER categories
        CATEGORY_COLORS = {
            "Somatic": "#636EFA",  # blue - high confidence somatic variants
            "Germline": "#00CC96",  # green - germline variants
            "Reference": "#FFA15A",  # orange - reference calls
            "Artifact": "#EF553B",  # red - artifacts/filtered
        }

        # Expected order for consistent display
        CATEGORY_ORDER = ["Somatic", "Germline", "Reference", "Artifact"]

        # Collect data from all VCF stats
        data = []

        for category, files in self.all_stats.items():
            for name, vcf_data in files.items():
                if "stats" not in vcf_data or "basic" not in vcf_data["stats"]:
                    continue

                basic = vcf_data["stats"]["basic"]
                classification = basic.get("classification", {})

                if not classification:
                    continue

                # Determine subplot category and tool name
                parts = name.split("_")

                if category == "variant_calling":
                    tool = parts[0] if parts else "unknown"
                    if "DNA_TUMOR" in name:
                        subplot_cat = "DNA"
                    elif "RNA_TUMOR" in name:
                        subplot_cat = "RNA"
                    else:
                        continue  # Skip if not DNA or RNA

                elif category == "consensus":
                    tool = "consensus"
                    if "DNA_TUMOR" in name:
                        subplot_cat = "DNA"
                    elif "RNA_TUMOR" in name:
                        subplot_cat = "RNA"
                    else:
                        continue

                elif category == "rescue":
                    tool = "rescue"
                    subplot_cat = "Rescue"

                else:
                    continue  # Skip other categories

                # Add counts for each classification category
                for filter_cat in CATEGORY_ORDER:
                    count = classification.get(filter_cat, 0)
                    if count > 0:  # Only add if there are variants
                        data.append(
                            {
                                "SubplotCategory": subplot_cat,
                                "Tool": tool,
                                "FilterCategory": filter_cat,
                                "Count": count,
                                "Name": name,
                            }
                        )

        if not data:
            print("No classification data available for plotting")
            return

        df = pd.DataFrame(data)

        # Determine available subplot categories
        subplot_categories = ["DNA", "RNA", "Rescue"]
        available_subplots = [
            cat for cat in subplot_categories if cat in df["SubplotCategory"].values
        ]
        n_subplots = len(available_subplots)

        if n_subplots == 0:
            print("No data to plot")
            return

        # Create subplots with unified y-axis
        fig = make_subplots(
            rows=1,
            cols=n_subplots,
            subplot_titles=available_subplots,
            horizontal_spacing=0.10,
            shared_yaxes=True,  # Unified y-axis across all subplots
        )

        # Find global y-axis max for consistent scaling
        max_y = 0

        # Plot each subplot
        for i, subplot_cat in enumerate(available_subplots, 1):
            df_subplot = df[df["SubplotCategory"] == subplot_cat]

            # Get unique tools in this subplot and sort them
            tools = sorted(df_subplot["Tool"].unique())

            # For each filter category, add a stacked bar
            for filter_cat in CATEGORY_ORDER:
                df_filter = df_subplot[df_subplot["FilterCategory"] == filter_cat]

                # Create counts array aligned with tools
                counts = []
                for tool in tools:
                    tool_data = df_filter[df_filter["Tool"] == tool]
                    count = tool_data["Count"].sum() if not tool_data.empty else 0
                    counts.append(count)

                # Only add trace if there are non-zero counts
                if sum(counts) > 0:
                    fig.add_trace(
                        go.Bar(
                            name=filter_cat,
                            x=tools,
                            y=counts,
                            marker_color=CATEGORY_COLORS[filter_cat],
                            text=counts,
                            textposition="inside",
                            textfont=dict(color="white", size=10),
                            showlegend=(i == 1),  # Only show legend for first subplot
                            legendgroup=filter_cat,
                            hovertemplate=f"<b>%{{x}}</b><br>{filter_cat}: %{{y}}<extra></extra>",
                        ),
                        row=1,
                        col=i,
                    )

            # Calculate total height for this subplot
            for tool in tools:
                tool_data = df_subplot[df_subplot["Tool"] == tool]
                total = tool_data["Count"].sum()
                max_y = max(max_y, total)

            # Update axes labels
            fig.update_xaxes(title_text="Caller", row=1, col=i)
            if i == 1:
                fig.update_yaxes(title_text="Number of Variants", row=1, col=i)

        # Update layout with unified settings
        fig.update_layout(
            title_text="Variant Classification by FILTER Category (Stacked)",
            template="plotly_white",
            height=500,
            barmode="stack",
            showlegend=True,
            legend=dict(
                orientation="v",
                yanchor="top",
                y=1.0,
                xanchor="left",
                x=1.02,
                title="FILTER Category",
            ),
        )

        # Set unified y-axis range with some padding
        for i in range(1, n_subplots + 1):
            fig.update_yaxes(range=[0, max_y * 1.1], row=1, col=i)

        fig.show()

---

## EXECUTION SECTION

All reusable code (classes and functions) are defined above. Below are the execution cells that use them.

## Step 1: Discover VCF and Alignment Files

Discover all VCF files across the pipeline output and alignment files.

In [5]:
# Discover files
discovery = VCFFileDiscovery(BASE_DIR)
vcf_files = discovery.discover_vcfs()
alignment_files = discovery.discover_alignments()
discovery.print_summary()

VCF FILE DISCOVERY SUMMARY

VARIANT_CALLING:
  ✓ deepsomatic_DNA_TUMOR_vs_DNA_NORMAL: DNA_TUMOR_vs_DNA_NORMAL.deepsomatic.vcf.gz
  ✓ deepsomatic_RNA_TUMOR_vs_DNA_NORMAL: RNA_TUMOR_vs_DNA_NORMAL.deepsomatic.vcf.gz
  ✓ mutect2_DNA_TUMOR_vs_DNA_NORMAL: DNA_TUMOR_vs_DNA_NORMAL.mutect2.vcf.gz
  ✓ mutect2_RNA_TUMOR_vs_DNA_NORMAL: RNA_TUMOR_vs_DNA_NORMAL.mutect2.vcf.gz
  ✓ strelka_DNA_TUMOR_vs_DNA_NORMAL: DNA_TUMOR_vs_DNA_NORMAL.strelka.variants.vcf.gz
  ✓ strelka_RNA_TUMOR_vs_DNA_NORMAL: RNA_TUMOR_vs_DNA_NORMAL.strelka.variants.vcf.gz

NORMALIZED:
  ✓ deepsomatic_DNA_TUMOR_vs_DNA_NORMAL: DNA_TUMOR_vs_DNA_NORMAL.deepsomatic.variants.dec.norm.vcf.gz
  ✓ deepsomatic_RNA_TUMOR_vs_DNA_NORMAL: RNA_TUMOR_vs_DNA_NORMAL.deepsomatic.variants.dec.norm.vcf.gz
  ✓ mutect2_DNA_TUMOR_vs_DNA_NORMAL: DNA_TUMOR_vs_DNA_NORMAL.mutect2.variants.dec.norm.vcf.gz
  ✓ mutect2_RNA_TUMOR_vs_DNA_NORMAL: RNA_TUMOR_vs_DNA_NORMAL.mutect2.variants.dec.norm.vcf.gz
  ✓ strelka_DNA_TUMOR_vs_DNA_NORMAL: DNA_TUMOR_vs_DNA_NORMAL

## Step 2: Process All VCF Files

Extract comprehensive statistics from all discovered VCF files.

In [6]:
# Process all VCFs
print("Starting comprehensive VCF analysis...")
all_vcf_stats = process_all_vcfs(vcf_files)
print("\n✓ All VCF files processed successfully!")

Starting comprehensive VCF analysis...

PROCESSING: VARIANT_CALLING

Processing: DNA_TUMOR_vs_DNA_NORMAL.deepsomatic.vcf.gz
  [DEBUG] Starting header parsing...
  [DEBUG] Found 1 INFO fields in header
  [DEBUG] Processed 10001 variants, calculating statistics...
  [DEBUG] Calculated statistics for 0 INFO fields
  ✓ Total variants: 27697
  ✓ SNPs: 26353, INDELs: 1344
  ✓ Passed filters: 52, Filtered: 27645

Processing: RNA_TUMOR_vs_DNA_NORMAL.deepsomatic.vcf.gz
  [DEBUG] Starting header parsing...
  [DEBUG] Found 1 INFO fields in header
  [DEBUG] Processed 10001 variants, calculating statistics...
  [DEBUG] Calculated statistics for 0 INFO fields
  ✓ Total variants: 13719
  ✓ SNPs: 10866, INDELs: 2853
  ✓ Passed filters: 48, Filtered: 13671

Processing: DNA_TUMOR_vs_DNA_NORMAL.mutect2.vcf.gz
  [DEBUG] Starting header parsing...
  [DEBUG] Found 25 INFO fields in header
  [DEBUG] Processed 758 variants, calculating statistics...
  [DEBUG] Calculated statistics for 16 INFO fields
  ✓ Total

In [7]:
all_vcf_stats["variant_calling"]["strelka_RNA_TUMOR_vs_DNA_NORMAL"]["stats"]["basic"]

{'total_variants': 8738,
 'snps': 8695,
 'indels': 43,
 'mnps': 0,
 'complex': 0,
 'passed': 248,
 'filtered': 8490,
 'chromosomes': ['chr1',
  'chr10',
  'chr11',
  'chr12',
  'chr13',
  'chr14',
  'chr15',
  'chr16',
  'chr17',
  'chr18',
  'chr19',
  'chr2',
  'chr20',
  'chr21',
  'chr22',
  'chr3',
  'chr4',
  'chr5',
  'chr6',
  'chr7',
  'chr8',
  'chr9',
  'chrX'],
 'qualities': [],
 'variant_types': defaultdict(int, {'SNP': 8695, 'INS': 26, 'DEL': 17}),
 'classification': {'Reference': 7897,
  'Somatic': 248,
  'Artifact': 526,
  'Germline': 67}}

In [8]:
# Check if INFO and FORMAT fields are populated
print("Checking INFO and FORMAT field extraction:\n")

for category, vcfs in all_vcf_stats.items():
    print(f"\n{category.upper()}:")
    for name, data in vcfs.items():
        if "stats" in data:
            info_count = len(data["stats"].get("info", {}))
            format_data = data["stats"].get("format", {})
            format_count = sum(len(fields) for fields in format_data.values())

            print(
                f"  {name[:40]:40} - INFO: {info_count:2} fields, FORMAT: {format_count:2} fields"
            )

            # Show first INFO field as example
            if info_count > 0:
                first_info = list(data["stats"]["info"].keys())[0]
                print(f"    → Example INFO field: {first_info}")


Checking INFO and FORMAT field extraction:


VARIANT_CALLING:
  deepsomatic_DNA_TUMOR_vs_DNA_NORMAL      - INFO:  0 fields, FORMAT:  2 fields
  deepsomatic_RNA_TUMOR_vs_DNA_NORMAL      - INFO:  0 fields, FORMAT:  2 fields
  mutect2_DNA_TUMOR_vs_DNA_NORMAL          - INFO: 16 fields, FORMAT:  4 fields
    → Example INFO field: AS_SB_TABLE
  mutect2_RNA_TUMOR_vs_DNA_NORMAL          - INFO: 16 fields, FORMAT:  4 fields
    → Example INFO field: AS_SB_TABLE
  strelka_DNA_TUMOR_vs_DNA_NORMAL          - INFO: 21 fields, FORMAT:  2 fields
    → Example INFO field: DP
  strelka_RNA_TUMOR_vs_DNA_NORMAL          - INFO: 22 fields, FORMAT:  2 fields
    → Example INFO field: DP

NORMALIZED:
  deepsomatic_DNA_TUMOR_vs_DNA_NORMAL      - INFO:  1 fields, FORMAT:  2 fields
    → Example INFO field: OLD_MULTIALLELIC
  deepsomatic_RNA_TUMOR_vs_DNA_NORMAL      - INFO:  1 fields, FORMAT:  2 fields
    → Example INFO field: OLD_MULTIALLELIC
  mutect2_DNA_TUMOR_vs_DNA_NORMAL          - INFO: 17 fields, FOR

## Step 3: BAM Validation (Optional)

Validate variants by checking read support in original BAM/CRAM alignment files.

### Run Validation Example

Example: Validate first 50 variants from a consensus VCF.

In [9]:
# Example: Validate DNA consensus VCF
if "DNA_TUMOR_vs_DNA_NORMAL" in vcf_files.get("consensus", {}):
    dna_consensus_vcf = vcf_files["consensus"]["DNA_TUMOR_vs_DNA_NORMAL"]

    # Map to BAM files
    bam_map = {}
    if "DNA_TUMOR" in alignment_files:
        bam_map["DNA_TUMOR"] = alignment_files["DNA_TUMOR"]
    if "DNA_NORMAL" in alignment_files:
        bam_map["DNA_NORMAL"] = alignment_files["DNA_NORMAL"]

    if bam_map:
        # Create validator instance with reference genome
        validator = BAMValidator(reference_fasta=REFERENCE_FASTA)

        print(f"Validating {dna_consensus_vcf.name} with alignment files...")
        validation_results = validator.validate_variants(
            dna_consensus_vcf, bam_map, max_variants=50
        )

        if validation_results:
            validation_df = validator.summarize_validation(validation_results)
            print(f"\n✓ Validated {len(validation_results)} variants")
            print("\nFirst few validation results:")
            print(validation_df.head(10))
        else:
            print("No validation results obtained")
    else:
        print("No alignment files available for validation")
else:
    print("No consensus VCF found for validation example")

Validating DNA_TUMOR_vs_DNA_NORMAL.consensus.vcf.gz with alignment files...

✓ Validated 50 variants

First few validation results:
  chrom     pos ref alt       qual       filter  DNA_TUMOR_total_depth  \
0  chr1  935849  GC   G   0.000000  NoConsensus                      4   
1  chr1  942451   T   C  37.400002  NoConsensus                      4   
2  chr1  943314   G   A   0.100000    Reference                      4   
3  chr1  944105   C   A   0.000000  NoConsensus                     22   
4  chr1  946277   C   A   0.000000      Somatic                      5   
5  chr1  948149   C   T   0.300000  NoConsensus                      6   
6  chr1  952056   G   T   0.000000    Reference                     10   
7  chr1  953259   T   C   8.400000  NoConsensus                     24   
8  chr1  953279   T   C  34.400002  NoConsensus                     25   
9  chr1  953858   G   A  36.599998  NoConsensus                      9   

   DNA_TUMOR_ref_count  DNA_TUMOR_alt_G_count  DNA_TU

## Step 4: Data Aggregation & Summary Statistics

Aggregate statistics across all VCF files and create comprehensive summaries.

In [10]:
# Create aggregator
aggregator = StatisticsAggregator(all_vcf_stats)

# Generate summary tables
print("=" * 80)
print("VARIANT COUNT SUMMARY")
print("=" * 80)
variant_summary = aggregator.create_variant_count_summary()
print(variant_summary.to_string(index=False))

print("\n" + "=" * 80)
print("QUALITY SCORE SUMMARY")
print("=" * 80)
quality_summary = aggregator.create_quality_summary()
print(quality_summary.to_string(index=False))

print("\n" + "=" * 80)
print("TOOL COMPARISON BY MODALITY")
print("=" * 80)
tool_comparison = aggregator.compare_tools_by_modality()
print(tool_comparison.to_string(index=False))

print("\n" + "=" * 80)
print("CONSENSUS vs INDIVIDUAL TOOLS")
print("=" * 80)
consensus_comparison = aggregator.compare_consensus_to_individual()
print(consensus_comparison.to_string(index=False))

VARIANT COUNT SUMMARY
       Category        Tool                                            Modality  Total_Variants  SNPs  INDELs  Passed  Filtered  Pass_Rate  Reference  Germline  Somatic  Artifact
      consensus         DNA                                 TUMOR_vs_DNA_NORMAL           30624 29260    1364       0     30624   0.000000    10948.0    1129.0      670   17877.0
      consensus         RNA                                 TUMOR_vs_DNA_NORMAL           19246 16331    2915       0     19246   0.000000     2441.0     637.0      318   15850.0
     normalized deepsomatic                             DNA_TUMOR_vs_DNA_NORMAL           27701 26357    1344      52     27649   0.001877    21613.0    6036.0       52       NaN
     normalized deepsomatic                             RNA_TUMOR_vs_DNA_NORMAL           13722 10868    2854      48     13674   0.003498    11279.0    2395.0       48       NaN
     normalized     mutect2                             DNA_TUMOR_vs_DNA_NORMAL    

In [11]:
print("\n" + "=" * 80)
print("VARIANT BIOLOGICAL CLASSIFICATION")
print("=" * 80)
classification_summary = aggregator.create_classification_summary()
if not classification_summary.empty:
    print(classification_summary.to_string(index=False))
    print("\nClassification Legend:")
    print("  • Somatic: High-confidence somatic variants (PASS filter)")
    print("  • Germline: Variants present in normal sample")
    print("  • Reference: Reference calls (no variant in tumor)")
    print("  • Artifact: Low-quality or failed filter variants")
else:
    print("No classification data available")


VARIANT BIOLOGICAL CLASSIFICATION
       Category        Tool                                            Modality  Total_Variants  Somatic  Germline  Reference  Artifact  Somatic_Rate  Germline_Rate  Reference_Rate  Artifact_Rate
      consensus         DNA                                 TUMOR_vs_DNA_NORMAL           30624      670      1129      10948     17877      0.021878       0.036867        0.357497       0.583758
      consensus         RNA                                 TUMOR_vs_DNA_NORMAL           19246      318       637       2441     15850      0.016523       0.033098        0.126832       0.823548
     normalized deepsomatic                             DNA_TUMOR_vs_DNA_NORMAL           27701       52      6036      21613         0      0.001877       0.217898        0.780225       0.000000
     normalized deepsomatic                             RNA_TUMOR_vs_DNA_NORMAL           13722       48      2395      11279         0      0.003498       0.174537        0.821965 

## Step 5: Visualizations

Create comprehensive visualizations of VCF statistics.

In [12]:
# Create visualizer
visualizer = VCFVisualizer(all_vcf_stats)
print("✓ Visualizer created. Ready to generate plots.")

✓ Visualizer created. Ready to generate plots.


### Plot 1: Variant Counts by Tool

In [13]:
visualizer.plot_variant_counts_by_tool()

### Plot 2: Quality Score Distributions

In [14]:
# visualizer.plot_quality_distributions()

### Plot 3: Variant Type Distribution

In [15]:
visualizer.plot_variant_type_distribution()

### Plot 4: Consensus vs Individual Tools

In [16]:
visualizer.plot_consensus_comparison()

### Plot 5: Filter Status

In [17]:
visualizer.plot_filter_status()

## Step 6: Advanced Analysis - Rescue VCF Statistics

Analyze the rescue VCFs that combine DNA and RNA modality variants.

In [18]:
analyze_rescue_vcf(all_vcf_stats)


RESCUE VCF ANALYSIS WITH FILTER TRACKING

Category        DNA Consensus   RNA Consensus   Rescued        
------------------------------------------------------------
Somatic         670             318             979            
Germline        1,129           637             1,579          
Reference       10,948          2,441           13,349         
Artifact        17,877          15,850          32,058         
------------------------------------------------------------
TOTAL           30,624          19,246          47,965         

Overall Rescue Rate: 156.63% of DNA consensus variants

FILTER CATEGORY TRANSITION ANALYSIS

This analysis shows how FILTER labels change across stages:
  DNA Consensus → RNA Consensus → Rescued

Note: Exact variant-level tracking requires VCF file parsing.
Below shows category-level changes based on total counts:


Somatic Category:
  DNA Consensus: 670
  RNA Consensus: 318 (-352, 47.5% of DNA)
  Rescued: 979 (+309, 146.1% of DNA)

Germline Cate


VARIANT TYPE COMPOSITION

DNA Consensus:
  SNPs: 29,260 (95.5%)
  INDELs: 1,364 (4.5%) [DEL: 824, INS: 540]

RNA Consensus:
  SNPs: 16,331 (84.9%)
  INDELs: 2,915 (15.1%) [DEL: 1,573, INS: 1,342]

Rescued:
  SNPs: 43,711 (91.1%)
  INDELs: 4,254 (8.9%) [DEL: 2,380, INS: 1,874]



## Step 7: Export Results

Export summary statistics to CSV files for further analysis.

In [19]:
# Export results
output_dir = Path("vcf_statistics_output")
export_results(
    variant_summary,
    quality_summary,
    tool_comparison,
    consensus_comparison,
    output_dir=output_dir,
)

Exporting results...
✓ Exported: vcf_statistics_output/variant_count_summary.csv
✓ Exported: vcf_statistics_output/quality_summary.csv
✓ Exported: vcf_statistics_output/tool_comparison.csv
✓ Exported: vcf_statistics_output/consensus_comparison.csv

✓ All results exported to vcf_statistics_output/


## Step 8: Summary Report

Generate a comprehensive summary report of all analyses.

In [21]:
# Generate report
summary_report = generate_summary_report(
    all_vcf_stats,
    vcf_files,
    variant_summary,
    quality_summary,
    tool_comparison,
    consensus_comparison,
    output_dir,
)

VCF STATISTICS - COMPREHENSIVE SUMMARY REPORT

## 1. OVERVIEW

Total VCF files analyzed: 15
Categories: variant_calling, normalized, consensus, rescue
Tools: deepsomatic, mutect2, strelka
Modalities: DNA, RNA

## 2. VARIANT CALLING TOOLS COMPARISON

### DNA Modality:
  deepsomatic  -  27697 variants (SNPs: 26353, INDELs: 1344)
  mutect2      -    758 variants (SNPs:   731, INDELs:   27)
  strelka      -  15555 variants (SNPs: 15545, INDELs:   10)

### RNA Modality:
  deepsomatic  -  13719 variants (SNPs: 10866, INDELs: 2853)
  mutect2      -    338 variants (SNPs:   303, INDELs:   35)
  strelka      -   8738 variants (SNPs:  8695, INDELs:   43)

## 3. CONSENSUS ANALYSIS

### DNA Consensus: 30624 variants

  Tool contributions:
    deepsomatic : 27697 variants → 110.6% retained in consensus
    mutect2     :   758 variants → 4040.1% retained in consensus
    strelka     : 15555 variants → 196.9% retained in consensus
    deepsomatic : 13719 variants → 140.3% retained in consensus
    mu

---

## Quick Reference Guide

### What This Notebook Does

This comprehensive VCF statistics notebook provides:

1. **File Discovery** - Automatically finds all VCF files across your pipeline
2. **Statistics Extraction** - Uses cyvcf2 to extract:
   - Variant counts (SNPs, INDELs, complex)
   - Quality scores and distributions
   - INFO field statistics (DP, AF, TLOD, etc.)
   - FORMAT field statistics (per-sample depth, allele frequency, genotype quality)
   - Filter status

3. **BAM Validation** - Uses pysam to:
   - Cross-reference variants with alignment files
   - Calculate read support (ref/alt counts)
   - Validate variant allele frequencies (VAF)

4. **Comprehensive Analysis**:
   - Tool comparison (DeepSomatic, Mutect2, Strelka)
   - Modality comparison (DNA vs RNA)
   - Consensus analysis (agreement across tools)
   - Rescue analysis (cross-modality variant recovery)

5. **Visualizations**:
   - Interactive Plotly charts
   - Quality distributions
   - Variant type breakdowns
   - Tool performance comparisons

6. **Export** - All results saved as CSV files

### Main Functions Available

- `VCFFileDiscovery`: Discover VCF and alignment files
- `VCFStatisticsExtractor`: Extract statistics from VCF files
- `BAMValidator`: Validate variants using BAM/CRAM files
- `StatisticsAggregator`: Aggregate and compare statistics
- `VCFVisualizer`: Create visualizations
- `process_all_vcfs()`: Batch process all VCFs
- `analyze_rescue_vcf()`: Analyze rescue variants
- `export_results()`: Export to CSV
- `generate_summary_report()`: Generate text report

### Quick Customization Examples

**Change base directory:**
```python
BASE_DIR = Path("/your/custom/path")
```

**Process single VCF:**
```python
extractor = VCFStatisticsExtractor(vcf_path)
stats = extractor.extract_all_stats()
```

**Analyze specific INFO field:**
```python
tlod_summary = aggregator.create_info_field_summary('TLOD')
print(tlod_summary)
```

**Custom validation:**
```python
results = validator.validate_variants(vcf_path, bam_paths, max_variants=200)
```