In [89]:
import numpy
import polars as pl
import matplotlib
matplotlib.use('Agg')  # Use Agg backend for non-GUI environments
import matplotlib.pyplot as plt
import seaborn
import glob
import pandas as pd
from typing import Dict, List, Union, Optional, Tuple
from os import sep
import re
import regex
import os
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed

In [130]:
def load_files(directory_path: str, file_pattern: str) -> Dict[str, pl.DataFrame]:
    """
    Optimized: Loads files matching a specific pattern from a directory using Polars
    and stores them with their chromosome label as key.
    """
    loaded_files_dict: Dict[str, pl.DataFrame] = {}
    files: List[str] = glob.glob(os.path.join(directory_path, file_pattern))

    if not files:
        print(f"No files matching the pattern '{file_pattern}' were found in '{directory_path}'.")
        return loaded_files_dict

    # Precompile regex
    chr_pattern: Pattern[str] = re.compile(r"chr_(\d+|[A-Z])")

    def read_and_parse(file: str) -> Optional[Tuple[str, pl.DataFrame]]:
        if not os.path.isfile(file):
            return None

        basename = os.path.basename(file)
        file_basename = os.path.splitext(basename)[0]
        match = chr_pattern.search(file_basename)

        if not match:
            print(f"Skipping file without valid chr label: {basename}")
            return None

        chr_label = match.group()
        try:
            df = pl.read_csv(file, separator='\t')
            return chr_label, df
        except Exception as e:
            print(f"Error reading {file}: {e}")
            return None

    # Use threads for parallel I/O (safe for Polars)
    with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as executor:
        futures = {executor.submit(read_and_parse, f): f for f in files}
        for future in as_completed(futures):
            result = future.result()
            if result:
                chr_label, df = result
                loaded_files_dict[chr_label] = df

    return loaded_files_dict
# Directory containing the files and the pattern to match
file_directory = "../vcf_stats/"
file_pattern = "tajimaD_10000_*"
    
    # Load the files from the directory based on the pattern
loaded_files = load_files(file_directory, file_pattern)
loaded_files.items()

