In [2]:
import time
import os
import multiprocessing as mp
import numpy as np
from collections import defaultdict, Counter
import psutil
from functools import partial

# Performance tracking
start_time = time.time()
memory_usage = []

def track_memory():
    """Track current memory usage"""
    process = psutil.Process(os.getpid())
    memory_usage.append(process.memory_info().rss / 1024 / 1024)  # MB
    return memory_usage[-1]

print(f"Starting with memory: {track_memory()} MB")

# ==================== FASTA/FASTQ Parsing Functions ====================

def read_fasta_file(filename):
    """Read a FASTA file and return the sequence"""
    sequence = ""
    with open(filename, 'r') as f:
        # Skip header line
        header = f.readline().strip()
        # Read sequence lines
        for line in f:
            if line.startswith('>'):
                continue  # Skip additional headers
            sequence += line.strip()
    return sequence

def read_fastq_file(filename, max_reads=None):
    """Read a FASTQ file and return reads as a list of (header, sequence) tuples"""
    reads = []
    count = 0
    
    with open(filename, 'r') as f:
        while True:
            if max_reads and count >= max_reads:
                break
                
            header = f.readline().strip()
            if not header:
                break  # End of file
                
            sequence = f.readline().strip()
            f.readline()  # Skip "+" line
            quality = f.readline().strip()
            
            if header and sequence and quality:
                reads.append((header, sequence))
                count += 1
            else:
                break  # Incomplete entry
    
    return reads

def read_genome_files(filenames):
    """Read all genome files and return a dictionary of sequences"""
    print("Reading reference genomes...")
    genomes = {}
    
    for filename in filenames:
        base_name = os.path.basename(filename)
        # Extract species name from filename
        species = base_name.split('.')[0]
        
        # Use prettier names for the species
        if species == "e_coli":
            species = "E. coli"
        elif species == "b_subtilis":
            species = "B. subtilis"
        elif species == "p_aeruginosa":
            species = "P. aeruginosa"
        elif species == "s_aureus":
            species = "S. aureus"
        elif species == "m_tuberculosis":
            species = "M. tuberculosis"
        
        print(f"Reading genome for {species}...")
        sequence = read_fasta_file(filename)
        genomes[species] = sequence
        print(f"Read {len(sequence)} bases for {species}")
    
    return genomes

def generate_taxonomic_tree():
    """Generate a taxonomic tree for the 5 reference genomes"""
    taxonomy = {
        'E. coli': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Enterobacterales', 'Enterobacteriaceae', 'Escherichia', 'E. coli'],
        'B. subtilis': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Bacillaceae', 'Bacillus', 'B. subtilis'],
        'P. aeruginosa': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Pseudomonadales', 'Pseudomonadaceae', 'Pseudomonas', 'P. aeruginosa'],
        'S. aureus': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Staphylococcaceae', 'Staphylococcus', 'S. aureus'],
        'M. tuberculosis': ['Bacteria', 'Actinobacteria', 'Actinomycetia', 'Mycobacteriales', 'Mycobacteriaceae', 'Mycobacterium', 'M. tuberculosis']
    }
    return taxonomy

def find_lca(taxonomy, species_list):
    """Find the lowest common ancestor of a list of species"""
    if not species_list:
        return "Unknown"
    if len(species_list) == 1:
        return species_list[0]  # Return the species if only one match
    
    # Get taxonomic paths for each species
    paths = [taxonomy[species] for species in species_list]
    
    # Find the common prefix of all paths
    common_path = []
    for i in range(min(len(path) for path in paths)):
        if len(set(path[i] for path in paths)) == 1:
            common_path.append(paths[0][i])
        else:
            break
    
    return common_path[-1] if common_path else "Bacteria"

# ==================== Suffix Array Implementation (Optimized) ====================

def build_suffix_array(text):
    """Build a suffix array for the given text (optimized for DNA sequences)"""
    n = len(text)
    
    # For very short texts, use direct approach
    if n < 10000:
        suffixes = [(text[i:], i) for i in range(n)]
        suffixes.sort()
        suffix_array = [pos for _, pos in suffixes]
        return suffix_array
    
    # For longer texts, use memory-efficient approach
    # Store only indices and compare characters directly during sort
    indices = list(range(n))
    
    # Use Python's built-in sort with our custom comparison function
    indices.sort(key=lambda i: (text[i:i+100], i))  # Use first 100 chars as initial sort
    
    return indices

def exact_search(text, pattern, suffix_array):
    """Fast exact pattern matching using suffix array with direct character comparison"""
    n, m = len(text), len(pattern)
    
    if m == 0:
        return False
    
    # Binary search for pattern
    left, right = 0, n - 1
    
    while left <= right:
        mid = (left + right) // 2
        suffix_pos = suffix_array[mid]
        
        # Compare pattern with current suffix without creating substrings
        match = True
        comparison = 0
        
        for i in range(min(m, n - suffix_pos)):
            if text[suffix_pos + i] < pattern[i]:
                match = False
                comparison = -1
                break
            elif text[suffix_pos + i] > pattern[i]:
                match = False
                comparison = 1
                break
        
        # If we've compared all characters and they matched
        if match:
            # If we've reached the end of the pattern, it's a match
            if m <= n - suffix_pos:
                return True
            # If pattern is longer than remaining suffix
            comparison = -1
        
        if comparison < 0:
            left = mid + 1
        else:
            right = mid - 1
    
    return False

def build_genome_indices(genomes):
    """Build suffix arrays for all genomes"""
    genome_indices = {}
    
    print("Building suffix arrays for all genomes...")
    for name, genome in genomes.items():
        print(f"Building index for {name}...")
        start = time.time()
        
        # Build suffix array
        suffix_array = build_suffix_array(genome)
        
        genome_indices[name] = {
            "sequence": genome,
            "suffix_array": suffix_array
        }
        
        print(f"Index for {name} built in {time.time() - start:.2f} seconds")
    
    return genome_indices

# ==================== Read Processing with Direct Parallelization ====================

def process_read_chunk(read_chunk, genome_indices, taxonomy):
    """Process a chunk of reads against all genome indices"""
    results = []
    
    for header, sequence in read_chunk:
        # Skip reads with ambiguous bases (N) for exact matching
        if 'N' in sequence:
            results.append((header, "Unclassified", "Unknown"))
            continue
            
        # Find exact matches in each genome
        matching_genomes = []
        
        for genome_name, index_data in genome_indices.items():
            if exact_search(index_data["sequence"], sequence, index_data["suffix_array"]):
                matching_genomes.append(genome_name)
        
        # Classify based on matches
        if not matching_genomes:
            classification = "Unclassified"
            lca = "Unknown"
        else:
            classification = matching_genomes
            lca = find_lca(taxonomy, matching_genomes)
        
        results.append((header, classification, lca))
    
    return results

