In [1]:
import pysam
import re
import csv
from collections import defaultdict
import os

In [26]:
sample_ids = [
    "ERR3579819", "ERR3579736","ERR3579812", "ERR3579813", "ERR3579814", "ERR3579815", "ERR3579816"
]

sample_map = {
    "ERR3579819": "VLC001",
    "ERR3579736": "EMN001",
    "ERR3579812": "TAF017",
    "ERR3579813": "TAF017",
    "ERR3579814": "TAF017",
    "ERR3579815": "TAF017",
    "ERR3579816": "TAF017"
}

# Create a regex pattern for the sample substrings
sample_pattern = re.compile(f"({'|'.join(sample_ids)})")


In [3]:
def count_reads_per_reference(bam_file, output_tsv):
    # Open BAM file
    bam = pysam.AlignmentFile(bam_file, "rb")
    
    # Create a set of unique sample names
    sample_names = set(sample_map.values())
    
    # Dictionary to hold counts per reference
    counts_per_reference = defaultdict(lambda: {sample_name: 0 for sample_name in sample_names})
    
    # Iterate through each read in the BAM file
    for read in bam.fetch(until_eof=True):
        # Get the reference name (contig) the read is aligned to
        reference_name = bam.get_reference_name(read.reference_id)
        
        # Extract the sample identifier from the read name
        match = sample_pattern.search(read.query_name)
        if match:
            sample_id = match.group(0)
            # Map the sample ID to the sample name
            sample_name = sample_map.get(sample_id)
            if sample_name:
                # Increment the counter for the appropriate sample and reference
                counts_per_reference[reference_name][sample_name] += 1
    
    # Convert sample_names to a sorted list for consistent column ordering
    sample_names = sorted(sample_names)
    
    # Write results to a TSV file
    with open(output_tsv, "w", newline="") as tsvfile:
        writer = csv.writer(tsvfile, delimiter='\t')
        # Write header
        writer.writerow(["Reference"] + sample_names)
        
        # Write counts for each reference
        for reference, sample_counts in counts_per_reference.items():
            writer.writerow([reference] + [sample_counts[sample] for sample in sample_names])
    
    print(f"Counts have been written to {output_tsv}")


In [4]:
def load_counts_tsv(counts_tsv):
    counts_dict = {}
    with open(counts_tsv, "r") as file:
        reader = csv.DictReader(file, delimiter='\t')
        # Get the sample names from the header (excluding 'Reference')
        sample_names = [field for field in reader.fieldnames if field != 'Reference']
        for row in reader:
            reference = row['Reference']
            counts_dict[reference] = {sample: int(row[sample]) for sample in sample_names}
    return counts_dict


In [5]:
def merge_tsvs(counts_tsv, cluster_tsv, output_tsv):
    # Load counts data
    counts_dict = load_counts_tsv(counts_tsv)
    
    with open(cluster_tsv, "r") as infile, open(output_tsv, "w", newline='') as outfile:
        reader = csv.DictReader(infile, delimiter='\t')
        # Get sample names from counts_tsv
        with open(counts_tsv, "r") as counts_file:
            counts_reader = csv.DictReader(counts_file, delimiter='\t')
            sample_names = [field for field in counts_reader.fieldnames if field != 'Reference']
        
        # Add the count headers to the output file
        fieldnames = reader.fieldnames + sample_names
        writer = csv.DictWriter(outfile, fieldnames=fieldnames, delimiter='\t')
        
        # Write the header
        writer.writeheader()
        
        # Process each row from the second TSV
        for row in reader:
            # Extract the reference name from the SampleName column
            sample_name = row['SampleName']
            # Assume reference is everything between 'md_' and '_n' in SampleName
            reference_parts = sample_name.split('md_')
            if len(reference_parts) > 1:
                reference = reference_parts[1].split('_n')[0]
            else:
                reference = sample_name  # Fallback if pattern doesn't match
            
            # Check if the reference exists in the counts dictionary
            if reference in counts_dict:
                # Add the counts from the first TSV to the current row
                row.update(counts_dict[reference])
            else:
                # If no counts are available, fill with 0
                row.update({sample: 0 for sample in sample_names})
            
            # Write the updated row to the output file
            writer.writerow(row)
    
    print(f"Output written to {output_tsv}")


In [6]:
def load_merged_tsv(merged_tsv, k_cluster):
    data = []
    with open(merged_tsv, "r") as file:
        reader = csv.DictReader(file, delimiter='\t')
        # Get sample names dynamically
        sample_names = [field for field in reader.fieldnames if field not in ['SampleID', 'SampleName'] and not field.startswith('Cluster')]
        for row in reader:
            # Find the cluster with the highest probability for this reference
            cluster = max(range(1, k_cluster + 1), key=lambda i: float(row.get(f'Cluster{i}', 0)))
            
            # Store relevant fields: reference name, sample counts, and assigned cluster
            sample_counts = {sample: int(row[sample]) for sample in sample_names}
            data.append({
                'SampleID': row['SampleID'],
                'Reference': row['SampleName'],
                'Cluster': cluster,  # The cluster with the highest probability
                **sample_counts
            })
    return data


