In [1]:
# Standard library imports
import os
import glob
import gzip
import pickle
import logging
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from statistics import mean
from typing import Dict, List, Any

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pysam
from scipy import stats
from scipy.optimize import curve_fit
from scipy.spatial.distance import cdist
from sklearn import metrics
from statsmodels.stats.multitest import multipletests

# Seaborn settings
sns.set_style("whitegrid")
sns.set_context("paper")
sns.set_palette("colorblind")

# Logging settings
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s", force=True)
logger = logging.getLogger(__name__)


# Oxford Nanopore Sequencing Benchmark

This report presents a benchmark of SNVs, indels, and SVs, and a characterisation of the ONT dataset used for this benchmark.

## Methods

### Data Processing

- **Basecalling**: [wf-basecalling v1.1.7](https://github.com/epi2me-labs/wf-basecalling/tree/v1.1.7)

- **Alignment and Variant Calling**: [wf-human-variation v2.1.0](https://github.com/epi2me-labs/wf-human-variation/tree/v2.1.0)

### Quality Control Tools

- **NanoPlot**: [1.42.0](https://quay.io/biocontainers/nanoplot:1.42.0--pyhdfd78af_0)

  - Generates summary statistics for each sample and creates visualizations of QC metrics for sequencing summaries and aligned BAM files.

- **NanoComp**: [1.23.1](https://quay.io/biocontainers/nanocomp:1.23.1--pyhdfd78af_0)

  - Compares multiple sequencing runs and generates comparative plots.

- **mosdepth**: [0.3.3](https://github.com/brentp/mosdepth/tree/v0.3.3)

  - Calculates sequencing depth across the human genome for each sample.

- **rtg-tools**: [3.12.1](https://github.com/RealTimeGenomics/rtg-tools/tree/3.12.1)

  - Performs performs variant comparison against a truth dataset.

- **SURVIVOR**: [1.0.7](https://github.com/fritzsedlazeck/SURVIVOR)

  - Performs merging of vcf files to compare SVs within a sample and among populations/samples.


# Sequencing Quality Control

Aggregate table of the QC metrics from NanoStats for both singleplexed and multiplexed samples, from the aligned `.cram` files produced by `wf-human-variation`.

Unless otherwise specified, subsequent plots and statistics include only samples basecalled with the `sup` algorithm.


In [2]:
@dataclass
class NanoplotMetrics:
    """Data class to store Nanoplot metrics data."""

    sample: str
    basecall: str
    multiplexing: str
    metrics: Dict[str, Any]


def parse_nanostats(nanostats_file: Path) -> Dict[str, Any]:
    """
    Parse NanoStats.txt file and extract metrics.

    Args:
        nanostats_file: Path to the NanoStats.txt file.

    Returns:
        Dictionary containing parsed metrics from the NanoStats file.

    Raises:
        FileNotFoundError: If the NanoStats file doesn't exist.
        ValueError: If the file format is invalid or missing required metrics.
    """
    metrics = {}

    try:
        with open(nanostats_file, "r") as f:
            # Skip header line
            next(f)

            for line in f:
                line = line.strip()
                if not line:
                    continue

                # Split on tab
                parts = line.split("\t") if "\t" in line else line.split()

                if len(parts) < 2:
                    continue

                key = parts[0]
                value = parts[1]

                # Clean up key names
                key = key.rstrip(":_()")
                key = key.replace(" ", "_")

                # Handle special cases
                if "longest_read" in key or "highest_Q_read" in key:
                    # Extract the number and quality/length
                    value = value.strip("()")
                    if "(" in value:
                        main_val, extra_val = value.split("(")
                        main_val = float(main_val.strip())
                        extra_val = float(extra_val.rstrip(")"))
                        metrics[f"{key}_length"] = main_val
                        metrics[f"{key}_quality"] = extra_val
                    continue

                # Handle percentage cases (e.g., "Reads >Q5: 8151702 (99.8%) 46705.8Mb")
                if "(" in value and ")" in value:
                    count = float(value.split()[0])
                    percentage = float(value.split("(")[1].split("%")[0])
                    size = float(value.split()[-1].rstrip("Mb"))
                    metrics[f"{key}_count"] = count
                    metrics[f"{key}_percentage"] = percentage
                    metrics[f"{key}_size_mb"] = size
                    continue

                # Convert numeric values
                try:
                    if "." in value:
                        value = float(value)
                    else:
                        value = int(value)
                except ValueError:
                    # Keep as string if not numeric
                    pass

                metrics[key] = value

        if not metrics:
            raise ValueError(f"No metrics found in {nanostats_file}")

        return metrics

    except FileNotFoundError:
        logger.error(f"NanoStats file not found: {nanostats_file}")
        raise
    except Exception as e:
        logger.error(f"Error parsing NanoStats file {nanostats_file}: {str(e)}")
        raise


def determine_multiplexing(sample_name: str, seq_summaries_path: Path) -> str:
    """
    Determine multiplexing status by checking if sequencing summary folder names
    contain two samples separated by double underscores.

    Args:
        sample_name: Name of the sample to check.
        seq_summaries_path: Path to the directory containing sequencing summaries.

    Returns:
        String indicating multiplexing status ('multiplexed' or 'singleplexed').

    Raises:
        FileNotFoundError: If the sequencing summaries directory doesn't exist.
    """
    try:
        # Check each directory in seq_summaries_path
        for dir_path in seq_summaries_path.iterdir():
            if not dir_path.is_dir():
                continue

            # Check if directory name contains the sample name
            if sample_name in dir_path.name:
                # Check if the directory name contains two samples separated by '__'
                if "__" in dir_path.name:
                    return "multiplexed"
                return "singleplexed"

        logger.warning(f"No matching directory found for sample: {sample_name}")
        return "singleplexed"

    except FileNotFoundError:
        logger.error(f"Sequencing summaries directory not found: {seq_summaries_path}")
        raise
    except Exception as e:
        logger.error(
            f"Error determining multiplexing for sample {sample_name}: {str(e)}"
        )
        raise


def collect_nanoplot_data(
    aligned_bams_dir: Path | str,
    seq_summaries_dir: Path | str,
    basecall_suffixes: List[str] = ["sup"],
) -> List[Dict[str, Any]]:
    """
    Collects and processes Nanoplot data from specified directories.

    Args:
        aligned_bams_dir: Directory containing NanoStats files generated from aligned bam files.
        seq_summaries_dir: Directory containing NanoStats files generated from sequencing summaries.
        basecall_suffixes: List of basecalling suffixes.

    Returns:
        List of dictionaries containing processed Nanoplot metrics.

    Raises:
        FileNotFoundError: If required directories don't exist.
        ValueError: If no valid data found in the directories.
    """
    aligned_bams_path = Path(aligned_bams_dir)
    seq_summaries_path = Path(seq_summaries_dir)

    if not aligned_bams_path.exists():
        raise FileNotFoundError(
            f"NanoStats Aligned BAMs directory not found: {aligned_bams_path}"
        )
    if not seq_summaries_path.exists():
        raise FileNotFoundError(
            f"NanoStats Sequencing summaries directory not found: {seq_summaries_path}"
        )

    nanoplot_data = []

    try:
        for subdir in aligned_bams_path.iterdir():
            if not subdir.is_dir():
                continue

            for basecall_suffix in basecall_suffixes:
                file_suffix = f"_{basecall_suffix}"
                if subdir.name.endswith(file_suffix):
                    sample_name = subdir.name.split(file_suffix)[0]
                    nanostats_file = subdir / "NanoStats.txt"

                    if nanostats_file.is_file():
                        logger.info(f"Processing NanoStats for sample: {sample_name}")
                        metrics = parse_nanostats(nanostats_file)
                        metrics["sample"] = sample_name
                        metrics["basecall"] = basecall_suffix
                        metrics["multiplexing"] = determine_multiplexing(
                            sample_name, seq_summaries_path
                        )
                        nanoplot_data.append(metrics)
                    else:
                        logger.warning(
                            f"NanoStats.txt file not found for sample: {sample_name}"
                        )
                    break

        if not nanoplot_data:
            raise ValueError(
                "No valid Nanoplot data found in the specified directories"
            )

        return nanoplot_data

    except Exception as e:
        logger.error(f"Error processing Nanoplot data: {str(e)}")
        raise


# Define paths
np_seq_summaries_dir = Path(
    "/scratch/prj/ppn_als_longread/ont-benchmark/qc/nanoplot/seq_summaries/"
)
np_aligned_bams_dir = Path(
    "/scratch/prj/ppn_als_longread/ont-benchmark/qc/nanoplot/aligned_bams/"
)

# Collect metrics
np_metrics_data = collect_nanoplot_data(
    aligned_bams_dir=np_aligned_bams_dir,
    seq_summaries_dir=np_seq_summaries_dir,
    basecall_suffixes=["sup"],
)

logger.info(f"Successfully processed {len(np_metrics_data)} samples")


FileNotFoundError: NanoStats Aligned BAMs directory not found: /scratch/prj/ppn_als_longread/ont-benchmark/qc/nanoplot/aligned_bams

In [None]:
def parse_nanostats(file_path):
    metrics = {}
    with open(file_path, "r") as file:
        for line in file:
            line = line.strip()
            if line and not line.startswith("Metrics dataset"):
                if line.startswith("Reads >Q"):
                    parts = line.split(":")
                    key = parts[0].strip()
                    values = parts[1].strip().split()
                    metrics[f"{key}_count"] = int(values[0].replace(",", ""))
                    metrics[f"{key}_percentage"] = float(values[1].strip("()%"))
                    metrics[f"{key}_bases"] = float(values[2].strip("Mb")) * 1e6
                else:
                    key, value = line.split(maxsplit=1)
                    metrics[key.strip()] = value.strip()
    return metrics


def determine_multiplexing(sample_name, seq_summaries_dir):
    for subdir in os.listdir(seq_summaries_dir):
        if sample_name in subdir:
            return "multiplex" if "__" in subdir else "singleplex"
    return "unknown"


def collect_nanoplot_data():
    np_data = []
    for subdir in os.listdir(NP_ALIGNED_BAMS_DIR):
        for suffix in NP_BASECALL_SUFFIXES:
            if subdir.endswith(suffix):
                sample_name = subdir.split(suffix)[0]
                file_path = os.path.join(NP_ALIGNED_BAMS_DIR, subdir, "NanoStats.txt")
                if os.path.isfile(file_path):
                    metrics = parse_nanostats(file_path)
                    metrics["sample"] = sample_name
                    metrics["basecall"] = suffix.lstrip("_")
                    metrics["multiplexing"] = determine_multiplexing(
                        sample_name, NP_SEQ_SUMMARIES_DIR
                    )
                    np_data.append(metrics)
                break
    return np_data


def process_nanoplot_dataframe(df):
    df.drop(columns=["Metrics"], inplace=True, errors="ignore")

    sorted_samples = sorted(df["sample"].unique())
    sample_to_anon = {
        sample: f"Sample {i+1}" for i, sample in enumerate(sorted_samples)
    }

    df["anonymised_sample"] = df["sample"].map(sample_to_anon)

    column_order = ["sample", "anonymised_sample", "basecall", "multiplexing"] + [
        col
        for col in df.columns
        if col not in ["sample", "anonymised_sample", "basecall", "multiplexing"]
    ]

    column_types = {
        "multiplexing": "category",
        "basecall": "category",
        "anonymised_sample": "category",
        "number_of_reads": "numeric",
        "number_of_bases": "numeric",
        "number_of_bases_aligned": "numeric",
        "fraction_bases_aligned": "numeric",
        "mean_read_length": "numeric",
        "median_read_length": "numeric",
        "read_length_stdev": "numeric",
        "n50": "numeric",
        "average_identity": "numeric",
        "Reads >Q5_count": "numeric",
        "Reads >Q5_percentage": "numeric",
        "Reads >Q5_bases": "numeric",
        "Reads >Q7_count": "numeric",
        "Reads >Q7_percentage": "numeric",
        "Reads >Q7_bases": "numeric",
        "Reads >Q10_count": "numeric",
        "Reads >Q10_percentage": "numeric",
        "Reads >Q10_bases": "numeric",
        "Reads >Q12_count": "numeric",
        "Reads >Q12_percentage": "numeric",
        "Reads >Q12_bases": "numeric",
        "Reads >Q15_count": "numeric",
        "Reads >Q15_percentage": "numeric",
        "Reads >Q15_bases": "numeric",
    }

    for column, dtype in column_types.items():
        if column in df.columns:
            if dtype == "category":
                df[column] = df[column].astype(dtype)
            elif dtype == "numeric":
                df[column] = pd.to_numeric(df[column], errors="coerce")

    df = df[column_order]
    return df.sort_values(by=["multiplexing", "sample"])


NP_SEQ_SUMMARIES_DIR = (
    "/scratch/prj/ppn_als_longread/ont-benchmark/qc/nanoplot/seq_summaries/"
)
NP_ALIGNED_BAMS_DIR = (
    "/scratch/prj/ppn_als_longread/ont-benchmark/qc/nanoplot/aligned_bams/"
)
NP_BASECALL_SUFFIXES = ["_sup", "_hac"]

np_metrics_data = collect_nanoplot_data()
np_metrics_df = pd.DataFrame(np_metrics_data)
np_metrics_df = process_nanoplot_dataframe(np_metrics_df)

np_metrics_df


## Sequencing Yield

### 1. Raw Yields


In [None]:
def create_yield_plot(ax, data, x, y, hue, title, xlabel, ylabel):
    data = data.copy()

    data["sample_num"] = data[x].str.extract(r"(\d+)").astype(int)
    data = data.sort_values("sample_num")

    sns.barplot(x=x, y=y, hue=hue, data=data, ax=ax, order=data[x])

    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    for tick in ax.get_xticklabels():
        tick.set_rotation(45)
        tick.set_ha("right")

    locs, labels = ax.get_xticks(), ax.get_xticklabels()
    ax.set_xticks([loc + 0.2 for loc in locs])

    ax.legend(title=hue)


def plot_sample_yields(metrics_df, basecall_type="sup", figsize=(16, 6), dpi=300):
    yields_df = metrics_df[metrics_df["basecall"] == basecall_type]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, dpi=dpi)

    create_yield_plot(
        ax1,
        yields_df,
        "anonymised_sample",
        "number_of_reads",
        "multiplexing",
        f"Read Yield per Sample",
        "Sample",
        "Number of Reads",
    )

    create_yield_plot(
        ax2,
        yields_df,
        "anonymised_sample",
        "number_of_bases",
        "multiplexing",
        f"Base Yield per Sample",
        "Sample",
        "Number of Bases",
    )

    plt.tight_layout()
    plt.show()


yield_plot = plot_sample_yields(np_metrics_df)


#### Summary Stats


In [None]:
def format_number(num):
    return f"{num:,}"


def calculate_yield_stats(df):
    return {
        "reads": {
            "max": df["number_of_reads"].max(),
            "min": df["number_of_reads"].min(),
            "mean": df["number_of_reads"].mean(),
            "std": df["number_of_reads"].std(),
            "median": df["number_of_reads"].median(),
        },
        "bases": {
            "max": df["number_of_bases"].max(),
            "min": df["number_of_bases"].min(),
            "mean": df["number_of_bases"].mean(),
            "std": df["number_of_bases"].std(),
            "median": df["number_of_bases"].median(),
        },
    }


def print_yield_stats(stats, sample_type):
    print(f"\n{sample_type} Samples Statistics:")
    print("=" * 40)
    for metric in ["reads", "bases"]:
        print(f"\n{metric.capitalize()}:")
        for stat, value in stats[metric].items():
            formatted_value = format_number(round(value))
            print(f"  {stat.capitalize():4s}: {formatted_value}")


def calculate_percentage_increase(singleplex_val, multiplex_val):
    return ((singleplex_val - multiplex_val) / multiplex_val) * 100


singleplex_yields = np_metrics_df[
    (np_metrics_df["multiplexing"] == "singleplex")
    & (np_metrics_df["basecall"] == "sup")
]
multiplex_yields = np_metrics_df[
    (np_metrics_df["multiplexing"] == "multiplex")
    & (np_metrics_df["basecall"] == "sup")
]

singleplex_stats = calculate_yield_stats(singleplex_yields)
multiplex_stats = calculate_yield_stats(multiplex_yields)

print_yield_stats(singleplex_stats, "Singleplexed")
print_yield_stats(multiplex_stats, "Multiplexed")

print("\nPercentage Increase (Singleplexed vs Multiplexed):")
print("=" * 40)
for metric in ["reads", "bases"]:
    increase = calculate_percentage_increase(
        singleplex_stats[metric]["mean"], multiplex_stats[metric]["mean"]
    )
    print(f"Mean Number of {metric.capitalize():5s}: {increase:6.2f}%")


### 2. Read Lengths


In [None]:
def prepare_read_length_data(metrics_df):
    read_length_df = pd.melt(
        metrics_df[metrics_df["basecall"] == "sup"],
        id_vars=["sample", "anonymised_sample", "multiplexing"],
        value_vars=["mean_read_length", "median_read_length"],
        var_name="read_length_type",
        value_name="read_length",
    )

    read_length_df["read_length_type"] = read_length_df["read_length_type"].replace(
        {
            "mean_read_length": "Mean Read Length",
            "median_read_length": "Median Read Length",
        }
    )

    read_length_df["sample_num"] = (
        read_length_df["anonymised_sample"].str.extract(r"(\d+)").astype(int)
    )
    read_length_df = read_length_df.sort_values("sample_num")

    return read_length_df


def create_length_subplot(data, ax, title, hue):
    sns.barplot(
        x="anonymised_sample",
        y="read_length",
        hue=hue,
        data=data,
        errorbar=None,
        ax=ax,
        order=data["anonymised_sample"].unique(),
    )
    ax.set_title(title)
    ax.set_xlabel("Sample")
    ax.set_ylabel("Read Length (bp)")

    for tick in ax.get_xticklabels():
        tick.set_rotation(45)
        tick.set_ha("right")

    locs, labels = ax.get_xticks(), ax.get_xticklabels()
    ax.set_xticks([loc + 0.2 for loc in locs])

    ax.legend(title=hue)


def plot_read_lengths(read_length_df, figsize=(16, 6), dpi=300):
    fig, axes = plt.subplots(1, 2, figsize=figsize, sharey=True, dpi=dpi)

    for ax, read_length_type in zip(axes, ["Mean Read Length", "Median Read Length"]):
        data = read_length_df[read_length_df["read_length_type"] == read_length_type]
        create_length_subplot(data, ax, read_length_type, hue="multiplexing")

        if ax != axes[0]:  # For the right subplot
            ax.yaxis.set_tick_params(labelleft=True)

    plt.tight_layout()
    plt.show()


read_length_df = prepare_read_length_data(np_metrics_df)
plot_read_lengths(read_length_df)


In [None]:
def load_nanoplot_data(base_dir, metrics_df):
    all_nanoplot_data = pd.DataFrame()

    for _, row in metrics_df.iterrows():
        anonymised_sample = row["anonymised_sample"]
        sample = row["sample"]
        basecall = row["basecall"]
        sample_dir = f"{sample}_{basecall}"
        pickle_path = os.path.join(base_dir, sample_dir, "NanoPlot-data.pickle")

        if os.path.isfile(pickle_path):
            with open(pickle_path, "rb") as file:
                nanoplot_data = pickle.load(file)
            sample_data = pd.DataFrame(nanoplot_data)
            sample_data = sample_data[["readIDs", "quals", "lengths", "mapQ"]].copy()
            sample_data["anonymised_sample"] = anonymised_sample
            sample_data["basecall"] = basecall
            all_nanoplot_data = pd.concat(
                [all_nanoplot_data, sample_data], ignore_index=True
            )

    return all_nanoplot_data


def process_nanoplot_data(nanoplot_data, metrics_df):
    processed_data = nanoplot_data.merge(
        metrics_df[
            ["anonymised_sample", "multiplexing", "basecall", "number_of_reads"]
        ],
        on=["anonymised_sample", "basecall"],
    )

    max_length = processed_data["lengths"].max()
    bin_edges = np.logspace(np.log10(10), np.log10(max_length), num=100)
    processed_data["length_bin"] = pd.cut(processed_data["lengths"], bins=bin_edges)

    return processed_data


def calculate_length_distribution(processed_data, basecall_type="sup"):
    length_dist = (
        processed_data[processed_data["basecall"] == basecall_type]
        .groupby(["anonymised_sample", "length_bin", "multiplexing"], observed=False)
        .size()
        .reset_index(name="count")
    )

    length_dist = length_dist.merge(
        processed_data[processed_data["basecall"] == basecall_type][
            ["anonymised_sample", "number_of_reads"]
        ].drop_duplicates(),
        on="anonymised_sample",
    )

    length_dist["percentage"] = (
        length_dist["count"] / length_dist["number_of_reads"] * 100
    )

    max_length = processed_data["lengths"].max()
    bin_edges = np.logspace(np.log10(10), np.log10(max_length), num=100)
    bin_centers = np.sqrt(bin_edges[:-1] * bin_edges[1:])
    bin_centers = pd.Series(bin_centers, index=length_dist["length_bin"].cat.categories)
    length_dist["bin_center"] = length_dist["length_bin"].map(bin_centers)

    return length_dist


def plot_read_length_distribution(length_dist, max_length):
    plt.figure(figsize=(14, 6), dpi=300)

    non_zero_samples = length_dist.groupby("anonymised_sample")["percentage"].sum() > 0
    non_zero_samples = non_zero_samples[non_zero_samples].index

    filtered_length_dist = length_dist[
        length_dist["anonymised_sample"].isin(non_zero_samples)
    ]

    filtered_length_dist["sample_num"] = (
        filtered_length_dist["anonymised_sample"].str.extract(r"(\d+)").astype(int)
    )
    filtered_length_dist = filtered_length_dist.sort_values("sample_num")

    filtered_length_dist = filtered_length_dist.rename(
        columns={"anonymised_sample": "sample ID"}
    )

    sns.lineplot(
        data=filtered_length_dist,
        x="bin_center",
        y="percentage",
        hue="sample ID",
        style="multiplexing",
        hue_order=filtered_length_dist["sample ID"].unique(),
    )

    plt.xscale("log")
    plt.xlabel("Read Length")
    plt.ylabel("Proportion of Reads (%)")
    plt.title("Distribution of Read Lengths")

    tick_positions = np.logspace(np.log10(10), np.log10(max_length), num=20)
    plt.xticks(
        ticks=tick_positions, labels=[f"{int(tick):,}" for tick in tick_positions]
    )

    plt.xlim(left=10, right=max_length)
    plt.tight_layout()
    plt.show()


NANOPLOT_BASE_DIR = "/scratch/prj/ppn_als_longread/qc/nanoplot/aligned_bams/"

nanoplot_data = load_nanoplot_data(NANOPLOT_BASE_DIR, np_metrics_df)
processed_nanoplot_data = process_nanoplot_data(nanoplot_data, np_metrics_df)
length_distribution = calculate_length_distribution(processed_nanoplot_data)
max_read_length = processed_nanoplot_data["lengths"].max()

plot_read_length_distribution(length_distribution, max_read_length)


In [None]:
def format_number(num):
    return f"{num:,.2f}"


def calculate_length_stats(df):
    return {
        "max": df["lengths"].max(),
        "min": df["lengths"].min(),
        "mean": df["lengths"].mean(),
        "std": df["lengths"].std(),
        "median": df["lengths"].median(),
    }


def print_length_stats(stats, sample_type):
    print(f"\n{sample_type} Samples Statistics:")
    print("=" * 40)
    print("\nRead Lengths:")
    for stat, value in stats.items():
        formatted_value = format_number(value)
        print(f"  {stat.capitalize():6s}: {formatted_value}")


def calculate_percentage_increase(singleplex_val, multiplex_val):
    return ((singleplex_val - multiplex_val) / multiplex_val) * 100


singleplexed_reads = processed_nanoplot_data[
    (processed_nanoplot_data["multiplexing"] == "singleplex")
    & (processed_nanoplot_data["basecall"] == "sup")
]
multiplexed_reads = processed_nanoplot_data[
    (processed_nanoplot_data["multiplexing"] == "multiplex")
    & (processed_nanoplot_data["basecall"] == "sup")
]

singleplex_stats = calculate_length_stats(singleplexed_reads)
multiplex_stats = calculate_length_stats(multiplexed_reads)

print_length_stats(singleplex_stats, "Singleplexed")
print_length_stats(multiplex_stats, "Multiplexed")

print("\nPercentage Increase (Singleplexed vs Multiplexed):")
print("=" * 40)
for stat in ["mean", "median"]:
    increase = calculate_percentage_increase(
        singleplex_stats[stat], multiplex_stats[stat]
    )
    print(f"{stat.capitalize():6s} Read Length: {increase:6.2f}%")


### 3. Combined Plots


In [None]:
def create_combined_yield_plot(metrics_df, read_length_df, length_distribution):
    fig = plt.figure(figsize=(12, 12), dpi=300)
    gs = fig.add_gridspec(3, 2, height_ratios=[1, 1, 1.2])

    def adjust_tick_labels(ax):
        for tick in ax.get_xticklabels():
            tick.set_rotation(45)
            tick.set_ha("right")
        locs, labels = ax.get_xticks(), ax.get_xticklabels()
        ax.set_xticks([loc + 0.2 for loc in locs])

    sample_order = sorted(
        metrics_df["anonymised_sample"].unique(), key=lambda x: int(x.split()[-1])
    )

    # A: Reads per Sample
    ax_a = fig.add_subplot(gs[0, 0])
    create_yield_plot(
        ax_a,
        metrics_df[metrics_df["basecall"] == "sup"],
        "anonymised_sample",
        "number_of_reads",
        "multiplexing",
        "Reads per Sample",
        "Sample",
        "Number of Reads",
    )
    ax_a.text(
        -0.10,
        1.07,
        "A",
        transform=ax_a.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )
    adjust_tick_labels(ax_a)

    # B: Bases per Sample
    ax_b = fig.add_subplot(gs[0, 1])
    create_yield_plot(
        ax_b,
        metrics_df[metrics_df["basecall"] == "sup"],
        "anonymised_sample",
        "number_of_bases",
        "multiplexing",
        "Bases per Sample",
        "Sample",
        "Number of Bases",
    )
    ax_b.text(
        -0.10,
        1.07,
        "B",
        transform=ax_b.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )
    adjust_tick_labels(ax_b)
    ax_b.get_legend().remove()

    # C: Mean Read Length
    ax_c = fig.add_subplot(gs[1, 0])
    mean_length_data = read_length_df[
        read_length_df["read_length_type"] == "Mean Read Length"
    ]
    create_length_subplot(mean_length_data, ax_c, "Mean Read Length", "multiplexing")
    ax_c.text(
        -0.10,
        1.07,
        "C",
        transform=ax_c.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )
    adjust_tick_labels(ax_c)
    ax_c.get_legend().remove()

    # D: Median Read Length
    ax_d = fig.add_subplot(gs[1, 1])
    median_length_data = read_length_df[
        read_length_df["read_length_type"] == "Median Read Length"
    ]
    create_length_subplot(
        median_length_data, ax_d, "Median Read Length", "multiplexing"
    )
    ax_d.text(
        -0.10,
        1.07,
        "D",
        transform=ax_d.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )
    adjust_tick_labels(ax_d)
    ax_d.get_legend().remove()

    # E: Distribution of Read Lengths
    ax_e = fig.add_subplot(gs[2, :])

    non_zero_samples = (
        length_distribution.groupby("anonymised_sample")["percentage"].sum() > 0
    )
    non_zero_samples = non_zero_samples[non_zero_samples].index

    filtered_length_dist = length_distribution[
        length_distribution["anonymised_sample"].isin(non_zero_samples)
    ]

    filtered_length_dist["sample_num"] = (
        filtered_length_dist["anonymised_sample"].str.split().str[-1].astype(int)
    )
    filtered_length_dist = filtered_length_dist.sort_values("sample_num")

    filtered_length_dist = filtered_length_dist.rename(
        columns={"anonymised_sample": "Sample ID"}
    )

    sns.lineplot(
        data=filtered_length_dist,
        x="bin_center",
        y="percentage",
        hue="Sample ID",
        style="multiplexing",
        ax=ax_e,
        hue_order=filtered_length_dist["Sample ID"].unique(),
    )

    ax_e.set_xscale("log")
    ax_e.set_xlabel("Read Length")
    ax_e.set_ylabel("Proportion of Reads (%)")
    ax_e.set_title("Distribution of Read Lengths")
    ax_e.text(
        -0.05,
        1.07,
        "E",
        transform=ax_e.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )

    max_length = filtered_length_dist["bin_center"].max()

    tick_positions = np.logspace(np.log10(10), np.log10(max_length), num=19)
    ax_e.set_xticks(tick_positions)
    ax_e.set_xticklabels([f"{int(tick):,}" for tick in tick_positions])
    ax_e.set_xlim(left=10, right=max_length)
    ax_e.set_ylim(bottom=0)

    plt.setp(ax_e.get_xticklabels(), ha="center")

    plt.tight_layout()
    plt.show()


create_combined_yield_plot(np_metrics_df, read_length_df, length_distribution)


## Read Quality

### 1. Basecalling Quality


In [None]:
def calculate_base_quality_distribution(processed_data, basecall_type="sup"):
    qual_data = processed_data[processed_data["basecall"] == basecall_type].copy()

    qual_data["quals_bin"] = pd.cut(
        qual_data["quals"],
        bins=pd.interval_range(
            start=qual_data["quals"].min(), end=qual_data["quals"].max(), freq=0.5
        ),
    )

    qual_data["sample_num"] = (
        qual_data["anonymised_sample"].str.extract(r"(\d+)").astype(int)
    )
    qual_data = qual_data.sort_values("sample_num")

    quality_dist = (
        qual_data.groupby(
            ["anonymised_sample", "quals_bin", "multiplexing"], observed=True
        )
        .size()
        .reset_index(name="count")
    )

    total_reads = quality_dist.groupby(
        ["anonymised_sample", "multiplexing"], observed=True
    )["count"].transform("sum")

    quality_dist["percentage"] = (quality_dist["count"] / total_reads) * 100

    quality_dist["quals_bin_mid"] = quality_dist["quals_bin"].apply(lambda x: x.mid)

    quality_dist["sample_num"] = (
        quality_dist["anonymised_sample"].str.extract(r"(\d+)").astype(int)
    )
    quality_dist = quality_dist.sort_values("sample_num")

    quality_dist = quality_dist.rename(columns={"anonymised_sample": "Sample ID"})

    return quality_dist


def plot_base_quality_distribution(quality_dist):
    plt.figure(figsize=(14, 6), dpi=300)

    sample_order = quality_dist["Sample ID"].unique()

    sns.lineplot(
        data=quality_dist,
        x="quals_bin_mid",
        y="percentage",
        hue="Sample ID",
        style="multiplexing",
        hue_order=sample_order,
    )

    plt.xlabel("Quality Score")
    plt.ylabel("Percentage of Total Reads")
    plt.title("Distribution of Read Quality Scores")

    max_qual = int(quality_dist["quals_bin_mid"].max())
    tick_positions = np.arange(0, max_qual + 1, 5)
    plt.xticks(ticks=tick_positions, labels=tick_positions)

    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.show()


quality_distribution = calculate_base_quality_distribution(processed_nanoplot_data)
plot_base_quality_distribution(quality_distribution)


In [None]:
def prepare_qscore_data(metrics_df):
    qscore_df = pd.melt(
        metrics_df[metrics_df["basecall"] == "sup"],
        id_vars=["anonymised_sample", "multiplexing"],
        value_vars=[
            "Reads >Q5_percentage",
            "Reads >Q7_percentage",
            "Reads >Q10_percentage",
            "Reads >Q12_percentage",
            "Reads >Q15_percentage",
        ],
        var_name="Quality_Score",
        value_name="Percentage",
    )

    qscore_df["Quality_Score"] = (
        qscore_df["Quality_Score"]
        .str.replace("Reads >", "")
        .str.replace("_percentage", "")
    )

    qscore_df = qscore_df.rename(columns={"anonymised_sample": "Sample ID"})

    return qscore_df


def plot_qscore_distribution(qscore_df, figsize=(20, 6), dpi=300):
    qscore_df["sample_num"] = qscore_df["Sample ID"].str.extract(r"(\d+)").astype(int)
    qscore_df = qscore_df.sort_values("sample_num")

    quality_score_order = ["Q5", "Q7", "Q10", "Q12", "Q15"]
    qscore_df["Quality_Score"] = pd.Categorical(
        qscore_df["Quality_Score"], categories=quality_score_order, ordered=True
    )

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, dpi=dpi)

    for ax, multiplex_type in zip([ax1, ax2], ["multiplex", "singleplex"]):
        data = qscore_df[qscore_df["multiplexing"] == multiplex_type]

        if not data.empty:
            sns.barplot(
                data=data,
                x="Sample ID",
                y="Percentage",
                hue="Quality_Score",
                errorbar=None,
                ax=ax,
                order=data["Sample ID"].unique(),
                hue_order=quality_score_order,
            )

            ax.set_xlabel("Sample ID")
            ax.set_ylabel("Percentage of Reads")
            ax.set_title(
                f"Percentage of Reads Above Quality Scores\n{multiplex_type.capitalize()} Samples"
            )
            ax.legend(title="Quality Score", loc="lower right")

            for tick in ax.get_xticklabels():
                tick.set_rotation(45)
                tick.set_ha("right")

            locs, labels = ax.get_xticks(), ax.get_xticklabels()
            ax.set_xticks([loc + 0.2 for loc in locs])
        else:
            ax.text(
                0.5,
                0.5,
                f"No {multiplex_type} samples",
                ha="center",
                va="center",
                transform=ax.transAxes,
            )
            ax.set_axis_off()

    plt.tight_layout()
    plt.show()


qscore_df = prepare_qscore_data(np_metrics_df)
plot_qscore_distribution(qscore_df)


In [None]:
def calculate_base_quality_stats(df):
    return {
        "max": df["quals"].max(),
        "min": df["quals"].min(),
        "mean": df["quals"].mean(),
        "std": df["quals"].std(),
        "median": df["quals"].median(),
    }


def print_base_quality_stats(stats, sample_type):
    print(f"\n{sample_type} Samples Quality Statistics:")
    print("=" * 40)
    print("\nBase Qualities:")
    for stat, value in stats.items():
        formatted_value = format_number(value)
        print(f"  {stat.capitalize():6s}: {formatted_value}")


def calculate_percentage_increase(singleplex_val, multiplex_val):
    return ((singleplex_val - multiplex_val) / multiplex_val) * 100


singleplexed_quals = processed_nanoplot_data[
    (processed_nanoplot_data["multiplexing"] == "singleplex")
    & (processed_nanoplot_data["basecall"] == "sup")
]
multiplexed_quals = processed_nanoplot_data[
    (processed_nanoplot_data["multiplexing"] == "multiplex")
    & (processed_nanoplot_data["basecall"] == "sup")
]

singleplex_quality_stats = calculate_base_quality_stats(singleplexed_quals)
multiplex_quality_stats = calculate_base_quality_stats(multiplexed_quals)

print_base_quality_stats(singleplex_quality_stats, "Singleplexed")
print_base_quality_stats(multiplex_quality_stats, "Multiplexed")

print("\nPercentage Increase (Singleplexed vs Multiplexed):")
print("=" * 40)
for stat in ["mean", "median"]:
    increase = calculate_percentage_increase(
        singleplex_quality_stats[stat], multiplex_quality_stats[stat]
    )
    print(f"{stat.capitalize():6s} Base Quality: {increase:6.2f}%")


### 2. Mapping Quality


In [None]:
def calculate_mapping_quality_distribution(processed_data, basecall_type="sup"):
    mapQ_data = processed_data[processed_data["basecall"] == basecall_type].copy()

    mapQ_data["mapQ_bin"] = pd.cut(
        mapQ_data["mapQ"],
        bins=pd.interval_range(
            start=mapQ_data["mapQ"].min(), end=mapQ_data["mapQ"].max(), freq=0.5
        ),
    )

    mapQ_data["sample_num"] = (
        mapQ_data["anonymised_sample"].str.extract(r"(\d+)").astype(int)
    )
    mapQ_data = mapQ_data.sort_values("sample_num")

    mapping_quality_dist = (
        mapQ_data.groupby(
            ["anonymised_sample", "mapQ_bin", "multiplexing"], observed=True
        )
        .size()
        .reset_index(name="count")
    )

    total_counts = mapping_quality_dist.groupby(
        ["anonymised_sample", "multiplexing"], observed=True
    )["count"].transform("sum")

    mapping_quality_dist["percentage"] = (
        mapping_quality_dist["count"] / total_counts * 100
    )

    mapping_quality_dist["mapQ_bin_mid"] = mapping_quality_dist["mapQ_bin"].apply(
        lambda x: x.mid
    )

    mapping_quality_dist["sample_num"] = (
        mapping_quality_dist["anonymised_sample"].str.extract(r"(\d+)").astype(int)
    )
    mapping_quality_dist = mapping_quality_dist.sort_values("sample_num")

    mapping_quality_dist = mapping_quality_dist.rename(
        columns={"anonymised_sample": "Sample ID"}
    )

    return mapping_quality_dist


def plot_mapping_quality_distribution(mapping_quality_dist):
    plt.figure(figsize=(14, 6), dpi=300)

    sample_order = mapping_quality_dist["Sample ID"].unique()

    sns.lineplot(
        data=mapping_quality_dist,
        x="mapQ_bin_mid",
        y="percentage",
        hue="Sample ID",
        style="multiplexing",
        hue_order=sample_order,
        legend="full",
    )

    plt.xlabel("Mapping Quality Score")
    plt.ylabel("Percentage of Total Reads")
    plt.title("Distribution of Mapping Quality Scores")

    max_mapQ = int(mapping_quality_dist["mapQ_bin_mid"].max())
    tick_positions = np.arange(0, max_mapQ + 1, 5)
    plt.xticks(ticks=tick_positions, labels=tick_positions)

    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.show()


mapping_quality_distribution = calculate_mapping_quality_distribution(
    processed_nanoplot_data
)

plot_mapping_quality_distribution(mapping_quality_distribution)


In [None]:
def calculate_mapQ_stats(df):
    return {
        "max": df["mapQ"].max(),
        "min": df["mapQ"].min(),
        "mean": df["mapQ"].mean(),
        "std": df["mapQ"].std(),
        "median": df["mapQ"].median(),
    }


def print_mapQ_stats(stats, sample_type):
    print(f"\n{sample_type} Samples Statistics:")
    print("=" * 40)
    print("\nMapping Quality:")
    for stat, value in stats.items():
        formatted_value = format_number(value)
        print(f"  {stat.capitalize():6s}: {formatted_value}")


def calculate_percentage_increase(singleplex_val, multiplex_val):
    return ((singleplex_val - multiplex_val) / multiplex_val) * 100


singleplex_mapQ_stats = calculate_mapQ_stats(singleplexed_quals)
multiplex_mapQ_stats = calculate_mapQ_stats(multiplexed_quals)

print_mapQ_stats(singleplex_mapQ_stats, "Singleplexed")
print_mapQ_stats(multiplex_mapQ_stats, "Multiplexed")

print("\nPercentage Increase (Singleplexed vs Multiplexed):")
print("=" * 40)
for stat in ["mean", "median"]:
    increase = calculate_percentage_increase(
        singleplex_mapQ_stats[stat], multiplex_mapQ_stats[stat]
    )
    print(f"{stat.capitalize():6s} Mapping Quality: {increase:6.2f}%")


### 3. Combined Plots


In [None]:
def create_combined_sequencing_metrics_plot(
    quality_distribution, mapping_quality_distribution, qscore_df
):
    fig = plt.figure(figsize=(12, 8), dpi=300)
    gs = fig.add_gridspec(2, 2)

    def adjust_tick_labels(ax):
        for tick in ax.get_xticklabels():
            tick.set_rotation(45)
            tick.set_ha("right")
        locs, labels = ax.get_xticks(), ax.get_xticklabels()
        ax.set_xticks([loc + 0.2 for loc in locs])

    # A: Base Quality Distribution
    ax_a = fig.add_subplot(gs[0, 0])
    sns.lineplot(
        data=quality_distribution,
        x="quals_bin_mid",
        y="percentage",
        hue="Sample ID",
        style="multiplexing",
        ax=ax_a,
        legend=False,
    )
    ax_a.set_xlabel("Quality Score")
    ax_a.set_ylabel("Percentage of Total Reads")
    ax_a.set_title("Distribution of Read Quality Scores")
    max_qual = int(quality_distribution["quals_bin_mid"].max())
    tick_positions = np.arange(0, max_qual + 1, 5)
    ax_a.set_xticks(ticks=tick_positions, labels=tick_positions)
    ax_a.text(
        -0.1,
        1.05,
        "A",
        transform=ax_a.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )

    # B: Mapping Quality Distribution
    ax_b = fig.add_subplot(gs[0, 1])
    sns.lineplot(
        data=mapping_quality_distribution,
        x="mapQ_bin_mid",
        y="percentage",
        hue="Sample ID",
        style="multiplexing",
        ax=ax_b,
    )
    ax_b.set_xlabel("Mapping Quality Score")
    ax_b.set_ylabel("Percentage of Total Reads")
    ax_b.set_title("Distribution of Mapping Quality Scores")
    max_mapQ = int(mapping_quality_distribution["mapQ_bin_mid"].max())
    tick_positions = np.arange(0, max_mapQ + 1, 5)
    ax_b.set_xticks(ticks=tick_positions, labels=tick_positions)
    ax_b.text(
        -0.1,
        1.05,
        "B",
        transform=ax_b.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )
    ax_b.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

    # C and D: Q-Score Distribution
    ax_c = fig.add_subplot(gs[1, 0])
    ax_d = fig.add_subplot(gs[1, 1])
    for ax, multiplex_type, letter in zip(
        [ax_c, ax_d], ["multiplex", "singleplex"], ["C", "D"]
    ):
        data = qscore_df[qscore_df["multiplexing"] == multiplex_type]
        if not data.empty:
            sns.barplot(
                data=data,
                x="Sample ID",
                y="Percentage",
                hue="Quality_Score",
                errorbar=None,
                ax=ax,
                order=data["Sample ID"].unique(),
            )
            ax.set_xlabel("Sample ID")
            ax.set_ylabel("Percentage of Reads")
            ax.set_title(
                f"Percentage of Reads Above Quality Scores\n{multiplex_type.capitalize()} Samples"
            )
            adjust_tick_labels(ax)
            ax.text(
                -0.1,
                1.05,
                letter,
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold",
                va="top",
            )
        else:
            ax.text(
                0.5,
                0.5,
                f"No {multiplex_type} samples",
                ha="center",
                va="center",
                transform=ax.transAxes,
            )
            ax.set_axis_off()

    if not ax_c.get_legend() is None:
        ax_c.get_legend().remove()
    if not ax_d.get_legend() is None:
        ax_d.legend(title="Quality Score", bbox_to_anchor=(1.05, 1), loc="upper left")

    plt.tight_layout()
    plt.show()


create_combined_sequencing_metrics_plot(
    quality_distribution, mapping_quality_distribution, qscore_df
)


## Sequencing Depth

### 1. Depth per Chromosome


In [None]:
def process_mosdepth_file(file_path, suffix):
    sample_name = os.path.basename(file_path).split(".")[0].replace(suffix, "")
    df = pd.read_csv(file_path, sep="\t")
    chromosomes = [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]
    df = df[df["chrom"].isin(chromosomes + ["total"])]
    df = df[~df["chrom"].str.endswith("_region")]
    df = df[["chrom", "mean"]]
    df["sample"] = sample_name
    return df


def process_per_base_file(file_path, suffix):
    sample_name = os.path.basename(file_path).split(".")[0].replace(suffix, "")
    df = pd.read_csv(
        file_path, sep="\t", header=None, names=["chrom", "start", "end", "depth"]
    )
    df["chrom"] = df["chrom"].astype(str)
    chromosomes = [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]
    df = df[df["chrom"].isin(chromosomes)]
    df_grouped = df.groupby("chrom")["depth"]
    df = df_grouped.agg(
        mean="mean", sem=lambda x: np.std(x, ddof=1) / np.sqrt(len(x))
    ).reset_index()
    df["sample"] = sample_name
    return df


def analyze_mosdepth_data(np_metrics_df, suffix="_sup"):
    path_to_mosdepth_summary = f"/scratch/prj/ppn_als_longread/qc/mosdepth/*{suffix}/*{suffix}.mosdepth.summary.txt"
    mosdepth_summary_files = glob.glob(path_to_mosdepth_summary)

    path_to_per_base = (
        f"/scratch/prj/ppn_als_longread/qc/mosdepth/*{suffix}/*{suffix}.per-base.bed.gz"
    )
    per_base_files = glob.glob(path_to_per_base)

    all_dfs = [process_mosdepth_file(file, suffix) for file in mosdepth_summary_files]
    depth_df = pd.concat(all_dfs)

    all_per_base_dfs = [process_per_base_file(file, suffix) for file in per_base_files]
    per_base_df = pd.concat(all_per_base_dfs)

    depth_df = depth_df.merge(
        per_base_df.rename(columns={"mean": "per_base_mean", "sem": "per_base_sem"}),
        on=["chrom", "sample"],
        how="left",
    )

    total_depth_df = (
        depth_df[depth_df["chrom"] == "total"].copy().drop_duplicates(subset="sample")
    )
    total_depth_df = total_depth_df[["sample", "mean"]]

    total_depth_df = total_depth_df.rename(columns={"mean": "mean_depth"}).merge(
        np_metrics_df[["sample", "multiplexing", "anonymised_sample"]], on="sample"
    )

    total_depth_df = total_depth_df.sort_values(by=["multiplexing", "sample"])

    depth_df = depth_df[depth_df["chrom"] != "total"]

    chromosome_order = [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]
    depth_df["chrom"] = pd.Categorical(
        depth_df["chrom"], categories=chromosome_order, ordered=True
    )

    depth_df = depth_df.merge(
        np_metrics_df[["sample", "multiplexing", "anonymised_sample"]],
        on="sample",
        how="left",
    )

    depth_df["sample_num"] = (
        depth_df["anonymised_sample"].str.extract(r"(\d+)").astype(int)
    )
    depth_df = depth_df.sort_values(by="sample_num")
    depth_df = depth_df.rename(columns={"anonymised_sample": "Sample ID"})

    depth_df = depth_df.drop_duplicates(subset=["chrom", "Sample ID"]).sort_values(
        by=["Sample ID", "multiplexing"]
    )

    wg_depth_df = depth_df[depth_df["chrom"].isin(chromosome_order)]

    return wg_depth_df, total_depth_df


def plot_mean_depth_per_chromosome_with_sem(wg_depth_df):
    plt.figure(figsize=(14, 6), dpi=300)

    wg_depth_df["sample_num"] = (
        wg_depth_df["Sample ID"].str.extract(r"(\d+)").astype(int)
    )
    wg_depth_df = wg_depth_df.sort_values(["sample_num", "chrom"])

    unique_samples = wg_depth_df["Sample ID"].unique()
    color_palette = sns.color_palette("husl", n_colors=len(unique_samples))
    color_dict = dict(zip(unique_samples, color_palette))

    line_plot = sns.lineplot(
        data=wg_depth_df,
        x="chrom",
        y="mean",
        hue="Sample ID",
        style="multiplexing",
        palette=color_dict,
        legend="full",
        hue_order=unique_samples,
    )

    for sample_id in unique_samples:
        sample_df = wg_depth_df[wg_depth_df["Sample ID"] == sample_id].sort_values(
            "chrom"
        )
        color = color_dict[sample_id]
        plt.fill_between(
            sample_df["chrom"],
            sample_df["mean"] - sample_df["per_base_sem"],
            sample_df["mean"] + sample_df["per_base_sem"],
            alpha=0.2,
            color=color,
        )

    plt.title("Mean Depth per Chromosome (with SEM)")
    plt.xlabel("Chromosome")
    plt.ylabel("Mean Depth")

    locs, labels = plt.xticks()
    new_locs = [loc + 0.01 for loc in locs]

    plt.xticks(new_locs, labels, rotation=45, ha="right")
    plt.grid(axis="y", linestyle="--", alpha=0.7)
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.show()


wg_depth_df, total_depth_df = analyze_mosdepth_data(np_metrics_df)

plot_mean_depth_per_chromosome_with_sem(wg_depth_df)


### 2. Mean Whole Genome Depth


In [None]:
def plot_mean_whole_genome_depth(total_depth_df):
    total_depth_df["sample_num"] = (
        total_depth_df["anonymised_sample"].str.extract(r"(\d+)").astype(int)
    )

    total_depth_df = total_depth_df.sort_values("sample_num")

    plt.figure(figsize=(14, 6), dpi=300)
    sns.barplot(
        data=total_depth_df,
        x="anonymised_sample",
        y="mean_depth",
        hue="multiplexing",
        dodge=False,
        order=total_depth_df["anonymised_sample"],
    )

    plt.title("Mean Whole Genome Depth per Sample")
    plt.xlabel("Sample ID")
    plt.ylabel("Depth")
    locs, labels = plt.xticks()
    plt.xticks([loc - 0.2 for loc in locs], labels, rotation=45, ha="right")
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.show()


plot_mean_whole_genome_depth(total_depth_df)


In [None]:
def calculate_depth_stats(df, column_name):
    return {
        "max": df[column_name].max(),
        "min": df[column_name].min(),
        "mean": df[column_name].mean(),
        "std": df[column_name].std(),
        "median": df[column_name].median(),
    }


def print_depth_stats(stats, sample_type):
    print(f"\n{sample_type} Samples Statistics:")
    print("=" * 40)
    print("\nDepth:")
    for stat, value in stats.items():
        formatted_value = f"{value:.2f}"
        print(f"  {stat.capitalize():6s}: {formatted_value}")


def calculate_percentage_increase(singleplex_val, multiplex_val):
    return ((singleplex_val - multiplex_val) / multiplex_val) * 100


# Calculate stats for per-chromosome depth
singleplexed_depth = wg_depth_df[wg_depth_df["multiplexing"] == "singleplex"]
multiplexed_depth = wg_depth_df[wg_depth_df["multiplexing"] == "multiplex"]

singleplex_depth_stats = calculate_depth_stats(singleplexed_depth, "mean")
multiplex_depth_stats = calculate_depth_stats(multiplexed_depth, "mean")

print("\nPer-Chromosome Depth Statistics:")
print_depth_stats(singleplex_depth_stats, "Singleplexed")
print_depth_stats(multiplex_depth_stats, "Multiplexed")

print("\nPercentage Increase (Singleplexed vs Multiplexed):")
print("=" * 40)
for stat in ["mean", "median"]:
    increase = calculate_percentage_increase(
        singleplex_depth_stats[stat], multiplex_depth_stats[stat]
    )
    print(f"{stat.capitalize():6s} Depth: {increase:6.2f}%")

# Calculate stats for whole genome depth
singleplexed_wg = total_depth_df[total_depth_df["multiplexing"] == "singleplex"]
multiplexed_wg = total_depth_df[total_depth_df["multiplexing"] == "multiplex"]

singleplex_wg_stats = calculate_depth_stats(singleplexed_wg, "mean_depth")
multiplex_wg_stats = calculate_depth_stats(multiplexed_wg, "mean_depth")

print("\nWhole Genome Depth Statistics:")
print_depth_stats(singleplex_wg_stats, "Singleplexed")
print_depth_stats(multiplex_wg_stats, "Multiplexed")

print("\nPercentage Increase (Singleplexed vs Multiplexed):")
print("=" * 40)
for stat in ["mean", "median"]:
    increase = calculate_percentage_increase(
        singleplex_wg_stats[stat], multiplex_wg_stats[stat]
    )
    print(f"{stat.capitalize():6s} Depth: {increase:6.2f}%")


### 3. Flowcell Quality


In [None]:
def read_seq_stats(file_path):
    df = pd.read_csv(file_path)
    df["multiplexing"] = df["flowcell_id"].apply(
        lambda x: "multiplex" if "__" in str(x) else "singleplex"
    )
    df["flowcell_id"] = df["flowcell_id"].astype(str)

    df = df.sort_values(["multiplexing", "flowcell_id"])

    multiplex_count = 1
    singleplex_count = 1
    new_names = []

    for _, row in df.iterrows():
        if row["multiplexing"] == "multiplex":
            new_names.append(f"Multiplex Flowcell {multiplex_count}")
            multiplex_count += 1
        else:
            new_names.append(f"Singleplex Flowcell {singleplex_count}")
            singleplex_count += 1

    df["new_flowcell_name"] = new_names
    return df


def plot_pores_at_start(df):
    df["multiplexing"] = df["multiplexing"].astype("category")

    plt.figure(figsize=(14, 6), dpi=300)
    sns.lineplot(
        data=df,
        x="new_flowcell_name",
        y="number_pores_start",
        hue="multiplexing",
        marker="o",
        style="multiplexing",
    )

    plt.title("Number of Pores Available at the Start of Sequencing")
    plt.xlabel("Flowcell ID")
    plt.ylabel("Number of Pores")
    plt.xticks(rotation=45, ha="right")
    plt.legend(title="Multiplexing", bbox_to_anchor=(1.05, 1), loc="upper left")

    plt.ylim(bottom=0)
    plt.tight_layout()
    plt.show()


SEQ_STATS_PATH = "/scratch/prj/ppn_als_longread/seq_stats.csv"
seq_stats_df = read_seq_stats(SEQ_STATS_PATH)

plot_pores_at_start(seq_stats_df)


In [None]:
def calculate_pore_stats(df, multiplexing_type):
    subset = df[df["multiplexing"] == multiplexing_type]
    return {
        "max": subset["number_pores_start"].max(),
        "min": subset["number_pores_start"].min(),
        "mean": subset["number_pores_start"].mean(),
        "std": subset["number_pores_start"].std(),
        "median": subset["number_pores_start"].median(),
    }


def print_pore_stats(stats, sample_type):
    print(f"\n{sample_type} Flowcells Statistics:")
    print("=" * 40)
    print("\nNumber of Pores at Start:")
    for stat, value in stats.items():
        formatted_value = f"{value:.2f}" if isinstance(value, float) else f"{value}"
        print(f"  {stat.capitalize():6s}: {formatted_value}")


def calculate_percentage_increase(singleplex_val, multiplex_val):
    return ((singleplex_val - multiplex_val) / multiplex_val) * 100


singleplex_pore_stats = calculate_pore_stats(seq_stats_df, "singleplex")
multiplex_pore_stats = calculate_pore_stats(seq_stats_df, "multiplex")

print_pore_stats(singleplex_pore_stats, "Singleplexed")
print_pore_stats(multiplex_pore_stats, "Multiplexed")

print("\nPercentage Increase (Singleplexed vs Multiplexed):")
print("=" * 40)
for stat in ["mean", "median"]:
    increase = calculate_percentage_increase(
        singleplex_pore_stats[stat], multiplex_pore_stats[stat]
    )
    print(f"{stat.capitalize():6s} Number of Pores: {increase:6.2f}%")


### 4. Relation between Flowcell Quality and Mean Whole Genome Depth


In [None]:
def parse_seq_stats_data(seq_stats_df, total_depth_df):
    sample_to_flowcell = {}
    sample_to_multiplexing = {}

    for _, row in seq_stats_df.iterrows():
        flowcell = row["flow_cell_id"]
        multiplexing = row["multiplexing"]

        if multiplexing == "multiplex":
            samples = row["flowcell_id"].split("__")
        else:
            samples = [row["flowcell_id"]]

        for sample in samples:
            sample_to_flowcell[sample] = flowcell
            sample_to_multiplexing[sample] = multiplexing

    total_depth_df["flowcell_id"] = total_depth_df["sample"].map(sample_to_flowcell)
    total_depth_df["multiplexing"] = total_depth_df["sample"].map(
        sample_to_multiplexing
    )

    total_depth_df["flowcell_id"] = total_depth_df["flowcell_id"].astype(str)

    grouped_depth_df = (
        total_depth_df.groupby("flowcell_id")
        .agg({"mean_depth": "sum", "multiplexing": "first"})
        .reset_index()
    )

    grouped_depth_df = grouped_depth_df.rename(
        columns={"flowcell_id": "flow_cell_id", "mean_depth": "total_mean_depth"}
    )

    merged_df = pd.merge(
        seq_stats_df,
        grouped_depth_df,
        on="flow_cell_id",
        how="inner",
        suffixes=(None, "_y"),
    )

    merged_df = merged_df.drop(columns=["multiplexing_y"])

    return merged_df


def plot_pores_vs_depth(grouped_seq_stats_df):
    palette = sns.color_palette("colorblind")
    main_color = palette[0]

    plt.figure(figsize=(12, 8), dpi=300)
    sns.scatterplot(
        data=grouped_seq_stats_df,
        x="number_pores_start",
        y="total_mean_depth",
        hue="multiplexing",
        s=100,
    )

    x = grouped_seq_stats_df["number_pores_start"]
    y = grouped_seq_stats_df["total_mean_depth"]

    slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)

    x_line = np.linspace(x.min(), x.max(), 100)
    y_line = slope * x_line + intercept

    plt.plot(
        x_line,
        y_line,
        color=main_color,
        linestyle="-",
        linewidth=2,
        label="line of best fit",
    )

    n = len(x)
    y_pred = slope * x + intercept
    s_err = np.sqrt(np.sum((y - y_pred) ** 2) / (n - 2))
    t = stats.t.ppf(0.975, n - 2)
    ci = (
        t
        * s_err
        * np.sqrt(1 / n + (x_line - np.mean(x)) ** 2 / np.sum((x - np.mean(x)) ** 2))
    )

    plt.fill_between(
        x_line,
        y_line - ci,
        y_line + ci,
        color=main_color,
        alpha=0.2,
        label="95% Confidence Interval",
    )

    plt.title("Number of Pores at Start vs Total Mean Whole Genome Depth")
    plt.xlabel("Number of Pores at Start")
    plt.ylabel("Mean Whole Genome Depth")
    plt.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left")

    plt.tight_layout()
    plt.show()

    return slope, intercept, r_value, p_value


SEQ_STATS_PATH = "/scratch/prj/ppn_als_longread/seq_stats.csv"
seq_stats_df = read_seq_stats(SEQ_STATS_PATH)

grouped_seq_stats_df = parse_seq_stats_data(seq_stats_df, total_depth_df)
slope, intercept, r_value, p_value = plot_pores_vs_depth(grouped_seq_stats_df)

print(f"Slope: {slope:.4f}")
print(f"Intercept: {intercept:.4f}")
print(f"R-value: {r_value:.4f}")
print(f"Linear regression p-value: {p_value:.4e}")


### 5. Barcoding Quality


In [None]:
def get_sample_barcode_mapping():
    return {
        "A046_12": "barcode01",
        "A079_07": "barcode02",
        "A081_91": "barcode03",
        "A048_09": "barcode04",
        "A097_92": "barcode05",
        "A085_00": "barcode06",
    }


def parse_nanostats_barcoded(file_path):
    metrics = {}
    with open(file_path, "r") as f:
        header = f.readline().strip().split("\t")
        values = f.readline().strip().split("\t")

        for barcode, value in zip(
            header[1:], values[1:]
        ):  # Skip the first column (Metrics)
            if barcode == "unclassified" or barcode.startswith("barcode"):
                metrics[barcode] = int(value)

    return metrics


def plot_multiplexed_flowcell_reads(seq_summaries_dir, ax=None):
    sample_barcode_mapping = get_sample_barcode_mapping()
    barcode_sample_mapping = {v: k for k, v in sample_barcode_mapping.items()}

    flowcell_samples = {
        "A046_12__A079_07": ["A046_12", "A079_07"],
        "A081_91__A048_09": ["A081_91", "A048_09"],
        "A097_92__A085_00": ["A097_92", "A085_00"],
    }

    flowcell_rename = {
        name: f"Multiplex Flowcell {i+1}"
        for i, name in enumerate(flowcell_samples.keys())
    }

    sample_rename = {
        sample: f"Sample {i+1 if i != 1 and i != 2 else 3 if i == 1 else 2}"
        for i, sample in enumerate(sorted(set(sum(flowcell_samples.values(), []))))
    }
    sample_rename["Unclassified"] = "Unclassified"

    data = []
    flowcell_total_reads = {}
    flowcell_unclassified_reads = {}

    for subdir in os.listdir(seq_summaries_dir):
        if "__" in subdir:
            flowcell_name = subdir
            nanostats_path = os.path.join(
                seq_summaries_dir, subdir, "NanoStats_barcoded.txt"
            )
            if os.path.exists(nanostats_path):
                metrics = parse_nanostats_barcoded(nanostats_path)
                flowcell_total_reads[flowcell_name] = 0
                flowcell_unclassified_reads[flowcell_name] = 0

                for barcode, read_count in metrics.items():
                    flowcell_total_reads[flowcell_name] += read_count
                    if barcode in barcode_sample_mapping:
                        sample = barcode_sample_mapping[barcode]
                        if sample in flowcell_samples[flowcell_name]:
                            data.append(
                                {
                                    "Flowcell": flowcell_rename[flowcell_name],
                                    "Sample": sample_rename[sample],
                                    "Read Count": read_count,
                                }
                            )
                    elif barcode == "unclassified":
                        flowcell_unclassified_reads[flowcell_name] = read_count
                        data.append(
                            {
                                "Flowcell": flowcell_rename[flowcell_name],
                                "Sample": "Unclassified",
                                "Read Count": read_count,
                            }
                        )

    df = pd.DataFrame(data)
    df["Flowcell"] = pd.Categorical(
        df["Flowcell"], categories=list(flowcell_rename.values()), ordered=True
    )
    df["Sample"] = pd.Categorical(
        df["Sample"],
        categories=[f"Sample {i+1}" for i in range(len(sample_rename) - 1)]
        + ["Unclassified"],
        ordered=True,
    )
    df = df.sort_values(["Flowcell", "Sample"])

    cv_per_flowcell = {}
    for flowcell in df["Flowcell"].unique():
        flowcell_data = df[
            (df["Flowcell"] == flowcell) & (df["Sample"] != "Unclassified")
        ]
        read_counts = flowcell_data["Read Count"].values
        cv = np.std(read_counts) / np.mean(read_counts) * 100
        cv_per_flowcell[flowcell] = cv

    all_read_counts = df[df["Sample"] != "Unclassified"]["Read Count"].values
    overall_cv = np.std(all_read_counts) / np.mean(all_read_counts) * 100

    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 6), dpi=300)
        standalone = True
    else:
        standalone = False

    color = sns.color_palette("colorblind")[0]

    flowcells = df["Flowcell"].unique()
    width = 0.25
    x = range(len(flowcells) * 3)

    bars = ax.bar(x, df["Read Count"], width, color=color)

    ax.set_ylabel("Number of reads")
    ax.set_title("Number of barcoded reads")

    ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)

    ax.xaxis.grid(False)

    for i, bar in enumerate(bars):
        sample = df.iloc[i]["Sample"]
        ax.text(
            bar.get_x() + bar.get_width() / 2,
            -2e6,
            sample,
            ha="center",
            va="bottom",
            rotation=45,
        )

    for i, flowcell in enumerate(flowcells):
        ax.text(
            i * 3 + 1,
            -2.4e6,
            flowcell,
            ha="center",
        )

    for i in range(1, len(flowcells)):
        ax.axvline(x=i * 3 - 0.5, color="gray", linestyle="-", linewidth=0.5)

    if standalone:
        plt.tight_layout()
        plt.show()

    unclassified_percentages = {
        flowcell_rename[flowcell]: (unclassified / flowcell_total_reads[flowcell]) * 100
        for flowcell, unclassified in flowcell_unclassified_reads.items()
    }
    avg_unclassified_percentage = mean(unclassified_percentages.values())

    print(
        f"Average percentage of unclassified reads across all flowcells: {avg_unclassified_percentage:.2f}%"
    )

    print(f"Coefficient of Variation (CV) for each flowcell:")
    for flowcell, cv in cv_per_flowcell.items():
        print(f"{flowcell}: {cv:.2f}%")
    print(f"Overall CV across all flowcells: {overall_cv:.2f}%")

    return ax, cv_per_flowcell, overall_cv


ax, cv_per_flowcell, overall_cv = plot_multiplexed_flowcell_reads(NP_SEQ_SUMMARIES_DIR)


### 6. Combined Plots


In [None]:
def create_combined_sequencing_plots(
    wg_depth_df, total_depth_df, seq_stats_df, grouped_seq_stats_df, seq_summaries_dir
):

    palette = sns.color_palette("colorblind")
    main_color = palette[0]

    fig = plt.figure(figsize=(14, 16), dpi=300)
    gs = fig.add_gridspec(4, 2, height_ratios=[1, 1, 1, 1])

    def adjust_tick_labels(ax):
        for tick in ax.get_xticklabels():
            tick.set_rotation(45)
            tick.set_ha("right")
        locs, labels = ax.get_xticks(), ax.get_xticklabels()
        ax.set_xticks([loc + 0.2 for loc in locs])

    # A: Mean Depth per Chromosome
    ax_a = fig.add_subplot(gs[0, :])
    wg_depth_df["Sample ID"] = wg_depth_df["Sample ID"].astype(str)
    wg_depth_df["sample_num"] = (
        wg_depth_df["Sample ID"].str.extract(r"(\d+)").astype(int)
    )
    wg_depth_df = wg_depth_df.sort_values(["sample_num", "chrom"])

    unique_samples = wg_depth_df["Sample ID"].unique()
    color_palette = sns.color_palette("husl", n_colors=len(unique_samples))
    color_dict = dict(zip(unique_samples, color_palette))

    line_plot = sns.lineplot(
        data=wg_depth_df,
        x="chrom",
        y="mean",
        hue="Sample ID",
        style="multiplexing",
        palette=color_dict,
        legend="full",
        hue_order=unique_samples,
        ax=ax_a,
    )

    for sample_id in unique_samples:
        sample_df = wg_depth_df[wg_depth_df["Sample ID"] == sample_id].sort_values(
            "chrom"
        )
        color = color_dict[sample_id]
        ax_a.fill_between(
            sample_df["chrom"],
            sample_df["mean"] - sample_df["per_base_sem"],
            sample_df["mean"] + sample_df["per_base_sem"],
            alpha=0.2,
            color=color,
        )

    ax_a.set_title("Mean Depth per Chromosome (with SEM)")
    ax_a.set_xlabel("Chromosome")
    ax_a.set_ylabel("Mean Depth")

    locs, labels = ax_a.get_xticks(), ax_a.get_xticklabels()
    new_locs = [loc + 0.01 for loc in locs]
    ax_a.set_xticks(new_locs)
    ax_a.set_xticklabels(labels, rotation=45, ha="right")

    ax_a.grid(axis="y", linestyle="--", alpha=0.7)
    ax_a.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    ax_a.text(
        -0.05,
        1.05,
        "A",
        transform=ax_a.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )

    # B: Mean Whole Genome Depth per Sample
    ax_b = fig.add_subplot(gs[1, 0])
    total_depth_df["sample_num"] = (
        total_depth_df["anonymised_sample"].str.extract(r"(\d+)").astype(int)
    )
    total_depth_df = total_depth_df.sort_values("sample_num")
    sns.barplot(
        data=total_depth_df,
        x="anonymised_sample",
        y="mean_depth",
        hue="multiplexing",
        dodge=False,
        order=total_depth_df["anonymised_sample"],
        ax=ax_b,
    )
    ax_b.set_title("Mean Whole Genome Depth per Sample")
    ax_b.set_xlabel("Sample ID")
    ax_b.set_ylabel("Depth")
    adjust_tick_labels(ax_b)
    ax_b.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    ax_b.text(
        -0.05,
        1.05,
        "B",
        transform=ax_b.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )

    # C: Number of Pores Available at the Start of Sequencing
    ax_c = fig.add_subplot(gs[1, 1])
    sns.lineplot(
        data=seq_stats_df,
        x="new_flowcell_name",
        y="number_pores_start",
        hue="multiplexing",
        marker="o",
        style="multiplexing",
        ax=ax_c,
    )
    ax_c.set_title("Number of Pores Available at the Start of Sequencing")
    ax_c.set_xlabel("Flowcell ID")
    ax_c.set_ylabel("Number of Pores")
    adjust_tick_labels(ax_c)
    ax_c.legend(title="Multiplexing", bbox_to_anchor=(1.05, 1), loc="upper left")
    ax_c.set_ylim(bottom=0)
    ax_c.text(
        -0.05,
        1.05,
        "C",
        transform=ax_c.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )

    # D: Number of Pores at Start vs Total Mean Whole Genome Depth
    ax_d = fig.add_subplot(gs[2, 0])
    sns.scatterplot(
        data=grouped_seq_stats_df,
        x="number_pores_start",
        y="total_mean_depth",
        hue="multiplexing",
        s=100,
        ax=ax_d,
    )

    x = grouped_seq_stats_df["number_pores_start"]
    y = grouped_seq_stats_df["total_mean_depth"]

    slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)

    x_line = np.linspace(x.min(), x.max(), 100)
    y_line = slope * x_line + intercept

    ax_d.plot(
        x_line,
        y_line,
        color=main_color,
        linestyle="-",
        linewidth=2,
        alpha=0.5,
        label="line of best fit",
    )

    n = len(x)
    y_pred = slope * x + intercept
    s_err = np.sqrt(np.sum((y - y_pred) ** 2) / (n - 2))
    t = stats.t.ppf(0.975, n - 2)
    ci = (
        t
        * s_err
        * np.sqrt(1 / n + (x_line - np.mean(x)) ** 2 / np.sum((x - np.mean(x)) ** 2))
    )

    ax_d.fill_between(
        x_line,
        y_line - ci,
        y_line + ci,
        color=main_color,
        alpha=0.2,
        label="95% confidence interval",
    )

    ax_d.set_title("Number of Pores at Start vs Total Mean Whole Genome Depth")
    ax_d.set_xlabel("Number of Pores at Start")
    ax_d.set_ylabel("Mean Whole Genome Depth")
    ax_d.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left")
    ax_d.text(
        -0.05,
        1.05,
        "D",
        transform=ax_d.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )

    stats_text = f"slope: {slope:.4f}\nintercept: {intercept:.4f}\nr-value: {r_value:.4f}\np-value: {p_value:.4e}"
    ax_d.text(
        0.05,
        0.95,
        stats_text,
        transform=ax_d.transAxes,
        fontsize=8,
        verticalalignment="top",
        bbox=dict(boxstyle="round", facecolor="white", alpha=0.7),
    )

    # E: Number of barcoded reads
    ax_e = fig.add_subplot(gs[2, 1])
    plot_multiplexed_flowcell_reads(seq_summaries_dir, ax=ax_e)
    ax_e.text(
        -0.05,
        1.05,
        "E",
        transform=ax_e.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )

    plt.tight_layout()
    plt.show()


