In [None]:
import time
import os
import subprocess
import multiprocessing as mp
from collections import Counter, defaultdict
import psutil
import sys
import tempfile
import shutil
import re

# First check if BLAST+ is installed
def check_blast_installation():
    try:
        # Check if blastn is available
        result = subprocess.run(["which", "blastn"], capture_output=True, text=True)
        if result.returncode != 0:
            print("ERROR: BLAST+ tools (blastn, makeblastdb) not found in PATH")
            print("\nPlease install NCBI BLAST+ tools:")
            print("For Conda: conda install -c bioconda blast")
            print("For Ubuntu/Debian: sudo apt-get install ncbi-blast+")
            return False
        
        blastn_path = result.stdout.strip()
        print(f"Found blastn at: {blastn_path}")
        
        # Check makeblastdb
        result = subprocess.run(["which", "makeblastdb"], capture_output=True, text=True)
        if result.returncode != 0:
            print("ERROR: makeblastdb tool not found in PATH")
            return False
            
        makeblastdb_path = result.stdout.strip()
        print(f"Found makeblastdb at: {makeblastdb_path}")
        return True
    except Exception as e:
        print(f"Error checking BLAST installation: {e}")
        return False

# Performance tracking
start_time = time.time()
memory_usage = []

def track_memory():
    """Track current memory usage"""
    process = psutil.Process(os.getpid())
    memory_usage.append(process.memory_info().rss / 1024 / 1024)  # MB
    return memory_usage[-1]

print(f"Starting with memory: {track_memory()} MB")

# ==================== Taxonomy Functions ====================

def generate_taxonomic_tree():
    """Generate a taxonomic tree for the 5 reference genomes"""
    taxonomy = {
        'E. coli': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Enterobacterales', 'Enterobacteriaceae', 'Escherichia', 'E. coli'],
        'B. subtilis': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Bacillaceae', 'Bacillus', 'B. subtilis'],
        'P. aeruginosa': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Pseudomonadales', 'Pseudomonadaceae', 'Pseudomonas', 'P. aeruginosa'],
        'S. aureus': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Staphylococcaceae', 'Staphylococcus', 'S. aureus'],
        'M. tuberculosis': ['Bacteria', 'Actinobacteria', 'Actinomycetia', 'Mycobacteriales', 'Mycobacteriaceae', 'Mycobacterium', 'M. tuberculosis'],
        # Add variations of names with different separators
        'E_coli': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Enterobacterales', 'Enterobacteriaceae', 'Escherichia', 'E_coli'],
        'B_subtilis': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Bacillaceae', 'Bacillus', 'B_subtilis'],
        'P_aeruginosa': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Pseudomonadales', 'Pseudomonadaceae', 'Pseudomonas', 'P_aeruginosa'],
        'S_aureus': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Staphylococcaceae', 'Staphylococcus', 'S_aureus'],
        'M_tuberculosis': ['Bacteria', 'Actinobacteria', 'Actinomycetia', 'Mycobacteriales', 'Mycobacteriaceae', 'Mycobacterium', 'M_tuberculosis'],
        # Add double underscore versions that might be created by regex
        'E__coli': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Enterobacterales', 'Enterobacteriaceae', 'Escherichia', 'E__coli'],
        'B__subtilis': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Bacillaceae', 'Bacillus', 'B__subtilis'],
        'P__aeruginosa': ['Bacteria', 'Proteobacteria', 'Gammaproteobacteria', 'Pseudomonadales', 'Pseudomonadaceae', 'Pseudomonas', 'P__aeruginosa'],
        'S__aureus': ['Bacteria', 'Firmicutes', 'Bacilli', 'Bacillales', 'Staphylococcaceae', 'Staphylococcus', 'S__aureus'],
        'M__tuberculosis': ['Bacteria', 'Actinobacteria', 'Actinomycetia', 'Mycobacteriales', 'Mycobacteriaceae', 'Mycobacterium', 'M__tuberculosis']
    }
    return taxonomy

def find_lca(taxonomy, species_list):
    """Find the lowest common ancestor of a list of species"""
    if not species_list:
        return "Unknown"
    if len(species_list) == 1:
        return species_list[0]  # Return the species if only one match
    
    # Get taxonomic paths for each species
    paths = [taxonomy[species] for species in species_list]
    
    # Find the common prefix of all paths
    common_path = []
    for i in range(min(len(path) for path in paths)):
        if len(set(path[i] for path in paths)) == 1:
            common_path.append(paths[0][i])
        else:
            break
    
    return common_path[-1] if common_path else "Bacteria"

# ==================== BLAST Database Creation ====================

