In [None]:
import time
import os
import multiprocessing as mp
import numpy as np
from collections import defaultdict, Counter
import psutil
from functools import partial

# Performance tracking
start_time = time.time()
memory_usage = []

def track_memory():
    """Track current memory usage"""
    process = psutil.Process(os.getpid())
    memory_usage.append(process.memory_info().rss / 1024 / 1024)  # MB
    return memory_usage[-1]

print(f"Starting with memory: {track_memory()} MB")

# ==================== FASTA/FASTQ Parsing Functions ====================

def read_fasta_file(filename):
    """Read a FASTA file and return the sequence"""
    sequence = ""
    with open(filename, 'r') as f:
        # Skip header line
        header = f.readline().strip()
        # Read sequence lines
        for line in f:
            if line.startswith('>'):
                continue  # Skip additional headers
            sequence += line.strip()
    return sequence

def read_fastq_file(filename, max_reads=None):
    """Read a FASTQ file and return reads as a list of (header, sequence) tuples"""
    reads = []
    count = 0
    
    with open(filename, 'r') as f:
        while True:
            if max_reads and count >= max_reads:
                break
                
            header = f.readline().strip()
            if not header:
                break  # End of file
                
            sequence = f.readline().strip()
            f.readline()  # Skip "+" line
            quality = f.readline().strip()
            
            if header and sequence and quality:
                reads.append((header, sequence))
                count += 1
            else:
                break  # Incomplete entry
    
    return reads

def read_genome_files(filenames):
    """Read all genome files and return a dictionary of sequences"""
    print("Reading reference genomes...")
    genomes = {}
    
    for filename in filenames:
        base_name = os.path.basename(filename)
        # Extract species name from filename
        species = base_name.split('.')[0]
        
        # Use prettier names for the species
        if species == "e_coli":
            species = "E. coli"
        elif species == "b_subtilis":
            species = "B. subtilis"
        elif species == "p_aeruginosa":
            species = "P. aeruginosa"
        elif species == "s_aureus":
            species = "S. aureus"
        elif species == "m_tuberculosis":
            species = "M. tuberculosis"
        
        print(f"Reading genome for {species}...")
        sequence = read_fasta_file(filename)
        genomes[species] = sequence
        print(f"Read {len(sequence)} bases for {species}")
    
    return genomes

def generate_taxonomic_tree():
    """Generate a taxonomic tree for the 5 reference genomes"""
    taxonomy = {
        'E. coli': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Enterobacterales', 'Enterobacteriaceae', 'Escherichia', 'E. coli'],
        'B. subtilis': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Bacillaceae', 'Bacillus', 'B. subtilis'],
        'P. aeruginosa': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Pseudomonadales', 'Pseudomonadaceae', 'Pseudomonas', 'P. aeruginosa'],
        'S. aureus': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Staphylococcaceae', 'Staphylococcus', 'S. aureus'],
        'M. tuberculosis': ['Bacteria', 'Actinobacteria', 'Actinomycetia', 'Mycobacteriales', 'Mycobacteriaceae', 'Mycobacterium', 'M. tuberculosis']
    }
    return taxonomy

def find_lca(taxonomy, species_list):
    """Find the lowest common ancestor of a list of species"""
    if not species_list:
        return "Unknown"
    if len(species_list) == 1:
        return species_list[0]  # Return the species if only one match
    
    # Get taxonomic paths for each species
    paths = [taxonomy[species] for species in species_list]
    
    # Find the common prefix of all paths
    common_path = []
    for i in range(min(len(path) for path in paths)):
        if len(set(path[i] for path in paths)) == 1:
            common_path.append(paths[0][i])
        else:
            break
    
    return common_path[-1] if common_path else "Bacteria"

# ==================== Suffix Array Implementation (Optimized) ====================

def build_suffix_array(text):
    """Build a suffix array for the given text (optimized for DNA sequences)"""
    n = len(text)
    
    # For very short texts, use direct approach
    if n < 10000:
        suffixes = [(text[i:], i) for i in range(n)]
        suffixes.sort()
        suffix_array = [pos for _, pos in suffixes]
        return suffix_array
    
    # For longer texts, use memory-efficient approach
    # Store only indices and compare characters directly during sort
    indices = list(range(n))
    
    # Use Python's built-in sort with our custom comparison function
    indices.sort(key=lambda i: (text[i:i+100], i))  # Use first 100 chars as initial sort
    
    return indices