combined_fig = create_combined_sequencing_plots(
    wg_depth_df,
    total_depth_df,
    seq_stats_df,
    grouped_seq_stats_df,
    NP_SEQ_SUMMARIES_DIR,
)


# SNV Benchmark

## Comparative analysis

### 1. Sensitivity, Precision, and F1


In [None]:
def read_summary(file_path):
    try:
        with open(file_path, "r") as f:
            lines = f.readlines()

        none_line = next(
            (line for line in lines if line.strip().startswith("None")), None
        )

        if none_line:
            values = none_line.split()
            metrics = {
                "True-pos-baseline": int(values[1]),
                "True-pos-call": int(values[2]),
                "False-pos": int(values[3]),
                "False-neg": int(values[4]),
                "Precision": float(values[5]),
                "Sensitivity": float(values[6]),
                "F-measure": float(values[7]),
            }
            return metrics
        else:
            print(f"No 'None' threshold line found in {file_path}")
            return {}
    except FileNotFoundError:
        print(f"File not found: {file_path}")
        return {}
    except Exception as e:
        print(f"Error reading file {file_path}: {str(e)}")
        return {}


def calculate_rtg_statistics(df):
    return df.groupby("complexity").agg(
        {
            "Precision": ["mean", "std", "median", "min", "max"],
            "Sensitivity": ["mean", "std", "median", "min", "max"],
            "F-measure": ["mean", "std", "median", "min", "max"],
        }
    )