dict_items([('chr_24', shape: (689, 4)
┌────────┬───────────┬────────┬───────────┐
│ CHROM  ┆ BIN_START ┆ N_SNPS ┆ TajimaD   │
│ ---    ┆ ---       ┆ ---    ┆ ---       │
│ str    ┆ i64       ┆ i64    ┆ str       │
╞════════╪═══════════╪════════╪═══════════╡
│ chr_24 ┆ 0         ┆ 39     ┆ -2.08023  │
│ chr_24 ┆ 10000     ┆ 34     ┆ -2.2012   │
│ chr_24 ┆ 20000     ┆ 230    ┆ -2.30632  │
│ chr_24 ┆ 30000     ┆ 168    ┆ -2.35023  │
│ chr_24 ┆ 40000     ┆ 151    ┆ -2.34999  │
│ …      ┆ …         ┆ …      ┆ …         │
│ chr_24 ┆ 6840000   ┆ 169    ┆ -2.35056  │
│ chr_24 ┆ 6850000   ┆ 417    ┆ -0.716767 │
│ chr_24 ┆ 6860000   ┆ 429    ┆ -1.94     │
│ chr_24 ┆ 6870000   ┆ 452    ┆ 0.386404  │
│ chr_24 ┆ 6880000   ┆ 89     ┆ -1.50223  │
└────────┴───────────┴────────┴───────────┘), ('chr_32', shape: (44, 4)
┌────────┬───────────┬────────┬───────────┐
│ CHROM  ┆ BIN_START ┆ N_SNPS ┆ TajimaD   │
│ ---    ┆ ---       ┆ ---    ┆ ---       │
│ str    ┆ i64       ┆ i64    ┆ f64       │
╞════════

In [137]:
import polars as pl

def merge_dataframes_from_dict(
    dataFrames: Dict[str, pl.DataFrame], 
    first_file_keyName: str,
    column_index = str
) -> pl.DataFrame:
    """
    Stacks DataFrames stored in a dictionary vertically (like appending them).
    Assumes all DataFrames have the same schema (same columns).

    Args:
    - dataFrames (dict): A dictionary where keys are file names and values are Polars DataFrames.
    - first_file_keyName (str): The key name of the first DataFrame to use as the base for stacking.

    Returns:
    - pl.DataFrame: A single DataFrame resulting from stacking all DataFrames vertically.
    """
    if first_file_keyName not in dataFrames:
        raise ValueError(f"Key '{first_file_keyName}' not found in dataFrames.")

    # Start with the first DataFrame
    merged_df = dataFrames[first_file_keyName]

    # Vertically stack all other DataFrames
    for key, df in dataFrames.items():
        if key != first_file_keyName:
            df = df.with_columns(pl.col(column_index).cast(pl.Float64()))
            merged_df = merged_df.vstack(df)

    return merged_df

# Example usage
# Assuming `loaded_files` is a dictionary where each value is a Polars DataFrame
merged_files = merge_dataframes_from_dict(loaded_files, "chr_1", "TajimaD")

In [138]:
def chrom_to_sortable(chrom: str) -> Union[int, float]:
    """
    Converts chromosome names to sortable numeric values.
    Chromosomes like 'chr1', 'chr2', etc., are sorted numerically,
    while special chromosomes like 'W' and 'Z' are handled explicitly.
    Any unexpected chromosomes are sorted last as `float('inf')`.

    Args:
    - chrom: The chromosome label as a string.

    Returns:
    - An integer representing the chromosome for sorting or float('inf') for unknown chromosomes.
    """
    # Clean the chromosome label
    chrom = chrom.lower().replace("chr_", "").replace("chr", "")
    
    if chrom.isdigit():
        return int(chrom)
    elif chrom == "w":
        return 34
    elif chrom == "z":
        return 35
    else:
        return float("inf")  # Any unexpected chromosome gets sorted last

def sort_by_chromosome(dataFrame: pl.DataFrame) -> pl.DataFrame:
    """
    Sorts the DataFrame based on the chromosome column using the chrom_to_sortable function.
    This keeps the workflow within Polars for faster performance.

    Args:
    - dataFrame: The input Polars DataFrame containing a 'CHROM' column.

    Returns:
    - A sorted Polars DataFrame.
    """
    # Add a temporary column CHROM_NUM based on the chrom_to_sortable function
    sorted_df: pl.DataFrame = dataFrame.with_columns(
        pl.col("CHROM")
            .map_elements(chrom_to_sortable)
            .alias("CHROM_NUM")
    )
    
    # Sort by the CHROM_NUM column and drop the temporary CHROM_NUM column
    return sorted_df.sort("CHROM_NUM").drop("CHROM_NUM")
 # Ensure the merged files are sorted by chromosome
sorted_merged_files: pl.DataFrame = sort_by_chromosome(merged_files)
    
    



In [141]:
def plot_variants_per_chromosome(
    vcf_stats: pl.DataFrame, 
    x_axis: str, 
    y_axis: str, 
    title: str,
    bins: int = 55, 
    cols: int = 4, 
    figsize: tuple = (16, 16),
    output_filename: str = "Histogram.png",
):
    """
    Plots the distribution of variants per chromosome in histograms.
    
    Args:
    - vcf_stats (pd.DataFrame): The input DataFrame containing variant data.
    - x_axis (str): The column name in the DataFrame to plot on the x-axis.
    - y_axis (str): The column name in the DataFrame to plot on the y-axis.
    - title (str): The title of the plot.
    - bins (int): Number of bins for the histogram. Default is 55.
    - cols (int): Number of columns for the subplots. Default is 4.
    - figsize (tuple): Size of the figure. Default is (16, 16).
    - output_filename (str): The output file name for the saved figure.
    """
    # Ensure the chrom_column is cleaned
    vcf_stats = vcf_stats.to_pandas()
    vcf_stats[x_axis] = vcf_stats[x_axis].str.replace(r"chr_", "", regex=True)
    
    # Unique chromosomes
    chroms: List[str] = vcf_stats[x_axis].unique()
    n = len(chroms)

    # Set up the subplot grid
    rows = (n + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=figsize, sharey=False)
    axes = axes.flatten()

    seaborn.set_style("darkgrid")

    for i, chrom in enumerate(chroms):
        ax = axes[i]
        seaborn.histplot(
            data=vcf_stats[vcf_stats[x_axis] == chrom],
            x=y_axis,
            stat="count",
            bins=bins,
            color="darkgreen",
            ax=ax
        )
        ax.set_title(f"Chromosome {chrom}", fontsize=14)

    # Turn off empty subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")

    # Set the overall title and adjust layout
    fig.suptitle(title, fontsize=26, y=0.90)
    plt.tight_layout()
    plt.subplots_adjust(top=0.80)

    # Save the plot
    plt.savefig(output_filename, format="png", dpi=600, bbox_inches="tight")
    plt.show()
# Plot the Tajima D histogram for each chromosome
plot_variants_per_chromosome(
        sorted_merged_files, 
        x_axis="CHROM", 
        y_axis="TajimaD", 
        title="Tajimas D at 10kb window sizes Histogram", 
        output_filename="TajimaD_histogram_10kb.png"
    )