def distribute_reads(reads, num_chunks):
    """Distribute reads evenly across chunks"""
    chunk_size = max(1, len(reads) // num_chunks)
    return [reads[i:i + chunk_size] for i in range(0, len(reads), chunk_size)]

def classify_reads_parallel(reads, genome_indices, taxonomy, num_processes=4):
    """Classify reads in parallel using multiple processes"""
    # Split reads into chunks for parallelization
    read_chunks = distribute_reads(reads, num_processes)
    
    # Process each chunk in parallel
    with mp.Pool(processes=num_processes) as pool:
        process_func = partial(process_read_chunk, genome_indices=genome_indices, taxonomy=taxonomy)
        chunk_results = pool.map(process_func, read_chunks)
    
    # Combine results from all chunks
    all_results = []
    for result in chunk_results:
        all_results.extend(result)
    
    return all_results

def summarize_results(classification_results):
    """Summarize classification results"""
    # Count matches per genome
    single_matches = Counter()
    multi_matches = Counter()
    lca_classifications = Counter()
    unclassified = 0
    
    for _, classification, lca in classification_results:
        if classification == "Unclassified":
            unclassified += 1
        elif len(classification) == 1:
            single_matches[classification[0]] += 1
        else:
            # For multi-matches, count each genome
            for genome in classification:
                multi_matches[genome] += 1
            # Also count by LCA
            lca_classifications[lca] += 1
    
    # Prepare summary
    summary = {
        "total_reads": len(classification_results),
        "single_matches": dict(single_matches),
        "multi_matches": dict(multi_matches),
        "lca_classifications": dict(lca_classifications),
        "unclassified": unclassified
    }
    
    return summary

def print_summary(summary):
    """Print a formatted summary of classification results"""
    print("\n====================== CLASSIFICATION SUMMARY ======================")
    print(f"Total reads processed: {summary['total_reads']}")
    
    print("\nSingle Genome Matches:")
    for genome, count in summary['single_matches'].items():
        percentage = (count / summary['total_reads']) * 100
        print(f"- {genome}: {count} reads ({percentage:.2f}%)")
    
    if summary['lca_classifications']:
        print("\nMulti-Genome Matches (by LCA):")
        total_multi = sum(summary['lca_classifications'].values())
        for lca, count in summary['lca_classifications'].items():
            percentage = (count / summary['total_reads']) * 100
            print(f"- {lca}: {count} reads ({percentage:.2f}%)")
    
    print(f"\nUnclassified: {summary['unclassified']} reads ({(summary['unclassified'] / summary['total_reads']) * 100:.2f}%)")
    print("===================================================================")

# ==================== Main Function ====================

def main(num_processes=8, max_reads_per_file=10000):
    """Main function to run the entire classification pipeline"""
    print(f"Starting metagenomic classification with {num_processes} processes")
    
    # Define file paths
    genome_files = [
        "b_subtilis.fna",
        "e_coli.fna",
        "m_tuberculosis.fna",
        "p_aeruginosa.fna",
        "s_aureus.fna"
    ]
    
    # Define combined read file sets
    read_file_sets = {
        "error_free_reads": [
            "simulated_reads_no_errors_10k_R1.fastq",
            "simulated_reads_no_errors_10k_R2.fastq"
        ],
        "with_errors_reads": [
            "simulated_reads_miseq_10k_R1.fastq", 
            "simulated_reads_miseq_10k_R2.fastq"
        ]
    }
    
    # Load reference genomes and build indices
    genomes = read_genome_files(genome_files)
    taxonomy = generate_taxonomic_tree()
    
    print(f"Memory after loading reference genomes: {track_memory()} MB")
    
    # Build genome indices
    genome_indices = build_genome_indices(genomes)
    
    print(f"Memory after building indices: {track_memory()} MB")
    
    # Process each combined read file set
    all_results = {}
    
    for set_name, file_paths in read_file_sets.items():
        print(f"\nProcessing {set_name}:")
        for path in file_paths:
            print(f"  - {path}")
        
        start_process = time.time()
        
        # Read and combine FASTQ files
        combined_reads = []
        for file_path in file_paths:
            reads = read_fastq_file(file_path, max_reads=max_reads_per_file)
            combined_reads.extend(reads)
            print(f"Read {len(reads)} reads from {file_path}")
            
        print(f"Total combined reads: {len(combined_reads)}")
        
        # Process reads
        results = classify_reads_parallel(
            combined_reads, 
            genome_indices, 
            taxonomy, 
            num_processes=num_processes
        )
        
        process_time = time.time() - start_process
        all_results[set_name] = results
        
        # Summarize and print results
        summary = summarize_results(results)
        print_summary(summary)
        
        print(f"Processing time: {process_time:.2f} seconds")
        print(f"Reads per second: {len(results) / process_time:.2f}")
        
        print(f"Memory after processing {set_name}: {track_memory()} MB")
    
    # Performance report
    total_time = time.time() - start_time
    peak_memory = max(memory_usage)
    
    print("\n====================== PERFORMANCE SUMMARY ======================")
    print(f"Total execution time: {total_time:.2f} seconds")
    print(f"Peak memory usage: {peak_memory:.2f} MB")
    
    # Calculate processing statistics
    total_reads = sum(len(results) for results in all_results.values())
    reads_per_second = total_reads / total_time
    
    print(f"\nTotal reads processed: {total_reads}")
    print(f"Overall processing speed: {reads_per_second:.2f} reads per second")
    print("================================================================")

# Run the main function with the specified number of processes
if __name__ == "__main__":
    # Set the number of CPU cores to use (adjust as needed)
    NUM_PROCESSES = 1  # Using 8 processes for better performance
    main(num_processes=NUM_PROCESSES)

Starting with memory: 272.140625 MB
Starting metagenomic classification with 1 processes
Reading reference genomes...
Reading genome for B. subtilis...
Read 4215606 bases for B. subtilis
Reading genome for E. coli...
Read 4641652 bases for E. coli
Reading genome for M. tuberculosis...
Read 4411532 bases for M. tuberculosis
Reading genome for P. aeruginosa...
Read 6264404 bases for P. aeruginosa
Reading genome for S. aureus...
Read 2821361 bases for S. aureus
Memory after loading reference genomes: 272.140625 MB
Building suffix arrays for all genomes...
Building index for B. subtilis...
Index for B. subtilis built in 5.83 seconds
Building index for E. coli...
Index for E. coli built in 6.70 seconds
Building index for M. tuberculosis...
Index for M. tuberculosis built in 6.32 seconds
Building index for P. aeruginosa...
Index for P. aeruginosa built in 9.43 seconds
Building index for S. aureus...
Index for S. aureus built in 3.79 seconds
Memory after building indices: 1567.4140625 MB

Pro

In [5]:
import time
import os
import subprocess
import multiprocessing as mp
from collections import Counter, defaultdict
import psutil
import sys
import tempfile
import shutil
import re

# First check if BLAST+ is installed
def check_blast_installation():
    try:
        # Check if blastn is available
        result = subprocess.run(["which", "blastn"], capture_output=True, text=True)
        if result.returncode != 0:
            print("ERROR: BLAST+ tools (blastn, makeblastdb) not found in PATH")
            print("\nPlease install NCBI BLAST+ tools:")
            print("For Conda: conda install -c bioconda blast")
            print("For Ubuntu/Debian: sudo apt-get install ncbi-blast+")
            return False
        
        blastn_path = result.stdout.strip()
        print(f"Found blastn at: {blastn_path}")
        
        # Check makeblastdb
        result = subprocess.run(["which", "makeblastdb"], capture_output=True, text=True)
        if result.returncode != 0:
            print("ERROR: makeblastdb tool not found in PATH")
            return False
            
        makeblastdb_path = result.stdout.strip()
        print(f"Found makeblastdb at: {makeblastdb_path}")
        return True
    except Exception as e:
        print(f"Error checking BLAST installation: {e}")
        return False

# Performance tracking
start_time = time.time()
memory_usage = []

def track_memory():
    """Track current memory usage"""
    process = psutil.Process(os.getpid())
    memory_usage.append(process.memory_info().rss / 1024 / 1024)  # MB
    return memory_usage[-1]

print(f"Starting with memory: {track_memory()} MB")

# ==================== Taxonomy Functions ====================

def generate_taxonomic_tree():
    """Generate a taxonomic tree for the 5 reference genomes"""
    taxonomy = {
        'E. coli': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Enterobacterales', 'Enterobacteriaceae', 'Escherichia', 'E. coli'],
        'B. subtilis': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Bacillaceae', 'Bacillus', 'B. subtilis'],
        'P. aeruginosa': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Pseudomonadales', 'Pseudomonadaceae', 'Pseudomonas', 'P. aeruginosa'],
        'S. aureus': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Staphylococcaceae', 'Staphylococcus', 'S. aureus'],
        'M. tuberculosis': ['Bacteria', 'Actinobacteria', 'Actinomycetia', 'Mycobacteriales', 'Mycobacteriaceae', 'Mycobacterium', 'M. tuberculosis'],
        # Add variations of names with different separators
        'E_coli': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Enterobacterales', 'Enterobacteriaceae', 'Escherichia', 'E_coli'],
        'B_subtilis': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Bacillaceae', 'Bacillus', 'B_subtilis'],
        'P_aeruginosa': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Pseudomonadales', 'Pseudomonadaceae', 'Pseudomonas', 'P_aeruginosa'],
        'S_aureus': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Staphylococcaceae', 'Staphylococcus', 'S_aureus'],
        'M_tuberculosis': ['Bacteria', 'Actinobacteria', 'Actinomycetia', 'Mycobacteriales', 'Mycobacteriaceae', 'Mycobacterium', 'M_tuberculosis'],
        # Add double underscore versions that might be created by regex
        'E__coli': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Enterobacterales', 'Enterobacteriaceae', 'Escherichia', 'E__coli'],
        'B__subtilis': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Bacillaceae', 'Bacillus', 'B__subtilis'],
        'P__aeruginosa': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Pseudomonadales', 'Pseudomonadaceae', 'Pseudomonas', 'P__aeruginosa'],
        'S__aureus': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Staphylococcaceae', 'Staphylococcus', 'S__aureus'],
        'M__tuberculosis': ['Bacteria', 'Actinobacteria', 'Actinomycetia', 'Mycobacteriales', 'Mycobacteriaceae', 'Mycobacterium', 'M__tuberculosis']
    }
    return taxonomy

def find_lca(taxonomy, species_list):
    """Find the lowest common ancestor of a list of species"""
    if not species_list:
        return "Unknown"
    if len(species_list) == 1:
        return species_list[0]  # Return the species if only one match
    
    # Get taxonomic paths for each species
    paths = [taxonomy[species] for species in species_list]
    
    # Find the common prefix of all paths
    common_path = []
    for i in range(min(len(path) for path in paths)):
        if len(set(path[i] for path in paths)) == 1:
            common_path.append(paths[0][i])
        else:
            break
    
    return common_path[-1] if common_path else "Bacteria"

# ==================== BLAST Database Creation ====================

def create_blast_db(fasta_files, temp_dir):
    """Create BLAST databases for each genome"""
    db_paths = {}
    
    print("Creating BLAST databases...")
    for fasta_file in fasta_files:
        base_name = os.path.basename(fasta_file)
        species = base_name.split('.')[0]
        
        # Use prettier names for the species but without spaces
        if species == "e_coli":
            species_display = "E. coli"
            species_id = "E_coli"
        elif species == "b_subtilis":
            species_display = "B. subtilis"
            species_id = "B_subtilis"
        elif species == "p_aeruginosa":
            species_display = "P. aeruginosa"
            species_id = "P_aeruginosa"
        elif species == "s_aureus":
            species_display = "S. aureus"
            species_id = "S_aureus"
        elif species == "m_tuberculosis":
            species_display = "M. tuberculosis"
            species_id = "M_tuberculosis"
        else:
            species_display = species
            species_id = species
        
        # Create a copy of the FASTA file in the temp directory (avoid spaces in filenames)
        db_path = os.path.join(temp_dir, species_id)
        shutil.copy(fasta_file, f"{db_path}.fna")
        
        # Create the BLAST database
        print(f"Creating BLAST database for {species_display}...")
        cmd = ["makeblastdb", "-in", f"{db_path}.fna", "-dbtype", "nucl", "-out", db_path, "-title", species_id]
        
        try:
            subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            print(f"Successfully created database for {species_display}")
        except subprocess.CalledProcessError as e:
            print(f"Error creating BLAST database for {species_display}:")
            print(f"Command: {' '.join(cmd)}")
            print(f"Error: {e}")
            print(f"stdout: {e.stdout.decode() if e.stdout else 'None'}")
            print(f"stderr: {e.stderr.decode() if e.stderr else 'None'}")
            raise e
        
        db_paths[species_display] = db_path
    
    return db_paths

# ==================== FASTQ Parsing Functions ====================

def read_fastq_entries(fastq_file, max_reads=None):
    """Read entries from a FASTQ file and return a list of (header, sequence) tuples"""
    reads = []
    count = 0
    
    with open(fastq_file, 'r') as f:
        while True:
            if max_reads and count >= max_reads:
                break
                
            # Read FASTQ entry
            header = f.readline().strip()
            if not header:
                break  # End of file
                
            sequence = f.readline().strip()
            f.readline()  # Skip '+' line
            quality = f.readline().strip()
            
            if header and sequence and quality:
                reads.append((header, sequence))
                count += 1
            else:
                break  # Incomplete entry
    
    return reads

def write_fasta_file(reads, output_fasta):
    """Write reads to a FASTA file"""
    with open(output_fasta, 'w') as f:
        for header, sequence in reads:
            # Extract read ID without @ symbol
            read_id = header[1:].split()[0]
            f.write(f">{read_id}\n{sequence}\n")
    
    return len(reads)

# ==================== BLAST Processing ====================

def run_blast(query_fasta, db_path, output_file, read_length=100, max_mismatches=1, evalue=1e-5, num_threads=1, max_hits=5):
    """
    Run blastn on the query against the specified database, allowing up to specified mismatches
    
    Parameters:
    - read_length: Average length of reads (used to calculate percent identity)
    - max_mismatches: Maximum number of mismatches to allow
    """
    
    # Calculate percent identity threshold based on read length and max mismatches
    # For example, allowing 1 mismatch in a 100bp read = 99% identity
    percent_identity = 100 - (max_mismatches * 100 / read_length)
    
    # Use text output format instead of XML for easier parsing
    cmd = [
        "blastn",
        "-query", query_fasta,
        "-db", db_path,
        "-out", output_file,
        "-outfmt", "6 qseqid sseqid pident length mismatch gapopen qstart qend sstart send evalue bitscore",
        "-evalue", str(evalue),
        "-max_target_seqs", str(max_hits),
        "-num_threads", str(num_threads),
        "-perc_identity", str(percent_identity),
        "-strand", "plus"  # Only search the forward strand
    ]
    
    try:
        subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    except subprocess.CalledProcessError as e:
        print(f"Error running BLAST:")
        print(f"Command: {' '.join(cmd)}")
        print(f"Error: {e}")
        print(f"stdout: {e.stdout.decode() if e.stdout else 'None'}")
        print(f"stderr: {e.stderr.decode() if e.stderr else 'None'}")
        raise e
    
    return output_file

def process_blast_results(results_file, query_ids, species):
    """Process BLAST tabular results and extract hits"""
    results = {}
    
    # Initialize all queries with no hits
    for query_id in query_ids:
        results[query_id] = []
    
    # Parse results file
    if os.path.getsize(results_file) > 0:  # Check if file is not empty
        with open(results_file, 'r') as f:
            for line in f:
                fields = line.strip().split('\t')
                if len(fields) >= 12:  # Ensure we have all expected fields
                    query_id = fields[0]
                    mismatches = int(fields[4])  # The mismatch column
                    
                    # Only add if this is a new hit for this query
                    if species not in results[query_id]:
                        results[query_id].append(species)
    
    return results

def get_query_ids_from_fasta(fasta_file):
    """Extract all query IDs from a FASTA file"""
    ids = []
    with open(fasta_file, 'r') as f:
        for line in f:
            if line.startswith('>'):
                ids.append(line.strip()[1:])  # Remove '>' and newline
    return ids

# ==================== Read Classification ====================

def classify_reads_with_blast(fastq_files, db_paths, taxonomy, max_reads_per_file=10000, temp_dir=None, num_threads=8, max_mismatches=1):
    """
    Classify reads using BLAST with a specified mismatch tolerance
    
    Parameters:
    - max_mismatches: Maximum number of mismatches to allow in alignments
    """
    if temp_dir is None:
        temp_dir = tempfile.mkdtemp()
    
    all_results = {}
    
    for set_name, file_paths in fastq_files.items():
        print(f"\nProcessing {set_name}:")
        for path in file_paths:
            print(f"  - {path}")
        
        start_process = time.time()
        
        # Read FASTQ files and combine reads
        combined_reads = []
        for fastq_file in file_paths:
            reads = read_fastq_entries(fastq_file, max_reads=max_reads_per_file)
            combined_reads.extend(reads)
            print(f"Read {len(reads)} reads from {fastq_file}")
        
        print(f"Total combined reads: {len(combined_reads)}")
        
        # Calculate average read length for percent identity threshold
        total_length = sum(len(seq) for _, seq in combined_reads)
        avg_read_length = total_length / len(combined_reads) if combined_reads else 100
        
        print(f"Average read length: {avg_read_length:.2f}bp")
        print(f"Setting up BLAST to allow maximum {max_mismatches} mismatch(es)")
        percent_identity = 100 - (max_mismatches * 100 / avg_read_length)
        print(f"Using percent identity threshold of {percent_identity:.2f}%")
        
        # Write combined reads to a FASTA file
        combined_fasta = os.path.join(temp_dir, f"{set_name}_combined.fasta")
        write_fasta_file(combined_reads, combined_fasta)
        
        # Get all query IDs
        query_ids = get_query_ids_from_fasta(combined_fasta)
        
        # Initialize results dictionary
        classification_results = {query_id: [] for query_id in query_ids}
        
        # Run BLAST against each reference genome
        for species, db_path in db_paths.items():
            print(f"BLASTing against {species}...")
            blast_start = time.time()
            
            # Run BLAST search
            output_file = os.path.join(temp_dir, f"{set_name}_{re.sub(r'[. ]', '_', species)}_blast.out")
            run_blast(
                combined_fasta, 
                db_path, 
                output_file, 
                read_length=avg_read_length,
                max_mismatches=max_mismatches,
                num_threads=num_threads
            )
            
            # Process BLAST results
            species_results = process_blast_results(output_file, query_ids, species)
            
            # Merge results
            for query_id, hits in species_results.items():
                classification_results[query_id].extend(hits)
            
            print(f"BLAST against {species} completed in {time.time() - blast_start:.2f} seconds")
        
        # Convert to the format used by the suffix array implementation
        formatted_results = []
        for query_id, matching_genomes in classification_results.items():
            if not matching_genomes:
                formatted_results.append((f"@{query_id}", "Unclassified", "Unknown"))
            else:
                # Use the original species names for taxonomy lookup
                lca = find_lca(taxonomy, matching_genomes)
                formatted_results.append((f"@{query_id}", matching_genomes, lca))
        
        all_results[set_name] = formatted_results
        
        # Calculate processing time
        process_time = time.time() - start_process
        
        # Summarize and print results
        summary = summarize_results(formatted_results)
        print_summary(summary)
        
        print(f"Processing time: {process_time:.2f} seconds")
        print(f"Reads per second: {len(formatted_results) / process_time:.2f}")
        
        print(f"Memory after processing {set_name}: {track_memory()} MB")
    
    return all_results

def summarize_results(classification_results):
    """Summarize classification results"""
    # Count matches per genome
    single_matches = Counter()
    multi_matches = Counter()
    lca_classifications = Counter()
    unclassified = 0
    
    for _, classification, lca in classification_results:
        if classification == "Unclassified":
            unclassified += 1
        elif len(classification) == 1:
            single_matches[classification[0]] += 1
        else:
            # For multi-matches, count each genome
            for genome in classification:
                multi_matches[genome] += 1
            # Also count by LCA
            lca_classifications[lca] += 1
    
    # Prepare summary
    summary = {
        "total_reads": len(classification_results),
        "single_matches": dict(single_matches),
        "multi_matches": dict(multi_matches),
        "lca_classifications": dict(lca_classifications),
        "unclassified": unclassified
    }
    
    return summary

def print_summary(summary):
    """Print a formatted summary of classification results"""
    print("\n====================== CLASSIFICATION SUMMARY ======================")
    print(f"Total reads processed: {summary['total_reads']}")
    
    print("\nSingle Genome Matches:")
    for genome, count in summary['single_matches'].items():
        percentage = (count / summary['total_reads']) * 100
        print(f"- {genome}: {count} reads ({percentage:.2f}%)")
    
    if summary['lca_classifications']:
        print("\nMulti-Genome Matches (by LCA):")
        for lca, count in summary['lca_classifications'].items():
            percentage = (count / summary['total_reads']) * 100
            print(f"- {lca}: {count} reads ({percentage:.2f}%)")
    
    print(f"\nUnclassified: {summary['unclassified']} reads ({(summary['unclassified'] / summary['total_reads']) * 100:.2f}%)")
    print("===================================================================")

# ==================== Main Function ====================

def main(num_threads=8, max_reads_per_file=10000, max_mismatches=1):
    """
    Main function to run BLAST-based classification allowing up to the specified number of mismatches
    
    Parameters:
    - num_threads: Number of threads to use for BLAST
    - max_reads_per_file: Maximum number of reads to process from each file
    - max_mismatches: Maximum number of mismatches to allow in alignments
    """
    # Check if BLAST+ is installed
    if not check_blast_installation():
        print("Please install BLAST+ tools before running this script")
        return
        
    print(f"Starting BLAST-based metagenomic classification with {num_threads} threads")
    print(f"Allowing up to {max_mismatches} mismatch(es) per read")
    
    # Define file paths
    genome_files = [
        "b_subtilis.fna",
        "e_coli.fna",
        "m_tuberculosis.fna",
        "p_aeruginosa.fna",
        "s_aureus.fna"
    ]
    
    # Define combined read file sets
    read_file_sets = {
        "error_free_reads": [
            "simulated_reads_no_errors_10k_R1.fastq",
            "simulated_reads_no_errors_10k_R2.fastq"
        ],
        "with_errors_reads": [
            "simulated_reads_miseq_10k_R1.fastq", 
            "simulated_reads_miseq_10k_R2.fastq"
        ]
    }
    
    # Create a temporary directory
    temp_dir = tempfile.mkdtemp()
    print(f"Using temporary directory: {temp_dir}")
    
    try:
        # Create BLAST databases
        db_paths = create_blast_db(genome_files, temp_dir)
        taxonomy = generate_taxonomic_tree()
        
        print(f"Memory after creating BLAST databases: {track_memory()} MB")
        
        # Classify reads
        all_results = classify_reads_with_blast(
            read_file_sets, 
            db_paths, 
            taxonomy, 
            max_reads_per_file=max_reads_per_file,
            temp_dir=temp_dir,
            num_threads=num_threads,
            max_mismatches=max_mismatches
        )
        
        # Performance report
        total_time = time.time() - start_time
        peak_memory = max(memory_usage)
        
        print("\n====================== PERFORMANCE SUMMARY ======================")
        print(f"Total execution time: {total_time:.2f} seconds")
        print(f"Peak memory usage: {peak_memory:.2f} MB")
        print(f"Maximum mismatches allowed: {max_mismatches}")
        
        # Calculate processing statistics
        total_reads = sum(len(results) for results in all_results.values())
        reads_per_second = total_reads / total_time
        
        print(f"\nTotal reads processed: {total_reads}")
        print(f"Overall processing speed: {reads_per_second:.2f} reads per second")
        print("================================================================")
        
    finally:
        # Clean up temporary directory
        print(f"Cleaning up temporary directory: {temp_dir}")
        shutil.rmtree(temp_dir)

# Run the main function with the specified number of threads
if __name__ == "__main__":
    # Use 8 threads for BLAST (or adjust based on your system)
    NUM_THREADS = 8
    
    # Maximum number of mismatches to allow (Task 1.3)
    MAX_MISMATCHES = 0
    
    main(num_threads=NUM_THREADS, max_mismatches=MAX_MISMATCHES)

Starting with memory: 255.203125 MB
Found blastn at: /usr/bin/blastn
Found makeblastdb at: /usr/bin/makeblastdb
Starting BLAST-based metagenomic classification with 8 threads
Allowing up to 0 mismatch(es) per read
Using temporary directory: /tmp/tmp14u7qb18
Creating BLAST databases...
Creating BLAST database for B. subtilis...
Successfully created database for B. subtilis
Creating BLAST database for E. coli...
Successfully created database for E. coli
Creating BLAST database for M. tuberculosis...
Successfully created database for M. tuberculosis
Creating BLAST database for P. aeruginosa...
Successfully created database for P. aeruginosa
Creating BLAST database for S. aureus...
Successfully created database for S. aureus
Memory after creating BLAST databases: 255.203125 MB

Processing error_free_reads:
  - simulated_reads_no_errors_10k_R1.fastq
  - simulated_reads_no_errors_10k_R2.fastq
Read 5000 reads from simulated_reads_no_errors_10k_R1.fastq
Read 5000 reads from simulated_reads_no_

In [1]:
import time
import os
import multiprocessing as mp
from collections import defaultdict, Counter
import psutil
from functools import partial
import numpy as np
import sys

# Performance tracking
start_time = time.time()
memory_usage = []

def track_memory():
    """Track current memory usage"""
    process = psutil.Process(os.getpid())
    memory_usage.append(process.memory_info().rss / 1024 / 1024)  # MB
    return memory_usage[-1]

print(f"Starting with memory: {track_memory()} MB")

# ==================== K-mer Encoding Functions ====================

def encode_kmer(kmer):
    """
    Encode a k-mer string into a 64-bit integer using 2-bit encoding
    A=00, C=01, G=10, T=11
    """
    encoding = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    result = 0
    for base in kmer:
        result = (result << 2) | encoding.get(base, 0)
    return result

def decode_kmer(encoded, k):
    """
    Decode a 2-bit encoded k-mer back to a string
    """
    decoding = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}
    result = []
    for i in range(k):
        # Extract 2 bits from the end
        bits = (encoded >> (2*i)) & 3
        result.append(decoding[bits])
    return ''.join(reversed(result))