def collect_snv_metrics(sample_ids, technologies, complexities):
    snv_metrics = {"ont": {"hc": [], "lc": []}, "illumina": {"hc": [], "lc": []}}

    for _, row in sample_ids.iterrows():
        ont_id = row["ont_id"]
        lp_id = row["lp_id"]

        for tech in technologies:
            for complexity in complexities:
                sample_id = ont_id if tech == "ont" else lp_id
                summary_file = (
                    f"output/snv/rtg_vcfeval/{complexity}/{sample_id}.snv/summary.txt"
                )
                summary = read_summary(summary_file)

                if summary:
                    metrics_entry = {
                        "sample_id": sample_id,
                        "complexity": complexity,
                        **summary,
                    }
                    snv_metrics[tech][complexity].append(metrics_entry)
                else:
                    print(
                        f"Skipping empty summary for {sample_id}, {tech}, {complexity}"
                    )

    return snv_metrics


def process_snv_metrics(snv_metrics):
    snv_ont_metrics_df = pd.DataFrame(
        snv_metrics["ont"]["hc"] + snv_metrics["ont"]["lc"]
    )
    snv_illumina_metrics_df = pd.DataFrame(
        snv_metrics["illumina"]["hc"] + snv_metrics["illumina"]["lc"]
    )

    snv_ont_stats = calculate_rtg_statistics(snv_ont_metrics_df)
    snv_illumina_stats = calculate_rtg_statistics(snv_illumina_metrics_df)

    return (
        snv_ont_metrics_df,
        snv_illumina_metrics_df,
        snv_ont_stats,
        snv_illumina_stats,
    )


