In [None]:
import time
import os
import multiprocessing as mp
from collections import defaultdict, Counter, deque
import psutil
from functools import partial
import numpy as np
import sys

# Performance tracking
start_time = time.time()
memory_usage = []

def track_memory():
    """Track current memory usage"""
    process = psutil.Process(os.getpid())
    memory_usage.append(process.memory_info().rss / 1024 / 1024)  # MB
    return memory_usage[-1]

print(f"Starting with memory: {track_memory()} MB")

# ==================== K-mer Encoding Functions ====================

def encode_kmer(kmer):
    """
    Encode a k-mer string into a 64-bit integer using 2-bit encoding
    A=00, C=01, G=10, T=11
    """
    encoding = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    result = 0
    for base in kmer:
        result = (result << 2) | encoding.get(base, 0)
    return result

def decode_kmer(encoded, k):
    """
    Decode a 2-bit encoded k-mer back to a string
    """
    decoding = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}
    result = []
    for i in range(k):
        # Extract 2 bits from the end
        bits = (encoded >> (2*i)) & 3
        result.append(decoding[bits])
    return ''.join(reversed(result))

# ==================== FASTA/FASTQ Parsing Functions ====================

def read_fasta_file(filename):
    """Read a FASTA file and return the sequence"""
    sequence = ""
    with open(filename, 'r') as f:
        # Skip header line
        header = f.readline().strip()
        # Read sequence lines
        for line in f:
            if line.startswith('>'):
                continue  # Skip additional headers
            sequence += line.strip()
    return sequence

def read_fastq_file(filename, max_reads=None):
    """Read a FASTQ file and return reads as a list of (header, sequence) tuples"""
    reads = []
    count = 0
    
    with open(filename, 'r') as f:
        while True:
            if max_reads and count >= max_reads:
                break
                
            header = f.readline().strip()
            if not header:
                break  # End of file
                
            sequence = f.readline().strip()
            f.readline()  # Skip "+" line
            quality = f.readline().strip()
            
            if header and sequence and quality:
                reads.append((header, sequence))
                count += 1
            else:
                break  # Incomplete entry
    
    return reads

def read_genome_files(filenames):
    """Read all genome files and return a dictionary of sequences"""
    print("Reading reference genomes...")
    genomes = {}
    
    for filename in filenames:
        base_name = os.path.basename(filename)
        # Extract species name from filename
        species = base_name.split('.')[0]
        
        # Use prettier names for the species
        if species == "e_coli":
            species = "E. coli"
        elif species == "b_subtilis":
            species = "B. subtilis"
        elif species == "p_aeruginosa":
            species = "P. aeruginosa"
        elif species == "s_aureus":
            species = "S. aureus"
        elif species == "m_tuberculosis":
            species = "M. tuberculosis"
        
        print(f"Reading genome for {species}...")
        sequence = read_fasta_file(filename)
        genomes[species] = sequence
        print(f"Read {len(sequence)} bases for {species}")
    
    return genomes

def generate_taxonomic_tree():
    """Generate a taxonomic tree for the 5 reference genomes"""
    taxonomy = {
        'E. coli': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Enterobacterales', 'Enterobacteriaceae', 'Escherichia', 'E. coli'],
        'B. subtilis': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Bacillaceae', 'Bacillus', 'B. subtilis'],
        'P. aeruginosa': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Pseudomonadales', 'Pseudomonadaceae', 'Pseudomonas', 'P. aeruginosa'],
        'S. aureus': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Staphylococcaceae', 'Staphylococcus', 'S. aureus'],
        'M. tuberculosis': ['Bacteria', 'Actinobacteria', 'Actinomycetia', 'Mycobacteriales', 'Mycobacteriaceae', 'Mycobacterium', 'M. tuberculosis']
    }
    return taxonomy

def find_lca(taxonomy, species_list):
    """Find the lowest common ancestor of a list of species"""
    if not species_list:
        return "Unknown"
    if len(species_list) == 1:
        return species_list[0]  # Return the species if only one match
    
    # Get taxonomic paths for each species
    paths = [taxonomy[species] for species in species_list]
    
    # Find the common prefix of all paths
    common_path = []
    for i in range(min(len(path) for path in paths)):
        if len(set(path[i] for path in paths)) == 1:
            common_path.append(paths[0][i])
        else:
            break
    
    return common_path[-1] if common_path else "Bacteria"