def create_blast_db(fasta_files, temp_dir):
    """Create BLAST databases for each genome"""
    db_paths = {}
    
    print("Creating BLAST databases...")
    for fasta_file in fasta_files:
        base_name = os.path.basename(fasta_file)
        species = base_name.split('.')[0]
        
        # Use prettier names for the species but without spaces
        if species == "e_coli":
            species_display = "E. coli"
            species_id = "E_coli"
        elif species == "b_subtilis":
            species_display = "B. subtilis"
            species_id = "B_subtilis"
        elif species == "p_aeruginosa":
            species_display = "P. aeruginosa"
            species_id = "P_aeruginosa"
        elif species == "s_aureus":
            species_display = "S. aureus"
            species_id = "S_aureus"
        elif species == "m_tuberculosis":
            species_display = "M. tuberculosis"
            species_id = "M_tuberculosis"
        else:
            species_display = species
            species_id = species
        
        # Create a copy of the FASTA file in the temp directory (avoid spaces in filenames)
        db_path = os.path.join(temp_dir, species_id)
        shutil.copy(fasta_file, f"{db_path}.fna")
        
        # Create the BLAST database
        print(f"Creating BLAST database for {species_display}...")
        cmd = ["makeblastdb", "-in", f"{db_path}.fna", "-dbtype", "nucl", "-out", db_path, "-title", species_id]
        
        try:
            subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            print(f"Successfully created database for {species_display}")
        except subprocess.CalledProcessError as e:
            print(f"Error creating BLAST database for {species_display}:")
            print(f"Command: {' '.join(cmd)}")
            print(f"Error: {e}")
            print(f"stdout: {e.stdout.decode() if e.stdout else 'None'}")
            print(f"stderr: {e.stderr.decode() if e.stderr else 'None'}")
            raise e
        
        db_paths[species_display] = db_path
    
    return db_paths

# ==================== FASTQ Parsing Functions ====================

def read_fastq_entries(fastq_file, max_reads=None):
    """Read entries from a FASTQ file and return a list of (header, sequence) tuples"""
    reads = []
    count = 0
    
    with open(fastq_file, 'r') as f:
        while True:
            if max_reads and count >= max_reads:
                break
                
            # Read FASTQ entry
            header = f.readline().strip()
            if not header:
                break  # End of file
                
            sequence = f.readline().strip()
            f.readline()  # Skip '+' line
            quality = f.readline().strip()
            
            if header and sequence and quality:
                reads.append((header, sequence))
                count += 1
            else:
                break  # Incomplete entry
    
    return reads

def write_fasta_file(reads, output_fasta):
    """Write reads to a FASTA file"""
    with open(output_fasta, 'w') as f:
        for header, sequence in reads:
            # Extract read ID without @ symbol
            read_id = header[1:].split()[0]
            f.write(f">{read_id}\n{sequence}\n")
    
    return len(reads)

# ==================== BLAST Processing ====================

def run_blast(query_fasta, db_path, output_file, read_length=100, max_mismatches=1, evalue=1e-5, num_threads=1, max_hits=5):
    """
    Run blastn on the query against the specified database, allowing up to specified mismatches
    
    Parameters:
    - read_length: Average length of reads (used to calculate percent identity)
    - max_mismatches: Maximum number of mismatches to allow
    """
    
    # Calculate percent identity threshold based on read length and max mismatches
    # For example, allowing 1 mismatch in a 100bp read = 99% identity
    percent_identity = 100 - (max_mismatches * 100 / read_length)
    
    # Use text output format instead of XML for easier parsing
    cmd = [
        "blastn",
        "-query", query_fasta,
        "-db", db_path,
        "-out", output_file,
        "-outfmt", "6 qseqid sseqid pident length mismatch gapopen qstart qend sstart send evalue bitscore",
        "-evalue", str(evalue),
        "-max_target_seqs", str(max_hits),
        "-num_threads", str(num_threads),
        "-perc_identity", str(percent_identity),
        "-strand", "plus"  # Only search the forward strand
    ]
    
    try:
        subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    except subprocess.CalledProcessError as e:
        print(f"Error running BLAST:")
        print(f"Command: {' '.join(cmd)}")
        print(f"Error: {e}")
        print(f"stdout: {e.stdout.decode() if e.stdout else 'None'}")
        print(f"stderr: {e.stderr.decode() if e.stderr else 'None'}")
        raise e
    
    return output_file