def exact_search(text, pattern, suffix_array):
    """Fast exact pattern matching using suffix array with direct character comparison"""
    n, m = len(text), len(pattern)
    
    if m == 0:
        return False
    
    # Binary search for pattern
    left, right = 0, n - 1
    
    while left <= right:
        mid = (left + right) // 2
        suffix_pos = suffix_array[mid]
        
        # Compare pattern with current suffix without creating substrings
        match = True
        comparison = 0
        
        for i in range(min(m, n - suffix_pos)):
            if text[suffix_pos + i] < pattern[i]:
                match = False
                comparison = -1
                break
            elif text[suffix_pos + i] > pattern[i]:
                match = False
                comparison = 1
                break
        
        # If we've compared all characters and they matched
        if match:
            # If we've reached the end of the pattern, it's a match
            if m <= n - suffix_pos:
                return True
            # If pattern is longer than remaining suffix
            comparison = -1
        
        if comparison < 0:
            left = mid + 1
        else:
            right = mid - 1
    
    return False

def build_genome_indices(genomes):
    """Build suffix arrays for all genomes"""
    genome_indices = {}
    
    print("Building suffix arrays for all genomes...")
    for name, genome in genomes.items():
        print(f"Building index for {name}...")
        start = time.time()
        
        # Build suffix array
        suffix_array = build_suffix_array(genome)
        
        genome_indices[name] = {
            "sequence": genome,
            "suffix_array": suffix_array
        }
        
        print(f"Index for {name} built in {time.time() - start:.2f} seconds")
    
    return genome_indices

# ==================== Read Processing with Direct Parallelization ====================

def process_read_chunk(read_chunk, genome_indices, taxonomy):
    """Process a chunk of reads against all genome indices"""
    results = []
    
    for header, sequence in read_chunk:
        # Skip reads with ambiguous bases (N) for exact matching
        if 'N' in sequence:
            results.append((header, "Unclassified", "Unknown"))
            continue
            
        # Find exact matches in each genome
        matching_genomes = []
        
        for genome_name, index_data in genome_indices.items():
            if exact_search(index_data["sequence"], sequence, index_data["suffix_array"]):
                matching_genomes.append(genome_name)
        
        # Classify based on matches
        if not matching_genomes:
            classification = "Unclassified"
            lca = "Unknown"
        else:
            classification = matching_genomes
            lca = find_lca(taxonomy, matching_genomes)
        
        results.append((header, classification, lca))
    
    return results