def perform_ttest(ont_data, illumina_data, metric):
    ont_values = ont_data[metric]
    illumina_data = illumina_data[metric]
    t_stat, p_value = stats.ttest_ind(ont_values, illumina_data)
    return t_stat, p_value


def prepare_stats_dataframe(df, metrics):
    for metric in metrics:
        df[(metric, "t_statistic")] = None
        df[(metric, "p_value")] = None
        df[(metric, "adjusted_p_value")] = None
    return df


def perform_statistical_tests(
    ont_metrics_df,
    illumina_metrics_df,
    ont_stats,
    illumina_stats,
    complexities,
    metrics,
):
    all_p_values = []

    for complexity in complexities:
        ont_data = ont_metrics_df[ont_metrics_df["complexity"] == complexity]
        illumina_data = illumina_metrics_df[
            illumina_metrics_df["complexity"] == complexity
        ]

        for metric in metrics:
            t_stat, p_value = perform_ttest(ont_data, illumina_data, metric)

            ont_stats.loc[complexity, (metric, "t_statistic")] = t_stat
            ont_stats.loc[complexity, (metric, "p_value")] = p_value

            illumina_stats.loc[complexity, (metric, "t_statistic")] = t_stat
            illumina_stats.loc[complexity, (metric, "p_value")] = p_value

            all_p_values.append(p_value)

    return all_p_values, ont_stats, illumina_stats


def apply_fdr_correction(
    all_p_values, ont_stats, illumina_stats, complexities, metrics
):
    _, adjusted_p_values, _, _ = multipletests(all_p_values, method="fdr_bh")

    adjusted_p_value_index = 0
    for complexity in complexities:
        for metric in metrics:
            adjusted_p_value = adjusted_p_values[adjusted_p_value_index]
            ont_stats.loc[complexity, (metric, "adjusted_p_value")] = adjusted_p_value
            illumina_stats.loc[complexity, (metric, "adjusted_p_value")] = (
                adjusted_p_value
            )
            adjusted_p_value_index += 1

    return ont_stats, illumina_stats


def run_statistical_analysis(
    snv_ont_metrics_df,
    snv_illumina_metrics_df,
    snv_ont_stats,
    snv_illumina_stats,
    complexities,
):
    metrics_to_test = ["Precision", "Sensitivity", "F-measure"]

    snv_ont_stats = prepare_stats_dataframe(snv_ont_stats, metrics_to_test)
    snv_illumina_stats = prepare_stats_dataframe(snv_illumina_stats, metrics_to_test)

    all_p_values, snv_ont_stats, snv_illumina_stats = perform_statistical_tests(
        snv_ont_metrics_df,
        snv_illumina_metrics_df,
        snv_ont_stats,
        snv_illumina_stats,
        complexities,
        metrics_to_test,
    )

    snv_ont_stats, snv_illumina_stats = apply_fdr_correction(
        all_p_values, snv_ont_stats, snv_illumina_stats, complexities, metrics_to_test
    )

    return snv_ont_stats, snv_illumina_stats


sample_ids = pd.read_csv("sample_ids.csv")
technologies = ["ont", "illumina"]
complexities = ["hc", "lc"]

snv_metrics = collect_snv_metrics(sample_ids, technologies, complexities)
snv_ont_metrics_df, snv_illumina_metrics_df, snv_ont_stats, snv_illumina_stats = (
    process_snv_metrics(snv_metrics)
)

snv_ont_stats, snv_illumina_stats = run_statistical_analysis(
    snv_ont_metrics_df,
    snv_illumina_metrics_df,
    snv_ont_stats,
    snv_illumina_stats,
    complexities,
)

with pd.option_context("display.max_rows", None, "display.max_columns", None):
    print("ONT SNV Stats:")
    display(snv_ont_stats)


In [None]:
with pd.option_context("display.max_rows", None, "display.max_columns", None):
    print("\nIllumina SNV Stats:")
    display(snv_illumina_stats)


In [None]:
def prepare_snv_performance_data(ont_stats, illumina_stats, metrics, complexities):
    plot_data = []
    for complexity in complexities:
        for metric in metrics:
            ont_value = ont_stats.loc[complexity, (metric, "mean")]
            illumina_value = illumina_stats.loc[complexity, (metric, "mean")]
            adjusted_p_value = ont_stats.loc[complexity, (metric, "adjusted_p_value")]

            significance = get_significance(adjusted_p_value)

            plot_data.extend(
                [
                    {
                        "Complexity": complexity.upper(),
                        "Metric": metric,
                        "Technology": "long-read",
                        "Value": ont_value,
                        "Significance": significance,
                    },
                    {
                        "Complexity": complexity.upper(),
                        "Metric": metric,
                        "Technology": "short-read",
                        "Value": illumina_value,
                        "Significance": significance,
                    },
                ]
            )

    return pd.DataFrame(plot_data)


def get_significance(p_value):
    if p_value < 0.001:
        return "***"
    elif p_value < 0.01:
        return "**"
    elif p_value < 0.05:
        return "*"
    else:
        return ""


def add_significance(ax, data, y_offset=0.02):
    for i, metric in enumerate(data["Metric"].unique()):
        metric_data = data[data["Metric"] == metric]
        significance = metric_data["Significance"].iloc[0]
        if significance:
            y = max(metric_data["Value"]) + y_offset
            ax.text(
                i,
                y,
                significance,
                ha="center",
                va="bottom",
                color="black",
                fontweight="bold",
            )


def plot_performance_bars(ax, data, x, y, hue, **kwargs):
    sns.barplot(x=x, y=y, hue=hue, data=data, errorbar=None, ax=ax, **kwargs)
    for container in ax.containers:
        ax.bar_label(container, fmt="%.3f", padding=3)


def create_performance_plot(
    df,
    metrics,
    complexities,
    ax=None,
    figsize=(14, 6),
    dpi=300,
    ylim=(0, 1.05),
    title="Performance Comparison",
):
    if ax is None:
        fig, ax = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
    else:
        fig = ax[0, 0].figure

    for i, complexity in enumerate(complexities):
        plot_performance_bars(
            ax[i],
            df[df["Complexity"] == complexity.upper()],
            x="Metric",
            y="Value",
            hue="Technology",
        )
        ax[i].set_title(
            f"{'High' if complexity == 'hc' else 'Low'} Complexity", fontsize=14
        )
        ax[i].set_ylim(ylim)
        add_significance(ax[i], df[df["Complexity"] == complexity.upper()])

        ax[i].set_xlabel("")
        ax[i].set_ylabel("Performance", fontsize=12)

        if i == 0:
            ax[i].legend_.remove()
        else:
            ax[i].legend(title="Technology", bbox_to_anchor=(1, 0.5))

    if title:
        fig.suptitle(title, fontsize=16)

    plt.tight_layout()
    return fig, ax


def plot_snv_performance_metrics(ont_stats, illumina_stats, ax=None, **kwargs):
    metrics = ["Precision", "Sensitivity", "F-measure"]
    complexities = ["hc", "lc"]

    plot_data = prepare_snv_performance_data(
        ont_stats, illumina_stats, metrics, complexities
    )
    return create_performance_plot(plot_data, metrics, complexities, ax=ax, **kwargs)


plot_snv_performance_metrics(snv_ont_stats, snv_illumina_stats)
plt.show()


### 2. Error Analysis


In [None]:
def calculate_snv_error_rates(sample_ids, technologies, complexities):
    snv_error_rates = {
        tech: {comp: {"FP": {}, "FN": {}} for comp in complexities}
        for tech in technologies
    }
    snv_types = [
        "A>C",
        "A>G",
        "A>T",
        "C>A",
        "C>G",
        "C>T",
        "G>A",
        "G>C",
        "G>T",
        "T>A",
        "T>C",
        "T>G",
    ]

    sample_counts = {tech: {comp: 0 for comp in complexities} for tech in technologies}

    for _, row in sample_ids.iterrows():
        ont_id = row["ont_id"]
        lp_id = row["lp_id"]

        for tech in technologies:
            for complexity in complexities:
                sample_id = ont_id if tech == "long-read" else lp_id
                fp_vcf = (
                    f"output/snv/rtg_vcfeval/{complexity}/{sample_id}.snv/fp.vcf.gz"
                )
                fn_vcf = (
                    f"output/snv/rtg_vcfeval/{complexity}/{sample_id}.snv/fn.vcf.gz"
                )
                query_vcf = (
                    f"output/snv/rtg_vcfeval/{complexity}/{sample_id}.snv/query.vcf.gz"
                )

                if not (
                    os.path.exists(fp_vcf)
                    and os.path.exists(fn_vcf)
                    and os.path.exists(query_vcf)
                ):
                    print(f"VCF files not found for {sample_id}, {tech}, {complexity}")
                    continue

                sample_counts[tech][complexity] += 1

                fp_counts = count_snv_types(fp_vcf)
                fn_counts = count_snv_types(fn_vcf)
                total_variants = count_total_variants(query_vcf)

                for snv_type in snv_types:
                    if snv_type not in snv_error_rates[tech][complexity]["FP"]:
                        snv_error_rates[tech][complexity]["FP"][snv_type] = []
                        snv_error_rates[tech][complexity]["FN"][snv_type] = []

                    fp_rate = (
                        fp_counts.get(snv_type, 0) / total_variants
                        if total_variants > 0
                        else 0
                    )
                    fn_rate = (
                        fn_counts.get(snv_type, 0) / total_variants
                        if total_variants > 0
                        else 0
                    )

                    snv_error_rates[tech][complexity]["FP"][snv_type].append(fp_rate)
                    snv_error_rates[tech][complexity]["FN"][snv_type].append(fn_rate)

    return snv_error_rates, sample_counts


def count_snv_types(vcf_file):
    snv_counts = {}
    with pysam.VariantFile(vcf_file) as vcf:
        for record in vcf:
            ref = record.ref
            alt = record.alts[0]
            if len(ref) == 1 and len(alt) == 1:
                snv_type = f"{ref}>{alt}"
                snv_counts[snv_type] = snv_counts.get(snv_type, 0) + 1
    return snv_counts


def count_total_variants(vcf_file):
    total_count = 0
    with pysam.VariantFile(vcf_file) as vcf:
        for _ in vcf:
            total_count += 1
    return total_count