# ==================== Minimizer Implementation ====================

def extract_minimizers_from_chunk(args):
    """Extract minimizers from a chunk of a genome sequence using a sliding window approach"""
    chunk, start_idx, k, w = args
    
    # Skip if chunk is too short
    if len(chunk) < k:
        return []
    
    # First, extract all k-mers with their positions
    kmers = []
    for i in range(len(chunk) - k + 1):
        kmer = chunk[i:i+k]
        if 'N' not in kmer:  # Skip k-mers with ambiguous bases
            encoded_kmer = encode_kmer(kmer)
            kmers.append((encoded_kmer, start_idx + i))
    
    # Now find minimizers in each window of w consecutive k-mers
    minimizers = {}
    
    for i in range(len(kmers) - w + 1):
        window = kmers[i:i+w]
        # Find the lexicographically smallest k-mer in the window
        min_kmer, min_pos = min(window, key=lambda x: x[0])
        
        if min_kmer not in minimizers:
            minimizers[min_kmer] = []
        if min_pos not in minimizers[min_kmer]:
            minimizers[min_kmer].append(min_pos)
    
    return list(minimizers.items())

def extract_minimizers_parallel(sequence, k=31, w=10, chunk_size=1000000, num_processes=8):
    """Extract minimizers from a sequence in parallel"""
    chunks = []
    for i in range(0, len(sequence), chunk_size):
        # Add k+w-1 to ensure we don't miss any minimizers at chunk boundaries
        chunk = sequence[i:i+chunk_size+k+w-1]
        if len(chunk) >= k:
            chunks.append((chunk, i, k, w))
    
    with mp.Pool(processes=num_processes) as pool:
        results = pool.map(extract_minimizers_from_chunk, chunks)
    
    # Combine results from all chunks
    all_minimizers = {}
    for result in results:
        for encoded_kmer, positions in result:
            if encoded_kmer not in all_minimizers:
                all_minimizers[encoded_kmer] = []
            all_minimizers[encoded_kmer].extend(positions)
    
    return all_minimizers

def build_minimizer_index_for_genome(genome_name, sequence, k=31, w=10, num_processes=8):
    """Build a minimizer-based index for a single genome"""
    print(f"Extracting minimizers (k={k}, w={w}) from {genome_name}...")
    start = time.time()
    
    # Extract minimizers in parallel
    minimizers = extract_minimizers_parallel(sequence, k, w, num_processes=num_processes)
    
    # Count the total number of k-mers
    total_kmers = len(sequence) - k + 1
    total_minimizers = len(minimizers)
    reduction_ratio = total_kmers / total_minimizers if total_minimizers > 0 else float('inf')
    
    print(f"Found {total_minimizers:,} unique minimizers out of {total_kmers:,} possible k-mers")
    print(f"Reduction ratio: {reduction_ratio:.2f}x")
    print(f"Minimizer extraction for {genome_name} completed in {time.time() - start:.2f} seconds")
    
    return minimizers, total_kmers, total_minimizers

def merge_minimizer_indices(genome_indices):
    """Merge minimizer indices from multiple genomes using regular dict"""
    combined_index = {}
    
    for genome_name, genome_index in genome_indices.items():
        for encoded_kmer, positions in genome_index.items():
            if encoded_kmer not in combined_index:
                combined_index[encoded_kmer] = {}
            combined_index[encoded_kmer][genome_name] = positions
    
    return combined_index