# ==================== FASTA/FASTQ Parsing Functions ====================

def read_fasta_file(filename):
    """Read a FASTA file and return the sequence"""
    sequence = ""
    with open(filename, 'r') as f:
        # Skip header line
        header = f.readline().strip()
        # Read sequence lines
        for line in f:
            if line.startswith('>'):
                continue  # Skip additional headers
            sequence += line.strip()
    return sequence

def read_fastq_file(filename, max_reads=None):
    """Read a FASTQ file and return reads as a list of (header, sequence) tuples"""
    reads = []
    count = 0
    
    with open(filename, 'r') as f:
        while True:
            if max_reads and count >= max_reads:
                break
                
            header = f.readline().strip()
            if not header:
                break  # End of file
                
            sequence = f.readline().strip()
            f.readline()  # Skip "+" line
            quality = f.readline().strip()
            
            if header and sequence and quality:
                reads.append((header, sequence))
                count += 1
            else:
                break  # Incomplete entry
    
    return reads

def read_genome_files(filenames):
    """Read all genome files and return a dictionary of sequences"""
    print("Reading reference genomes...")
    genomes = {}
    
    for filename in filenames:
        base_name = os.path.basename(filename)
        # Extract species name from filename
        species = base_name.split('.')[0]
        
        # Use prettier names for the species
        if species == "e_coli":
            species = "E. coli"
        elif species == "b_subtilis":
            species = "B. subtilis"
        elif species == "p_aeruginosa":
            species = "P. aeruginosa"
        elif species == "s_aureus":
            species = "S. aureus"
        elif species == "m_tuberculosis":
            species = "M. tuberculosis"
        
        print(f"Reading genome for {species}...")
        sequence = read_fasta_file(filename)
        genomes[species] = sequence
        print(f"Read {len(sequence)} bases for {species}")
    
    return genomes