def process_blast_results(results_file, query_ids, species):
    """Process BLAST tabular results and extract hits"""
    results = {}
    
    # Initialize all queries with no hits
    for query_id in query_ids:
        results[query_id] = []
    
    # Parse results file
    if os.path.getsize(results_file) > 0:  # Check if file is not empty
        with open(results_file, 'r') as f:
            for line in f:
                fields = line.strip().split('\t')
                if len(fields) >= 12:  # Ensure we have all expected fields
                    query_id = fields[0]
                    mismatches = int(fields[4])  # The mismatch column
                    
                    # Only add if this is a new hit for this query
                    if species not in results[query_id]:
                        results[query_id].append(species)
    
    return results

def get_query_ids_from_fasta(fasta_file):
    """Extract all query IDs from a FASTA file"""
    ids = []
    with open(fasta_file, 'r') as f:
        for line in f:
            if line.startswith('>'):
                ids.append(line.strip()[1:])  # Remove '>' and newline
    return ids

# ==================== Read Classification ====================

def classify_reads_with_blast(fastq_files, db_paths, taxonomy, max_reads_per_file=10000, temp_dir=None, num_threads=8, max_mismatches=1):
    """
    Classify reads using BLAST with a specified mismatch tolerance
    
    Parameters:
    - max_mismatches: Maximum number of mismatches to allow in alignments
    """
    if temp_dir is None:
        temp_dir = tempfile.mkdtemp()
    
    all_results = {}
    
    for set_name, file_paths in fastq_files.items():
        print(f"\nProcessing {set_name}:")
        for path in file_paths:
            print(f"  - {path}")
        
        start_process = time.time()
        
        # Read FASTQ files and combine reads
        combined_reads = []
        for fastq_file in file_paths:
            reads = read_fastq_entries(fastq_file, max_reads=max_reads_per_file)
            combined_reads.extend(reads)
            print(f"Read {len(reads)} reads from {fastq_file}")
        
        print(f"Total combined reads: {len(combined_reads)}")
        
        # Calculate average read length for percent identity threshold
        total_length = sum(len(seq) for _, seq in combined_reads)
        avg_read_length = total_length / len(combined_reads) if combined_reads else 100
        
        print(f"Average read length: {avg_read_length:.2f}bp")
        print(f"Setting up BLAST to allow maximum {max_mismatches} mismatch(es)")
        percent_identity = 100 - (max_mismatches * 100 / avg_read_length)
        print(f"Using percent identity threshold of {percent_identity:.2f}%")
        
        # Write combined reads to a FASTA file
        combined_fasta = os.path.join(temp_dir, f"{set_name}_combined.fasta")
        write_fasta_file(combined_reads, combined_fasta)
        
        # Get all query IDs
        query_ids = get_query_ids_from_fasta(combined_fasta)
        
        # Initialize results dictionary
        classification_results = {query_id: [] for query_id in query_ids}
        
        # Run BLAST against each reference genome
        for species, db_path in db_paths.items():
            print(f"BLASTing against {species}...")
            blast_start = time.time()
            
            # Run BLAST search
            output_file = os.path.join(temp_dir, f"{set_name}_{re.sub(r'[. ]', '_', species)}_blast.out")
            run_blast(
                combined_fasta, 
                db_path, 
                output_file, 
                read_length=avg_read_length,
                max_mismatches=max_mismatches,
                num_threads=num_threads
            )
            
            # Process BLAST results
            species_results = process_blast_results(output_file, query_ids, species)
            
            # Merge results
            for query_id, hits in species_results.items():
                classification_results[query_id].extend(hits)
            
            print(f"BLAST against {species} completed in {time.time() - blast_start:.2f} seconds")
        
        # Convert to the format used by the suffix array implementation
        formatted_results = []
        for query_id, matching_genomes in classification_results.items():
            if not matching_genomes:
                formatted_results.append((f"@{query_id}", "Unclassified", "Unknown"))
            else:
                # Use the original species names for taxonomy lookup
                lca = find_lca(taxonomy, matching_genomes)
                formatted_results.append((f"@{query_id}", matching_genomes, lca))
        
        all_results[set_name] = formatted_results
        
        # Calculate processing time
        process_time = time.time() - start_process
        
        # Summarize and print results
        summary = summarize_results(formatted_results)
        print_summary(summary)
        
        print(f"Processing time: {process_time:.2f} seconds")
        print(f"Reads per second: {len(formatted_results) / process_time:.2f}")
        
        print(f"Memory after processing {set_name}: {track_memory()} MB")
    
    return all_results