def build_minimizer_index(genomes, k=31, w=10, num_processes=8):
    """Build a minimizer-based index for all genomes"""
    genome_indices = {}
    total_genome_kmers = 0
    total_genome_minimizers = 0
    
    print(f"Building minimizer index (k={k}, w={w}) for all genomes...")
    for name, sequence in genomes.items():
        print(f"Processing {name}...")
        genome_index, genome_kmers, genome_minimizers = build_minimizer_index_for_genome(
            name, sequence, k, w, num_processes
        )
        genome_indices[name] = genome_index
        total_genome_kmers += genome_kmers
        total_genome_minimizers += genome_minimizers
    
    # Merge individual genome indices
    print("Merging minimizer indices...")
    start = time.time()
    combined_index = merge_minimizer_indices(genome_indices)
    print(f"Merged minimizer index contains {len(combined_index):,} unique minimizers")
    print(f"Index merging completed in {time.time() - start:.2f} seconds")
    
    # Calculate overall reduction ratio
    overall_reduction = total_genome_kmers / len(combined_index) if len(combined_index) > 0 else float('inf')
    print(f"Overall reduction ratio: {overall_reduction:.2f}x (from {total_genome_kmers:,} k-mers to {len(combined_index):,} minimizers)")
    
    # Calculate statistics
    minimizer_stats = analyze_minimizer_index(combined_index, k, w, total_genome_kmers, total_genome_minimizers)
    
    return combined_index, minimizer_stats

def analyze_minimizer_index(index, k, w, total_kmers, total_minimizers):
    """Analyze minimizer index statistics"""
    # Count minimizers unique to each genome vs shared
    unique_to_genome = {}
    for genome in ['E. coli', 'B. subtilis', 'P. aeruginosa', 'S. aureus', 'M. tuberculosis']:
        unique_to_genome[genome] = 0
    
    shared_minimizers = 0
    
    for encoded_kmer, genome_dict in index.items():
        if len(genome_dict) == 1:
            # Unique to one genome
            genome = next(iter(genome_dict.keys()))
            unique_to_genome[genome] += 1
        else:
            # Shared between multiple genomes
            shared_minimizers += 1
    
    # Theoretical k-mer space
    theoretical_kmers = 4**k
    
    # Build statistics dictionary
    stats = {
        "k": k,
        "w": w,
        "total_kmers": total_kmers,
        "total_minimizers": total_minimizers,
        "total_unique_minimizers": len(index),
        "theoretical_kmers": theoretical_kmers,
        "coverage_percent": (len(index) / theoretical_kmers) * 100,
        "reduction_ratio": total_kmers / len(index) if len(index) > 0 else float('inf'),
        "shared_minimizers": shared_minimizers,
        "unique_to_genome": unique_to_genome
    }
    
    return stats

def print_minimizer_stats(stats):
    """Print minimizer index statistics"""
    print("\n================ MINIMIZER INDEX STATISTICS ================")
    print(f"Parameters: k={stats['k']}, w={stats['w']}")
    print(f"Total unique minimizers in index: {stats['total_unique_minimizers']:,}")
    print(f"Total k-mers in reference genomes: {stats['total_kmers']:,}")
    print(f"Theoretical k-mer space (4^k): {stats['theoretical_kmers']:,}")
    print(f"Space coverage: {stats['coverage_percent']:.10f}%")
    print(f"Minimizer reduction ratio: {stats['reduction_ratio']:.2f}x")
    
    print("\nMinimizers unique to each genome:")
    for genome, count in stats['unique_to_genome'].items():
        print(f"- {genome}: {count:,} unique minimizers")
    
    print(f"\nMinimizers shared between genomes: {stats['shared_minimizers']:,}")
    print("===========================================================")

# ==================== Minimizer Based Classification ====================

def extract_read_minimizers(sequence, k=31, w=10):
    """Extract minimizers from a read sequence using a sliding window approach"""
    # Skip if read is too short
    if len(sequence) < k:
        return []
    
    # Extract all k-mers from the read
    kmers = []
    for i in range(len(sequence) - k + 1):
        kmer = sequence[i:i+k]
        if 'N' not in kmer:  # Skip k-mers with ambiguous bases
            kmers.append((encode_kmer(kmer), i))
    
    # Find minimizers in each window
    minimizers = set()
    for i in range(len(kmers) - w + 1):
        window = kmers[i:i+w]
        min_kmer, _ = min(window, key=lambda x: x[0])
        minimizers.add(min_kmer)
    
    return minimizers