def prepare_snv_error_data(snv_error_rates):
    plot_data = []
    for tech in snv_error_rates:
        for complexity in snv_error_rates[tech]:
            for error_type in snv_error_rates[tech][complexity]:
                for snv_type, rates in snv_error_rates[tech][complexity][
                    error_type
                ].items():
                    plot_data.append(
                        {
                            "Technology": tech,
                            "Complexity": complexity,
                            "Error Type": error_type,
                            "SNV Type": snv_type,
                            "Error Rate": np.mean(rates),
                        }
                    )
    return pd.DataFrame(plot_data)


def perform_statistical_tests(snv_error_rates, sample_counts):
    results = []
    p_values = []

    for complexity in snv_error_rates["long-read"]:
        for error_type in ["FP", "FN"]:
            for snv_type in snv_error_rates["long-read"][complexity][error_type]:
                long_read_rates = snv_error_rates["long-read"][complexity][error_type][
                    snv_type
                ]
                short_read_rates = snv_error_rates["short-read"][complexity][
                    error_type
                ][snv_type]

                n = min(len(long_read_rates), len(short_read_rates))

                t_statistic, p_value = stats.ttest_rel(
                    long_read_rates[:n], short_read_rates[:n]
                )

                results.append(
                    {
                        "Complexity": complexity,
                        "Error Type": error_type,
                        "SNV Type": snv_type,
                        "t-statistic": t_statistic,
                        "p-value": p_value,
                        "n": n,
                    }
                )
                p_values.append(p_value)

    rejected, p_values_corrected, _, _ = multipletests(p_values, method="fdr_bh")

    for result, p_corrected, is_rejected in zip(results, p_values_corrected, rejected):
        result["p-value (FDR corrected)"] = p_corrected
        result["Significance"] = get_significance(p_corrected)
        result["Rejected (FDR)"] = is_rejected

    return pd.DataFrame(results)


def get_significance(p_value):
    if p_value < 0.001:
        return "***"
    elif p_value < 0.01:
        return "**"
    elif p_value < 0.05:
        return "*"
    else:
        return ""


def plot_error_rates(plot_data, statistical_results, fig=None, axes=None):
    if fig is None and axes is None:
        fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    elif fig is None or axes is None:
        raise ValueError("Both fig and axes must be provided if one is provided")

    complexities = ["hc", "lc"]
    error_types = ["FP", "FN"]

    for i, complexity in enumerate(complexities):
        for j, error_type in enumerate(error_types):
            ax = axes[i, j]
            subset = plot_data[
                (plot_data["Complexity"] == complexity)
                & (plot_data["Error Type"] == error_type)
            ]

            sns.barplot(
                x="SNV Type", y="Error Rate", hue="Technology", data=subset, ax=ax
            )

            ax.set_title(
                f"{'High' if complexity == 'hc' else 'Low'} Complexity - {error_type}"
            )
            ax.set_xlabel("SNV Type")
            ax.set_ylabel("Error Rate")

            for idx, snv_type in enumerate(subset["SNV Type"].unique()):
                result = statistical_results[
                    (statistical_results["Complexity"] == complexity)
                    & (statistical_results["Error Type"] == error_type)
                    & (statistical_results["SNV Type"] == snv_type)
                ]
                if not result.empty:
                    significance = result["Significance"].values[0]
                    max_height = subset[subset["SNV Type"] == snv_type][
                        "Error Rate"
                    ].max()
                    ax.text(
                        idx,
                        max_height,
                        significance,
                        ha="center",
                        va="bottom",
                        fontweight="bold",
                    )

    plt.tight_layout()
    return fig, axes


sample_ids = pd.read_csv("sample_ids.csv")
technologies = ["long-read", "short-read"]
complexities = ["hc", "lc"]

snv_error_rates, sample_counts = calculate_snv_error_rates(
    sample_ids, technologies, complexities
)
plot_data = prepare_snv_error_data(snv_error_rates)
statistical_results = perform_statistical_tests(snv_error_rates, sample_counts)

plot_error_rates(plot_data, statistical_results)
plt.show()


In [None]:
statistical_results


### 3. Combined Plots


In [None]:
def create_combined_snv_metrics_plot(
    snv_ont_stats, snv_illumina_stats, plot_data, statistical_results
):
    fig = plt.figure(figsize=(12, 14), dpi=300)
    gs = fig.add_gridspec(3, 2)

    def add_significance(ax, data, y_offset=0.02):
        for i, metric in enumerate(data["Metric"].unique()):
            metric_data = data[data["Metric"] == metric]
            significance = metric_data["Significance"].iloc[0]
            if significance:
                y = max(metric_data["Value"]) + y_offset
                ax.text(
                    i,
                    y,
                    significance,
                    ha="center",
                    va="bottom",
                    color="black",
                    fontweight="bold",
                )

    # A and B: SNV Performance Metrics
    fig.text(0.5, 1.00, "SNV Performance", fontsize=12, ha="center")
    metrics = ["Precision", "Sensitivity", "F-measure"]
    complexities = ["hc", "lc"]

    for i, complexity in enumerate(complexities):
        ax = fig.add_subplot(gs[0, i])

        data = []
        for metric in metrics:
            ont_value = snv_ont_stats.loc[complexity, (metric, "mean")]
            illumina_value = snv_illumina_stats.loc[complexity, (metric, "mean")]
            p_value = snv_ont_stats.loc[complexity, (metric, "adjusted_p_value")]
            significance = (
                "***"
                if p_value < 0.001
                else ("**" if p_value < 0.01 else ("*" if p_value < 0.05 else ""))
            )

            data.extend(
                [
                    {
                        "Complexity": complexity.upper(),
                        "Metric": metric,
                        "Technology": "long-read",
                        "Value": ont_value,
                        "Significance": significance,
                    },
                    {
                        "Complexity": complexity.upper(),
                        "Metric": metric,
                        "Technology": "short-read",
                        "Value": illumina_value,
                        "Significance": significance,
                    },
                ]
            )

        df = pd.DataFrame(data)
        sns.barplot(x="Metric", y="Value", hue="Technology", data=df, ax=ax)
        ax.set_ylim(0, 1.05)
        ax.set_title(
            f"{'High' if complexity == 'hc' else 'Low'} Complexity",  # fontsize=14
        )
        ax.set_xlabel("")
        ax.set_ylabel("Performance" if i == 0 else "")

        if i == 1:
            ax.legend(title="Technology", loc="upper right")
        else:
            ax.get_legend().remove()

        add_significance(ax, df)
        ax.text(
            -0.05,
            1.05,
            "AB"[i],
            transform=ax.transAxes,
            fontsize=16,
            fontweight="bold",
            va="top",
        )

    # C, D, E, F: Error Rates
    fig.text(0.5, 0.665, "SNV Error Analysis", fontsize=12, ha="center")
    complexities = ["hc", "lc"]
    error_types = ["FP", "FN"]

    for i, complexity in enumerate(complexities):
        for j, error_type in enumerate(error_types):
            ax = fig.add_subplot(gs[i + 1, j])

            subset = plot_data[
                (plot_data["Complexity"] == complexity)
                & (plot_data["Error Type"] == error_type)
            ]
            sns.barplot(
                x="SNV Type", y="Error Rate", hue="Technology", data=subset, ax=ax
            )

            ax.set_title(
                f"{'High' if complexity == 'hc' else 'Low'} Complexity - {error_type}"
            )
            ax.set_xlabel("SNV Type")
            ax.set_ylabel("Error Rate")
            ax.get_legend().remove()

            for idx, snv_type in enumerate(subset["SNV Type"].unique()):
                result = statistical_results[
                    (statistical_results["Complexity"] == complexity)
                    & (statistical_results["Error Type"] == error_type)
                    & (statistical_results["SNV Type"] == snv_type)
                ]
                if not result.empty:
                    significance = result["Significance"].values[0]
                    max_height = subset[subset["SNV Type"] == snv_type][
                        "Error Rate"
                    ].max()
                    ax.text(
                        idx,
                        max_height,
                        significance,
                        ha="center",
                        va="bottom",
                        fontweight="bold",
                    )

            ax.text(
                -0.05,
                1.05,
                chr(ord("C") + i * 2 + j),
                transform=ax.transAxes,
                fontsize=16,
                fontweight="bold",
                va="top",
            )

    plt.tight_layout()

    return fig


combined_fig = create_combined_snv_metrics_plot(
    snv_ont_stats, snv_illumina_stats, plot_data, statistical_results
)

plt.show()


# Indel Benchmark

## Comparative analysis

### 1. Sensitivity, Precision, and F1


In [None]:
def collect_indel_metrics(indel_sample_ids, indel_complexities):
    indel_metrics_collection = {"ont": {"hc": [], "lc": []}}

    for _, row in indel_sample_ids.iterrows():
        indel_ont_id = row["ont_id"]

        for indel_complexity in indel_complexities:
            indel_summary_file = f"output/indel/rtg_vcfeval/{indel_complexity}/{indel_ont_id}.indel/summary.txt"
            indel_summary = read_summary(indel_summary_file)

            if indel_summary:
                indel_metrics_entry = {
                    "sample_id": indel_ont_id,
                    "complexity": indel_complexity,
                    **indel_summary,
                }
                indel_metrics_collection["ont"][indel_complexity].append(
                    indel_metrics_entry
                )
            else:
                print(f"Skipping empty summary for {indel_ont_id}, {indel_complexity}")

    return indel_metrics_collection


def process_indel_metrics(indel_metrics_collection):
    indel_ont_metrics_df = pd.DataFrame(
        indel_metrics_collection["ont"]["hc"] + indel_metrics_collection["ont"]["lc"]
    )
    indel_ont_stats = calculate_rtg_statistics(indel_ont_metrics_df)
    return indel_ont_metrics_df, indel_ont_stats


def plot_indel_performance_metrics(indel_ont_stats):
    indel_metrics_list = ["Precision", "Sensitivity", "F-measure"]
    indel_complexities_list = ["hc", "lc"]

    indel_plot_data = []
    for indel_complexity in indel_complexities_list:
        for indel_metric in indel_metrics_list:
            indel_value = indel_ont_stats.loc[indel_complexity, (indel_metric, "mean")]
            indel_plot_data.append(
                {
                    "Complexity": indel_complexity.upper(),
                    "Metric": indel_metric,
                    "Value": indel_value,
                }
            )

    indel_df = pd.DataFrame(indel_plot_data)

    indel_fig, (indel_ax1, indel_ax2) = plt.subplots(1, 2, figsize=(16, 6))

    for indel_ax, indel_comp in zip([indel_ax1, indel_ax2], ["HC", "LC"]):
        sns.barplot(
            x="Metric",
            y="Value",
            data=indel_df[indel_df["Complexity"] == indel_comp],
            ax=indel_ax,
        )
        indel_ax.set_title(
            f"{'High' if indel_comp == 'HC' else 'Low'} Complexity", fontsize=14
        )
        indel_ax.set_ylim(0, 1.05)
        indel_ax.set_xlabel("")
        indel_ax.set_ylabel("Performance (%)", fontsize=12)

        for container in indel_ax.containers:
            indel_ax.bar_label(container, fmt="%.3f", padding=3)

    indel_fig.suptitle("Indel Performance for ONT", fontsize=16)
    plt.tight_layout()
    plt.show()


indel_sample_ids = pd.read_csv("sample_ids.csv")
indel_complexities = ["hc", "lc"]

indel_metrics = collect_indel_metrics(indel_sample_ids, indel_complexities)
indel_ont_metrics_df, indel_ont_stats = process_indel_metrics(indel_metrics)

print("ONT Indel Stats:")
indel_ont_stats


In [None]:
plot_indel_performance_metrics(indel_ont_stats)


### 2. Error Analysis


In [None]:
def collect_indel_error_metrics(indel_sample_ids, indel_complexities):
    indel_error_metrics = {
        comp: defaultdict(lambda: defaultdict(int)) for comp in indel_complexities
    }
    indel_total_variants = {comp: 0 for comp in indel_complexities}

    for _, row in indel_sample_ids.iterrows():
        indel_ont_id = row["ont_id"]

        for indel_complexity in indel_complexities:
            base_path = (
                f"output/indel/rtg_vcfeval/{indel_complexity}/{indel_ont_id}.indel"
            )
            fp_file = os.path.join(base_path, "fp.vcf.gz")
            fn_file = os.path.join(base_path, "fn.vcf.gz")
            tp_file = os.path.join(base_path, "tp.vcf.gz")

            with pysam.VariantFile(tp_file) as vcf:
                indel_total_variants[indel_complexity] += sum(1 for _ in vcf)

            with pysam.VariantFile(fp_file) as vcf:
                for record in vcf:
                    indel_type = (
                        "insertion"
                        if len(record.alts[0]) > len(record.ref)
                        else "deletion"
                    )
                    indel_length = abs(len(record.alts[0]) - len(record.ref))
                    indel_error_metrics[indel_complexity]["FP"][
                        f"{indel_type}_{indel_length}"
                    ] += 1
                    indel_total_variants[indel_complexity] += 1

            with pysam.VariantFile(fn_file) as vcf:
                for record in vcf:
                    indel_type = (
                        "insertion"
                        if len(record.alts[0]) > len(record.ref)
                        else "deletion"
                    )
                    indel_length = abs(len(record.alts[0]) - len(record.ref))
                    indel_error_metrics[indel_complexity]["FN"][
                        f"{indel_type}_{indel_length}"
                    ] += 1
                    indel_total_variants[indel_complexity] += 1

    return indel_error_metrics, indel_total_variants


def process_indel_error_metrics(indel_error_metrics, indel_total_variants):
    indel_error_plot_data = []
    for indel_complexity, error_data in indel_error_metrics.items():
        total_variants = indel_total_variants[indel_complexity]
        for error_type, indel_types in error_data.items():
            for indel_type, count in indel_types.items():
                error_rate = (count / total_variants) * 100
                indel_error_plot_data.append(
                    {
                        "Complexity": indel_complexity,
                        "Error Type": error_type,
                        "Indel Type": indel_type,
                        "Error Rate (%)": error_rate,
                    }
                )

    indel_error_df = pd.DataFrame(indel_error_plot_data)
    indel_error_df["Indel Length"] = (
        indel_error_df["Indel Type"].str.split("_").str[-1].astype(int)
    )
    indel_error_df["Indel Category"] = (
        indel_error_df["Indel Type"].str.split("_").str[0]
    )
    indel_error_df = indel_error_df.sort_values(
        ["Indel Category", "Indel Length", "Complexity"]
    )

    return indel_error_df


def plot_indel_error_metrics(indel_error_df):
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    for i, error_type in enumerate(["FP", "FN"]):
        for j, indel_category in enumerate(["insertion", "deletion"]):
            data = indel_error_df[
                (indel_error_df["Error Type"] == error_type)
                & (indel_error_df["Indel Category"] == indel_category)
            ]

            sns.barplot(
                data=data,
                x="Indel Length",
                y="Error Rate (%)",
                hue="Complexity",
                hue_order=["hc", "lc"],
                ax=axes[j, i],
            )

            axes[j, i].set_title(f"{error_type} - {indel_category.capitalize()}s")
            axes[j, i].set_xlabel(f"{indel_category.capitalize()} Length")
            axes[j, i].set_ylabel("Error Rate (%)")

            unique_lengths = sorted(data["Indel Length"].unique())
            axes[j, i].set_xticks(range(len(unique_lengths)))
            axes[j, i].set_xticklabels(unique_lengths)

            axes[j, i].tick_params(axis="x", rotation=90)
            axes[j, i].legend(title="Complexity")

    plt.tight_layout()
    plt.show()


def perform_indel_statistical_test(
    hc_data, lc_data, hc_total, lc_total, error_type, indel_type
):
    hc_error = hc_data[error_type].get(indel_type, 0)
    lc_error = lc_data[error_type].get(indel_type, 0)

    contingency_table = [
        [hc_error, hc_total - hc_error],
        [lc_error, lc_total - lc_error],
    ]

    if any(any(cell < 5 for cell in row) for row in contingency_table):
        _, p_value = stats.fisher_exact(contingency_table)
    else:
        _, p_value, _, _ = stats.chi2_contingency(contingency_table)

    return p_value


def analyze_indel_error_statistics(indel_error_metrics, indel_total_variants):
    hc_data = indel_error_metrics["hc"]
    lc_data = indel_error_metrics["lc"]
    hc_total = indel_total_variants["hc"]
    lc_total = indel_total_variants["lc"]

    indel_stat_results = []

    for error_type in ["FP", "FN"]:
        indel_types = set(hc_data[error_type].keys()) | set(lc_data[error_type].keys())

        for indel_type in indel_types:
            p_value = perform_indel_statistical_test(
                hc_data, lc_data, hc_total, lc_total, error_type, indel_type
            )
            hc_rate = hc_data[error_type].get(indel_type, 0) / hc_total * 100
            lc_rate = lc_data[error_type].get(indel_type, 0) / lc_total * 100

            indel_stat_results.append(
                {
                    "Error Type": error_type,
                    "Indel Type": indel_type,
                    "HC Rate (%)": hc_rate,
                    "LC Rate (%)": lc_rate,
                    "p-value": p_value,
                }
            )

    indel_stat_df = pd.DataFrame(indel_stat_results)

    # Perform FDR correction
    _, p_values_corrected, _, _ = multipletests(
        indel_stat_df["p-value"], method="fdr_bh"
    )
    indel_stat_df["Adjusted p-value"] = p_values_corrected

    indel_stat_df = indel_stat_df.sort_values(["Error Type", "Indel Type"])

    indel_stat_df["HC Rate (%)"] = indel_stat_df["HC Rate (%)"].map("{:.4f}".format)
    indel_stat_df["LC Rate (%)"] = indel_stat_df["LC Rate (%)"].map("{:.4f}".format)
    indel_stat_df["p-value"] = indel_stat_df["p-value"].map("{:.4e}".format)
    indel_stat_df["Adjusted p-value"] = indel_stat_df["Adjusted p-value"].map(
        "{:.4e}".format
    )

    indel_stat_df["Significance"] = indel_stat_df["Adjusted p-value"].apply(
        lambda x: (
            "***"
            if float(x) < 0.001
            else ("**" if float(x) < 0.01 else ("*" if float(x) < 0.05 else ""))
        )
    )

    return indel_stat_df


indel_sample_ids = pd.read_csv("sample_ids.csv")
indel_complexities = ["hc", "lc"]

indel_error_metrics, indel_total_variants = collect_indel_error_metrics(
    indel_sample_ids, indel_complexities
)
indel_error_df = process_indel_error_metrics(indel_error_metrics, indel_total_variants)

plot_indel_error_metrics(indel_error_df)

indel_stat_df = analyze_indel_error_statistics(
    indel_error_metrics, indel_total_variants
)


In [None]:
with pd.option_context("display.max_rows", None, "display.max_columns", None):
    print("Indel Error Analysis DataFrame:")
    display(indel_error_df)


### 3. Size Distribution Analysis


In [None]:
def analyze_indel_size_distribution(indel_complexity, indel_sample_ids):
    indel_ont_sizes = defaultdict(lambda: defaultdict(int))
    indel_illumina_sizes = defaultdict(lambda: defaultdict(int))

    for _, row in indel_sample_ids.iterrows():
        indel_sample_id = row["ont_id"]
        indel_base_path = (
            f"output/indel/rtg_vcfeval/{indel_complexity}/{indel_sample_id}.indel"
        )

        indel_query_file = os.path.join(indel_base_path, "query.vcf.gz")
        indel_truth_file = os.path.join(indel_base_path, "truth.vcf.gz")

        with pysam.VariantFile(indel_query_file) as vcf:
            for record in vcf:
                indel_type = (
                    "insertion" if len(record.alts[0]) > len(record.ref) else "deletion"
                )
                indel_length = abs(len(record.alts[0]) - len(record.ref))
                indel_ont_sizes[indel_type][indel_length] += 1

        with pysam.VariantFile(indel_truth_file) as vcf:
            for record in vcf:
                indel_type = (
                    "insertion" if len(record.alts[0]) > len(record.ref) else "deletion"
                )
                indel_length = abs(len(record.alts[0]) - len(record.ref))
                indel_illumina_sizes[indel_type][indel_length] += 1

    return indel_ont_sizes, indel_illumina_sizes


def prepare_indel_plot_data(indel_ont_sizes, indel_illumina_sizes, indel_complexity):
    indel_plot_data = []
    for indel_type in ["insertion", "deletion"]:
        indel_ont_total = sum(indel_ont_sizes[indel_type].values())
        indel_illumina_total = sum(indel_illumina_sizes[indel_type].values())
        indel_max_size = max(
            max(indel_ont_sizes[indel_type].keys()),
            max(indel_illumina_sizes[indel_type].keys()),
        )

        for indel_size in range(1, indel_max_size + 1):
            indel_ont_percent = (
                (indel_ont_sizes[indel_type][indel_size] / indel_ont_total) * 100
                if indel_ont_total > 0
                else 0
            )
            indel_illumina_percent = (
                (indel_illumina_sizes[indel_type][indel_size] / indel_illumina_total)
                * 100
                if indel_illumina_total > 0
                else 0
            )

            indel_plot_data.append(
                {
                    "Indel Type": indel_type,
                    "Size": indel_size,
                    "Percentage": indel_ont_percent,
                    "Platform": "ONT",
                    "Complexity": indel_complexity,
                }
            )
            indel_plot_data.append(
                {
                    "Indel Type": indel_type,
                    "Size": indel_size,
                    "Percentage": indel_illumina_percent,
                    "Platform": "Illumina",
                    "Complexity": indel_complexity,
                }
            )
    return indel_plot_data