def generate_taxonomic_tree():
    """Generate a taxonomic tree for the 5 reference genomes"""
    taxonomy = {
        'E. coli': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Enterobacterales', 'Enterobacteriaceae', 'Escherichia', 'E. coli'],
        'B. subtilis': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Bacillaceae', 'Bacillus', 'B. subtilis'],
        'P. aeruginosa': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Pseudomonadales', 'Pseudomonadaceae', 'Pseudomonas', 'P. aeruginosa'],
        'S. aureus': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Staphylococcaceae', 'Staphylococcus', 'S. aureus'],
        'M. tuberculosis': ['Bacteria', 'Actinobacteria', 'Actinomycetia', 'Mycobacteriales', 'Mycobacteriaceae', 'Mycobacterium', 'M. tuberculosis']
    }
    return taxonomy

def find_lca(taxonomy, species_list):
    """Find the lowest common ancestor of a list of species"""
    if not species_list:
        return "Unknown"
    if len(species_list) == 1:
        return species_list[0]  # Return the species if only one match
    
    # Get taxonomic paths for each species
    paths = [taxonomy[species] for species in species_list]
    
    # Find the common prefix of all paths
    common_path = []
    for i in range(min(len(path) for path in paths)):
        if len(set(path[i] for path in paths)) == 1:
            common_path.append(paths[0][i])
        else:
            break
    
    return common_path[-1] if common_path else "Bacteria"