def process_read_chunk_with_minimizers(read_chunk, minimizer_index, taxonomy, k=31, w=10, min_minimizer_fraction=0.1):
    """Process a chunk of reads against the minimizer index"""
    results = []
    
    for header, sequence in read_chunk:
        # Skip very short reads
        if len(sequence) < k:
            results.append((header, "Unclassified", "Unknown"))
            continue
        
        # Extract minimizers from the read
        read_minimizers = extract_read_minimizers(sequence, k, w)
        if not read_minimizers:
            results.append((header, "Unclassified", "Unknown"))
            continue
        
        # Count matches to each genome
        genome_matches = {}
        for genome in ['E. coli', 'B. subtilis', 'P. aeruginosa', 'S. aureus', 'M. tuberculosis']:
            genome_matches[genome] = 0
            
        total_minimizers = len(read_minimizers)
        matched_minimizers = 0
        
        for encoded_kmer in read_minimizers:
            if encoded_kmer in minimizer_index:
                matched_minimizers += 1
                for genome in minimizer_index[encoded_kmer]:
                    genome_matches[genome] += 1
        
        # Determine best matching genome(s)
        max_matches = max(genome_matches.values())
        if max_matches == 0:
            results.append((header, "Unclassified", "Unknown"))
            continue
            
        match_fraction = max_matches / total_minimizers if total_minimizers > 0 else 0
        
        # Classify based on matches and threshold
        if match_fraction < min_minimizer_fraction:
            results.append((header, "Unclassified", "Unknown"))
        else:
            best_genomes = [g for g, c in genome_matches.items() if c == max_matches]
            if len(best_genomes) == 1:
                results.append((header, best_genomes, best_genomes[0]))
            else:
                lca = find_lca(taxonomy, best_genomes)
                results.append((header, best_genomes, lca))
    
    return results