def distribute_reads(reads, num_chunks):
    """Distribute reads evenly across chunks"""
    chunk_size = max(1, len(reads) // num_chunks)
    return [reads[i:i + chunk_size] for i in range(0, len(reads), chunk_size)]

def classify_reads_parallel(reads, genome_indices, taxonomy, num_processes=4):
    """Classify reads in parallel using multiple processes"""
    # Split reads into chunks for parallelization
    read_chunks = distribute_reads(reads, num_processes)
    
    # Process each chunk in parallel
    with mp.Pool(processes=num_processes) as pool:
        process_func = partial(process_read_chunk, genome_indices=genome_indices, taxonomy=taxonomy)
        chunk_results = pool.map(process_func, read_chunks)
    
    # Combine results from all chunks
    all_results = []
    for result in chunk_results:
        all_results.extend(result)
    
    return all_results

def summarize_results(classification_results):
    """Summarize classification results"""
    # Count matches per genome
    single_matches = Counter()
    multi_matches = Counter()
    lca_classifications = Counter()
    unclassified = 0
    
    for _, classification, lca in classification_results:
        if classification == "Unclassified":
            unclassified += 1
        elif len(classification) == 1:
            single_matches[classification[0]] += 1
        else:
            # For multi-matches, count each genome
            for genome in classification:
                multi_matches[genome] += 1
            # Also count by LCA
            lca_classifications[lca] += 1
    
    # Prepare summary
    summary = {
        "total_reads": len(classification_results),
        "single_matches": dict(single_matches),
        "multi_matches": dict(multi_matches),
        "lca_classifications": dict(lca_classifications),
        "unclassified": unclassified
    }
    
    return summary

def print_summary(summary):
    """Print a formatted summary of classification results"""
    print("\n====================== CLASSIFICATION SUMMARY ======================")
    print(f"Total reads processed: {summary['total_reads']}")
    
    print("\nSingle Genome Matches:")
    for genome, count in summary['single_matches'].items():
        percentage = (count / summary['total_reads']) * 100
        print(f"- {genome}: {count} reads ({percentage:.2f}%)")
    
    if summary['lca_classifications']:
        print("\nMulti-Genome Matches (by LCA):")
        total_multi = sum(summary['lca_classifications'].values())
        for lca, count in summary['lca_classifications'].items():
            percentage = (count / summary['total_reads']) * 100
            print(f"- {lca}: {count} reads ({percentage:.2f}%)")
    
    print(f"\nUnclassified: {summary['unclassified']} reads ({(summary['unclassified'] / summary['total_reads']) * 100:.2f}%)")
    print("===================================================================")

# ==================== Main Function ====================

def main(num_processes=8, max_reads_per_file=10000):
    """Main function to run the entire classification pipeline"""
    print(f"Starting metagenomic classification with {num_processes} processes")
    
    # Define file paths
    genome_files = [
        "b_subtilis.fna",
        "e_coli.fna",
        "m_tuberculosis.fna",
        "p_aeruginosa.fna",
        "s_aureus.fna"
    ]
    
    # Define combined read file sets
    read_file_sets = {
        "error_free_reads": [
            "simulated_reads_no_errors_10k_R1.fastq",
            "simulated_reads_no_errors_10k_R2.fastq"
        ],
        "with_errors_reads": [
            "simulated_reads_miseq_10k_R1.fastq", 
            "simulated_reads_miseq_10k_R2.fastq"
        ]
    }
    
    # Load reference genomes and build indices
    genomes = read_genome_files(genome_files)
    taxonomy = generate_taxonomic_tree()
    
    print(f"Memory after loading reference genomes: {track_memory()} MB")
    
    # Build genome indices
    genome_indices = build_genome_indices(genomes)
    
    print(f"Memory after building indices: {track_memory()} MB")
    
    # Process each combined read file set
    all_results = {}
    
    for set_name, file_paths in read_file_sets.items():
        print(f"\nProcessing {set_name}:")
        for path in file_paths:
            print(f"  - {path}")
        
        start_process = time.time()
        
        # Read and combine FASTQ files
        combined_reads = []
        for file_path in file_paths:
            reads = read_fastq_file(file_path, max_reads=max_reads_per_file)
            combined_reads.extend(reads)
            print(f"Read {len(reads)} reads from {file_path}")
            
        print(f"Total combined reads: {len(combined_reads)}")
        
        # Process reads
        results = classify_reads_parallel(
            combined_reads, 
            genome_indices, 
            taxonomy, 
            num_processes=num_processes
        )
        
        process_time = time.time() - start_process
        all_results[set_name] = results
        
        # Summarize and print results
        summary = summarize_results(results)
        print_summary(summary)
        
        print(f"Processing time: {process_time:.2f} seconds")
        print(f"Reads per second: {len(results) / process_time:.2f}")
        
        print(f"Memory after processing {set_name}: {track_memory()} MB")
    
    # Performance report
    total_time = time.time() - start_time
    peak_memory = max(memory_usage)
    
    print("\n====================== PERFORMANCE SUMMARY ======================")
    print(f"Total execution time: {total_time:.2f} seconds")
    print(f"Peak memory usage: {peak_memory:.2f} MB")
    
    # Calculate processing statistics
    total_reads = sum(len(results) for results in all_results.values())
    reads_per_second = total_reads / total_time
    
    print(f"\nTotal reads processed: {total_reads}")
    print(f"Overall processing speed: {reads_per_second:.2f} reads per second")
    print("================================================================")

# Run the main function with the specified number of processes
if __name__ == "__main__":
    # Set the number of CPU cores to use (adjust as needed)
    NUM_PROCESSES = 1  # Using 8 processes for better performance
    main(num_processes=NUM_PROCESSES)