<a href="https://colab.research.google.com/github/Eitan177/EitanAmrom/blob/main/seebaf_values_over_segments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import glob
from google.colab import files
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from scipy.signal import find_peaks
from sklearn.cluster import KMeans

def chrom_key(chrom):
    """
    Define a custom sorting key function for chromosomes.

    This function handles different chromosome name formats (e.g., "chr1", "1", "X", "Y").

    Args:
        chrom (str): The chromosome name.

    Returns:
        tuple: A tuple containing the numeric and string parts of the chromosome name, for proper sorting.
    """
    chrom_parts = []
    for part in re.split('(\d+)', str(chrom).upper()):
        if part.isdigit():
            chrom_parts.append(int(part))
        else:
            chrom_parts.append(part)
    return tuple(chrom_parts)

def parse_dp4(dp4_str):
    """
    Parse the DP4 string from a MAF file and calculate B-allele frequency.
    DP4 format: ref_forward,ref_reverse,alt_forward,alt_reverse

    Parameters:
    -----------
    dp4_str : str
        DP4 string containing read counts

    Returns:
    --------
    float
        B-allele frequency
    """
    try:
        # Try to parse directly as comma-separated values
        counts = [int(x) for x in dp4_str.split(',')]
        if len(counts) != 4:
            return None
    except (ValueError, AttributeError):
        # If direct parsing fails, try to extract numbers using regex
        try:
            matches = re.findall(r'\d+', str(dp4_str))
            if len(matches) != 4:
                return None
            counts = [int(x) for x in matches]
        except:
            return None

    ref_counts = counts[0] + counts[1]  # ref_forward + ref_reverse
    alt_counts = counts[2] + counts[3]  # alt_forward + alt_reverse
    total_counts = ref_counts + alt_counts

    if total_counts == 0:
        return None

    return alt_counts / total_counts

def load_maf_baf_data(maf_file):
    """
    Load B-allele frequency data from a MAF file with DP4 column.

    Parameters:
    -----------
    maf_file : str
        Path to the MAF file

    Returns:
    --------
    pandas.DataFrame
        DataFrame with chromosome, position, and BAF columns
    """
    print(f"Loading MAF file: {maf_file}")

    try:
        maf_data = pd.read_csv(maf_file, sep='\t', comment='#')

        # Check if required columns exist
        required_columns = ['Chromosome', 'Start_Position', 'DP4']
        alt_chromosome_cols = ['chr', 'chromosome', 'chrom']
        alt_position_cols = ['pos', 'position', 'start', 'start_position']
        alt_dp4_cols = ['dp4', 'readcounts']

        # Map standard column names to alternatives in file
        column_mapping = {}

        # Check for chromosome column
        if 'Chromosome' not in maf_data.columns:
            for alt in alt_chromosome_cols:
                if alt in maf_data.columns or alt.capitalize() in maf_data.columns:
                    col = alt if alt in maf_data.columns else alt.capitalize()
                    column_mapping['Chromosome'] = col
                    break

        # Check for position column
        if 'Start_Position' not in maf_data.columns:
            for alt in alt_position_cols:
                if alt in maf_data.columns or alt.capitalize() in maf_data.columns:
                    col = alt if alt in maf_data.columns else alt.capitalize()
                    column_mapping['Start_Position'] = col
                    break

        # Check for DP4 column
        if 'DP4' not in maf_data.columns:
            for alt in alt_dp4_cols:
                if alt in maf_data.columns or alt.capitalize() in maf_data.columns:
                    col = alt if alt in maf_data.columns else alt.capitalize()
                    column_mapping['DP4'] = col
                    break

        # Rename columns if needed
        if column_mapping:
            maf_data = maf_data.rename(columns=column_mapping)

        # Check if required columns are now available
        missing_columns = [col for col in required_columns if col not in maf_data.columns]
        if missing_columns:
            print(f"Error: Missing required columns in MAF file: {missing_columns}")
            print(f"Available columns: {list(maf_data.columns)}")
            return None

        # Calculate BAF from DP4
        maf_data['BAF'] = maf_data['DP4'].apply(parse_dp4)

        # Filter out rows with None BAF values
        maf_data = maf_data.dropna(subset=['BAF'])

        # Create standardized output with chromosome, position, and BAF
        baf_data = pd.DataFrame({
            'chromosome': maf_data['Chromosome'],
            'position': maf_data['Start_Position'],
            'baf': maf_data['BAF']
        })

        print(f"Loaded {len(baf_data)} variants with valid BAF values")
        return baf_data

    except Exception as e:
        print(f"Error loading MAF file: {e}")
        return None