# ==================== K-mer Index Implementation ====================

def extract_kmers_from_chunk(args):
    """Extract encoded k-mers from a chunk of a genome sequence"""
    chunk, start_idx, k = args
    kmers = []
    
    for i in range(len(chunk) - k + 1):
        kmer = chunk[i:i+k]
        if 'N' not in kmer:  # Skip k-mers with ambiguous bases
            encoded_kmer = encode_kmer(kmer)
            kmers.append((encoded_kmer, start_idx + i))
            
    return kmers

def extract_kmers_parallel(sequence, k=31, chunk_size=1000000, num_processes=8):
    """Extract k-mers from a sequence in parallel"""
    chunks = []
    for i in range(0, len(sequence), chunk_size):
        chunk = sequence[i:i+chunk_size+k-1]  # Add k-1 to ensure we catch kmers at chunk boundaries
        if len(chunk) >= k:
            chunks.append((chunk, i, k))
    
    with mp.Pool(processes=num_processes) as pool:
        results = pool.map(extract_kmers_from_chunk, chunks)
    
    # Combine results from all chunks
    all_kmers = []
    for result in results:
        all_kmers.extend(result)
        
    return all_kmers

def build_kmer_index_for_genome(genome_name, sequence, k=31, num_processes=8):
    """Build a k-mer index for a single genome"""
    print(f"Extracting {k}-mers from {genome_name}...")
    start = time.time()
    
    # Extract k-mers in parallel
    kmers = extract_kmers_parallel(sequence, k, num_processes=num_processes)
    
    # Build index - using standard dict
    genome_index = {}
    for encoded_kmer, position in kmers:
        if encoded_kmer not in genome_index:
            genome_index[encoded_kmer] = []
        genome_index[encoded_kmer].append(position)
    
    print(f"Found {len(kmers)} {k}-mers ({len(genome_index)} unique) in {genome_name}")
    print(f"K-mer extraction for {genome_name} completed in {time.time() - start:.2f} seconds")
    
    return genome_index

def merge_kmer_indices(genome_indices):
    """Merge k-mer indices from multiple genomes using regular dict"""
    # Use a regular dict instead of defaultdict with lambda
    combined_index = {}
    
    for genome_name, genome_index in genome_indices.items():
        for encoded_kmer, positions in genome_index.items():
            if encoded_kmer not in combined_index:
                combined_index[encoded_kmer] = {}
            combined_index[encoded_kmer][genome_name] = positions
    
    return combined_index

def build_kmer_index(genomes, k=31, num_processes=8):
    """Build a k-mer index for all genomes"""
    genome_indices = {}
    
    print(f"Building {k}-mer indices for all genomes...")
    for name, sequence in genomes.items():
        print(f"Processing {name}...")
        genome_indices[name] = build_kmer_index_for_genome(name, sequence, k, num_processes)
    
    # Merge individual genome indices
    print("Merging k-mer indices...")
    start = time.time()
    combined_index = merge_kmer_indices(genome_indices)
    print(f"Merged k-mer index contains {len(combined_index)} unique k-mers")
    print(f"Index merging completed in {time.time() - start:.2f} seconds")
    
    # Calculate some statistics
    kmer_stats = analyze_kmer_index(combined_index, k)
    
    return combined_index, kmer_stats

def analyze_kmer_index(index, k=31):
    """Analyze k-mer index statistics"""
    total_kmers = len(index)
    theoretical_kmers = 4**k
    
    # Count k-mers unique to each genome vs shared
    unique_to_genome = {}  # Using regular dict
    for genome in ['E. coli', 'B. subtilis', 'P. aeruginosa', 'S. aureus', 'M. tuberculosis']:
        unique_to_genome[genome] = 0
    
    shared_kmers = 0
    
    for encoded_kmer, genome_dict in index.items():
        if len(genome_dict) == 1:
            # Unique to one genome
            genome = next(iter(genome_dict.keys()))
            unique_to_genome[genome] += 1
        else:
            # Shared between multiple genomes
            shared_kmers += 1
    
    # Build statistics dictionary
    stats = {
        "total_unique_kmers": total_kmers,
        "theoretical_kmers": theoretical_kmers,
        "coverage_percent": (total_kmers / theoretical_kmers) * 100,
        "shared_kmers": shared_kmers,
        "unique_to_genome": unique_to_genome
    }
    
    return stats

def print_kmer_stats(stats):
    """Print k-mer index statistics"""
    print("\n====================== K-MER INDEX STATISTICS ======================")
    print(f"Total unique k-mers in index: {stats['total_unique_kmers']:,}")
    print(f"Theoretical k-mer space (4^k): {stats['theoretical_kmers']:,}")
    print(f"Space coverage: {stats['coverage_percent']:.10f}%")
    
    print("\nK-mers unique to each genome:")
    for genome, count in stats['unique_to_genome'].items():
        print(f"- {genome}: {count:,} unique k-mers")
    
    print(f"\nK-mers shared between genomes: {stats['shared_kmers']:,}")
    print("==================================================================")

# ==================== K-mer Based Classification ====================

def extract_read_kmers(sequence, k=31):
    """Extract encoded k-mers from a read sequence"""
    encoded_kmers = []
    for i in range(len(sequence) - k + 1):
        kmer = sequence[i:i+k]
        if 'N' not in kmer:  # Skip k-mers with ambiguous bases
            encoded_kmers.append(encode_kmer(kmer))
    return encoded_kmers

def process_read_chunk_with_kmers(read_chunk, kmer_index, taxonomy, k=31, min_kmer_fraction=0.1):
    """Process a chunk of reads against the k-mer index"""
    results = []
    
    for header, sequence in read_chunk:
        # Skip very short reads
        if len(sequence) < k:
            results.append((header, "Unclassified", "Unknown"))
            continue
        
        # Extract k-mers from the read
        read_kmers = extract_read_kmers(sequence, k)
        if not read_kmers:
            results.append((header, "Unclassified", "Unknown"))
            continue
        
        # Count matches to each genome
        genome_matches = {}  # Using regular dict instead of defaultdict
        for genome in ['E. coli', 'B. subtilis', 'P. aeruginosa', 'S. aureus', 'M. tuberculosis']:
            genome_matches[genome] = 0
            
        total_kmers = len(read_kmers)
        matched_kmers = 0
        
        for encoded_kmer in read_kmers:
            if encoded_kmer in kmer_index:
                matched_kmers += 1
                for genome in kmer_index[encoded_kmer]:
                    genome_matches[genome] += 1
        
        # Determine best matching genome(s)
        max_matches = max(genome_matches.values())
        if max_matches == 0:
            results.append((header, "Unclassified", "Unknown"))
            continue
            
        match_fraction = max_matches / total_kmers if total_kmers > 0 else 0
        
        # Classify based on matches and threshold
        if match_fraction < min_kmer_fraction:
            results.append((header, "Unclassified", "Unknown"))
        else:
            best_genomes = [g for g, c in genome_matches.items() if c == max_matches]
            if len(best_genomes) == 1:
                results.append((header, best_genomes, best_genomes[0]))
            else:
                lca = find_lca(taxonomy, best_genomes)
                results.append((header, best_genomes, lca))
    
    return results