def plot_indel_size_distribution(indel_plot_data):
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    indel_types = ["insertion", "deletion"]
    indel_complexities = ["High Complexity", "Low Complexity"]

    for i, indel_type in enumerate(indel_types):
        for j, indel_complexity in enumerate(indel_complexities):
            data = indel_plot_data[
                (indel_plot_data["Indel Type"] == indel_type)
                & (indel_plot_data["Complexity"] == indel_complexity)
            ]

            sns.lineplot(
                data=data,
                x="Size",
                y="Percentage",
                hue="Platform",
                ax=axes[i, j],
                marker="o",
            )

            axes[i, j].set_title(f"{indel_complexity} - {indel_type.capitalize()}s")
            axes[i, j].set_xlabel("Indel Size")
            axes[i, j].set_ylabel("Percentage")
            axes[i, j].legend(title="Platform")

    plt.tight_layout()
    plt.show()


def calculate_indel_summary_stats(indel_sizes):
    indel_total = sum(indel_sizes.values())
    if indel_total == 0:
        return {"total": 0, "mean": 0, "median": 0, "std": 0}

    indel_sizes_array = np.array([(k, v) for k, v in indel_sizes.items()])
    indel_mean = np.average(indel_sizes_array[:, 0], weights=indel_sizes_array[:, 1])
    indel_median = np.median(
        np.repeat(indel_sizes_array[:, 0], indel_sizes_array[:, 1].astype(int))
    )
    indel_std = np.sqrt(
        np.average(
            (indel_sizes_array[:, 0] - indel_mean) ** 2, weights=indel_sizes_array[:, 1]
        )
    )

    return {
        "total": indel_total,
        "mean": indel_mean,
        "median": indel_median,
        "std": indel_std,
    }


def print_indel_summary_stats(indel_ont_sizes, indel_illumina_sizes, indel_complexity):
    print(f"\n{indel_complexity}")
    for indel_type in ["insertion", "deletion"]:
        indel_ont_stats = calculate_indel_summary_stats(indel_ont_sizes[indel_type])
        indel_illumina_stats = calculate_indel_summary_stats(
            indel_illumina_sizes[indel_type]
        )

        print(f"  {indel_type.capitalize()}s:")
        print(
            f"    ONT     - Total: {indel_ont_stats['total']}, Mean Size: {indel_ont_stats['mean']:.2f}, Std Dev: {indel_ont_stats['std']:.2f}, Median Size: {indel_ont_stats['median']}"
        )
        print(
            f"    Illumina - Total: {indel_illumina_stats['total']}, Mean Size: {indel_illumina_stats['mean']:.2f}, Std Dev: {indel_illumina_stats['std']:.2f}, Median Size: {indel_illumina_stats['median']}"
        )


def compare_indel_distributions(indel_ont_sizes, indel_illumina_sizes):
    indel_ont_dist = [
        size for size, count in indel_ont_sizes.items() for _ in range(count)
    ]
    indel_illumina_dist = [
        size for size, count in indel_illumina_sizes.items() for _ in range(count)
    ]

    ks_statistic, p_value = stats.ks_2samp(indel_ont_dist, indel_illumina_dist)
    return ks_statistic, p_value


def perform_indel_statistical_tests(
    indel_hc_ont_sizes,
    indel_hc_illumina_sizes,
    indel_lc_ont_sizes,
    indel_lc_illumina_sizes,
):
    indel_all_p_values = []
    indel_test_results = []

    for indel_complexity, (indel_ont_sizes, indel_illumina_sizes) in [
        ("High Complexity", (indel_hc_ont_sizes, indel_hc_illumina_sizes)),
        ("Low Complexity", (indel_lc_ont_sizes, indel_lc_illumina_sizes)),
    ]:
        for indel_type in ["insertion", "deletion"]:
            ks_statistic, p_value = compare_indel_distributions(
                indel_ont_sizes[indel_type], indel_illumina_sizes[indel_type]
            )
            indel_all_p_values.append(p_value)
            indel_test_results.append(
                (indel_complexity, indel_type, ks_statistic, p_value)
            )

    _, indel_corrected_p_values, _, _ = multipletests(
        indel_all_p_values, method="fdr_bh"
    )

    print("\nStatistical Test Results (Kolmogorov-Smirnov test with FDR correction):")
    for (indel_complexity, indel_type, ks_statistic, p_value), corrected_p_value in zip(
        indel_test_results, indel_corrected_p_values
    ):
        print(f"\n{indel_complexity} - {indel_type.capitalize()}s:")
        print(f"    KS statistic: {ks_statistic:.4f}")
        print(f"    Original p-value: {p_value:.4e}")
        print(f"    Corrected p-value: {corrected_p_value:.4e}")


indel_sample_ids = pd.read_csv("sample_ids.csv")
indel_complexities = ["hc", "lc"]

indel_hc_ont_sizes, indel_hc_illumina_sizes = analyze_indel_size_distribution(
    "hc", indel_sample_ids
)
indel_lc_ont_sizes, indel_lc_illumina_sizes = analyze_indel_size_distribution(
    "lc", indel_sample_ids
)

indel_hc_plot_data = prepare_indel_plot_data(
    indel_hc_ont_sizes, indel_hc_illumina_sizes, "High Complexity"
)
indel_lc_plot_data = prepare_indel_plot_data(
    indel_lc_ont_sizes, indel_lc_illumina_sizes, "Low Complexity"
)

indel_plot_data = pd.DataFrame(indel_hc_plot_data + indel_lc_plot_data)

plot_indel_size_distribution(indel_plot_data)

print("\nIndel Size Distribution Summary Statistics:")
print_indel_summary_stats(
    indel_hc_ont_sizes, indel_hc_illumina_sizes, "High Complexity"
)
print_indel_summary_stats(indel_lc_ont_sizes, indel_lc_illumina_sizes, "Low Complexity")

perform_indel_statistical_tests(
    indel_hc_ont_sizes,
    indel_hc_illumina_sizes,
    indel_lc_ont_sizes,
    indel_lc_illumina_sizes,
)


### 4. Combined Plots


In [None]:
def create_combined_indel_plot(indel_ont_stats, indel_error_df, indel_plot_data):
    fig = plt.figure(figsize=(12, 16), dpi=300)
    gs = fig.add_gridspec(5, 2)

    def add_significance(ax, data, y_offset=0.02):
        for i, metric in enumerate(data["Metric"].unique()):
            metric_data = data[data["Metric"] == metric]
            significance = metric_data["Significance"].iloc[0]
            if significance:
                y = max(metric_data["Value"]) + y_offset
                ax.text(
                    i,
                    y,
                    significance,
                    ha="center",
                    va="bottom",
                    color="black",
                    fontweight="bold",
                )

    # A and B: Indel Performance for ONT
    fig.text(0.5, 1.00, "ONT Indel Performance", fontsize=12, ha="center")

    complexities = ["hc", "lc"]
    metrics = ["Precision", "Sensitivity", "F-measure"]

    for i, complexity in enumerate(complexities):
        ax = fig.add_subplot(gs[0, i])

        data = []
        for metric in metrics:
            value = indel_ont_stats.loc[complexity, (metric, "mean")]
            data.append(
                {"Complexity": complexity.upper(), "Metric": metric, "Value": value}
            )

        df = pd.DataFrame(data)
        sns.barplot(x="Metric", y="Value", data=df, ax=ax, legend=False)
        ax.set_ylim(0, 1.05)
        ax.set_title(f"{'High' if complexity == 'hc' else 'Low'} Complexity")
        ax.set_xlabel("")
        ax.set_ylabel("Performance" if i == 0 else "")

        ax.text(
            -0.05,
            1.1,
            "AB"[i],
            transform=ax.transAxes,
            fontsize=16,
            fontweight="bold",
            va="top",
        )

    # C, D, E, F: Indel Size Distribution
    fig.text(0.5, 0.80, "Indel Size Distribution", fontsize=12, ha="center")

    indel_types = ["insertion", "deletion"]
    complexities = ["High Complexity", "Low Complexity"]

    plot_data = indel_plot_data.copy()
    plot_data["Platform"] = plot_data["Platform"].replace(
        {"ONT": "long-read", "Illumina": "short-read"}
    )

    for i, indel_type in enumerate(indel_types):
        for j, complexity in enumerate(complexities):
            ax = fig.add_subplot(gs[i + 1, j])

            data = plot_data[
                (plot_data["Indel Type"] == indel_type)
                & (plot_data["Complexity"] == complexity)
            ]

            sns.lineplot(
                data=data,
                x="Size",
                y="Percentage",
                hue="Platform",
                ax=ax,
                marker="o",
                legend=(i == 0 and j == 1),  # Only show legend for plot D
            )

            ax.set_title(f"{complexity} - {indel_type.capitalize()}s")
            ax.set_xlabel("Indel Size")
            ax.set_ylabel("Percentage")

            if i == 0 and j == 1:  # Only keep legend for plot D
                ax.legend(title="Platform")
            else:
                ax.get_legend().remove() if ax.get_legend() else None

            ax.text(
                -0.05,
                1.1,
                chr(ord("C") + i * 2 + j),
                transform=ax.transAxes,
                fontsize=16,
                fontweight="bold",
                va="top",
            )

    # G, H, I, J: Indel Error Analysis
    fig.text(0.5, 0.405, "Indel Error Analysis", fontsize=12, ha="center")

    error_types = ["FP", "FN"]
    indel_categories = ["insertion", "deletion"]

    error_data = indel_error_df.copy()
    error_data["Complexity"] = error_data["Complexity"].replace(
        {"hc": "high complexity", "lc": "low complexity"}
    )

    for i, error_type in enumerate(error_types):
        for j, indel_category in enumerate(indel_categories):
            ax = fig.add_subplot(gs[i + 3, j])

            data = error_data[
                (error_data["Error Type"] == error_type)
                & (error_data["Indel Category"] == indel_category)
            ]

            sns.barplot(
                data=data,
                x="Indel Length",
                y="Error Rate (%)",
                hue="Complexity",
                hue_order=["high complexity", "low complexity"],
                ax=ax,
                legend=(i == 0 and j == 1),  # Only show legend for plot H
            )

            ax.set_title(f"{error_type} - {indel_category.capitalize()}s")
            ax.set_xlabel(f"{indel_category.capitalize()} Length")
            ax.set_ylabel("Error Rate (%)")

            unique_lengths = sorted(data["Indel Length"].unique())
            ax.set_xticks(range(len(unique_lengths)))
            ax.set_xticklabels(unique_lengths)

            ax.tick_params(axis="x", rotation=90)

            if i == 0 and j == 1:  # Only keep legend for plot H
                ax.legend(title="Complexity")
            else:
                ax.get_legend().remove() if ax.get_legend() else None

            ax.text(
                -0.05,
                1.1,
                chr(ord("G") + i * 2 + j),
                transform=ax.transAxes,
                fontsize=16,
                fontweight="bold",
                va="top",
            )

    plt.tight_layout()
    return fig


combined_indel_fig = create_combined_indel_plot(
    indel_ont_stats, indel_error_df, indel_plot_data
)

plt.show()


## Impacts of sequencing on SNV and Indel variant calling

### 1. Impact of multiplexing on variant calling


In [None]:
def create_performance_comparison_plot(metrics_df, multiplexing_df):
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle(
        "Performance Metrics: Multiplexed vs Singleplexed Samples",
        fontsize=16,
    )

    metrics = ["Precision", "Sensitivity", "F-measure"]
    complexities = ["hc", "lc"]

    for row, complexity in enumerate(complexities):
        merged_df = pd.merge(
            metrics_df[metrics_df["complexity"] == complexity],
            multiplexing_df[["sample", "multiplexing"]],
            left_on="sample_id",
            right_on="sample",
        )

        for col, metric in enumerate(metrics):
            ax = axes[row, col]
            sns.violinplot(x="multiplexing", y=metric, data=merged_df, ax=ax)
            ax.set_title(f"{metric} ({complexity.upper()})")
            ax.set_xlabel("Multiplexing")
            ax.set_ylabel(metric)

    plt.tight_layout()
    return fig


performance_comparison_fig = create_performance_comparison_plot(
    snv_ont_metrics_df, np_metrics_df
)


In [None]:
def create_indel_performance_comparison_plot(indel_metrics_df, multiplexing_df):
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle(
        "Indel Performance Metrics: Multiplexed vs Singleplexed Samples",
        fontsize=16,
    )

    metrics = ["Precision", "Sensitivity", "F-measure"]
    complexities = ["hc", "lc"]

    for row, complexity in enumerate(complexities):
        merged_df = pd.merge(
            indel_metrics_df[indel_metrics_df["complexity"] == complexity],
            multiplexing_df[["sample", "multiplexing"]],
            left_on="sample_id",
            right_on="sample",
        )

        for col, metric in enumerate(metrics):
            ax = axes[row, col]
            sns.violinplot(x="multiplexing", y=metric, data=merged_df, ax=ax)
            ax.set_title(f"{metric} ({complexity.upper()})")
            ax.set_xlabel("Multiplexing")
            ax.set_ylabel(metric)

    plt.tight_layout()
    return fig


indel_performance_comparison_fig = create_indel_performance_comparison_plot(
    indel_ont_metrics_df, np_metrics_df
)


In [None]:
def create_combined_performance_comparison_plot(
    snv_metrics_df, indel_metrics_df, multiplexing_df
):
    fig, axes = plt.subplots(4, 3, figsize=(14, 16), dpi=300)
    fig.suptitle("Performance Metrics vs Multiplexing", fontsize=16, y=0.99)

    metrics = ["Precision", "Sensitivity", "F-measure"]
    complexities = ["hc", "lc"]
    data_types = ["SNV", "Indel"]

    for data_type_idx, (data_type, metrics_df) in enumerate(
        zip(data_types, [snv_metrics_df, indel_metrics_df])
    ):
        for complexity_idx, complexity in enumerate(complexities):
            row = data_type_idx * 2 + complexity_idx
            merged_df = pd.merge(
                metrics_df[metrics_df["complexity"] == complexity],
                multiplexing_df[["sample", "multiplexing"]],
                left_on="sample_id",
                right_on="sample",
            )

            y_min = min(merged_df[metric].min() for metric in metrics)
            y_max = max(merged_df[metric].max() for metric in metrics)

            y_range = y_max - y_min
            y_min -= 0.4 * y_range
            y_max += 0.2 * y_range

            row_title = f"{data_type} {complexity.upper()}"
            fig.text(
                -0.01,
                0.853 - row * 0.242,
                row_title,
                va="center",
                ha="left",
                fontsize=12,
                rotation=90,
            )

            axes[row, 0].text(
                -0.15,
                1.1,
                chr(65 + row),
                transform=axes[row, 0].transAxes,
                fontsize=16,
                fontweight="bold",
                va="top",
            )

            for col, metric in enumerate(metrics):
                ax = axes[row, col]
                sns.violinplot(x="multiplexing", y=metric, data=merged_df, ax=ax)
                ax.set_title(metric)
                ax.set_xlabel("")

                if col == 0:
                    ax.set_ylabel("Performance")
                else:
                    ax.set_ylabel("")

                ax.set_ylim(y_min, y_max)

    plt.tight_layout()
    plt.show()
    return fig


combined_performance_comparison_fig = create_combined_performance_comparison_plot(
    snv_ont_metrics_df, indel_ont_metrics_df, np_metrics_df
)


In [None]:
def create_variant_multiplexing_stats(snv_df, indel_df, multiplexing_df):
    def calculate_stats(df, variant_type, complexity):
        merged_df = pd.merge(
            df[df["complexity"] == complexity],
            multiplexing_df[["sample", "multiplexing"]],
            left_on="sample_id",
            right_on="sample",
        )

        stats = merged_df.groupby("multiplexing", observed=True).agg(
            {
                "Precision": ["mean", "std", "median"],
                "Sensitivity": ["mean", "std", "median"],
                "F-measure": ["mean", "std", "median"],
            }
        )

        stats.columns = [f"{col[0]}_{col[1]}" for col in stats.columns]
        stats["variant_type"] = variant_type
        stats["complexity"] = complexity
        return stats.reset_index()

    stats_list = []
    for variant_type, df in [("SNV", snv_df), ("Indel", indel_df)]:
        for complexity in ["hc", "lc"]:
            stats_list.append(calculate_stats(df, variant_type, complexity))

    combined_stats = pd.concat(stats_list, ignore_index=True)

    column_order = ["variant_type", "complexity", "multiplexing"] + [
        col
        for col in combined_stats.columns
        if col not in ["variant_type", "complexity", "multiplexing"]
    ]
    combined_stats = combined_stats[column_order]

    return combined_stats


def format_number(num):
    return f"{num:.4f}"


def calculate_percentage_increase(singleplex_val, multiplex_val):
    return ((singleplex_val - multiplex_val) / multiplex_val) * 100


def summarize_variant_stats(stats_df):
    for variant_type in ["SNV", "Indel"]:
        for complexity in ["hc", "lc"]:
            print(f"\n{variant_type} Statistics ({complexity.upper()}):")
            print("=" * 40)

            variant_stats = stats_df[
                (stats_df["variant_type"] == variant_type)
                & (stats_df["complexity"] == complexity)
            ]
            singleplex_stats = variant_stats[
                variant_stats["multiplexing"] == "singleplex"
            ].iloc[0]
            multiplex_stats = variant_stats[
                variant_stats["multiplexing"] == "multiplex"
            ].iloc[0]

            metrics = ["Precision", "Sensitivity", "F-measure"]

            for metric in metrics:
                print(f"\n{metric}:")
                for stat in ["mean", "std", "median"]:
                    singleplex_val = singleplex_stats[f"{metric}_{stat}"]
                    multiplex_val = multiplex_stats[f"{metric}_{stat}"]
                    print(
                        f"  {stat.capitalize():6s}: Singleplex: {format_number(singleplex_val)}, "
                        f"Multiplex: {format_number(multiplex_val)}"
                    )

                increase = calculate_percentage_increase(
                    singleplex_stats[f"{metric}_mean"],
                    multiplex_stats[f"{metric}_mean"],
                )
                print(
                    f"  Mean Percentage Increase (Singleplex vs Multiplex): {increase:6.2f}%"
                )


variant_multiplexing_stats = create_variant_multiplexing_stats(
    snv_ont_metrics_df, indel_ont_metrics_df, np_metrics_df
)

summarize_variant_stats(variant_multiplexing_stats)


### 2. Impact of sequencing depth on variant calling


In [None]:
def plot_depth_vs_performance(
    wg_depth_df, snv_metrics_df, indel_metrics_df, np_metrics_df
):
    depth_data = wg_depth_df[wg_depth_df["chrom"] == "chr1"][["sample", "mean"]]
    depth_data = depth_data.rename(columns={"mean": "wg_mean_depth"})

    def asymptotic_func(x, a, b, c):
        return a - b * np.exp(-c * x)

    def plot_correlation(data, variant_type, complexity):
        metrics = ["Precision", "Sensitivity", "F-measure"]
        fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=300)
        fig.suptitle(
            f"{variant_type} Performance Metrics vs Whole Genome Mean Depth ({complexity.upper()})",
            fontsize=16,
        )

        palette = sns.color_palette("colorblind")
        main_color = palette[0]

        correlations = {}

        for ax, metric in zip(axes, metrics):
            sns.scatterplot(
                data=data, x="wg_mean_depth", y=metric, hue="multiplexing", ax=ax, s=100
            )

            x = data["wg_mean_depth"]
            y = data[metric]

            popt, pcov = curve_fit(
                asymptotic_func, x, y, p0=[1, 0.1, 0.1], bounds=([0, 0, 0], [2, 1, 1])
            )

            x_range = np.linspace(x.min(), x.max(), 100)
            y_fit = asymptotic_func(x_range, *popt)
            ax.plot(
                x_range,
                y_fit,
                color=main_color,
                linestyle="-",
                linewidth=2,
                label="Line of best fit",
            )

            perr = np.sqrt(np.diag(pcov))
            n = len(x)
            dof = max(0, n - len(popt))
            t = stats.t.ppf(0.975, dof)
            y_err = np.sqrt(np.sum((y - asymptotic_func(x, *popt)) ** 2) / dof)

            ci = (
                t
                * y_err
                * np.sqrt(
                    1 / n + (x_range - np.mean(x)) ** 2 / np.sum((x - np.mean(x)) ** 2)
                )
            )
            ax.fill_between(
                x_range,
                y_fit - ci,
                y_fit + ci,
                color=main_color,
                alpha=0.2,
                label="95% Confidence Interval",
            )

            r_value, p_value = stats.pearsonr(x, y)
            correlations[metric] = (r_value, p_value)

            ax.set_title(metric)
            ax.set_xlabel("Whole Genome Mean Depth")
            ax.set_ylabel(metric)
            ax.legend(title="")

        plt.tight_layout()
        plt.show()

        print(f"\nPearson Correlations for {variant_type} ({complexity.upper()}):")
        for metric, (r_value, p_value) in correlations.items():
            print(f"{metric}:")
            print(f"  Correlation coefficient: {r_value:.4f}")
            print(f"  p-value: {p_value:.4e}")
            print()

    for variant_type, metrics_df in [
        ("SNV", snv_metrics_df),
        ("Indel", indel_metrics_df),
    ]:
        for complexity in ["hc", "lc"]:
            data = pd.merge(
                metrics_df[metrics_df["complexity"] == complexity],
                np_metrics_df[["sample", "multiplexing"]],
                left_on="sample_id",
                right_on="sample",
            )
            data = pd.merge(data, depth_data, on="sample")
            plot_correlation(data, variant_type, complexity)