def detect_baf_clusters(baf_values, min_cluster_points=5):
    """
    Detect clusters in B-allele frequencies to determine if they diverge from 0.5

    Parameters:
    -----------
    baf_values : np.array
        Array of BAF values for a segment
    min_cluster_points : int
        Minimum number of points required to form a cluster

    Returns:
    --------
    tuple
        Tuple of (True/False for divergence, list of cluster centers)
    """
    if len(baf_values) < min_cluster_points:
        return False, [0.5]  # Not enough data points

    # Remove extreme values (0 and 1) as they might skew clustering
    filtered_values = baf_values[(baf_values > 0.05) & (baf_values < 0.95)]
    if len(filtered_values) < min_cluster_points:
        return False, [0.5]

    # First attempt with 2 clusters to see if we have equal support
    kmeans = KMeans(n_clusters=2, random_state=42, n_init=10)
    kmeans.fit(filtered_values.reshape(-1, 1))
    labels = kmeans.labels_

    # Count points in each cluster
    cluster_counts = np.bincount(labels)
    cluster_centers = sorted(kmeans.cluster_centers_.flatten())

    # Check if we have approximately equal support
    # Define "approximately equal" as smaller cluster having at least 30% of points
    if len(cluster_counts) == 2:
        min_count = np.min(cluster_counts)
        total_count = np.sum(cluster_counts)
        proportion = min_count / total_count

        if proportion >= 0.3:
            # We have two reasonably supported clusters
            # Check if they're far enough from 0.5 to be meaningful
            if abs(cluster_centers[0] - 0.5) > 0.05 or abs(cluster_centers[1] - 0.5) > 0.05:
                return True, cluster_centers

    # If we don't have two well-supported clusters, try with a single cluster
    kmeans = KMeans(n_clusters=1, random_state=42, n_init=10)
    kmeans.fit(filtered_values.reshape(-1, 1))
    cluster_center = kmeans.cluster_centers_.flatten()[0]

    # Check if the single cluster deviates significantly from 0.5
    if abs(cluster_center - 0.5) > 0.1:
        return True, [cluster_center]
    else:
        return False, [0.5]