def distribute_reads(reads, num_chunks):
    """Distribute reads evenly across chunks"""
    chunk_size = max(1, len(reads) // num_chunks)
    return [reads[i:i + chunk_size] for i in range(0, len(reads), chunk_size)]

def classify_reads_parallel(reads, minimizer_index, taxonomy, k=31, w=10, num_processes=8, min_minimizer_fraction=0.1):
    """Classify reads in parallel using multiple processes"""
    # Split reads into chunks for parallelization
    read_chunks = distribute_reads(reads, num_processes)
    
    # Process each chunk in parallel
    with mp.Pool(processes=num_processes) as pool:
        process_func = partial(process_read_chunk_with_minimizers, 
                               minimizer_index=minimizer_index, 
                               taxonomy=taxonomy, 
                               k=k,
                               w=w,
                               min_minimizer_fraction=min_minimizer_fraction)
        chunk_results = pool.map(process_func, read_chunks)
    
    # Combine results from all chunks
    all_results = []
    for result in chunk_results:
        all_results.extend(result)
    
    return all_results

def summarize_results(classification_results):
    """Summarize classification results"""
    # Count matches per genome
    single_matches = Counter()
    multi_matches = Counter()
    lca_classifications = Counter()
    unclassified = 0
    
    for _, classification, lca in classification_results:
        if classification == "Unclassified":
            unclassified += 1
        elif len(classification) == 1:
            single_matches[classification[0]] += 1
        else:
            # For multi-matches, count each genome
            for genome in classification:
                multi_matches[genome] += 1
            # Also count by LCA
            lca_classifications[lca] += 1
    
    # Prepare summary
    summary = {
        "total_reads": len(classification_results),
        "single_matches": dict(single_matches),
        "multi_matches": dict(multi_matches),
        "lca_classifications": dict(lca_classifications),
        "unclassified": unclassified
    }
    
    return summary

def print_summary(summary):
    """Print a formatted summary of classification results"""
    print("\n====================== CLASSIFICATION SUMMARY ======================")
    print(f"Total reads processed: {summary['total_reads']}")
    
    print("\nSingle Genome Matches:")
    for genome, count in summary['single_matches'].items():
        percentage = (count / summary['total_reads']) * 100
        print(f"- {genome}: {count} reads ({percentage:.2f}%)")
    
    if summary['lca_classifications']:
        print("\nMulti-Genome Matches (by LCA):")
        for lca, count in summary['lca_classifications'].items():
            percentage = (count / summary['total_reads']) * 100
            print(f"- {lca}: {count} reads ({percentage:.2f}%)")
    
    print(f"\nUnclassified: {summary['unclassified']} reads ({(summary['unclassified'] / summary['total_reads']) * 100:.2f}%)")
    print("===================================================================")

# ==================== Main Function ====================

def main(k=31, w=10, num_processes=8, max_reads_per_file=10000, min_minimizer_fraction=1.0):
    """Main function to run the entire classification pipeline"""
    print(f"Starting minimizer-based metagenomic classification with k={k}, w={w}, and {num_processes} processes")
    print(f"Using minimum minimizer match fraction: {min_minimizer_fraction}")
    
    # Define file paths
    genome_files = [
        "b_subtilis.fna",
        "e_coli.fna",
        "m_tuberculosis.fna",
        "p_aeruginosa.fna",
        "s_aureus.fna"
    ]
    
    # Define combined read file sets
    read_file_sets = {
        "error_free_reads": [
            "simulated_reads_no_errors_10k_R1.fastq",
            "simulated_reads_no_errors_10k_R2.fastq"
        ],
        "with_errors_reads": [
            "simulated_reads_miseq_10k_R1.fastq", 
            "simulated_reads_miseq_10k_R2.fastq"
        ]
    }
    
    # Load reference genomes
    genomes = read_genome_files(genome_files)
    taxonomy = generate_taxonomic_tree()
    
    print(f"Memory after loading reference genomes: {track_memory()} MB")
    
    # Build minimizer index
    index_start = time.time()
    minimizer_index, minimizer_stats = build_minimizer_index(
        genomes, k=k, w=w, num_processes=num_processes
    )
    index_time = time.time() - index_start
    
    print(f"Minimizer index built in {index_time:.2f} seconds")
    print(f"Memory after building minimizer index: {track_memory()} MB")
    
    # Print minimizer statistics
    print_minimizer_stats(minimizer_stats)
    
    # Process each combined read file set
    all_results = {}
    
    for set_name, file_paths in read_file_sets.items():
        print(f"\nProcessing {set_name}:")
        for path in file_paths:
            print(f"  - {path}")
        
        start_process = time.time()
        
        # Read and combine FASTQ files
        combined_reads = []
        for file_path in file_paths:
            reads = read_fastq_file(file_path, max_reads=max_reads_per_file)
            combined_reads.extend(reads)
            print(f"Read {len(reads)} reads from {file_path}")
            
        print(f"Total combined reads: {len(combined_reads)}")
        
        # Process reads
        results = classify_reads_parallel(
            combined_reads, 
            minimizer_index, 
            taxonomy, 
            k=k,
            w=w,
            num_processes=num_processes,
            min_minimizer_fraction=min_minimizer_fraction
        )
        
        process_time = time.time() - start_process
        all_results[set_name] = results
        
        # Summarize and print results
        summary = summarize_results(results)
        print_summary(summary)
        
        print(f"Processing time: {process_time:.2f} seconds")
        print(f"Reads per second: {len(results) / process_time:.2f}")
        
        print(f"Memory after processing {set_name}: {track_memory()} MB")
    
    # Performance report
    total_time = time.time() - start_time
    peak_memory = max(memory_usage)
    
    print("\n====================== PERFORMANCE SUMMARY ======================")
    print(f"Total execution time: {total_time:.2f} seconds")
    print(f"Minimizer index building time: {index_time:.2f} seconds")
    print(f"Peak memory usage: {peak_memory:.2f} MB")
    print(f"Memory reduction compared to full k-mer index: {minimizer_stats['reduction_ratio']:.2f}x")
    
    # Calculate processing statistics
    total_reads = sum(len(results) for results in all_results.values())
    reads_per_second = total_reads / (total_time - index_time)
    
    print(f"\nTotal reads processed: {total_reads}")
    print(f"Overall processing speed: {reads_per_second:.2f} reads per second")
    print("================================================================")

# Run the main function with the specified parameters
if __name__ == "__main__":
    # Set parameters
    K = 31  # k-mer length (as specified in the assignment)
    W = 10  # window size for minimizers
    NUM_PROCESSES = 1 # Number of processes to use
    MIN_MINIMIZER_FRACTION = 1.0  # For exact matching, require 100% of minimizers to match
    
    main(k=K, w=W, num_processes=NUM_PROCESSES, min_minimizer_fraction=MIN_MINIMIZER_FRACTION)