def summarize_results(classification_results):
    """Summarize classification results"""
    # Count matches per genome
    single_matches = Counter()
    multi_matches = Counter()
    lca_classifications = Counter()
    unclassified = 0
    
    for _, classification, lca in classification_results:
        if classification == "Unclassified":
            unclassified += 1
        elif len(classification) == 1:
            single_matches[classification[0]] += 1
        else:
            # For multi-matches, count each genome
            for genome in classification:
                multi_matches[genome] += 1
            # Also count by LCA
            lca_classifications[lca] += 1
    
    # Prepare summary
    summary = {
        "total_reads": len(classification_results),
        "single_matches": dict(single_matches),
        "multi_matches": dict(multi_matches),
        "lca_classifications": dict(lca_classifications),
        "unclassified": unclassified
    }
    
    return summary

def print_summary(summary):
    """Print a formatted summary of classification results"""
    print("\n====================== CLASSIFICATION SUMMARY ======================")
    print(f"Total reads processed: {summary['total_reads']}")
    
    print("\nSingle Genome Matches:")
    for genome, count in summary['single_matches'].items():
        percentage = (count / summary['total_reads']) * 100
        print(f"- {genome}: {count} reads ({percentage:.2f}%)")
    
    if summary['lca_classifications']:
        print("\nMulti-Genome Matches (by LCA):")
        for lca, count in summary['lca_classifications'].items():
            percentage = (count / summary['total_reads']) * 100
            print(f"- {lca}: {count} reads ({percentage:.2f}%)")
    
    print(f"\nUnclassified: {summary['unclassified']} reads ({(summary['unclassified'] / summary['total_reads']) * 100:.2f}%)")
    print("===================================================================")

# ==================== Main Function ====================

def main(num_threads=8, max_reads_per_file=10000, max_mismatches=1):
    """
    Main function to run BLAST-based classification allowing up to the specified number of mismatches
    
    Parameters:
    - num_threads: Number of threads to use for BLAST
    - max_reads_per_file: Maximum number of reads to process from each file
    - max_mismatches: Maximum number of mismatches to allow in alignments
    """
    # Check if BLAST+ is installed
    if not check_blast_installation():
        print("Please install BLAST+ tools before running this script")
        return
        
    print(f"Starting BLAST-based metagenomic classification with {num_threads} threads")
    print(f"Allowing up to {max_mismatches} mismatch(es) per read")
    
    # Define file paths
    genome_files = [
        "b_subtilis.fna",
        "e_coli.fna",
        "m_tuberculosis.fna",
        "p_aeruginosa.fna",
        "s_aureus.fna"
    ]
    
    # Define combined read file sets
    read_file_sets = {
        "error_free_reads": [
            "simulated_reads_no_errors_10k_R1.fastq",
            "simulated_reads_no_errors_10k_R2.fastq"
        ],
        "with_errors_reads": [
            "simulated_reads_miseq_10k_R1.fastq", 
            "simulated_reads_miseq_10k_R2.fastq"
        ]
    }
    
    # Create a temporary directory
    temp_dir = tempfile.mkdtemp()
    print(f"Using temporary directory: {temp_dir}")
    
    try:
        # Create BLAST databases
        db_paths = create_blast_db(genome_files, temp_dir)
        taxonomy = generate_taxonomic_tree()
        
        print(f"Memory after creating BLAST databases: {track_memory()} MB")
        
        # Classify reads
        all_results = classify_reads_with_blast(
            read_file_sets, 
            db_paths, 
            taxonomy, 
            max_reads_per_file=max_reads_per_file,
            temp_dir=temp_dir,
            num_threads=num_threads,
            max_mismatches=max_mismatches
        )
        
        # Performance report
        total_time = time.time() - start_time
        peak_memory = max(memory_usage)
        
        print("\n====================== PERFORMANCE SUMMARY ======================")
        print(f"Total execution time: {total_time:.2f} seconds")
        print(f"Peak memory usage: {peak_memory:.2f} MB")
        print(f"Maximum mismatches allowed: {max_mismatches}")
        
        # Calculate processing statistics
        total_reads = sum(len(results) for results in all_results.values())
        reads_per_second = total_reads / total_time
        
        print(f"\nTotal reads processed: {total_reads}")
        print(f"Overall processing speed: {reads_per_second:.2f} reads per second")
        print("================================================================")
        
    finally:
        # Clean up temporary directory
        print(f"Cleaning up temporary directory: {temp_dir}")
        shutil.rmtree(temp_dir)

# Run the main function with the specified number of threads
if __name__ == "__main__":
    # Use 8 threads for BLAST (or adjust based on your system)
    NUM_THREADS = 8
    
    # Maximum number of mismatches to allow (Task 1.3)
    MAX_MISMATCHES = 0
    
    main(num_threads=NUM_THREADS, max_mismatches=MAX_MISMATCHES)