def distribute_reads(reads, num_chunks):
    """Distribute reads evenly across chunks"""
    chunk_size = max(1, len(reads) // num_chunks)
    return [reads[i:i + chunk_size] for i in range(0, len(reads), chunk_size)]

def classify_reads_parallel(reads, kmer_index, taxonomy, k=31, num_processes=8, min_kmer_fraction=0.1):
    """Classify reads in parallel using multiple processes"""
    # Split reads into chunks for parallelization
    read_chunks = distribute_reads(reads, num_processes)
    
    # Process each chunk in parallel
    with mp.Pool(processes=num_processes) as pool:
        process_func = partial(process_read_chunk_with_kmers, 
                               kmer_index=kmer_index, 
                               taxonomy=taxonomy, 
                               k=k, 
                               min_kmer_fraction=min_kmer_fraction)
        chunk_results = pool.map(process_func, read_chunks)
    
    # Combine results from all chunks
    all_results = []
    for result in chunk_results:
        all_results.extend(result)
    
    return all_results

def summarize_results(classification_results):
    """Summarize classification results"""
    # Count matches per genome
    single_matches = Counter()
    multi_matches = Counter()
    lca_classifications = Counter()
    unclassified = 0
    
    for _, classification, lca in classification_results:
        if classification == "Unclassified":
            unclassified += 1
        elif len(classification) == 1:
            single_matches[classification[0]] += 1
        else:
            # For multi-matches, count each genome
            for genome in classification:
                multi_matches[genome] += 1
            # Also count by LCA
            lca_classifications[lca] += 1
    
    # Prepare summary
    summary = {
        "total_reads": len(classification_results),
        "single_matches": dict(single_matches),
        "multi_matches": dict(multi_matches),
        "lca_classifications": dict(lca_classifications),
        "unclassified": unclassified
    }
    
    return summary

def print_summary(summary):
    """Print a formatted summary of classification results"""
    print("\n====================== CLASSIFICATION SUMMARY ======================")
    print(f"Total reads processed: {summary['total_reads']}")
    
    print("\nSingle Genome Matches:")
    for genome, count in summary['single_matches'].items():
        percentage = (count / summary['total_reads']) * 100
        print(f"- {genome}: {count} reads ({percentage:.2f}%)")
    
    if summary['lca_classifications']:
        print("\nMulti-Genome Matches (by LCA):")
        for lca, count in summary['lca_classifications'].items():
            percentage = (count / summary['total_reads']) * 100
            print(f"- {lca}: {count} reads ({percentage:.2f}%)")
    
    print(f"\nUnclassified: {summary['unclassified']} reads ({(summary['unclassified'] / summary['total_reads']) * 100:.2f}%)")
    print("===================================================================")

# ==================== Main Function ====================

def main(k=31, num_processes=8, max_reads_per_file=10000, min_kmer_fraction=1.0):
    """Main function to run the entire classification pipeline"""
    print(f"Starting k-mer based metagenomic classification with k={k} and {num_processes} processes")
    print(f"Using minimum k-mer match fraction: {min_kmer_fraction}")
    
    # Define file paths
    genome_files = [
        "b_subtilis.fna",
        "e_coli.fna",
        "m_tuberculosis.fna",
        "p_aeruginosa.fna",
        "s_aureus.fna"
    ]
    
    # Define combined read file sets
    read_file_sets = {
        "error_free_reads": [
            "simulated_reads_no_errors_10k_R1.fastq",
            "simulated_reads_no_errors_10k_R2.fastq"
        ],
        "with_errors_reads": [
            "simulated_reads_miseq_10k_R1.fastq", 
            "simulated_reads_miseq_10k_R2.fastq"
        ]
    }
    
    # Load reference genomes
    genomes = read_genome_files(genome_files)
    taxonomy = generate_taxonomic_tree()
    
    print(f"Memory after loading reference genomes: {track_memory()} MB")
    
    # Build k-mer index
    index_start = time.time()
    kmer_index, kmer_stats = build_kmer_index(genomes, k=k, num_processes=num_processes)
    index_time = time.time() - index_start
    
    print(f"K-mer index built in {index_time:.2f} seconds")
    print(f"Memory after building k-mer index: {track_memory()} MB")
    
    # Print k-mer statistics
    print_kmer_stats(kmer_stats)
    
    # Process each combined read file set
    all_results = {}
    
    for set_name, file_paths in read_file_sets.items():
        print(f"\nProcessing {set_name}:")
        for path in file_paths:
            print(f"  - {path}")
        
        start_process = time.time()
        
        # Read and combine FASTQ files
        combined_reads = []
        for file_path in file_paths:
            reads = read_fastq_file(file_path, max_reads=max_reads_per_file)
            combined_reads.extend(reads)
            print(f"Read {len(reads)} reads from {file_path}")
            
        print(f"Total combined reads: {len(combined_reads)}")
        
        # Process reads
        results = classify_reads_parallel(
            combined_reads, 
            kmer_index, 
            taxonomy, 
            k=k,
            num_processes=num_processes,
            min_kmer_fraction=min_kmer_fraction
        )
        
        process_time = time.time() - start_process
        all_results[set_name] = results
        
        # Summarize and print results
        summary = summarize_results(results)
        print_summary(summary)
        
        print(f"Processing time: {process_time:.2f} seconds")
        print(f"Reads per second: {len(results) / process_time:.2f}")
        
        print(f"Memory after processing {set_name}: {track_memory()} MB")
    
    # Performance report
    total_time = time.time() - start_time
    peak_memory = max(memory_usage)
    
    print("\n====================== PERFORMANCE SUMMARY ======================")
    print(f"Total execution time: {total_time:.2f} seconds")
    print(f"K-mer index building time: {index_time:.2f} seconds")
    print(f"Peak memory usage: {peak_memory:.2f} MB")
    
    # Calculate processing statistics
    total_reads = sum(len(results) for results in all_results.values())
    reads_per_second = total_reads / (total_time - index_time)
    
    print(f"\nTotal reads processed: {total_reads}")
    print(f"Overall processing speed: {reads_per_second:.2f} reads per second")
    print("================================================================")

# Run the main function with the specified parameters
if __name__ == "__main__":
    # Set parameters
    K = 31  # k-mer length (as specified in the assignment)
    NUM_PROCESSES = 1  # Number of processes to use
    MIN_KMER_FRACTION = 1.0  # For exact matching, require 100% of k-mers to match
    
    main(k=K, num_processes=NUM_PROCESSES, min_kmer_fraction=MIN_KMER_FRACTION)

Starting with memory: 99.5 MB
Starting k-mer based metagenomic classification with k=31 and 1 processes
Using minimum k-mer match fraction: 1.0
Reading reference genomes...
Reading genome for B. subtilis...
Read 4215606 bases for B. subtilis
Reading genome for E. coli...
Read 4641652 bases for E. coli
Reading genome for M. tuberculosis...
Read 4411532 bases for M. tuberculosis
Reading genome for P. aeruginosa...
Read 6264404 bases for P. aeruginosa
Reading genome for S. aureus...
Read 2821361 bases for S. aureus
Memory after loading reference genomes: 121.5 MB
Building 31-mer indices for all genomes...
Processing B. subtilis...
Extracting 31-mers from B. subtilis...
Found 4215576 31-mers (4170526 unique) in B. subtilis
K-mer extraction for B. subtilis completed in 12.40 seconds
Processing E. coli...
Extracting 31-mers from E. coli...
Found 4641622 31-mers (4570839 unique) in E. coli
K-mer extraction for E. coli completed in 12.65 seconds
Processing M. tuberculosis...
Extracting 31-mers

In [1]:
import time
import os
import multiprocessing as mp
from collections import defaultdict, Counter, deque
import psutil
from functools import partial
import numpy as np
import sys

# Performance tracking
start_time = time.time()
memory_usage = []

def track_memory():
    """Track current memory usage"""
    process = psutil.Process(os.getpid())
    memory_usage.append(process.memory_info().rss / 1024 / 1024)  # MB
    return memory_usage[-1]

print(f"Starting with memory: {track_memory()} MB")

# ==================== K-mer Encoding Functions ====================

def encode_kmer(kmer):
    """
    Encode a k-mer string into a 64-bit integer using 2-bit encoding
    A=00, C=01, G=10, T=11
    """
    encoding = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    result = 0
    for base in kmer:
        result = (result << 2) | encoding.get(base, 0)
    return result

def decode_kmer(encoded, k):
    """
    Decode a 2-bit encoded k-mer back to a string
    """
    decoding = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}
    result = []
    for i in range(k):
        # Extract 2 bits from the end
        bits = (encoded >> (2*i)) & 3
        result.append(decoding[bits])
    return ''.join(reversed(result))

# ==================== FASTA/FASTQ Parsing Functions ====================

def read_fasta_file(filename):
    """Read a FASTA file and return the sequence"""
    sequence = ""
    with open(filename, 'r') as f:
        # Skip header line
        header = f.readline().strip()
        # Read sequence lines
        for line in f:
            if line.startswith('>'):
                continue  # Skip additional headers
            sequence += line.strip()
    return sequence

def read_fastq_file(filename, max_reads=None):
    """Read a FASTQ file and return reads as a list of (header, sequence) tuples"""
    reads = []
    count = 0
    
    with open(filename, 'r') as f:
        while True:
            if max_reads and count >= max_reads:
                break
                
            header = f.readline().strip()
            if not header:
                break  # End of file
                
            sequence = f.readline().strip()
            f.readline()  # Skip "+" line
            quality = f.readline().strip()
            
            if header and sequence and quality:
                reads.append((header, sequence))
                count += 1
            else:
                break  # Incomplete entry
    
    return reads

def read_genome_files(filenames):
    """Read all genome files and return a dictionary of sequences"""
    print("Reading reference genomes...")
    genomes = {}
    
    for filename in filenames:
        base_name = os.path.basename(filename)
        # Extract species name from filename
        species = base_name.split('.')[0]
        
        # Use prettier names for the species
        if species == "e_coli":
            species = "E. coli"
        elif species == "b_subtilis":
            species = "B. subtilis"
        elif species == "p_aeruginosa":
            species = "P. aeruginosa"
        elif species == "s_aureus":
            species = "S. aureus"
        elif species == "m_tuberculosis":
            species = "M. tuberculosis"
        
        print(f"Reading genome for {species}...")
        sequence = read_fasta_file(filename)
        genomes[species] = sequence
        print(f"Read {len(sequence)} bases for {species}")
    
    return genomes

def generate_taxonomic_tree():
    """Generate a taxonomic tree for the 5 reference genomes"""
    taxonomy = {
        'E. coli': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Enterobacterales', 'Enterobacteriaceae', 'Escherichia', 'E. coli'],
        'B. subtilis': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Bacillaceae', 'Bacillus', 'B. subtilis'],
        'P. aeruginosa': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Pseudomonadales', 'Pseudomonadaceae', 'Pseudomonas', 'P. aeruginosa'],
        'S. aureus': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Staphylococcaceae', 'Staphylococcus', 'S. aureus'],
        'M. tuberculosis': ['Bacteria', 'Actinobacteria', 'Actinomycetia', 'Mycobacteriales', 'Mycobacteriaceae', 'Mycobacterium', 'M. tuberculosis']
    }
    return taxonomy