def plot_cns_and_baf(cns_file, maf_file, output_prefix):
    """
    Plot segmented copy number data (.cns file) overlaid with B-allele frequencies from a MAF file.
    Add summary lines through BAF points for each segment.

    Parameters:
    -----------
    cns_file : str
        Path to the CNVkit .cns file (segmented copy number)
    maf_file : str
        Path to the MAF file with DP4 column for B-allele frequencies
    output_prefix : str
        Prefix for output files

    Returns:
    --------
    None
    """
    print(f"Loading CNVkit segments from {cns_file}...")
    try:
        # CNVkit .cns files are tab-delimited with segments
        cns_data = pd.read_csv(cns_file, sep='\t')

        # Check that required columns are present
        required_columns = ['chromosome', 'start', 'end', 'log2']
        missing_columns = [col for col in required_columns if col not in cns_data.columns]
        if missing_columns:
            print(f"Error: Missing required columns in CNVkit segment file: {missing_columns}")
            return None

    except Exception as e:
        print(f"Error loading CNVkit segment file: {e}")
        return None

    # Load BAF data from MAF file
    baf_data = load_maf_baf_data(maf_file)
    if baf_data is None:
        print("Warning: Could not load BAF data from MAF file. Plotting only segmented copy number data.")

    # Get list of chromosomes and sort them
    chromosomes = cns_data['chromosome'].unique()
    sorted_chroms = sorted(chromosomes, key=chrom_key)

    # Set up figure and primary axis
    fig, ax1 = plt.subplots(figsize=(40, 10))

    # Create chromosome mapping dictionary for coordinates
    chrom_mapping = {}
    chrom_sizes = {}
    chrom_starts = {}
    current_pos = 0

    # First pass: calculate chromosome sizes and start positions
    for chrom in sorted_chroms:
        chrom_data = cns_data[cns_data['chromosome'] == chrom]
        if not chrom_data.empty:
            chrom_size = chrom_data['end'].max() - chrom_data['start'].min()
            chrom_sizes[chrom] = chrom_size
            chrom_starts[chrom] = current_pos
            # Store start and end positions for this chromosome
            chrom_mapping[chrom] = (current_pos, current_pos + chrom_size)
            current_pos += chrom_size + (chrom_size * 0.05)  # Add 5% gap between chromosomes

    # Plot segmented copy number data
    for i, chrom in enumerate(sorted_chroms):
        chrom_data = cns_data[cns_data['chromosome'] == chrom]
        if not chrom_data.empty:
            chrom_start = chrom_starts[chrom]
            color = plt.cm.tab10.colors[i % len(plt.cm.tab10.colors)]

            for _, segment in chrom_data.iterrows():
                # Calculate relative position within the chromosome
                rel_start = segment['start'] - chrom_data['start'].min()
                rel_end = segment['end'] - chrom_data['start'].min()

                # Plot segment at the mapped position
                abs_start = chrom_start + rel_start
                abs_end = chrom_start + rel_end
                ax1.plot([abs_start, abs_end], [segment['log2'], segment['log2']],
                        linewidth=3, color=color, alpha=0.7)

                # Add segment log2 value text at the middle of each segment
                segment_mid = (abs_start + abs_end) / 2
                ax1.text(segment_mid, segment['log2'] + 0.1, f"{segment['log2']:.2f}",
                         fontsize=8, ha='center', va='bottom', color=color, alpha=0.9)

    # Set up the first y-axis for log2 ratio
    ax1.set_xlabel('Genomic Position by Chromosome', fontsize=12)
    ax1.set_ylabel('Log2 Ratio', color='blue', fontsize=12)
    ax1.tick_params(axis='y', labelcolor='blue')

    # Set y-axis limits for log2 ratio - make it symmetric around 0
    # Use the maximum absolute value to ensure symmetry
    log2_max = max(2, cns_data['log2'].abs().max())
    ax1.set_ylim(-3, 3)

    # Add a horizontal line at log2 = 0
    ax1.axhline(y=0, color='blue', linestyle='-', alpha=0.3)

    # Add chromosome name labels, boundaries and number labels
    max_chroms = 24  # Maximum number of chromosomes to label individually
    if len(sorted_chroms) <= max_chroms:
        # Add chromosome labels at the middle of each chromosome
        chrom_midpoints = []
        for chrom in sorted_chroms:
            start, end = chrom_mapping[chrom]
            midpoint = (start + end) / 2
            chrom_midpoints.append(midpoint)

            # Add vertical lines at chromosome boundaries
            ax1.axvline(x=end, color='black', linestyle='-', alpha=0.3)

            # Add chromosome number/name at the bottom of the plot
            ax1.text(midpoint, -3.2, f"Chr {chrom}",
                     fontsize=10, ha='center', va='top', fontweight='bold')

        ax1.set_xticks(chrom_midpoints)
        ax1.set_xticklabels(sorted_chroms, rotation=45, fontsize=10)

        # For each chromosome, add scale markers
        for chrom in sorted_chroms:
            start, end = chrom_mapping[chrom]
            chrom_width = end - start

            # Add position ticks at 25%, 50%, 75% within each chromosome (below main axis)
            for pct, label in [(0.25, '25%'), (0.5, '50%'), (0.75, '75%')]:
                tick_pos = start + (chrom_width * pct)
                # Add small tick mark
                ax1.annotate(label, xy=(tick_pos, -log2_max - 0.2), xycoords=('data', 'data'),
                            fontsize=8, ha='center', va='top', color='gray')
    else:
        # If too many chromosomes, use simplified axis
        for chrom in sorted_chroms:
            start, end = chrom_mapping[chrom]
            # Add vertical lines at chromosome boundaries
            ax1.axvline(x=end, color='black', linestyle='-', alpha=0.3)

            # Add chromosome number at the bottom of the plot
            midpoint = (start + end) / 2
            ax1.text(midpoint, -3.2, f"Chr {chrom}",
                     fontsize=10, ha='center', va='top', fontweight='bold')

        # Add text for each chromosome at the center
        for chrom in sorted_chroms:
            start, end = chrom_mapping[chrom]
            ax1.text((start + end) / 2, -log2_max - 0.2, chrom,
                    ha='center', va='top', fontsize=8, rotation=45)

    # Create a second y-axis for BAF if data is available
    if baf_data is not None:
        ax2 = ax1.twinx()
        ax2.set_ylabel('B-allele Frequency', color='red', fontsize=12)
        ax2.tick_params(axis='y', labelcolor='red')

        # Set y-axis limits for BAF from 0 to 1, ensuring 0.5 aligns with log2=0
        ax2.set_ylim(0, 1)

        # Align the transforms between the two y-axes to ensure BAF=0.5 aligns with log2=0
        ax1_min, ax1_max = ax1.get_ylim()
        ax2_min, ax2_max = 0, 1

        # The key part: adjust the position of BAF=0.5 to align with log2=0
        # This is done by transforming the axis scales appropriately
        ax2.spines['right'].set_position(('axes', 1.0))

        # Add reference line at BAF = 0.5 (should align with log2 = 0)
        ax2.axhline(y=0.5, color='red', linestyle='-', alpha=0.3)

        # Plot BAF data points on the second y-axis
        for chrom in sorted_chroms:
            if chrom in chrom_mapping:
                chrom_start = chrom_starts[chrom]
                chrom_min_pos = cns_data[cns_data['chromosome'] == chrom]['start'].min()

                # Plot BAF points for this chromosome
                chrom_baf_data = baf_data[baf_data['chromosome'] == chrom]
                if not chrom_baf_data.empty:
                    # Calculate relative positions within chromosome
                    rel_positions = chrom_baf_data['position'] - chrom_min_pos
                    # Calculate absolute positions
                    abs_positions = chrom_start + rel_positions
                    ax2.scatter(abs_positions, chrom_baf_data['baf'],
                               s=2, alpha=0.4, color='red')

                # Get segments for this chromosome
                chrom_segments = cns_data[cns_data['chromosome'] == chrom]

                # For each segment, plot BAF summary lines
                for _, segment in chrom_segments.iterrows():
                    # Find BAF points within this segment
                    segment_baf = chrom_baf_data[
                        (chrom_baf_data['position'] >= segment['start']) &
                        (chrom_baf_data['position'] <= segment['end'])
                    ]

                    if len(segment_baf) > 0:
                        # Calculate start and end positions for the segment
                        rel_start = segment['start'] - chrom_min_pos
                        rel_end = segment['end'] - chrom_min_pos
                        abs_start = chrom_start + rel_start
                        abs_end = chrom_start + rel_end

                        # Detect if BAF values diverge from 0.5
                        has_divergence, cluster_centers = detect_baf_clusters(segment_baf['baf'].values)

                        if has_divergence and len(cluster_centers) > 1:
                            # Draw two lines for diverging BAF values
                            for center in cluster_centers:
                                ax2.plot([abs_start, abs_end], [center, center],
                                       linewidth=2, color='darkred', alpha=0.7)
                                # Add BAF value text
                                ax2.text(abs_end + (current_pos * 0.001), center,
                                         f"{center:.2f}", fontsize=8, color='darkred',
                                         ha='left', va='center')
                        elif has_divergence and len(cluster_centers) == 1:
                            # Draw one line for shifted BAF (not at 0.5)
                            ax2.plot([abs_start, abs_end], [cluster_centers[0], cluster_centers[0]],
                                   linewidth=2, color='darkred', alpha=0.7)
                            # Add BAF value text
                            ax2.text(abs_end + (current_pos * 0.001), cluster_centers[0],
                                     f"{cluster_centers[0]:.2f}", fontsize=8, color='darkred',
                                     ha='left', va='center')

                            # If significantly shifted from 0.5, also draw a complementary line
                            if abs(cluster_centers[0] - 0.5) > 0.1:
                                complementary_baf = 1 - cluster_centers[0]
                                ax2.plot([abs_start, abs_end], [complementary_baf, complementary_baf],
                                       linewidth=2, color='darkred', alpha=0.7, linestyle='--')
                                # Add complementary BAF value text
                                ax2.text(abs_end + (current_pos * 0.001), complementary_baf,
                                         f"{complementary_baf:.2f}", fontsize=8, color='darkred',
                                         ha='left', va='center', alpha=0.7)
                        else:
                            # Draw one line at BAF = 0.5 for normal segments
                            ax2.plot([abs_start, abs_end], [0.5, 0.5],
                                   linewidth=2, color='darkred', alpha=0.5)

        # Add horizontal lines for more common expected BAF patterns
        # Expanded list of BAF values
        expected_baf_values = [0.125, 0.167, 0.25, 0.33, 0.375, 0.5, 0.625, 0.67, 0.75, 0.833, 0.875]
        for baf_value in expected_baf_values:
            if baf_value != 0.5:  # Skip 0.5 as we already added it
                ax2.axhline(y=baf_value, color='red', linestyle='--', alpha=0.2)

            # Add text annotation at the right edge
            ax2.text(current_pos * 1.01, baf_value, f'BAF={baf_value:.3f}',
                   color='red', fontsize=9, verticalalignment='center')

    # Highlight the alignment between log2=0 and BAF=0.5 with a vertical annotation
    if baf_data is not None:
        plt.figtext(0.02, 0.5, "⟵ Aligned at log2=0 and BAF=0.5 ⟶",
                   rotation=90, color='purple', fontsize=10,
                   ha='center', va='center', alpha=0.7)

    # Add a legend for BAF lines
    if baf_data is not None:
        from matplotlib.lines import Line2D
        legend_elements = [
            Line2D([0], [0], color='darkred', lw=2, alpha=0.7, label='BAF Summary Line'),
            Line2D([0], [0], color='darkred', lw=2, alpha=0.7, linestyle='--', label='Complementary BAF Line')
        ]
        ax2.legend(handles=legend_elements, loc='upper right')

    plt.title('Segmented Copy Number and B-allele Frequency by Chromosome', fontsize=14)
    plt.tight_layout()
    plt.savefig(f"{output_prefix}_segments_and_baf.png", dpi=300)
    plt.close()

    print(f"Plot saved to {output_prefix}_segments_and_baf.png")

    return None


# Example usage:
plot_cns_and_baf(glob.glob("/content/*adjusted.cns")[0], glob.glob("/content/*snpbackbone.maf")[0], "CPDC")

files.download('/content/CPDC_segments_and_baf.png')

Loading CNVkit segments from /content/CPDC2503695-SEQ-250405.adjusted.cns...
Loading MAF file: /content/CPDC2503695-SEQ-250405.snpbackbone.maf
Loaded 6151 variants with valid BAF values


  plt.tight_layout()


Plot saved to CPDC_segments_and_baf.png


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>