plot_depth_vs_performance(
    wg_depth_df, snv_ont_metrics_df, indel_ont_metrics_df, np_metrics_df
)


In [None]:
def create_combined_depth_vs_performance_plot(
    wg_depth_df, snv_metrics_df, indel_metrics_df, np_metrics_df
):
    depth_data = wg_depth_df[wg_depth_df["chrom"] == "chr1"][["sample", "mean"]]
    depth_data = depth_data.rename(columns={"mean": "wg_mean_depth"})

    def asymptotic_func(x, a, b, c):
        return a - b * np.exp(-c * x)

    fig, axes = plt.subplots(4, 3, figsize=(18, 20), dpi=300)
    fig.suptitle("Performance Metrics vs Whole Genome Mean Depth", fontsize=16, y=0.99)

    metrics = ["Precision", "Sensitivity", "F-measure"]
    complexities = ["hc", "lc"]
    data_types = ["SNV", "Indel"]

    palette = sns.color_palette("colorblind")
    main_color = palette[0]

    for data_type_idx, (data_type, metrics_df) in enumerate(
        zip(data_types, [snv_metrics_df, indel_metrics_df])
    ):
        for complexity_idx, complexity in enumerate(complexities):
            row = data_type_idx * 2 + complexity_idx

            axes[row, 0].text(
                -0.15,
                1.1,
                chr(65 + row),
                transform=axes[row, 0].transAxes,
                fontsize=16,
                fontweight="bold",
                va="top",
            )

            merged_df = pd.merge(
                metrics_df[metrics_df["complexity"] == complexity],
                np_metrics_df[["sample", "multiplexing"]],
                left_on="sample_id",
                right_on="sample",
            )
            merged_df = pd.merge(merged_df, depth_data, on="sample")

            y_min = min(merged_df[metric].min() for metric in metrics)
            y_max = max(merged_df[metric].max() for metric in metrics)

            y_range = y_max - y_min
            y_min -= 0.4 * y_range
            y_max += 0.2 * y_range

            row_title = f"{data_type} {complexity.upper()}"
            fig.text(
                -0.01,
                0.855 - row * 0.244,
                row_title,
                va="center",
                ha="left",
                fontsize=12,
                rotation=90,
            )

            for col, metric in enumerate(metrics):
                ax = axes[row, col]
                sns.scatterplot(
                    data=merged_df,
                    x="wg_mean_depth",
                    y=metric,
                    hue="multiplexing",
                    ax=ax,
                    legend=row == 0 and col == 0,
                    s=100,
                )

                x = merged_df["wg_mean_depth"]
                y = merged_df[metric]

                popt, pcov = curve_fit(
                    asymptotic_func,
                    x,
                    y,
                    p0=[1, 0.1, 0.1],
                    bounds=([0, 0, 0], [2, 1, 1]),
                )

                x_range = np.linspace(x.min(), x.max(), 100)
                y_fit = asymptotic_func(x_range, *popt)
                ax.plot(
                    x_range,
                    y_fit,
                    color=main_color,
                    linestyle="-",
                    linewidth=2,
                    label="Line of best fit",
                )

                perr = np.sqrt(np.diag(pcov))
                n = len(x)
                dof = max(0, n - len(popt))
                t = stats.t.ppf(0.975, dof)
                y_err = np.sqrt(np.sum((y - asymptotic_func(x, *popt)) ** 2) / dof)

                ci = (
                    t
                    * y_err
                    * np.sqrt(
                        1 / n
                        + (x_range - np.mean(x)) ** 2 / np.sum((x - np.mean(x)) ** 2)
                    )
                )
                ax.fill_between(
                    x_range,
                    y_fit - ci,
                    y_fit + ci,
                    color=main_color,
                    alpha=0.2,
                    label="95% Confidence Interval",
                )

                r_value, p_value = stats.pearsonr(x, y)

                ax.set_title(f"{metric}\nr={r_value:.2f}, p={p_value:.2e}")
                ax.set_xlabel("")

                if col == 0:
                    ax.set_ylabel("Performance")
                else:
                    ax.set_ylabel("")
                    ax.set_yticklabels([])

                ax.set_ylim(y_min, y_max)

                if row == 0 and col == 0:
                    handles, labels = ax.get_legend_handles_labels()
                    ax.legend(
                        handles=handles[:2] + handles[-2:],
                        labels=labels[:2] + labels[-2:],
                        bbox_to_anchor=(1, 0.3),
                    )

    plt.tight_layout()
    plt.show()
    return fig


combined_depth_performance_fig = create_combined_depth_vs_performance_plot(
    wg_depth_df, snv_ont_metrics_df, indel_ont_metrics_df, np_metrics_df
)


### 3. Impact of read length on variant calling


In [None]:
def plot_read_length_vs_performance(
    read_length_df, snv_metrics_df, indel_metrics_df, np_metrics_df
):
    mean_read_length_data = read_length_df[
        read_length_df["read_length_type"] == "Mean Read Length"
    ]
    mean_read_length_data = mean_read_length_data[["anonymised_sample", "read_length"]]
    mean_read_length_data = mean_read_length_data.rename(
        columns={"read_length": "mean_read_length"}
    )

    def plot_correlation(data, variant_type, complexity):
        metrics = ["Precision", "Sensitivity", "F-measure"]
        fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=300)
        fig.suptitle(
            f"{variant_type} Performance Metrics vs Mean Read Length ({complexity.upper()})",
            fontsize=16,
        )

        palette = sns.color_palette("colorblind")
        main_color = palette[0]
        correlations = {}

        for ax, metric in zip(axes, metrics):
            sns.scatterplot(
                data=data,
                x="mean_read_length",
                y=metric,
                hue="multiplexing",
                ax=ax,
                s=100,
            )

            x = data["mean_read_length"]
            y = data[metric]

            slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
            x_range = np.linspace(x.min(), x.max(), 100)
            y_fit = slope * x_range + intercept

            ax.plot(
                x_range,
                y_fit,
                color=main_color,
                linestyle="-",
                linewidth=2,
                label="Line of best fit",
            )

            n = len(x)
            y_err = np.sqrt(np.sum((y - (slope * x + intercept)) ** 2) / (n - 2))
            ci = (
                stats.t.ppf(0.975, n - 2)
                * y_err
                * np.sqrt(
                    1 / n + (x_range - np.mean(x)) ** 2 / np.sum((x - np.mean(x)) ** 2)
                )
            )
            ax.fill_between(
                x_range,
                y_fit - ci,
                y_fit + ci,
                color=main_color,
                alpha=0.2,
                label="95% Confidence Interval",
            )

            r_value, p_value = stats.pearsonr(x, y)
            correlations[metric] = (r_value, p_value)

            ax.set_title(metric)
            ax.set_xlabel("Mean Read Length")
            ax.set_ylabel(metric)
            ax.legend(title="")

        plt.tight_layout()
        plt.show()

        print(f"\nPearson Correlations for {variant_type} ({complexity.upper()}):")
        for metric, (r_value, p_value) in correlations.items():
            print(f"{metric}:")
            print(f"  Correlation coefficient: {r_value:.4f}")
            print(f"  p-value: {p_value:.4e}")
            print()

    for variant_type, metrics_df in [
        ("SNV", snv_metrics_df),
        ("Indel", indel_metrics_df),
    ]:
        for complexity in ["hc", "lc"]:
            metrics_df["sample_id"] = metrics_df["sample_id"]

            filtered_metrics = metrics_df[metrics_df["complexity"] == complexity]

            data = pd.merge(
                filtered_metrics,
                np_metrics_df[["sample", "anonymised_sample", "multiplexing"]],
                left_on="sample_id",
                right_on="sample",
            )

            data = pd.merge(data, mean_read_length_data, on="anonymised_sample")

            if data.empty:
                print(f"No data after merging for {variant_type} ({complexity})")
                continue

            plot_correlation(data, variant_type, complexity)


plot_read_length_vs_performance(
    read_length_df, snv_ont_metrics_df, indel_ont_metrics_df, np_metrics_df
)


In [None]:
def plot_combined_read_length_vs_performance(
    read_length_df, snv_metrics_df, indel_metrics_df, np_metrics_df
):
    mean_read_length_data = read_length_df[
        read_length_df["read_length_type"] == "Mean Read Length"
    ]
    mean_read_length_data = mean_read_length_data[["sample", "read_length"]]
    mean_read_length_data = mean_read_length_data.rename(
        columns={"read_length": "mean_read_length"}
    )

    fig, axes = plt.subplots(4, 3, figsize=(14, 16), dpi=300)
    fig.suptitle("Performance Metrics vs Mean Read Length", fontsize=16, y=0.99)

    metrics = ["Precision", "Sensitivity", "F-measure"]
    complexities = ["hc", "lc"]
    data_types = ["SNV", "Indel"]

    def plot_correlation(data, variant_type, complexity, row):
        y_min = min(data[metric].min() for metric in metrics)
        y_max = max(data[metric].max() for metric in metrics)

        y_range = y_max - y_min
        y_min -= 0.4 * y_range
        y_max += 0.2 * y_range

        row_title = f"{variant_type} {complexity.upper()}"
        fig.text(
            -0.01,
            0.85 - row * 0.242,
            row_title,
            va="center",
            ha="left",
            fontsize=12,
            rotation=90,
        )

        # Add letter label to each row
        axes[row, 0].text(
            -0.15,
            1.1,
            chr(65 + row),
            transform=axes[row, 0].transAxes,
            fontsize=16,
            fontweight="bold",
            va="top",
        )

        palette = sns.color_palette("colorblind")
        main_color = palette[0]

        for col, metric in enumerate(metrics):
            ax = axes[row, col]
            sns.scatterplot(
                data=data,
                x="mean_read_length",
                y=metric,
                hue="multiplexing",
                ax=ax,
                legend=row == 0 and col == 0,
            )

            x = data["mean_read_length"]
            y = data[metric]

            slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
            x_range = np.linspace(x.min(), x.max(), 100)
            y_fit = slope * x_range + intercept

            ax.plot(
                x_range,
                y_fit,
                color=main_color,
                linestyle="-",
                linewidth=2,
                label="Line of best fit" if row == 0 and col == 0 else None,
            )

            n = len(x)
            y_err = np.sqrt(np.sum((y - (slope * x + intercept)) ** 2) / (n - 2))
            ci = (
                stats.t.ppf(0.975, n - 2)
                * y_err
                * np.sqrt(
                    1 / n + (x_range - np.mean(x)) ** 2 / np.sum((x - np.mean(x)) ** 2)
                )
            )
            ax.fill_between(
                x_range,
                y_fit - ci,
                y_fit + ci,
                color=main_color,
                alpha=0.2,
                label="95% Confidence Interval" if row == 0 and col == 0 else None,
            )

            r_value, p_value = stats.pearsonr(x, y)

            ax.set_title(f"{metric}\nr={r_value:.2f}, p={p_value:.2e}")
            ax.set_xlabel("")

            if col == 0:
                ax.set_ylabel("Performance")
            else:
                ax.set_ylabel("")
                ax.set_yticklabels([])

            ax.set_ylim(y_min, y_max)

            if row == 0 and col == 0:
                handles, labels = ax.get_legend_handles_labels()
                ax.legend(handles=handles, labels=labels, title="")

    for variant_type_idx, (variant_type, metrics_df) in enumerate(
        zip(data_types, [snv_metrics_df, indel_metrics_df])
    ):
        for complexity_idx, complexity in enumerate(complexities):
            row = variant_type_idx * 2 + complexity_idx
            data = pd.merge(
                metrics_df[metrics_df["complexity"] == complexity],
                np_metrics_df[["sample", "multiplexing"]],
                left_on="sample_id",
                right_on="sample",
            )
            data = pd.merge(data, mean_read_length_data, on="sample")
            plot_correlation(data, variant_type, complexity, row)

    plt.tight_layout()
    plt.show()
    return fig


combined_read_length_performance_fig = plot_combined_read_length_vs_performance(
    read_length_df, snv_ont_metrics_df, indel_ont_metrics_df, np_metrics_df
)


# SV Benchmark

## 1. SV Consensus Calls


In [None]:
def read_sv_vcf_file(file_path):
    svs = []
    with pysam.VariantFile(file_path) as vcf:
        for record in vcf:
            for alt_idx, alt in enumerate(record.alts):
                sv_info = extract_sv_info(record, alt, alt_idx)
                if sv_info:
                    sv_info["allele_idx"] = alt_idx
                    svs.append(sv_info)
    return pd.DataFrame(svs)


def extract_sv_info(record, alt, alt_idx):
    chrom = record.chrom
    start = record.pos
    end = record.info.get("END", start)
    sv_type = record.info.get("SVTYPE", "Unknown")

    if sv_type == "STR" or (isinstance(alt, str) and alt.startswith("<STR")):
        return handle_str(record, alt, chrom, start, end, alt_idx)
    elif isinstance(alt, str) and alt.startswith("<") and alt.endswith(">"):
        return handle_symbolic_allele(record, alt, chrom, start, end, sv_type, alt_idx)
    else:
        return handle_standard_sv(record, alt, chrom, start, end, sv_type, alt_idx)


def handle_str(record, alt, chrom, start, end, alt_idx):
    def parse_int_or_first(value):
        if isinstance(value, (int, float)):
            return int(value)
        elif isinstance(value, str):
            return int(value.split("/")[0])
        elif isinstance(value, tuple):
            return int(value[0])
        else:
            raise ValueError(f"Unexpected value type: {type(value)}")

    repcn = record.samples[0].get("REPCN")

    if repcn is not None:
        if isinstance(repcn, tuple):
            repeat_count = parse_int_or_first(repcn[alt_idx])
        else:
            repeat_count = parse_int_or_first(repcn)
    elif alt.startswith("<STR"):
        str_alleles = record.alts
        current_alt = str_alleles[alt_idx]
        repeat_count = int(current_alt[4:-1])
    else:
        return None

    ru = record.info.get("RU", "")

    sv_len = repeat_count * len(ru)

    return {
        "type": "STR",
        "length": sv_len,
        "chrom": chrom,
        "start": start,
        "end": end,
    }


def handle_symbolic_allele(record, alt, chrom, start, end, sv_type, alt_idx):
    sv_len = None

    if sv_type == "INV" and "SVINSLEN" in record.info:
        sv_len = record.info.get("SVINSLEN")
        if isinstance(sv_len, tuple):
            sv_len = sv_len[alt_idx] if len(sv_len) > alt_idx else sv_len[0]

    if sv_len is None:
        sv_len = record.info.get("SVLEN")
        if isinstance(sv_len, tuple):
            sv_len = sv_len[alt_idx] if len(sv_len) > alt_idx else sv_len[0]

    if sv_len is None and sv_type == "INS":
        left_seq = record.info.get("LEFT_SVINSSEQ", "")
        right_seq = record.info.get("RIGHT_SVINSSEQ", "")
        sv_len = len(left_seq) + len(right_seq)

    if sv_len is None:
        sv_len = end - start

    return {
        "type": sv_type,
        "length": abs(sv_len) if sv_len is not None else None,
        "chrom": chrom,
        "start": start,
        "end": end,
    }


def handle_standard_sv(record, alt, chrom, start, end, sv_type, alt_idx):
    if "SVLEN" in record.info:
        sv_len = record.info["SVLEN"]
        if isinstance(sv_len, tuple):
            sv_len = sv_len[alt_idx] if len(sv_len) > alt_idx else sv_len[0]
    elif sv_type == "INS":
        left_seq = record.info.get("LEFT_SVINSSEQ", "")
        right_seq = record.info.get("RIGHT_SVINSSEQ", "")
        sv_len = len(left_seq) + len(right_seq)
    else:
        ref_len = len(record.ref)
        alt_len = len(alt)
        sv_len = alt_len - ref_len if sv_type == "INS" else ref_len - alt_len

    return {
        "type": sv_type,
        "length": abs(sv_len),
        "chrom": chrom,
        "start": start,
        "end": end,
    }


def analyze_sv_calls(sample_id, ont_id, illumina_id):
    ont_file = f"output/sv/survivor/{ont_id}/{ont_id}.ont.sv_str.filtered.vcf"
    illumina_file = (
        f"output/sv/survivor/{ont_id}/{illumina_id}.illumina.sv.filtered.vcf"
    )
    merged_file = f"output/sv/survivor/{ont_id}/{ont_id}_{illumina_id}_merged.vcf"

    ont_svs = read_sv_vcf_file(ont_file)
    illumina_svs = read_sv_vcf_file(illumina_file)
    merged_svs = read_sv_vcf_file(merged_file)

    ont_svs["sv_id"] = ont_svs.apply(
        lambda row: f"{row['chrom']}_{row['start']}_{row['end']}_{row['type']}", axis=1
    )
    illumina_svs["sv_id"] = illumina_svs.apply(
        lambda row: f"{row['chrom']}_{row['start']}_{row['end']}_{row['type']}", axis=1
    )
    merged_svs["sv_id"] = merged_svs.apply(
        lambda row: f"{row['chrom']}_{row['start']}_{row['end']}_{row['type']}", axis=1
    )

    all_svs = pd.concat([ont_svs, illumina_svs]).drop_duplicates(subset="sv_id")

    all_svs["ONT"] = all_svs["sv_id"].isin(ont_svs["sv_id"])
    all_svs["Illumina"] = all_svs["sv_id"].isin(illumina_svs["sv_id"])
    all_svs["Merged"] = all_svs["sv_id"].isin(merged_svs["sv_id"])

    all_svs["sample_id"] = sample_id

    all_svs = all_svs.drop("sv_id", axis=1)

    return all_svs


sv_data_list = []

for _, row in sample_ids.iterrows():
    sample_data = analyze_sv_calls(row["ont_id"], row["ont_id"], row["lp_id"])
    sv_data_list.append(sample_data)

sv_data_df = pd.concat(sv_data_list, ignore_index=True)

sv_data_df


In [None]:
def compare_sv_counts(sv_data_df):
    counts = []
    for sample_id in sv_data_df["sample_id"].unique():
        sample_data = sv_data_df[sv_data_df["sample_id"] == sample_id]
        counts.append(
            {
                "Sample": sample_id,
                "ONT": sample_data["ONT"].sum(),
                "Illumina": sample_data["Illumina"].sum(),
                "Consensus": sample_data["Merged"].sum(),
            }
        )

    counts_df = pd.DataFrame(counts)

    counts_df = counts_df.sort_values("Sample")

    counts_df["anonymised_sample"] = [f"Sample {i+1}" for i in range(len(counts_df))]

    return counts_df[["anonymised_sample", "ONT", "Illumina", "Consensus"]]


def plot_sv_counts(sv_counts_df, figsize=(12, 6), dpi=300):
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

    sv_counts_melted = sv_counts_df.melt(
        id_vars=["anonymised_sample"], var_name="Platform", value_name="Count"
    )

    sns.barplot(
        x="anonymised_sample", y="Count", hue="Platform", data=sv_counts_melted, ax=ax
    )

    ax.set_title("SV Call Counts by Sample and Platform")
    ax.set_xlabel("Sample")
    ax.set_ylabel("Number of SV Calls")

    for tick in ax.get_xticklabels():
        tick.set_rotation(45)
        tick.set_ha("right")

    plt.tight_layout()
    plt.show()


sv_counts_df = compare_sv_counts(sv_data_df)
print("SV call counts:")
print(sv_counts_df)

plot_sv_counts(sv_counts_df)


In [None]:
def calculate_consensus_percentages(sv_counts_df):
    sv_counts_df["ONT_Consensus_Percent"] = (
        sv_counts_df["Consensus"] / sv_counts_df["ONT"] * 100
    ).fillna(0)
    sv_counts_df["Illumina_Consensus_Percent"] = (
        sv_counts_df["Consensus"] / sv_counts_df["Illumina"] * 100
    ).fillna(0)

    avg_ont_percent = np.mean(sv_counts_df["ONT_Consensus_Percent"])
    sd_ont_percent = np.std(sv_counts_df["ONT_Consensus_Percent"])
    avg_illumina_percent = np.mean(sv_counts_df["Illumina_Consensus_Percent"])
    sd_illumina_percent = np.std(sv_counts_df["Illumina_Consensus_Percent"])

    print(
        f"Average consensus percentage for ONT: {avg_ont_percent:.2f}% ± {sd_ont_percent:.2f}% (mean ± SD)"
    )
    print(
        f"Average consensus percentage for Illumina: {avg_illumina_percent:.2f}% ± {sd_illumina_percent:.2f}% (mean ± SD)"
    )

    return sv_counts_df


calculate_consensus_percentages(sv_counts_df)