def find_lca(taxonomy, species_list):
    """Find the lowest common ancestor of a list of species"""
    if not species_list:
        return "Unknown"
    if len(species_list) == 1:
        return species_list[0]  # Return the species if only one match
    
    # Get taxonomic paths for each species
    paths = [taxonomy[species] for species in species_list]
    
    # Find the common prefix of all paths
    common_path = []
    for i in range(min(len(path) for path in paths)):
        if len(set(path[i] for path in paths)) == 1:
            common_path.append(paths[0][i])
        else:
            break
    
    return common_path[-1] if common_path else "Bacteria"

# ==================== Minimizer Implementation ====================

def extract_minimizers_from_chunk(args):
    """Extract minimizers from a chunk of a genome sequence using a sliding window approach"""
    chunk, start_idx, k, w = args
    
    # Skip if chunk is too short
    if len(chunk) < k:
        return []
    
    # First, extract all k-mers with their positions
    kmers = []
    for i in range(len(chunk) - k + 1):
        kmer = chunk[i:i+k]
        if 'N' not in kmer:  # Skip k-mers with ambiguous bases
            encoded_kmer = encode_kmer(kmer)
            kmers.append((encoded_kmer, start_idx + i))
    
    # Now find minimizers in each window of w consecutive k-mers
    minimizers = {}
    
    for i in range(len(kmers) - w + 1):
        window = kmers[i:i+w]
        # Find the lexicographically smallest k-mer in the window
        min_kmer, min_pos = min(window, key=lambda x: x[0])
        
        if min_kmer not in minimizers:
            minimizers[min_kmer] = []
        if min_pos not in minimizers[min_kmer]:
            minimizers[min_kmer].append(min_pos)
    
    return list(minimizers.items())

def extract_minimizers_parallel(sequence, k=31, w=10, chunk_size=1000000, num_processes=8):
    """Extract minimizers from a sequence in parallel"""
    chunks = []
    for i in range(0, len(sequence), chunk_size):
        # Add k+w-1 to ensure we don't miss any minimizers at chunk boundaries
        chunk = sequence[i:i+chunk_size+k+w-1]
        if len(chunk) >= k:
            chunks.append((chunk, i, k, w))
    
    with mp.Pool(processes=num_processes) as pool:
        results = pool.map(extract_minimizers_from_chunk, chunks)
    
    # Combine results from all chunks
    all_minimizers = {}
    for result in results:
        for encoded_kmer, positions in result:
            if encoded_kmer not in all_minimizers:
                all_minimizers[encoded_kmer] = []
            all_minimizers[encoded_kmer].extend(positions)
    
    return all_minimizers

def build_minimizer_index_for_genome(genome_name, sequence, k=31, w=10, num_processes=8):
    """Build a minimizer-based index for a single genome"""
    print(f"Extracting minimizers (k={k}, w={w}) from {genome_name}...")
    start = time.time()
    
    # Extract minimizers in parallel
    minimizers = extract_minimizers_parallel(sequence, k, w, num_processes=num_processes)
    
    # Count the total number of k-mers
    total_kmers = len(sequence) - k + 1
    total_minimizers = len(minimizers)
    reduction_ratio = total_kmers / total_minimizers if total_minimizers > 0 else float('inf')
    
    print(f"Found {total_minimizers:,} unique minimizers out of {total_kmers:,} possible k-mers")
    print(f"Reduction ratio: {reduction_ratio:.2f}x")
    print(f"Minimizer extraction for {genome_name} completed in {time.time() - start:.2f} seconds")
    
    return minimizers, total_kmers, total_minimizers

def merge_minimizer_indices(genome_indices):
    """Merge minimizer indices from multiple genomes using regular dict"""
    combined_index = {}
    
    for genome_name, genome_index in genome_indices.items():
        for encoded_kmer, positions in genome_index.items():
            if encoded_kmer not in combined_index:
                combined_index[encoded_kmer] = {}
            combined_index[encoded_kmer][genome_name] = positions
    
    return combined_index

def build_minimizer_index(genomes, k=31, w=10, num_processes=8):
    """Build a minimizer-based index for all genomes"""
    genome_indices = {}
    total_genome_kmers = 0
    total_genome_minimizers = 0
    
    print(f"Building minimizer index (k={k}, w={w}) for all genomes...")
    for name, sequence in genomes.items():
        print(f"Processing {name}...")
        genome_index, genome_kmers, genome_minimizers = build_minimizer_index_for_genome(
            name, sequence, k, w, num_processes
        )
        genome_indices[name] = genome_index
        total_genome_kmers += genome_kmers
        total_genome_minimizers += genome_minimizers
    
    # Merge individual genome indices
    print("Merging minimizer indices...")
    start = time.time()
    combined_index = merge_minimizer_indices(genome_indices)
    print(f"Merged minimizer index contains {len(combined_index):,} unique minimizers")
    print(f"Index merging completed in {time.time() - start:.2f} seconds")
    
    # Calculate overall reduction ratio
    overall_reduction = total_genome_kmers / len(combined_index) if len(combined_index) > 0 else float('inf')
    print(f"Overall reduction ratio: {overall_reduction:.2f}x (from {total_genome_kmers:,} k-mers to {len(combined_index):,} minimizers)")
    
    # Calculate statistics
    minimizer_stats = analyze_minimizer_index(combined_index, k, w, total_genome_kmers, total_genome_minimizers)
    
    return combined_index, minimizer_stats

def analyze_minimizer_index(index, k, w, total_kmers, total_minimizers):
    """Analyze minimizer index statistics"""
    # Count minimizers unique to each genome vs shared
    unique_to_genome = {}
    for genome in ['E. coli', 'B. subtilis', 'P. aeruginosa', 'S. aureus', 'M. tuberculosis']:
        unique_to_genome[genome] = 0
    
    shared_minimizers = 0
    
    for encoded_kmer, genome_dict in index.items():
        if len(genome_dict) == 1:
            # Unique to one genome
            genome = next(iter(genome_dict.keys()))
            unique_to_genome[genome] += 1
        else:
            # Shared between multiple genomes
            shared_minimizers += 1
    
    # Theoretical k-mer space
    theoretical_kmers = 4**k
    
    # Build statistics dictionary
    stats = {
        "k": k,
        "w": w,
        "total_kmers": total_kmers,
        "total_minimizers": total_minimizers,
        "total_unique_minimizers": len(index),
        "theoretical_kmers": theoretical_kmers,
        "coverage_percent": (len(index) / theoretical_kmers) * 100,
        "reduction_ratio": total_kmers / len(index) if len(index) > 0 else float('inf'),
        "shared_minimizers": shared_minimizers,
        "unique_to_genome": unique_to_genome
    }
    
    return stats

def print_minimizer_stats(stats):
    """Print minimizer index statistics"""
    print("\n================ MINIMIZER INDEX STATISTICS ================")
    print(f"Parameters: k={stats['k']}, w={stats['w']}")
    print(f"Total unique minimizers in index: {stats['total_unique_minimizers']:,}")
    print(f"Total k-mers in reference genomes: {stats['total_kmers']:,}")
    print(f"Theoretical k-mer space (4^k): {stats['theoretical_kmers']:,}")
    print(f"Space coverage: {stats['coverage_percent']:.10f}%")
    print(f"Minimizer reduction ratio: {stats['reduction_ratio']:.2f}x")
    
    print("\nMinimizers unique to each genome:")
    for genome, count in stats['unique_to_genome'].items():
        print(f"- {genome}: {count:,} unique minimizers")
    
    print(f"\nMinimizers shared between genomes: {stats['shared_minimizers']:,}")
    print("===========================================================")

# ==================== Minimizer Based Classification ====================

def extract_read_minimizers(sequence, k=31, w=10):
    """Extract minimizers from a read sequence using a sliding window approach"""
    # Skip if read is too short
    if len(sequence) < k:
        return []
    
    # Extract all k-mers from the read
    kmers = []
    for i in range(len(sequence) - k + 1):
        kmer = sequence[i:i+k]
        if 'N' not in kmer:  # Skip k-mers with ambiguous bases
            kmers.append((encode_kmer(kmer), i))
    
    # Find minimizers in each window
    minimizers = set()
    for i in range(len(kmers) - w + 1):
        window = kmers[i:i+w]
        min_kmer, _ = min(window, key=lambda x: x[0])
        minimizers.add(min_kmer)
    
    return minimizers