In [7]:
def calculate_fractions_relative_to_total(merged_tsv, output_tsv, sample_map, k_cluster):
    # Load merged data
    data = load_merged_tsv(merged_tsv, k_cluster)
    
    # Get unique sample names from sample_map
    sample_names = sorted(set(sample_map.values()))
    
    # Dictionary to store total reads and sample-specific reads per cluster
    cluster_totals = defaultdict(lambda: {'total': 0, **{sample: 0 for sample in sample_names}})
    
    # Global totals to calculate the total aligned reads across all clusters
    global_totals = {'total': 0, **{sample: 0 for sample in sample_names}}
    
    # Aggregate read counts by the assigned cluster and globally
    for row in data:
        cluster = row['Cluster']
        # Sum counts for all samples in this row
        total_reads = sum(row[sample] for sample in sample_names)
        
        # Update total reads and sample-specific reads for the assigned cluster
        cluster_totals[cluster]['total'] += total_reads
        for sample in sample_names:
            cluster_totals[cluster][sample] += row[sample]
        
        # Update global totals
        global_totals['total'] += total_reads
        for sample in sample_names:
            global_totals[sample] += row[sample]
    
    # Variables to store the sum of all fractions per sample
    total_fractions = {sample: 0 for sample in sample_names}
    
    # Write the results to a new TSV file
    with open(output_tsv, "w", newline="") as tsvfile:
        writer = csv.writer(tsvfile, delimiter='\t')
        
        # Write the header
        header = ['Cluster']
        # Add fraction of total reads columns
        header += [f"{sample}_Fraction_of_Total" for sample in sample_names]
        # Add fraction within cluster columns
        header += [f"{sample}_Fraction_of_Cluster" for sample in sample_names]
        writer.writerow(header)
        
        # Calculate and write the fractions for each cluster compared to total reads
        for cluster in sorted(cluster_totals.keys()):
            counts = cluster_totals[cluster]
            row = [cluster]
            
            # Calculate fractions of total mapped reads
            fractions_of_total = []
            for sample in sample_names:
                if global_totals['total'] > 0:
                    fraction = round(counts[sample] / global_totals['total'], 4)
                else:
                    fraction = 0
                fractions_of_total.append(fraction)
                total_fractions[sample] += fraction  # Sum up for the final row
            
            # Calculate fractions within each cluster
            fractions_within_cluster = []
            total_reads_cluster = counts['total']
            for sample in sample_names:
                if total_reads_cluster > 0:
                    fraction = round(counts[sample] / total_reads_cluster, 4)
                else:
                    fraction = 0
                fractions_within_cluster.append(fraction)
            
            # Combine all fractions into the row
            row += fractions_of_total + fractions_within_cluster
            writer.writerow(row)
        
        # Add the final row with sums of the fractions of total reads
        total_row = ['Total']
        total_row += [round(total_fractions[sample], 4) for sample in sample_names]
        # Add empty cells for fractions within cluster
        total_row += ['' for _ in sample_names]
        writer.writerow(total_row)
    
    print(f"Cluster-specific fractions relative to total reads with sums have been written to {output_tsv}")


In [34]:
cluster = 3
out = "VLCEMNTAF/min10k"
os.makedirs(out, exist_ok=True)

In [35]:
# Specify the input BAM file and output TSV file
#bam_file = "/home/project2/data/empirical/yates2021/results/alignments/VLC_EMN_PYK_sorted_md.bam"
bam_file = "/home/project2/data/empirical/yates2021/results/alignments/VLC_EMN_TAF_min10k_sorted_md.bam"
cluster_tsv = f"/home/project2/bam2profDevs/bam2prof_production/yates_vlc_emn_taf_min10k_v2_plots/GMM/k{cluster}/cluster_report_k{cluster}.tsv"
reads_per_ref = f"{out}/output_counts.tsv"
merge_tsv = f"{out}/merged_output_k{cluster}.tsv"
fractions = f"{out}/cluster_fractions_k{cluster}.tsv"




In [31]:
# Run the function
count_reads_per_reference(bam_file, reads_per_ref)

Counts have been written to VLCEMNTAF/min10k/output_counts.tsv


In [36]:
merge_tsvs(reads_per_ref, cluster_tsv, merge_tsv)

Output written to VLCEMNTAF/min10k/merged_output_k3.tsv


In [37]:
calculate_fractions_relative_to_total(merge_tsv, fractions, sample_map, cluster)

Cluster-specific fractions relative to total reads with sums have been written to VLCEMNTAF/min10k/cluster_fractions_k3.tsv