In [None]:
def calculate_average_difference(sv_counts_df):
    sv_counts_df["ONT_Illumina_Ratio"] = (
        sv_counts_df["ONT"] / sv_counts_df["Illumina"]
    ).fillna(0)
    average_difference = sv_counts_df["ONT_Illumina_Ratio"].mean()
    sd_difference = sv_counts_df["ONT_Illumina_Ratio"].std()

    print(
        f"Average ratio of SV counts between ONT and Illumina: {average_difference:.2f} ± {sd_difference:.2f} (mean ± SD)"
    )

    return None


calculate_average_difference(sv_counts_df)


## 2. SV Size Distribution


In [None]:
def plot_sv_length_distributions(sv_data_df, figsize=(12, 6), dpi=300):
    ont_lengths = (
        sv_data_df[sv_data_df["ONT"]]["length"]
        .replace([np.inf, -np.inf], np.nan)
        .dropna()
        .tolist()
    )
    illumina_lengths = (
        sv_data_df[sv_data_df["Illumina"]]["length"]
        .replace([np.inf, -np.inf], np.nan)
        .dropna()
        .tolist()
    )

    plt.figure(figsize=figsize, dpi=dpi)

    sns.histplot(
        ont_lengths,
        log_scale=True,
        bins=50,
        stat="density",
        kde=True,
        alpha=0.5,
        label="long-read",
    )
    sns.histplot(
        illumina_lengths,
        log_scale=True,
        bins=50,
        stat="density",
        kde=True,
        alpha=0.5,
        label="short-read",
    )

    plt.title("Aggregated ONT and Illumina SV Size Distribution")
    plt.xlabel("SV Size (log scale)")
    plt.ylabel("Density")
    plt.legend()

    plt.tight_layout()
    plt.show()


plot_sv_length_distributions(sv_data_df)


In [None]:
def calculate_sv_length_stats(lengths):
    return {
        "max": np.max(lengths),
        "min": np.min(lengths),
        "mean": np.mean(lengths),
        "std": np.std(lengths),
        "median": np.median(lengths),
    }


def format_number(num):
    if isinstance(num, (int, np.integer)):
        return f"{num:,d}"
    elif isinstance(num, float):
        return f"{num:,.2f}"
    return str(num)


def print_sv_length_stats(stats, sample_type):
    print(f"\n{sample_type} SV Length Statistics:")
    print("=" * 40)
    for stat, value in stats.items():
        formatted_value = format_number(value)
        print(f"  {stat.capitalize():6s}: {formatted_value}")


def analyze_sv_length_distributions(sv_data_df):
    ont_lengths = (
        sv_data_df[sv_data_df["ONT"]]["length"]
        .replace([np.inf, -np.inf], np.nan)
        .dropna()
        .tolist()
    )
    illumina_lengths = (
        sv_data_df[sv_data_df["Illumina"]]["length"]
        .replace([np.inf, -np.inf], np.nan)
        .dropna()
        .tolist()
    )

    ont_stats = calculate_sv_length_stats(ont_lengths)
    illumina_stats = calculate_sv_length_stats(illumina_lengths)

    print_sv_length_stats(ont_stats, "ONT")
    print_sv_length_stats(illumina_stats, "Illumina")


analyze_sv_length_distributions(sv_data_df)


## 3. SV Types


In [None]:
def analyze_sv_types(sv_data_df):
    type_counts = defaultdict(lambda: defaultdict(int))
    type_stats = defaultdict(lambda: defaultdict(list))

    for sample_id in sv_data_df["sample_id"].unique():
        sample_data = sv_data_df[sv_data_df["sample_id"] == sample_id]

        for platform in ["ONT", "Illumina"]:
            platform_data = sample_data[sample_data[platform]]
            type_counts_series = platform_data["type"].value_counts()

            for sv_type, count in type_counts_series.items():
                type_counts[(sample_id, platform.lower())][sv_type] = count
                type_stats[platform.lower()][sv_type].append(count)

    df_sv_type_counts = pd.DataFrame(type_counts).T

    sv_type_stats = {}
    for platform, sv_types in type_stats.items():
        for sv_type, counts in sv_types.items():
            sv_type_stats[(platform, sv_type)] = {
                "mean": np.mean(counts),
                "median": np.median(counts),
                "std_dev": np.std(counts),
            }

    df_sv_type_stats = pd.DataFrame(sv_type_stats).T
    df_sv_type_stats.index.names = ["Platform", "SV Type"]

    return df_sv_type_counts, df_sv_type_stats


sv_type_counts_df, sv_type_stats_df = analyze_sv_types(sv_data_df)

print("\nSV type counts:")
print(sv_type_counts_df)


In [None]:
print("\nSV type statistics:")
print(sv_type_stats_df)


In [None]:
def plot_sv_types(sv_data_df, figsize=(15, 10), dpi=300):
    plt.figure(figsize=figsize, dpi=dpi)

    platform_names = {"ONT": "Long-read", "Illumina": "Short-read"}

    unique_samples = sorted(sv_data_df["sample_id"].unique())
    sample_map = {sample: f"Sample {i+1}" for i, sample in enumerate(unique_samples)}
    sv_data_df["anonymised_sample"] = sv_data_df["sample_id"].map(sample_map)

    for i, platform in enumerate(["ONT", "Illumina"]):
        ax = plt.subplot(2, 1, i + 1)

        platform_data = sv_data_df[sv_data_df[platform]]

        sv_type_counts = platform_data.pivot_table(
            index="anonymised_sample", columns="type", values=platform, aggfunc="sum"
        ).fillna(0)

        sv_type_counts = sv_type_counts.reindex(
            [f"Sample {i+1}" for i in range(len(unique_samples))]
        )

        sv_type_counts.plot(kind="bar", stacked=True, ax=ax)

        plt.title(f"SV Types by Sample - {platform_names[platform]}")
        plt.xlabel("Sample")
        plt.ylabel("Number of SVs")
        plt.xticks(rotation=45, ha="right")

        if i == 0:
            handles, labels = ax.get_legend_handles_labels()

        ax.legend().remove()

    plt.legend(
        handles, labels, title="SV Type", bbox_to_anchor=(1.05, 1), loc="upper left"
    )

    plt.tight_layout()
    plt.show()


plot_sv_types(sv_data_df)


In [None]:
def plot_sv_size_distribution_by_type(sv_data_df, figsize=(20, 15), dpi=300):
    sv_types = ["INS", "DEL", "DUP", "INV", "BND", "STR"]
    sv_full_names = {
        "INS": "Insertion",
        "DEL": "Deletion",
        "DUP": "Duplication",
        "INV": "Inversion",
        "BND": "Breakend",
        "STR": "Short Tandem Repeat",
    }

    fig, axes = plt.subplots(3, 2, figsize=figsize, dpi=dpi)
    axes = axes.flatten()

    for idx, sv_type in enumerate(sv_types):
        ax = axes[idx]

        for platform, alpha, label in [
            ("ONT", 0.7, "long-read"),
            ("Illumina", 0.7, "short-read"),
        ]:
            sv_data = sv_data_df[
                (sv_data_df["type"] == sv_type) & (sv_data_df[platform])
            ]

            if not sv_data.empty:
                sns.histplot(
                    data=sv_data,
                    x="length",
                    kde=True,
                    log_scale=(sv_type not in ["BND", "STR"]),
                    stat="density",
                    ax=ax,
                    alpha=alpha,
                    label=label,
                )

        ax.set_title(f"{sv_full_names[sv_type]} Size Distribution")
        ax.set_xlabel(
            "SV Size" + (" (log scale)" if sv_type not in ["BND", "STR"] else "")
        )
        ax.set_ylabel("Density")

        if idx == 0:
            ax.legend()

    plt.tight_layout()
    plt.show()


plot_sv_size_distribution_by_type(sv_data_df)


## 4. SV Chromosomal Distribution


In [None]:
chrom_lengths = {
    "chr1": 248956422,
    "chr2": 242193529,
    "chr3": 198295559,
    "chr4": 190214555,
    "chr5": 181538259,
    "chr6": 170805979,
    "chr7": 159345973,
    "chr8": 145138636,
    "chr9": 138394717,
    "chr10": 133797422,
    "chr11": 135086622,
    "chr12": 133275309,
    "chr13": 114364328,
    "chr14": 107043718,
    "chr15": 101991189,
    "chr16": 90338345,
    "chr17": 83257441,
    "chr18": 80373285,
    "chr19": 58617616,
    "chr20": 64444167,
    "chr21": 46709983,
    "chr22": 50818468,
    "chrX": 156040895,
    "chrY": 57227415,
}


def analyze_chrom_distribution(sv_data_df):
    valid_chroms = [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]
    chrom_counts = defaultdict(lambda: defaultdict(int))

    for platform in ["ONT", "Illumina"]:
        platform_data = sv_data_df[sv_data_df[platform]]
        for chrom in valid_chroms:
            chrom_counts[platform.lower()][chrom] = platform_data[
                platform_data["chrom"] == chrom
            ].shape[0]

    return pd.DataFrame(chrom_counts)


def normalize_by_chrom_length(chrom_distribution_df):
    normalized_df = chrom_distribution_df.copy().astype(float)
    for chrom in normalized_df.index:
        normalized_df.loc[chrom] = normalized_df.loc[chrom] / (
            chrom_lengths[chrom] / 1e6
        )
    return normalized_df


def plot_chrom_distribution(chrom_distribution_df, figsize=(12, 6), dpi=300):
    normalized_df = normalize_by_chrom_length(chrom_distribution_df)
    normalized_df = normalized_df.rename(
        columns={"ont": "long-read", "illumina": "short-read"}
    )

    chrom_distribution_pct = normalized_df.apply(lambda x: x / x.sum() * 100)

    chrom_distribution_melted = chrom_distribution_pct.reset_index().melt(
        id_vars="index", var_name="Platform", value_name="Normalized Percentage"
    )
    chrom_distribution_melted = chrom_distribution_melted.rename(
        columns={"index": "Chromosome"}
    )

    plt.figure(figsize=figsize, dpi=dpi)
    ax = sns.barplot(
        x="Chromosome",
        y="Normalized Percentage",
        hue="Platform",
        data=chrom_distribution_melted,
    )

    plt.title("Normalised Chromosomal Distribution of SVs", fontsize=16)
    plt.xlabel("Chromosome", fontsize=12)
    plt.ylabel("Percentage of SVs per Mb", fontsize=12)
    plt.legend(title="Platform", title_fontsize=12, fontsize=10)

    plt.xticks(rotation=45, ha="right")

    ax.set_xticks(range(len(ax.get_xticklabels())))
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")

    ax.tick_params(axis="x", which="major", pad=0)

    plt.tight_layout()
    plt.show()


# Usage:
chrom_distribution_df = analyze_chrom_distribution(sv_data_df)
plot_chrom_distribution(chrom_distribution_df)


In [None]:
def calculate_sv_correlations(chrom_distribution_df, chrom_lengths):
    corr_data = pd.DataFrame(
        {
            "chrom": chrom_lengths.keys(),
            "length": chrom_lengths.values(),
            "ont_count": chrom_distribution_df["ont"],
            "illumina_count": chrom_distribution_df["illumina"],
        }
    )

    ont_corr, ont_p = stats.pearsonr(corr_data["length"], corr_data["ont_count"])
    illumina_corr, illumina_p = stats.pearsonr(
        corr_data["length"], corr_data["illumina_count"]
    )

    print(f"ONT correlation: {ont_corr:.3f} (p-value: {ont_p:.3e})")
    print(f"Illumina correlation: {illumina_corr:.3f} (p-value: {illumina_p:.3e})")

    return corr_data, ont_corr, illumina_corr


def plot_sv_correlations(corr_data, ont_corr, illumina_corr, figsize=(12, 5), dpi=300):
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

    sns.regplot(
        x="length",
        y="ont_count",
        data=corr_data,
        ax=ax,
        label="long-read",
        scatter_kws={"alpha": 0.7},
    )
    sns.regplot(
        x="length",
        y="illumina_count",
        data=corr_data,
        ax=ax,
        label="short-read",
        scatter_kws={"alpha": 0.7},
    )

    ax.set_title("SVs vs Chromosome Length")
    ax.set_xlabel("Chromosome Length")
    ax.set_ylabel("Number of SVs")
    ax.legend()

    plt.tight_layout()
    plt.show()


corr_data, ont_corr, illumina_corr = calculate_sv_correlations(
    chrom_distribution_df, chrom_lengths
)

plot_sv_correlations(corr_data, ont_corr, illumina_corr)


## 5. Combined Plots


In [None]:
def create_combined_sv_plots(
    sv_data_df, sv_counts_df, chrom_distribution_df, corr_data, ont_corr, illumina_corr
):
    sample_column = next(
        col
        for col in sv_counts_df.columns
        if col.lower() in ["sample", "sample_id", "anonymised_sample"]
    )

    unique_samples = sorted(sv_data_df["sample_id"].unique())
    sample_map = {sample: f"Sample {i+1}" for i, sample in enumerate(unique_samples)}
    sv_data_df["anonymised_sample"] = sv_data_df["sample_id"].map(sample_map)
    sv_counts_df["anonymised_sample"] = sv_counts_df[sample_column].map(sample_map)

    # Figure 1
    fig1 = plt.figure(figsize=(12, 8), dpi=300)
    gs1 = fig1.add_gridspec(2, 2)

    # Plot A: SV Counts per sample
    ax_counts = fig1.add_subplot(gs1[0, 0])
    sv_counts_df = compare_sv_counts(sv_data_df)
    sv_counts_df_renamed = sv_counts_df.rename(
        columns={"ONT": "long-read", "Illumina": "short-read", "Consensus": "consensus"}
    )
    sv_counts_melted = sv_counts_df_renamed.melt(
        id_vars=["anonymised_sample"], var_name="Platform", value_name="Count"
    )

    sns.barplot(
        x="anonymised_sample",
        y="Count",
        hue="Platform",
        data=sv_counts_melted,
        ax=ax_counts,
        order=sv_counts_df["anonymised_sample"],
    )
    ax_counts.set_title("SV Call Counts by Sample and Platform")
    ax_counts.set_xlabel("Sample")
    ax_counts.set_ylabel("Number of SV Calls")
    for tick in ax_counts.get_xticklabels():
        tick.set_rotation(45)
        tick.set_ha("right")
    ax_counts.legend(title="Platform", loc="upper left")

    ax_counts.text(
        -0.12,
        1.05,
        "A",
        transform=ax_counts.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )

    # Plot B: Aggregated SV Size Distribution
    ax_size = fig1.add_subplot(gs1[0, 1])
    ont_lengths = (
        sv_data_df[sv_data_df["ONT"]]["length"]
        .replace([np.inf, -np.inf], np.nan)
        .dropna()
        .tolist()
    )
    illumina_lengths = (
        sv_data_df[sv_data_df["Illumina"]]["length"]
        .replace([np.inf, -np.inf], np.nan)
        .dropna()
        .tolist()
    )
    sns.histplot(
        ont_lengths,
        log_scale=True,
        bins=50,
        stat="density",
        kde=True,
        alpha=0.7,
        label="long-read",
        ax=ax_size,
    )
    sns.histplot(
        illumina_lengths,
        log_scale=True,
        bins=50,
        stat="density",
        kde=True,
        alpha=0.7,
        label="short-read",
        ax=ax_size,
    )
    ax_size.set_title("Aggregated SV Size Distribution")
    ax_size.set_xlabel("SV Size (log scale)")
    ax_size.set_ylabel("Density")
    ax_size.legend()
    ax_size.text(
        -0.12,
        1.05,
        "B",
        transform=ax_size.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )

    # Plot C and D: SV Types per sample
    ax_types_ont = fig1.add_subplot(gs1[1, 0])
    ax_types_illumina = fig1.add_subplot(gs1[1, 1])

    platform_names = {"ONT": "Long-read", "Illumina": "Short-read"}

    for i, (platform, ax) in enumerate(
        [("ONT", ax_types_ont), ("Illumina", ax_types_illumina)]
    ):
        platform_data = sv_data_df[sv_data_df[platform]]

        sv_type_counts = platform_data.pivot_table(
            index="anonymised_sample", columns="type", values=platform, aggfunc="sum"
        ).fillna(0)

        # Ensure correct order of samples
        sv_type_counts = sv_type_counts.reindex(
            [f"Sample {i+1}" for i in range(len(unique_samples))]
        )

        sv_type_counts.plot(kind="bar", stacked=True, ax=ax)

        ax.set_title(f"SV Types by Sample - {platform_names[platform]}")
        ax.set_xlabel("Sample")
        ax.set_ylabel("Number of SVs")
        for tick in ax.get_xticklabels():
            tick.set_rotation(45)
            tick.set_ha("right")

        if i == 0:
            ax.legend(title="SV Type", loc="upper left")
        else:
            ax.legend().remove()

        ax.text(
            -0.12,
            1.05,
            "C" if i == 0 else "D",
            transform=ax.transAxes,
            fontsize=12,
            fontweight="bold",
            va="top",
        )

    plt.tight_layout()
    plt.show()

    # Figure 2
    fig2 = plt.figure(figsize=(12, 12), dpi=300)
    gs2 = fig2.add_gridspec(3, 2)

    sv_types = ["INS", "DEL", "DUP", "INV", "BND", "STR"]
    sv_full_names = {
        "INS": "Insertion",
        "DEL": "Deletion",
        "DUP": "Duplication",
        "INV": "Inversion",
        "BND": "Breakend",
        "STR": "Short Tandem Repeat",
    }

    for idx, sv_type in enumerate(sv_types):
        ax = fig2.add_subplot(gs2[idx // 2, idx % 2])
        for platform, alpha, label in [
            ("ONT", 0.7, "long-read"),
            ("Illumina", 0.7, "short-read"),
        ]:
            sv_data = sv_data_df[
                (sv_data_df["type"] == sv_type) & (sv_data_df[platform])
            ]
            if not sv_data.empty:
                sns.histplot(
                    data=sv_data,
                    x="length",
                    kde=True,
                    log_scale=(sv_type not in ["BND", "STR"]),
                    stat="density",
                    ax=ax,
                    alpha=alpha,
                    label=label,
                )
        ax.set_title(f"{sv_full_names[sv_type]} Size Distribution")
        ax.set_xlabel(
            "SV Size" + (" (log scale)" if sv_type not in ["BND", "STR"] else "")
        )
        ax.set_ylabel("Density")
        if idx == 0:
            ax.legend()
        ax.text(
            -0.07,
            1.05,
            chr(ord("A") + idx),
            transform=ax.transAxes,
            fontsize=12,
            fontweight="bold",
            va="top",
        )

    plt.tight_layout()
    plt.show()

    # Figure 3: Chromosome distribution and correlation plots side-by-side
    fig3 = plt.figure(figsize=(12, 4), dpi=300)
    gs3 = fig3.add_gridspec(1, 2)

    # Chromosome distribution plot
    ax_chrom = fig3.add_subplot(gs3[0, 0])
    normalized_df = normalize_by_chrom_length(chrom_distribution_df)
    normalized_df = normalized_df.rename(
        columns={"ont": "long-read", "illumina": "short-read"}
    )
    chrom_distribution_pct = normalized_df.apply(lambda x: x / x.sum() * 100)
    chrom_distribution_melted = chrom_distribution_pct.reset_index().melt(
        id_vars="index", var_name="Platform", value_name="Normalized Percentage"
    )
    chrom_distribution_melted = chrom_distribution_melted.rename(
        columns={"index": "Chromosome"}
    )
    sns.barplot(
        x="Chromosome",
        y="Normalized Percentage",
        hue="Platform",
        data=chrom_distribution_melted,
        ax=ax_chrom,
    )
    ax_chrom.set_title("Normalised Chromosomal Distribution of SVs", fontsize=12)
    ax_chrom.set_xlabel("Chromosome", fontsize=12)
    ax_chrom.set_ylabel("Percentage of SVs per Mb", fontsize=12)
    ax_chrom.legend(title="Platform", title_fontsize=12, fontsize=10)
    plt.xticks(rotation=45, ha="right")
    ax_chrom.text(
        -0.07,
        1.05,
        "A",
        transform=ax_chrom.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )

    # Correlation plot (long-read and short-read overlaid)
    ax_corr = fig3.add_subplot(gs3[0, 1])
    sns.regplot(
        x="length",
        y="ont_count",
        data=corr_data,
        ax=ax_corr,
        label=f"long-read (r={ont_corr:.3f})",
        scatter_kws={"alpha": 0.7},
    )
    sns.regplot(
        x="length",
        y="illumina_count",
        data=corr_data,
        ax=ax_corr,
        label=f"short-read (r={illumina_corr:.3f})",
        scatter_kws={"alpha": 0.7},
    )
    ax_corr.set_title("SV Counts vs Chromosome Length")
    ax_corr.set_xlabel("Chromosome Length")
    ax_corr.set_ylabel("Number of SVs")
    ax_corr.legend()
    ax_corr.text(
        -0.07,
        1.05,
        "B",
        transform=ax_corr.transAxes,
        fontsize=12,
        fontweight="bold",
        va="top",
    )

    plt.tight_layout()
    plt.show()


create_combined_sv_plots(
    sv_data_df, sv_counts_df, chrom_distribution_df, corr_data, ont_corr, illumina_corr
)