def process_read_chunk_with_minimizers(read_chunk, minimizer_index, taxonomy, k=31, w=10, min_minimizer_fraction=0.1):
    """Process a chunk of reads against the minimizer index"""
    results = []
    
    for header, sequence in read_chunk:
        # Skip very short reads
        if len(sequence) < k:
            results.append((header, "Unclassified", "Unknown"))
            continue
        
        # Extract minimizers from the read
        read_minimizers = extract_read_minimizers(sequence, k, w)
        if not read_minimizers:
            results.append((header, "Unclassified", "Unknown"))
            continue
        
        # Count matches to each genome
        genome_matches = {}
        for genome in ['E. coli', 'B. subtilis', 'P. aeruginosa', 'S. aureus', 'M. tuberculosis']:
            genome_matches[genome] = 0
            
        total_minimizers = len(read_minimizers)
        matched_minimizers = 0
        
        for encoded_kmer in read_minimizers:
            if encoded_kmer in minimizer_index:
                matched_minimizers += 1
                for genome in minimizer_index[encoded_kmer]:
                    genome_matches[genome] += 1
        
        # Determine best matching genome(s)
        max_matches = max(genome_matches.values())
        if max_matches == 0:
            results.append((header, "Unclassified", "Unknown"))
            continue
            
        match_fraction = max_matches / total_minimizers if total_minimizers > 0 else 0
        
        # Classify based on matches and threshold
        if match_fraction < min_minimizer_fraction:
            results.append((header, "Unclassified", "Unknown"))
        else:
            best_genomes = [g for g, c in genome_matches.items() if c == max_matches]
            if len(best_genomes) == 1:
                results.append((header, best_genomes, best_genomes[0]))
            else:
                lca = find_lca(taxonomy, best_genomes)
                results.append((header, best_genomes, lca))
    
    return results

def distribute_reads(reads, num_chunks):
    """Distribute reads evenly across chunks"""
    chunk_size = max(1, len(reads) // num_chunks)
    return [reads[i:i + chunk_size] for i in range(0, len(reads), chunk_size)]

def classify_reads_parallel(reads, minimizer_index, taxonomy, k=31, w=10, num_processes=8, min_minimizer_fraction=0.1):
    """Classify reads in parallel using multiple processes"""
    # Split reads into chunks for parallelization
    read_chunks = distribute_reads(reads, num_processes)
    
    # Process each chunk in parallel
    with mp.Pool(processes=num_processes) as pool:
        process_func = partial(process_read_chunk_with_minimizers, 
                               minimizer_index=minimizer_index, 
                               taxonomy=taxonomy, 
                               k=k,
                               w=w,
                               min_minimizer_fraction=min_minimizer_fraction)
        chunk_results = pool.map(process_func, read_chunks)
    
    # Combine results from all chunks
    all_results = []
    for result in chunk_results:
        all_results.extend(result)
    
    return all_results

def summarize_results(classification_results):
    """Summarize classification results"""
    # Count matches per genome
    single_matches = Counter()
    multi_matches = Counter()
    lca_classifications = Counter()
    unclassified = 0
    
    for _, classification, lca in classification_results:
        if classification == "Unclassified":
            unclassified += 1
        elif len(classification) == 1:
            single_matches[classification[0]] += 1
        else:
            # For multi-matches, count each genome
            for genome in classification:
                multi_matches[genome] += 1
            # Also count by LCA
            lca_classifications[lca] += 1
    
    # Prepare summary
    summary = {
        "total_reads": len(classification_results),
        "single_matches": dict(single_matches),
        "multi_matches": dict(multi_matches),
        "lca_classifications": dict(lca_classifications),
        "unclassified": unclassified
    }
    
    return summary

def print_summary(summary):
    """Print a formatted summary of classification results"""
    print("\n====================== CLASSIFICATION SUMMARY ======================")
    print(f"Total reads processed: {summary['total_reads']}")
    
    print("\nSingle Genome Matches:")
    for genome, count in summary['single_matches'].items():
        percentage = (count / summary['total_reads']) * 100
        print(f"- {genome}: {count} reads ({percentage:.2f}%)")
    
    if summary['lca_classifications']:
        print("\nMulti-Genome Matches (by LCA):")
        for lca, count in summary['lca_classifications'].items():
            percentage = (count / summary['total_reads']) * 100
            print(f"- {lca}: {count} reads ({percentage:.2f}%)")
    
    print(f"\nUnclassified: {summary['unclassified']} reads ({(summary['unclassified'] / summary['total_reads']) * 100:.2f}%)")
    print("===================================================================")

# ==================== Main Function ====================

def main(k=31, w=10, num_processes=8, max_reads_per_file=10000, min_minimizer_fraction=1.0):
    """Main function to run the entire classification pipeline"""
    print(f"Starting minimizer-based metagenomic classification with k={k}, w={w}, and {num_processes} processes")
    print(f"Using minimum minimizer match fraction: {min_minimizer_fraction}")
    
    # Define file paths
    genome_files = [
        "b_subtilis.fna",
        "e_coli.fna",
        "m_tuberculosis.fna",
        "p_aeruginosa.fna",
        "s_aureus.fna"
    ]
    
    # Define combined read file sets
    read_file_sets = {
        "error_free_reads": [
            "simulated_reads_no_errors_10k_R1.fastq",
            "simulated_reads_no_errors_10k_R2.fastq"
        ],
        "with_errors_reads": [
            "simulated_reads_miseq_10k_R1.fastq", 
            "simulated_reads_miseq_10k_R2.fastq"
        ]
    }
    
    # Load reference genomes
    genomes = read_genome_files(genome_files)
    taxonomy = generate_taxonomic_tree()
    
    print(f"Memory after loading reference genomes: {track_memory()} MB")
    
    # Build minimizer index
    index_start = time.time()
    minimizer_index, minimizer_stats = build_minimizer_index(
        genomes, k=k, w=w, num_processes=num_processes
    )
    index_time = time.time() - index_start
    
    print(f"Minimizer index built in {index_time:.2f} seconds")
    print(f"Memory after building minimizer index: {track_memory()} MB")
    
    # Print minimizer statistics
    print_minimizer_stats(minimizer_stats)
    
    # Process each combined read file set
    all_results = {}
    
    for set_name, file_paths in read_file_sets.items():
        print(f"\nProcessing {set_name}:")
        for path in file_paths:
            print(f"  - {path}")
        
        start_process = time.time()
        
        # Read and combine FASTQ files
        combined_reads = []
        for file_path in file_paths:
            reads = read_fastq_file(file_path, max_reads=max_reads_per_file)
            combined_reads.extend(reads)
            print(f"Read {len(reads)} reads from {file_path}")
            
        print(f"Total combined reads: {len(combined_reads)}")
        
        # Process reads
        results = classify_reads_parallel(
            combined_reads, 
            minimizer_index, 
            taxonomy, 
            k=k,
            w=w,
            num_processes=num_processes,
            min_minimizer_fraction=min_minimizer_fraction
        )
        
        process_time = time.time() - start_process
        all_results[set_name] = results
        
        # Summarize and print results
        summary = summarize_results(results)
        print_summary(summary)
        
        print(f"Processing time: {process_time:.2f} seconds")
        print(f"Reads per second: {len(results) / process_time:.2f}")
        
        print(f"Memory after processing {set_name}: {track_memory()} MB")
    
    # Performance report
    total_time = time.time() - start_time
    peak_memory = max(memory_usage)
    
    print("\n====================== PERFORMANCE SUMMARY ======================")
    print(f"Total execution time: {total_time:.2f} seconds")
    print(f"Minimizer index building time: {index_time:.2f} seconds")
    print(f"Peak memory usage: {peak_memory:.2f} MB")
    print(f"Memory reduction compared to full k-mer index: {minimizer_stats['reduction_ratio']:.2f}x")
    
    # Calculate processing statistics
    total_reads = sum(len(results) for results in all_results.values())
    reads_per_second = total_reads / (total_time - index_time)
    
    print(f"\nTotal reads processed: {total_reads}")
    print(f"Overall processing speed: {reads_per_second:.2f} reads per second")
    print("================================================================")

# Run the main function with the specified parameters
if __name__ == "__main__":
    # Set parameters
    K = 31  # k-mer length (as specified in the assignment)
    W = 10  # window size for minimizers
    NUM_PROCESSES = 1 # Number of processes to use
    MIN_MINIMIZER_FRACTION = 1.0  # For exact matching, require 100% of minimizers to match
    
    main(k=K, w=W, num_processes=NUM_PROCESSES, min_minimizer_fraction=MIN_MINIMIZER_FRACTION)

Starting with memory: 99.0 MB
Starting minimizer-based metagenomic classification with k=31, w=10, and 1 processes
Using minimum minimizer match fraction: 1.0
Reading reference genomes...
Reading genome for B. subtilis...
Read 4215606 bases for B. subtilis
Reading genome for E. coli...
Read 4641652 bases for E. coli
Reading genome for M. tuberculosis...
Read 4411532 bases for M. tuberculosis
Reading genome for P. aeruginosa...
Read 6264404 bases for P. aeruginosa
Reading genome for S. aureus...
Read 2821361 bases for S. aureus
Memory after loading reference genomes: 121.0 MB
Building minimizer index (k=31, w=10) for all genomes...
Processing B. subtilis...
Extracting minimizers (k=31, w=10) from B. subtilis...
Found 890,678 unique minimizers out of 4,215,576 possible k-mers
Reduction ratio: 4.73x
Minimizer extraction for B. subtilis completed in 13.88 seconds
Processing E. coli...
Extracting minimizers (k=31, w=10) from E. coli...
Found 943,297 unique minimizers out of 4,641,622 possib