In [75]:
clust2kingdom = dict(zip(merged_df['cluster'], merged_df['kingdom']))
merged_df[['cluster', 'kingdom']].to_csv('/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/cluster_kingdoms.csv')



In [79]:
viral_df = merged_df[merged_df['kingdom'] == 'Viruses']
viral_df.to_csv('/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/viral_clusters.csv', index=None)
len(viral_df)

3030

In [81]:
bacteria_df = merged_df[merged_df['kingdom'] == 'Bacteria']
bacteria_df.to_csv('/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Test_blast_SENZOR/LCA_results/bacteria_clusters.csv')
len(bacteria_df)

90908

In [64]:
import os
import pandas as pd
import numpy as np
from glob import glob
from collections import defaultdict
from statistics import mode
from Bio import Entrez
from ete3 import NCBITaxa
import logging
import sys
import re
from logging.handlers import RotatingFileHandler

# Configure logging
handler = RotatingFileHandler('util.log', maxBytes=5*1024*1024, backupCount=2)
logging.basicConfig(level=logging.INFO, handlers=[handler],
                    format='%(asctime)s - %(levelname)s - %(message)s')

# Set Entrez email
Entrez.email = "ifeanyi.omah@ed.ac.uk"

# Initialize NCBI Taxonomy
ncbi = NCBITaxa()

valid_lineages = set()

# Updated column names to match your .m9 file
nr_column_names = [
    'query', 'subject', 'identity', 'alignment_length', 'mismatches', 'gap_opens',
    'q_start', 'q_end', 's_start', 's_end', 'evalue', 'bit_score', 'tax_id',
    'sci_name', 'com_names', 'superkingdom', 'phylum', 'family',
    'lca_taxid', 'lca_name'
]

# --------------------------------------------------------------------
# 1) Change nr_dir to where your NR LCA files actually live
#    If these files are named like: LCA_<SAMPLE>_modified_diamond_blast_nr_annotated.m9
#    then we can glob them below.
# --------------------------------------------------------------------
nr_dir = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_Blast_NR/LCA/"

def load_lca(nr_directory):
    """
    Load and process LCA results from NR files in the specified directory.
    Dynamically glob any pattern you want.
    """
    pattern = "LCA_*_modified_diamond_blast_nr_annotated.m9"
    nr_file_paths = glob(os.path.join(nr_directory, pattern))
    logging.info(f"Found {len(nr_file_paths)} NR files matching '{pattern}' in {nr_directory}.")

    if not nr_file_paths:
        logging.error(f"No files matching {pattern} found in directory: {nr_directory}")
        raise FileNotFoundError(f"No files matching {pattern} found in directory: {nr_directory}")

    # Process each BLAST output file
    nr_df = load_blast_output(nr_file_paths, database='nr')

    # Log columns of combined DataFrame
    logging.info(f"Columns in combined nr_df: {nr_df.columns.tolist()}")

    # Debugging: Print the first few rows before dropping NaNs
    print("Sample of combined nr_df before dropping NaNs in 'taxid':")
    print(nr_df.head())

    # Remove rows with missing taxid
    initial_nr = len(nr_df)
    if 'taxid' not in nr_df.columns:
        logging.error("'taxid' column is missing in the combined DataFrame after loading.")
        print("'taxid' column is missing in the combined DataFrame after loading.", file=sys.stderr)
        raise KeyError("'taxid' column is missing in the combined DataFrame after loading.")
    nr_df = nr_df.dropna(subset=['taxid'])
    final_nr = len(nr_df)
    dropped_nr = initial_nr - final_nr
    if dropped_nr > 0:
        logging.info(f"Dropped {dropped_nr} rows with missing taxid from nr_df.")

    # Debugging: Print the first few rows after dropping NaNs
    print("Sample of combined nr_df after dropping NaNs in 'taxid':")
    print(nr_df.head())

    # Convert taxid to integer
    try:
        nr_df['taxid'] = nr_df['taxid'].astype(int)
    except ValueError as e:
        logging.error(f"Error converting 'taxid' to integer: {e}")
        print(f"Error converting 'taxid' to integer: {e}", file=sys.stderr)
        raise e

    # Build dictionaries for name/rank
    taxids = set(nr_df.taxid)
    taxids = {t for taxid in taxids for t in ncbi.get_lineage(taxid) if t != 0}

    try:
        taxid2name = ncbi.get_taxid_translator(taxids)
    except Exception as e:
        logging.error(f"Error fetching taxid to name mapping: {e}")
        print(f"Error fetching taxid to name mapping: {e}", file=sys.stderr)
        raise e

    taxid2name[-1] = 'Ambiguous'
    taxid2name[0] = 'Unknown'
    try:
        taxid2rank = ncbi.get_rank(taxids)
    except Exception as e:
        logging.error(f"Error fetching taxid to rank mapping: {e}")
        print(f"Error fetching taxid to rank mapping: {e}", file=sys.stderr)
        raise e

    # Collect valid lineages
    taxids_in_df = set(nr_df['taxid'])
    for tx in taxids_in_df:
        try:
            lineage = ncbi.get_lineage(tx)
            for t in lineage:
                if t != 0:
                    valid_lineages.add(t)
        except ValueError:
            logging.warning(f"Skipping {tx}: taxid not found in local DB.")

    # Restrict to valid lineages
    taxids = valid_lineages
    try:
        taxid2name = ncbi.get_taxid_translator(taxids)
    except Exception as e:
        logging.error(f"Error fetching taxid to name mapping: {e}")
        print(f"Error fetching taxid to name mapping: {e}", file=sys.stderr)
        raise e

    def taxid2kingdom(taxid):
        if taxid == 0:
            return 'Unknown'
        try:
            lineage = ncbi.get_lineage(taxid)
            if not lineage:
                return 'Unknown'
            ranks = ncbi.get_rank(lineage)
            for taxid_in_lineage in lineage:
                rank = ranks.get(taxid_in_lineage, None)
                if rank == 'superkingdom':
                    kingdom = taxid2name.get(taxid_in_lineage, 'Unknown')
                    return kingdom
            return 'Unknown'
        except Exception as e:
            logging.error(f"Error fetching lineage for taxid {taxid}: {e}")
            return 'Unknown'

    # Map taxid -> name, rank, kingdom
    nr_df['name'] = nr_df['taxid'].map(taxid2name)
    nr_df['rank'] = nr_df['taxid'].map(taxid2rank)
    nr_df['kingdom'] = nr_df['taxid'].apply(taxid2kingdom)
    nr_df['db'] = 'nr'  # Since this is the NR dataset

    return nr_df

def load_blast_output(files, database='nr'):
    """
    Load BLAST output files into a single DataFrame.
    """
    if database == 'nr':
        column_names = nr_column_names
    else:
        raise ValueError(f"Unknown database type: {database}. Expected 'nr'.")

    dfs = []
    for file in files:
        try:
            # Quick check of columns in the first line
            with open(file, 'r') as f:
                first_line = f.readline()
                columns_in_file = len(first_line.strip().split('\t'))
                expected_columns = len(column_names)
                if columns_in_file != expected_columns:
                    logging.warning(f"{file} has {columns_in_file} columns; expected {expected_columns}.")

            # Potentially skip a header if present
            with open(file, 'r') as f:
                first_line = f.readline()
                if 'sci_name' in first_line.lower() or 'lca_taxid' in first_line.lower():
                    skiprows = 1
                else:
                    skiprows = 0

            df = pd.read_csv(
                file,
                sep='\t',
                header=None,
                names=column_names,
                na_values=['', 'NA', 'nan', 'Unknown'],
                skiprows=skiprows,
                on_bad_lines='warn'  # or 'error' if you want it strict
            )

            # Log columns after reading
            logging.info(f"Columns in {file} after reading: {df.columns.tolist()}")

            # Standardize column names to lowercase
            df.columns = [col.lower() for col in df.columns]

            if 'tax_id' in df.columns:
                df.rename(columns={'tax_id': 'taxid'}, inplace=True)
                logging.info(f"'tax_id' renamed to 'taxid' in {file}.")
            elif 'taxid' in df.columns:
                logging.info(f"'taxid' column is already correctly named in {file}.")
            else:
                logging.error(f"No 'tax_id' or 'taxid' column found in {file}.")
                raise KeyError(f"No 'tax_id' or 'taxid' column found in {file}.")

            # Replace 'Unknown' with '0' in taxid
            df['taxid'] = df['taxid'].replace('Unknown', '0')

            # Convert taxid to numeric
            df['taxid'] = pd.to_numeric(df['taxid'], errors='coerce')

            # Drop rows with NaN in 'taxid'
            before_drop = len(df)
            df.dropna(subset=['taxid'], inplace=True)
            after_drop = len(df)
            dropped = before_drop - after_drop
            if dropped > 0:
                logging.info(f"Dropped {dropped} rows with invalid taxid in {file}.")

            # Rename 'alignment_length' to 'align_length' if present
            if 'alignment_length' in df.columns:
                df.rename(columns={'alignment_length': 'align_length'}, inplace=True)
                logging.info(f"'alignment_length' renamed to 'align_length' in {file}.")

            # Derive sample name
            df['sample'] = get_sample_from_path(file)
            logging.info(f"Derived 'sample' for {file}: {df['sample'].unique()}")

            # Extract contig
            df['contig'] = df.apply(lambda row: extract_contig(row['query'], row['sample']), axis=1)
            logging.info(f"Extracted 'contig' for {file}.")

            # Build contig_key
            df['contig_key'] = (df['sample'] + '~' + df['contig']).str.lower().str.strip()
            logging.info(f"Created 'contig_key' for {file}.")

            # Convert numeric columns
            numeric_columns = ['identity', 'align_length', 'bit_score', 'evalue']
            for col in numeric_columns:
                if col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')

            # Drop rows with NaN in numeric fields (optional)
            before_drop_numeric = len(df)
            df = df.dropna(subset=numeric_columns)
            dropped_numeric = before_drop_numeric - len(df)
            if dropped_numeric > 0:
                logging.info(f"Dropped {dropped_numeric} rows with non-numeric values in {file}.")

            # Clean 'contig' by removing trailing dots and whitespace
            df['contig'] = df['contig'].str.rstrip('.').str.strip()

            dfs.append(df)

        except Exception as e:
            logging.error(f"Error reading {file}: {e}")
            continue

    combined_df = pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()
    logging.info(f"Combined DataFrame has {len(combined_df)} rows and {len(combined_df.columns)} columns.")
    logging.info(f"Columns in combined DataFrame: {combined_df.columns.tolist()}")
    return combined_df


def get_sample_from_path(path):
    """
    Extract the sample name from the file path.
    Assumes filenames are like: LCA_<SAMPLE>_modified_diamond_blast_nr_annotated.m9
    """
    filename = os.path.basename(path)
    prefix = "LCA_"
    suffix = "_modified_diamond_blast_nr_annotated.m9"
    if filename.startswith(prefix) and filename.endswith(suffix):
        sample = filename[len(prefix):-len(suffix)]
        return sample
    else:
        logging.warning(f"Filename {filename} does not match expected pattern. Assigning 'unknown_sample'.")
        return 'unknown_sample'


def extract_contig(query, sample):
    """
    Extract contig information by splitting the query string using '~'.
    Remove trailing dots if any.

    Parameters:
        query (str): The query string from BLAST output.
        sample (str): The sample name.

    Returns:
        str: Extracted contig information or 'unknown_contig'.
    """
    try:
        sample_from_query, contig = query.split('~', 1)
        if sample_from_query.lower() == sample.lower():
            contig = contig.rstrip('.').strip()  # Remove trailing dots and whitespace
            if contig:
                return contig
            else:
                logging.warning(f"Contig extracted from query '{query}' is empty after cleaning.")
                return 'unknown_contig'
        else:
            logging.warning(
                f"Sample in query '{sample_from_query}' does not match expected '{sample}'. Using 'unknown_contig'."
            )
            return 'unknown_contig'
    except ValueError:
        logging.warning(f"Query '{query}' does not contain '~'. Using 'unknown_contig'.")
        return 'unknown_contig'


# -------------------------------------------
# The rest of your functions remain unchanged
# -------------------------------------------

def load_cdhit_clusters(filename):
    """
    Load CD-HIT cluster information from a cluster file.

    Parameters:
        filename (str): Path to the CD-HIT cluster file.

    Returns:
        defaultdict: Cluster ID mapped to list of members.
    """
    clusters = defaultdict(list)
    cluster_id = None
    with open(filename, 'r') as file:
        for line in file:
            if line.startswith('>Cluster'):
                cluster_id = line.strip().split()[-1]
            else:
                if cluster_id is not None:
                    try:
                        member = parse_cdhit_row(line)
                        clusters[cluster_id].append(member)
                    except ValueError as e:
                        logging.error(f"Error parsing line: {line.strip()} - {e}")
    return clusters

def parse_cdhit_row(row):
    """
    Parse a single line from a CD-HIT cluster file.

    Parameters:
        row (str): Line from CD-HIT cluster file.

    Returns:
        dict: Parsed member information.
    """
    if '*' in row:
        # Reference member
        parts = row.strip().split()
        if len(parts) < 4:
            raise ValueError(f"Unexpected format in row: {row}")
        index, length_str, name, percent_id = parts[:4]
        is_ref = True
    else:
        # Non-reference member
        parts = row.strip().split()
        if len(parts) < 4:
            raise ValueError(f"Unexpected format in row: {row}")
        index, length_str, name, percent_id = parts[:4]
        is_ref = False

    # Clean and extract fields
    # Use regex to extract numeric length
    match = re.match(r'(\d+)', length_str)
    if match:
        length = int(match.group(1))
    else:
        logging.warning(f"Could not extract numeric length from '{length_str}' in row: {row}")
        length = 0  # Assign a default value or handle as needed

    name = name.strip('>').strip('.')

    # Extract sample and contig
    try:
        sample, contig = name.split('~')
    except ValueError:
        # Handle unexpected naming format
        sample, contig = 'unknown_sample', name  # Assign a default sample name

    return {
        'contig': contig,
        'sample': sample,
        'length': length,
        'is_ref': is_ref
    }

def get_cluster_rep(cluster):
    """
    Get the representative member of a cluster (the reference member).

    Parameters:
        cluster (list): List of cluster members.

    Returns:
        dict: Representative member.
    """
    for member in cluster:
        if member['is_ref']:
            return member
    # If no reference member, return the first member
    return cluster[0] if cluster else None

def merge_clusters_lca(clusters, lca_df):
    """
    Merge cluster information with LCA results.

    Parameters:
        clusters (defaultdict): Cluster information.
        lca_df (pd.DataFrame): LCA DataFrame.

    Returns:
        pd.DataFrame: Merged cluster and LCA information.
    """
    # Prepare cluster data
    cluster_ids = [int(id.strip()) for id in clusters.keys()]
    cluster_reps = [get_cluster_rep(clusters[str(id)]) for id in cluster_ids]
    cluster_lengths = [rep['length'] for rep in cluster_reps]
    cluster_sizes = [len(clusters[str(id)]) for id in cluster_ids]
    # Correct contig_key formatting using '~'
    cluster_contig_keys = [f"{rep['sample']}~{rep['contig']}".lower() for rep in cluster_reps]

    # Create DataFrame for clusters
    df = pd.DataFrame({
        'contig_key': cluster_contig_keys,
        'cluster': cluster_ids,
        'cluster_size': cluster_sizes,
        'contig_length': cluster_lengths
    })

    # Explicitly convert to string, then strip and lowercase
    df['contig_key'] = df['contig_key'].astype(str).str.strip().str.lower()

    # Debug: Log sample 'contig_key's
    logging.info("Sample 'contig_key's in clusters DataFrame:")
    logging.info(df['contig_key'].head())

    # Merge with LCA DataFrame
    merged_df = df.merge(lca_df, on='contig_key', how='inner')

    # Debug: Log the number of merged rows
    logging.info(f"Merged DataFrame has {len(merged_df)} rows.")

    if merged_df.empty:
        logging.warning("merged_df is empty after merging. Check 'contig_key' consistency.")
        print("merged_df is empty after merging. Please check 'contig_key' consistency between clusters and LCA DataFrames.", file=sys.stderr)
    else:
        logging.info("Successfully merged clusters with LCA results.")

    # Calculate align_percent
    merged_df['align_percent'] = np.round(100 * merged_df['align_length'] / merged_df['contig_length'], 1)

    # Use 'identity' as 'percent_identity'
    merged_df['percent_identity'] = np.round(merged_df['identity'], 1)

    # Define required columns (including additional columns)
    required_columns = [
        'cluster', 'cluster_size', 'contig_length', 'name', 'rank', 
        'kingdom', 'taxid', 'db', 'align_length', 'align_percent', 
        'percent_identity', 'bit_score', 'evalue', 'subject', 'sci_name',
        'com_names', 'superkingdom', 'phylum', 'family', 'lca_taxid', 'lca_name',
        'contig', 'sample', 'contig_key'
    ]

    # Check for missing columns
    missing_columns = [col for col in required_columns if col not in merged_df.columns]
    if missing_columns:
        logging.error(f"Missing columns in merged DataFrame: {missing_columns}")
        logging.info(f"Available columns: {merged_df.columns.tolist()}")
        raise KeyError(f"{missing_columns} not in index")

    # Subset the DataFrame to include only required columns
    merged_df = merged_df[required_columns]

    # Sort by cluster
    merged_df = merged_df.sort_values('cluster').reset_index(drop=True)

    return merged_df

def load_counts_data(counts_file):
    """
    Load and process the counts DataFrame.

    Parameters:
        counts_file (str): Path to the counts file.

    Returns:
        pd.DataFrame: Processed counts DataFrame with 'contig_key'.
    """
    try:
        # Read counts file with header, assuming first row is header
        counts = pd.read_csv(
            counts_file, 
            sep=',', 
            header=0,  # Assumes first row is header
            usecols=[0, 1, 2],  # Read 'contig', 'count', 'length'
            names=['contig', 'count', 'length'],  # Assign column names
            dtype={'contig': str, 'count': str, 'length': str},  # Read as strings initially
            na_values=['', 'NA', 'nan', 'Unknown'],
            low_memory=False
        )
        
        # Log a sample of the counts DataFrame
        logging.info("Raw Counts DataFrame sample:")
        logging.info(counts.head())
        
        # Convert 'count' to numeric, coercing errors to NaN
        counts['count'] = pd.to_numeric(counts['count'], errors='coerce')
        
        # Log the conversion result
        logging.info("Counts DataFrame after converting 'count' to numeric:")
        logging.info(counts.head())
        
        # Extract sample and contig from 'contig' column assuming 'sample~contig'
        counts[['sample', 'contig']] = counts['contig'].str.split('~', expand=True)
        
        # Handle cases where 'contig' does not contain '~'
        counts['contig'] = counts['contig'].fillna(counts['contig'])
        counts['sample'] = counts['sample'].fillna('unknown_sample')
        
        # Clean 'contig' by removing trailing dots and whitespace
        counts['contig'] = counts['contig'].str.rstrip('.').str.strip()
        
        # Create 'contig_key' as 'sample~contig', lowercase
        counts['contig_key'] = (counts['sample'].astype(str) + '~' + counts['contig'].astype(str)).str.lower().str.strip()
        
        # Debug: Log the processed counts DataFrame
        logging.info("Processed Counts DataFrame:")
        logging.info(counts.head())
        
        return counts
    except pd.errors.ParserError as e:
        logging.error(f"ParserError while reading counts file {counts_file}: {e}")
        print(f"ParserError while reading counts file {counts_file}: {e}", file=sys.stderr)
        raise e
    except Exception as e:
        logging.error(f"Unexpected error while processing counts file {counts_file}: {e}")
        print(f"Unexpected error while processing counts file {counts_file}: {e}", file=sys.stderr)
        raise e

def load_taxid2name(lca_df):
    """
    Create a taxid to name mapping from the LCA DataFrame.

    Parameters:
        lca_df (pd.DataFrame): LCA DataFrame.

    Returns:
        dict: Mapping from taxid to scientific name.
    """
    taxid2name = dict(zip(lca_df['taxid'], lca_df['name']))
    taxid2name[-1] = 'Ambiguous'  # Handle ambiguous taxids
    taxid2name[0] = 'Unknown'     # Handle 'Unknown' taxids
    return taxid2name

if __name__ == "__main__":
    # Example usage
    try:
        lca_df = load_lca(nr_dir)
        if lca_df.empty:
            logging.error("LCA DataFrame is empty. Exiting.")
            raise ValueError("LCA DataFrame is empty.")
        logging.info(f"LCA DataFrame loaded: {lca_df.shape[0]} rows")
        print(lca_df.head(3))
    except Exception as e:
        logging.error(f"Failed to load LCA DataFrame: {e}")
        sys.exit(1)
        
    try:
        # Load CD-HIT clusters
        clusters = load_cdhit_clusters('/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/Combined_filtered_500_SENZOR_clsuter.clstr')
        logging.info("Loaded CD-HIT clusters.")
    except Exception as e:
        logging.error(f"Failed to load CD-HIT clusters: {e}")
        raise e  # Let the exception propagate

    try:
        # Merge clusters with LCA
        merged_df = merge_clusters_lca(clusters, lca_df)
        logging.info("Successfully merged clusters with LCA results.")
    except KeyError as e:
        logging.error(f"KeyError during merging: {e}")
        print(f"Error during merging: {e}", file=sys.stderr)
        raise e  # Let the exception propagate
    except Exception as e:
        logging.error(f"Unexpected error during merging: {e}")
        print(f"Unexpected error during merging: {e}", file=sys.stderr)
        raise e  # Let the exception propagate

    try:
        # Create clust2kingdom dictionary
        clust2kingdom = dict(zip(merged_df['cluster'], merged_df['kingdom']))
        logging.info("Created clust2kingdom dictionary.")

        # Save clust2kingdom to CSV
        merged_df[['cluster', 'kingdom']].to_csv('/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/cluster_kingdoms.csv', index=False)
        logging.info("Saved cluster_kingdoms.csv successfully.")

        # Print sample entries of clust2kingdom
        print("clust2kingdom Dictionary Sample Entries:")
        for cluster_id, kingdom in list(clust2kingdom.items())[:5]:
            print(f"Cluster {cluster_id}: Kingdom {kingdom}")
    except Exception as e:
        logging.error(f"Error creating clust2kingdom: {e}")
        print(f"Error creating clust2kingdom: {e}", file=sys.stderr)
        raise e  # Let the exception propagate

    try:
        # Load and process the counts DataFrame
        counts_file_path = '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/combined_SENSOR_csvs/1Modified_CSV/merged_contig_stats.csv'
        counts_df = load_counts_data(counts_file_path)
        logging.info("Loaded and processed counts DataFrame.")
    except Exception as e:
        logging.error(f"Failed to load counts DataFrame: {e}")
        print(f"Error loading counts DataFrame: {e}", file=sys.stderr)
        raise e  # Let the exception propagate

    try:
        # Merge counts with merged_df on 'contig_key'
        final_df = merged_df.merge(counts_df[['contig_key', 'count']], on='contig_key', how='left')
        logging.info("Merged counts with merged_df to create final_df.")

        # Convert 'count' to numeric if not already (redundant but safe)
        final_df['count'] = pd.to_numeric(final_df['count'], errors='coerce')

        # Apply filtering: Remove rows where 'count' is NaN
        initial_count = len(final_df)
        final_df = final_df.dropna(subset=['count'])
        final_count = len(final_df)
        removed_rows = initial_count - final_count
        if removed_rows > 0:
            logging.info(f"Filtered out {removed_rows} rows without 'count' value.")
            print(f"Filtered out {removed_rows} rows without 'count' value.")

        # Remove rows where 'count' is zero or negative:
        before_zero_drop = len(final_df)
        final_df = final_df[final_df['count'] > 0]
        after_zero_drop = len(final_df)
        removed_zero = before_zero_drop - after_zero_drop
        if removed_zero > 0:
            logging.info(f"Filtered out {removed_zero} rows with 'count' <= 0.")
            print(f"Filtered out {removed_zero} rows with 'count' <= 0.")

        # Verify if final_df is empty after filtering
        if final_df.empty:
            logging.warning("final_df is empty after filtering out rows without 'count'.")
            print("final_df is empty after filtering out rows without 'count'.", file=sys.stderr)
        else:
            # Display the final merged DataFrame with counts
            print("Final Merged DataFrame with Counts (after filtering):")
            print(final_df.head())

            # Save the final DataFrame to CSV
            final_df.to_csv('/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/final_merged_with_counts.csv', index=False)
            logging.info("Saved final_merged_with_counts.csv successfully.")
    except Exception as e:
        logging.error(f"Error merging counts with merged_df: {e}")
        print(f"Error merging counts with merged_df: {e}", file=sys.stderr)
        raise e  # Let the exception propagate


Sample of combined nr_df before dropping NaNs in 'taxid':
                                               query  \
0  AIAMA_GOT005_OS_S43_673459~NODE_10000_length_3...   
1  AIAMA_GOT005_OS_S43_673459~NODE_10003_length_3...   
2  AIAMA_GOT005_OS_S43_673459~NODE_10010_length_3...   
3  AIAMA_GOT005_OS_S43_673459~NODE_10014_length_3...   
4  AIAMA_GOT005_OS_S43_673459~NODE_10015_length_3...   

                                             subject  identity  align_length  \
0  MCU8647632.1 hypothetical protein [Escherichia...     100.0            73   
1                       QLI46611.1 VP4 [Rotavirus A]      98.9            88   
2  XP_041588126.1 tyrosine-protein kinase Srms-li...      96.9            65   
3  XP_048962668.1 uncharacterized protein LOC1183...      90.3            31   
4  WP_288535963.1 M42 family metallopeptidase [un...     100.0            95   

   mismatches  gap_opens  q_start  q_end  s_start  s_end  ...  \
0           0          0      316     98        3     75  .



                                               query  \
0  AIAMA_GOT005_OS_S43_673459~NODE_10000_length_3...   
1  AIAMA_GOT005_OS_S43_673459~NODE_10003_length_3...   
2  AIAMA_GOT005_OS_S43_673459~NODE_10010_length_3...   

                                             subject  identity  align_length  \
0  MCU8647632.1 hypothetical protein [Escherichia...     100.0            73   
1                       QLI46611.1 VP4 [Rotavirus A]      98.9            88   
2  XP_041588126.1 tyrosine-protein kinase Srms-li...      96.9            65   

   mismatches  gap_opens  q_start  q_end  s_start  s_end  ...  \
0           0          0      316     98        3     75  ...   
1           1          0       53    316      600    687  ...   
2           2          0      219     25      247    311  ...   

               family  lca_taxid          lca_name  \
0  Enterobacteriaceae      562.0  Escherichia coli   
1          Reoviridae    28875.0       Rotavirus A   
2                 NaN        Na

In [55]:
import os
from glob import glob
import pandas as pd
import numpy as np
from collections import defaultdict
from statistics import mode
from Bio import Entrez
from ete3 import NCBITaxa
import logging
import sys
import re

# Configure logging
from logging.handlers import RotatingFileHandler
handler = RotatingFileHandler('util.log', maxBytes=5*1024*1024, backupCount=2)
logging.basicConfig(level=logging.INFO, handlers=[handler],
                    format='%(asctime)s - %(levelname)s - %(message)s')

Entrez.email = "ifeanyi.omah@ed.ac.uk"
ncbi = NCBITaxa()

# Define column names for NR
nr_column_names = [
    'query', 'lca_taxid', 'best_hit_sciname', 'subject_title',
    'identity', 'aligned_bases', 'contig_length', 'superkingdom',
    'bitscore', 'tax_id'
]import os
import pandas as pd
import numpy as np
from glob import glob
from collections import defaultdict
from statistics import mode
from Bio import Entrez
from ete3 import NCBITaxa
import logging
import sys
import re
from logging.handlers import RotatingFileHandler

handler = RotatingFileHandler('util.log', maxBytes=5*1024*1024, backupCount=2)
logging.basicConfig(level=logging.INFO, handlers=[handler],
                    format='%(asctime)s - %(levelname)s - %(message)s')

Entrez.email = "ifeanyi.omah@ed.ac.uk"
ncbi = NCBITaxa()

valid_lineages = set()

# Updated column names to match your .m9 file
nr_column_names = [
    'query', 'subject', 'identity', 'alignment_length', 'mismatches', 'gap_opens',
    'q_start', 'q_end', 's_start', 's_end', 'evalue', 'bit_score', 'tax_id',
    'sci_name', 'com_names', 'superkingdom', 'phylum', 'family',
    'lca_taxid', 'lca_name'
]

# --------------------------------------------------------------------
# 1) Change nr_dir to where your NR LCA files actually live
#    If these files are named like: LCA_<SAMPLE>_modified_diamond_blast_nr_annotated.m9
#    then we can glob them below.
# --------------------------------------------------------------------
nr_dir = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_Blast_NR/LCA/"

def load_lca(nr_directory):
    """
    Load and process LCA results from NR files in the specified directory.
    Dynamically glob any pattern you want.
    """
    # ----------------------------------------------------------------
    # 2) Dynamically pick up files named LCA_*_modified_diamond_blast_nr_annotated.m9
    #    Adjust the pattern to match your actual file naming.
    # ----------------------------------------------------------------
    pattern = "LCA_*_modified_diamond_blast_nr_annotated.m9"
    nr_file_paths = glob(os.path.join(nr_directory, pattern))
    logging.info(f"Found {len(nr_file_paths)} NR files matching '{pattern}' in {nr_directory}.")

    if not nr_file_paths:
        logging.error(f"No files matching {pattern} found in directory: {nr_directory}")
        raise FileNotFoundError(f"No files matching {pattern} found in directory: {nr_directory}")

    # Process each BLAST output file
    nr_df = load_blast_output(nr_file_paths, database='nr')

    # Remove rows with missing taxid
    initial_nr = len(nr_df)
    nr_df = nr_df.dropna(subset=['taxid'])  # <- Make sure 'taxid' actually exists
    final_nr = len(nr_df)
    dropped_nr = initial_nr - final_nr
    if dropped_nr > 0:
        logging.info(f"Dropped {dropped_nr} rows with missing taxid from nr_df.")

    # Convert taxid to integer
    nr_df['taxid'] = nr_df['taxid'].astype(int)

    # Build dictionaries for name/rank
    taxids = set(nr_df.taxid)
    taxids = {t for taxid in taxids for t in ncbi.get_lineage(taxid) if t != 0}

    taxid2name = ncbi.get_taxid_translator(taxids)
    taxid2name[-1] = 'Ambiguous'
    taxid2name[0] = 'Unknown'
    taxid2rank = ncbi.get_rank(taxids)

    # Collect valid lineages
    taxids_in_df = set(nr_df['taxid'])
    for tx in taxids_in_df:
        try:
            lineage = ncbi.get_lineage(tx)
            for t in lineage:
                if t != 0:
                    valid_lineages.add(t)
        except ValueError:
            logging.warning(f"Skipping {tx}: taxid not found in local DB.")

    # Restrict to valid lineages
    taxids = valid_lineages
    taxid2name = ncbi.get_taxid_translator(taxids)
    
    def taxid2kingdom(taxid):
        if taxid == 0:
            return 'Unknown'
        lineage_dict = ncbi.get_rank(ncbi.get_lineage(taxid))
        for k, v in lineage_dict.items():
            if v == 'superkingdom':
                kingdom = taxid2name.get(k, 'Unknown')
                return kingdom
        return 'Unknown'

    # Map taxid -> name, rank, kingdom
    nr_df['name'] = nr_df['taxid'].map(taxid2name)
    nr_df['rank'] = nr_df['taxid'].map(taxid2rank)
    nr_df['kingdom'] = nr_df['taxid'].apply(taxid2kingdom)
    nr_df['db'] = 'nr'  # Since this is the NR dataset

    return nr_df

def load_blast_output(files, database='nr'):
    """
    Load BLAST output files into a single DataFrame.
    """
    if database == 'nr':
        column_names = nr_column_names
    else:
        raise ValueError(f"Unknown database type: {database}. Expected 'nr'.")

    dfs = []
    for file in files:
        try:
            # Quick check of columns in the first line
            with open(file, 'r') as f:
                first_line = f.readline()
                columns_in_file = len(first_line.strip().split('\t'))
                expected_columns = len(column_names)
                if columns_in_file != expected_columns:
                    logging.warning(f"{file} has {columns_in_file} columns; expected {expected_columns}.")

            # Potentially skip a header if present
            with open(file, 'r') as f:
                first_line = f.readline()
                if 'sci_name' in first_line.lower() or 'lca_taxid' in first_line.lower():
                    skiprows = 1
                else:
                    skiprows = 0

            df = pd.read_csv(
                file,
                sep='\t',
                header=None,
                names=column_names,
                na_values=['', 'NA', 'nan', 'Unknown'],
                skiprows=skiprows,
                on_bad_lines='warn'
            )

            if 'tax_id' not in df.columns:
                # If your data is guaranteed to have 'tax_id', raise KeyError.
                logging.error(f"'tax_id' column is missing in {file}.")
                raise KeyError(f"'tax_id' column is missing in {file}.")

            # Rename 'tax_id' -> 'taxid'
            df.rename(columns={'tax_id': 'taxid'}, inplace=True)
            df['taxid'] = df['taxid'].replace('Unknown', '0')

            # Convert taxid to numeric
            df['taxid'] = pd.to_numeric(df['taxid'], errors='coerce')

            # Drop rows with NaN in 'taxid'
            before_drop = len(df)
            df.dropna(subset=['taxid'], inplace=True)
            after_drop = len(df)
            if (d := (before_drop - after_drop)) > 0:
                logging.info(f"Dropped {d} rows with invalid taxid in {file}.")

            # alignment_length -> align_length
            if 'alignment_length' in df.columns:
                df.rename(columns={'alignment_length': 'align_length'}, inplace=True)
                
                # Build contig_key
                df['contig_key'] = (df['sample'] + '~' + df['contig']).str.lower()

                # Convert numeric columns
                numeric_columns = ['identity', 'align_length', 'bit_score', 'evalue']
                for col in numeric_columns:
                    if col in df.columns:
                        df[col] = pd.to_numeric(df[col], errors='coerce')

                # Drop rows with NaN in numeric fields (optional)
                before_drop_numeric = len(df)
                df = df.dropna(subset=numeric_columns)
                dropped_numeric = before_drop_numeric - len(df)
                if dropped_numeric > 0:
                    logging.info(f"Dropped {dropped_numeric} rows with non-numeric in {file}.")

            # Build a contig_key if needed
            # For example: df['contig_key'] = (df['query'] + '~' + df['subject']).str.lower()
            # Or keep your existing method if you rely on a different scheme

            dfs.append(df)

        except Exception as e:
            logging.error(f"Error reading {file}: {e}")
            continue

    combined_df = pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()
    return combined_df


# ----------------------------------------------------------------------
# Below are the rest of your original functions (merge_clusters_lca, etc.)
# You can keep them exactly as is. The main difference is that now your
# `load_lca()` is truly dynamic, picking up all the .m9 files in nr_dir.
# ----------------------------------------------------------------------


def extract_contig(query, sample):
    """
    Extract contig information by removing the sample prefix from the query.

    Parameters:
        query (str): The query string from BLAST output.
        sample (str): The sample name.

    Returns:
        str: Extracted contig information or 'unknown_contig'.
    """
    prefix = sample + '_'
    if query.startswith(prefix):
        return query[len(prefix):]
    else:
        logging.warning(f"Query '{query}' does not start with sample prefix '{prefix}'. Using 'unknown_contig'.")
        return 'unknown_contig'

# -------------------------------------------
# The rest of your functions remain unchanged
# -------------------------------------------


def load_cdhit_clusters(filename):
    """
    Load CD-HIT cluster information from a cluster file.

    Parameters:
        filename (str): Path to the CD-HIT cluster file.

    Returns:
        defaultdict: Cluster ID mapped to list of members.
    """
    clusters = defaultdict(list)
    cluster_id = None
    with open(filename, 'r') as file:
        for line in file:
            if line.startswith('>Cluster'):
                cluster_id = line.strip().split()[-1]
            else:
                if cluster_id is not None:
                    try:
                        member = parse_cdhit_row(line)
                        clusters[cluster_id].append(member)
                    except ValueError as e:
                        logging.error(f"Error parsing line: {line.strip()} - {e}")
    return clusters

def parse_cdhit_row(row):
    """
    Parse a single line from a CD-HIT cluster file.

    Parameters:
        row (str): Line from CD-HIT cluster file.

    Returns:
        dict: Parsed member information.
    """
    if '*' in row:
        # Reference member
        parts = row.strip().split()
        if len(parts) < 4:
            raise ValueError(f"Unexpected format in row: {row}")
        index, length_str, name, percent_id = parts[:4]
        is_ref = True
    else:
        # Non-reference member
        parts = row.strip().split()
        if len(parts) < 4:
            raise ValueError(f"Unexpected format in row: {row}")
        index, length_str, name, percent_id = parts[:4]
        is_ref = False

    # Clean and extract fields
    # Use regex to extract numeric length
    match = re.match(r'(\d+)', length_str)
    if match:
        length = int(match.group(1))
    else:
        logging.warning(f"Could not extract numeric length from '{length_str}' in row: {row}")
        length = 0  # Assign a default value or handle as needed

    name = name.strip('>').strip('.')

    # Extract sample and contig
    try:
        sample, contig = name.split('~')
    except ValueError:
        # Handle unexpected naming format
        sample, contig = 'unknown_sample', name  # Assign a default sample name

    return {
        'contig': contig,
        'sample': sample,
        'length': length,
        'is_ref': is_ref
    }

def get_cluster_rep(cluster):
    """
    Get the representative member of a cluster (the reference member).

    Parameters:
        cluster (list): List of cluster members.

    Returns:
        dict: Representative member.
    """
    for member in cluster:
        if member['is_ref']:
            return member
    # If no reference member, return the first member
    return cluster[0] if cluster else None

def merge_clusters_lca(clusters, lca_df):
    """
    Merge cluster information with LCA results.

    Parameters:
        clusters (defaultdict): Cluster information.
        lca_df (pd.DataFrame): LCA DataFrame.

    Returns:
        pd.DataFrame: Merged cluster and LCA information.
    """
    # Prepare cluster data
    cluster_ids = [int(id.strip()) for id in clusters.keys()]
    cluster_reps = [get_cluster_rep(clusters[str(id)]) for id in cluster_ids]
    cluster_lengths = [rep['length'] for rep in cluster_reps]
    cluster_sizes = [len(clusters[str(id)]) for id in cluster_ids]
    # Correct contig_key formatting using '~'
    cluster_contig_keys = [f"{rep['sample']}~{rep['contig']}".lower() for rep in cluster_reps]

    # Create DataFrame for clusters
    df = pd.DataFrame({
        'contig_key': cluster_contig_keys,
        'cluster': cluster_ids,
        'cluster_size': cluster_sizes,
        'contig_length': cluster_lengths
    })

    # Explicitly convert to string, then strip and lowercase
    df['contig_key'] = df['contig_key'].astype(str).str.strip().str.lower()

    # Debug: Log sample 'contig_key's
    logging.info("Sample 'contig_key's in clusters DataFrame:")
    logging.info(df['contig_key'].head())

    # Merge with LCA DataFrame
    merged_df = df.merge(lca_df, on='contig_key', how='inner')

    # Debug: Log the number of merged rows
    logging.info(f"Merged DataFrame has {len(merged_df)} rows.")

    if merged_df.empty:
        logging.warning("merged_df is empty after merging. Check 'contig_key' consistency.")
        print("merged_df is empty after merging. Please check 'contig_key' consistency between clusters and LCA DataFrames.", file=sys.stderr)
    else:
        logging.info("Successfully merged clusters with LCA results.")

    # Calculate align_percent
    merged_df['align_percent'] = np.round(100 * merged_df['align_length'] / merged_df['contig_length'], 1)

    # Use 'identity' as 'percent_identity'
    merged_df['percent_identity'] = np.round(merged_df['identity'], 1)

    # Define required columns (including additional columns)
    required_columns = [
        'cluster', 'cluster_size', 'contig_length', 'name', 'rank', 
        'kingdom', 'taxid', 'db', 'align_length', 'align_percent', 
        'percent_identity', 'bit_score', 'evalue', 'subject', 'sci_name',
        'com_names', 'superkingdom', 'phylum', 'family', 'lca_taxid', 'lca_name',
        'contig', 'sample', 'contig_key'
    ]

    # Check for missing columns
    missing_columns = [col for col in required_columns if col not in merged_df.columns]
    if missing_columns:
        logging.error(f"Missing columns in merged DataFrame: {missing_columns}")
        logging.info(f"Available columns: {merged_df.columns.tolist()}")
        raise KeyError(f"{missing_columns} not in index")

    # Subset the DataFrame to include only required columns
    merged_df = merged_df[required_columns]

    # Sort by cluster
    merged_df = merged_df.sort_values('cluster').reset_index(drop=True)

    return merged_df

def load_counts_data(counts_file):
    """
    Load and process the counts DataFrame.

    Parameters:
        counts_file (str): Path to the counts file.

    Returns:
        pd.DataFrame: Processed counts DataFrame with 'contig_key'.
    """
    try:
        # Read counts file with header, assuming first row is header
        counts = pd.read_csv(
            counts_file, 
            sep=',', 
            header=0,  # Assumes first row is header
            usecols=[0, 1, 2],  # Read 'contig', 'count', 'length'
            names=['contig', 'count', 'length'],  # Assign column names
            dtype={'contig': str, 'count': str, 'length': str},  # Read as strings initially
            na_values=['', 'NA', 'nan', 'Unknown'],
            low_memory=False
        )
        
        # Log a sample of the counts DataFrame
        logging.info("Raw Counts DataFrame sample:")
        logging.info(counts.head())
        
        # Convert 'count' to numeric, coercing errors to NaN
        counts['count'] = pd.to_numeric(counts['count'], errors='coerce')
        
        # Log the conversion result
        logging.info("Counts DataFrame after converting 'count' to numeric:")
        logging.info(counts.head())
        
        # Extract sample and contig from 'contig' column assuming 'sample~contig'
        counts[['sample', 'contig']] = counts['contig'].str.split('~', expand=True)
        
        # Handle cases where 'contig' does not contain '~'
        counts['contig'] = counts['contig'].fillna(counts['contig'])
        counts['sample'] = counts['sample'].fillna('unknown_sample')
        
        # Create 'contig_key' as 'sample~contig', lowercase
        counts['contig_key'] = (counts['sample'].astype(str) + '~' + counts['contig'].astype(str)).str.strip().str.lower()
        
        # Debug: Log the processed counts DataFrame
        logging.info("Processed Counts DataFrame:")
        logging.info(counts.head())
        
        return counts
    except pd.errors.ParserError as e:
        logging.error(f"ParserError while reading counts file {counts_file}: {e}")
        print(f"ParserError while reading counts file {counts_file}: {e}", file=sys.stderr)
        raise e
    except Exception as e:
        logging.error(f"Unexpected error while processing counts file {counts_file}: {e}")
        print(f"Unexpected error while processing counts file {counts_file}: {e}", file=sys.stderr)
        raise e

def load_taxid2name(lca_df):
    """
    Create a taxid to name mapping from the LCA DataFrame.

    Parameters:
        lca_df (pd.DataFrame): LCA DataFrame.

    Returns:
        dict: Mapping from taxid to scientific name.
    """
    taxid2name = dict(zip(lca_df['taxid'], lca_df['name']))
    taxid2name[-1] = 'Ambiguous'  # Handle ambiguous taxids
    taxid2name[0] = 'Unknown'     # Handle 'Unknown' taxids
    return taxid2name

if __name__ == "__main__":
    # Example usage
    try:
        lca_df = load_lca(nr_dir)
        if lca_df.empty:
            logging.error("LCA DataFrame is empty. Exiting.")
            raise ValueError("LCA DataFrame is empty.")
        logging.info(f"LCA DataFrame loaded: {lca_df.shape[0]} rows")
        print(lca_df.head(3))
    except Exception as e:
        logging.error(f"Failed to load LCA DataFrame: {e}")
        sys.exit(1)
        
    try:
        # Load CD-HIT clusters
        clusters = load_cdhit_clusters('/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/Combined_filtered_500_SENZOR_clsuter.clstr')
        logging.info("Loaded CD-HIT clusters.")
    except Exception as e:
        logging.error(f"Failed to load CD-HIT clusters: {e}")
        raise e  # Let the exception propagate

    try:
        # Merge clusters with LCA
        merged_df = merge_clusters_lca(clusters, lca_df)
        logging.info("Successfully merged clusters with LCA results.")
    except KeyError as e:
        logging.error(f"KeyError during merging: {e}")
        print(f"Error during merging: {e}", file=sys.stderr)
        raise e  # Let the exception propagate
    except Exception as e:
        logging.error(f"Unexpected error during merging: {e}")
        print(f"Unexpected error during merging: {e}", file=sys.stderr)
        raise e  # Let the exception propagate

    try:
        # Create clust2kingdom dictionary
        clust2kingdom = dict(zip(merged_df['cluster'], merged_df['kingdom']))
        logging.info("Created clust2kingdom dictionary.")

        # Save clust2kingdom to CSV
        merged_df[['cluster', 'kingdom']].to_csv('/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/cluster_kingdoms.csv', index=False)
        logging.info("Saved cluster_kingdoms.csv successfully.")

        # Print sample entries of clust2kingdom
        print("clust2kingdom Dictionary Sample Entries:")
        for cluster_id, kingdom in list(clust2kingdom.items())[:5]:
            print(f"Cluster {cluster_id}: Kingdom {kingdom}")
    except Exception as e:
        logging.error(f"Error creating clust2kingdom: {e}")
        print(f"Error creating clust2kingdom: {e}", file=sys.stderr)
        raise e  # Let the exception propagate

    try:
        # Load and process the counts DataFrame
        counts_file_path = '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/combined_SENSOR_csvs/1Modified_CSV/merged_contig_stats.csv'
        counts_df = load_counts_data(counts_file_path)
        logging.info("Loaded and processed counts DataFrame.")
    except Exception as e:
        logging.error(f"Failed to load counts DataFrame: {e}")
        print(f"Error loading counts DataFrame: {e}", file=sys.stderr)
        raise e  # Let the exception propagate

    try:
        # Merge counts with merged_df on 'contig_key'
        final_df = merged_df.merge(counts_df[['contig_key', 'count']], on='contig_key', how='left')
        logging.info("Merged counts with merged_df to create final_df.")

        # Convert 'count' to numeric if not already (redundant but safe)
        final_df['count'] = pd.to_numeric(final_df['count'], errors='coerce')

        # Apply filtering: Remove rows where 'count' is NaN
        initial_count = len(final_df)
        final_df = final_df.dropna(subset=['count'])
        final_count = len(final_df)
        removed_rows = initial_count - final_count
        if removed_rows > 0:
            logging.info(f"Filtered out {removed_rows} rows without 'count' value.")
            print(f"Filtered out {removed_rows} rows without 'count' value.")

        # Remove rows where 'count' is zero or negative:
        logging.info("Filtered out rows with 'count' <= 0.")

        # Verify if final_df is empty after filtering
        if final_df.empty:
            logging.warning("final_df is empty after filtering out rows without 'count'.")
            print("final_df is empty after filtering out rows without 'count'.", file=sys.stderr)
        else:
            # Display the final merged DataFrame with counts
            print("Final Merged DataFrame with Counts (after filtering):")
            print(final_df.head())

            # Save the final DataFrame to CSV
            final_df.to_csv('/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/final_merged_with_counts.csv', index=False)
            logging.info("Saved final_merged_with_counts.csv successfully.")
    except Exception as e:
        logging.error(f"Error merging counts with merged_df: {e}")
        print(f"Error merging counts with merged_df: {e}", file=sys.stderr)
        raise e  # Let the exception propagate

    # ...the rest of your pipeline (e.g., load CD-HIT clusters, merge, etc.)
    # ...


# Suppose you no longer have NT files or simply skip them
nt_lca_files = []

def load_lca(lca_directory, nr_files):
    """
    Simplified function that only loads NR LCA files.
    """
    nr_file_paths = [os.path.join(lca_directory, filename) for filename in nr_files]
    # Load BLAST outputs with NR database
    nr_df = load_blast_output(nr_file_paths, database='nr')
    nr_df = nr_df.dropna(subset=['taxid'])  # remove rows with missing taxid
    nr_df['taxid'] = nr_df['taxid'].astype(int)

    # (Optionally do lineage lookups, merges, etc.)
    return nr_df

def load_blast_output(files, database='nr'):
    """
    Loads BLAST output files into a single DataFrame.
    """
    if database == 'nr':
        column_names = nr_column_names
    else:
        raise ValueError("Unknown DB")

    dfs = []
    for file in files:
        try:
            df = pd.read_csv(
                file,
                sep='\t',
                header=None,
                names=column_names,
                na_values=['', 'NA', 'nan', 'Unknown'],
                on_bad_lines='warn'
            )
            if 'tax_id' in df.columns:
                df.rename(columns={'tax_id': 'taxid'}, inplace=True)
            df['taxid'] = df['taxid'].replace('Unknown', '0')
            df['taxid'] = pd.to_numeric(df['taxid'], errors='coerce')
            df.dropna(subset=['taxid'], inplace=True)

            # You can further process the sample/contig name here:
            df['sample'] = [os.path.basename(file).replace('_modified_diamond_blast_nr_annotated.m9','')] * len(df)
            df['contig'] = df['query']  # or other logic

            # Create contig_key
            df['contig_key'] = df['sample'].str.lower() + "_" + df['contig'].str.lower()

            dfs.append(df)
        except Exception as e:
            logging.error(f"Error reading {file}: {e}")
            continue

    combined_df = pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()
    return combined_df

# Make sure these helper functions exist in your code
def load_cdhit_clusters(file_path):
    # ... your implementation ...
    return {}

def merge_clusters_lca(clusters, lca_df):
    # ... your implementation ...
    return pd.DataFrame()

def load_counts_data(file_path):
    # ... your implementation ...
    return pd.DataFrame()

if __name__ == "__main__":
    lca_dir = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_Blast_NR/LCA"
    
    # Dynamically fetch .m9 files that match your naming pattern
    nr_lca_files = glob(os.path.join(lca_dir, "LCA_*_modified_diamond_blast_nr_annotated.m9"))
    print("Found NR LCA files:", nr_lca_files)

    # Now load and process only the NR files
    try:
        lca_df = load_lca(lca_dir, [os.path.basename(f) for f in nr_lca_files])
        print("NR LCA DataFrame shape:", lca_df.shape)
        print(lca_df.head())
    except Exception as e:
        logging.error(f"Failed to load LCA DataFrame: {e}")
        sys.exit(1)

    # --- Next block (same indentation as above) ---
    try:
        # Load CD-HIT clusters
        clusters = load_cdhit_clusters(
            '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/Combined_filtered_500_SENZOR_clsuter.clstr'
        )
        logging.info("Loaded CD-HIT clusters.")
    except Exception as e:
        logging.error(f"Failed to load CD-HIT clusters: {e}")
        raise e

    try:
        # Merge clusters with LCA
        merged_df = merge_clusters_lca(clusters, lca_df)
        logging.info("Successfully merged clusters with LCA results.")
    except KeyError as e:
        logging.error(f"KeyError during merging: {e}")
        print(f"Error during merging: {e}", file=sys.stderr)
        raise e
    except Exception as e:
        logging.error(f"Unexpected error during merging: {e}")
        print(f"Unexpected error during merging: {e}", file=sys.stderr)
        raise e

    try:
        # Create clust2kingdom dictionary
        clust2kingdom = dict(zip(merged_df['cluster'], merged_df['kingdom']))
        logging.info("Created clust2kingdom dictionary.")

        # Save clust2kingdom to CSV
        merged_df[['cluster', 'kingdom']].to_csv(
            '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/cluster_kingdoms.csv',
            index=False
        )
        logging.info("Saved cluster_kingdoms.csv successfully.")

        # Print sample entries of clust2kingdom
        print("clust2kingdom Dictionary Sample Entries:")
        for cluster_id, kingdom in list(clust2kingdom.items())[:5]:
            print(f"Cluster {cluster_id}: Kingdom {kingdom}")
    except Exception as e:
        logging.error(f"Error creating clust2kingdom: {e}")
        print(f"Error creating clust2kingdom: {e}", file=sys.stderr)
        raise e

    try:
        # Load and process the counts DataFrame
        counts_file_path = '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/bowtie_csp_counts_1000.txt'
        counts_df = load_counts_data(counts_file_path)
        logging.info("Loaded and processed counts DataFrame.")
    except Exception as e:
        logging.error(f"Failed to load counts DataFrame: {e}")
        print(f"Error loading counts DataFrame: {e}", file=sys.stderr)
        raise e

    try:
        # Merge counts with merged_df on 'contig_key'
        final_df = merged_df.merge(counts_df[['contig_key', 'count']], on='contig_key', how='left')
        logging.info("Merged counts with merged_df to create final_df.")

        # Apply filtering: Remove rows where 'count' is NaN
        initial_count = len(final_df)
        final_df = final_df.dropna(subset=['count'])
        final_count = len(final_df)
        removed_rows = initial_count - final_count
        if removed_rows > 0:
            logging.info(f"Filtered out {removed_rows} rows without 'count' value.")
            print(f"Filtered out {removed_rows} rows without 'count' value.")

        # Optionally remove rows where 'count' is zero or negative
        final_df = final_df[final_df['count'] > 0]
        logging.info("Filtered out rows with 'count' <= 0.")

        # Verify if final_df is empty after filtering
        if final_df.empty:
            logging.warning("final_df is empty after filtering out rows without 'count'.")
            print("final_df is empty after filtering out rows without 'count'.", file=sys.stderr)
        else:
            # Display the final merged DataFrame with counts
            print("Final Merged DataFrame with Counts (after filtering):")
            print(final_df.head())

            # Save the final DataFrame to CSV
            final_df.to_csv(
                '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/final_merged_with_counts.csv',
                index=False
            )
            logging.info("Saved final_merged_with_counts.csv successfully.")
    except Exception as e:
        logging.error(f"Error merging counts with merged_df: {e}")
        print(f"Error merging counts with merged_df: {e}", file=sys.stderr)
        raise e



Found NR LCA files: ['/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_Blast_NR/LCA/LCA_AIAMA_GOT005_OS_S43_673459_modified_diamond_blast_nr_annotated.m9', '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_Blast_NR/LCA/LCA_AIAMA_GOT006_RS_S44_673460_modified_diamond_blast_nr_annotated.m9', '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_Blast_NR/LCA/LCA_AIAMA_GOT007_OS_S45_673461_modified_diamond_blast_nr_annotated.m9', '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_Blast_NR/LCA/LCA_AIAMA_LIZ002_OS_S50_673462_modified_diamond_blast_nr_annotated.m9', '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_Blast_NR/LCA/LCA_AIAMA_PIG003_RS_S51_673463_modified_diamond_blast_nr_annotated.m9', '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_Blast_NR/LCA/LCA_AIAMA_SHP004_OS_S46_673464_modified_diamond_blast_n

  df = pd.read_csv(
  df = pd.read_csv(
  df = pd.read_csv(
  df = pd.read_csv(
  df = pd.read_csv(
  df = pd.read_csv(
  df = pd.read_csv(
  df = pd.read_csv(
  df = pd.read_csv(
  df = pd.read_csv(
  df = pd.read_csv(
  df = pd.read_csv(
  df = pd.read_csv(


AttributeError: 'tuple' object has no attribute 'tb_frame'

In [66]:
import pandas as pd
import numpy as np
from glob import glob
from collections import defaultdict
from statistics import mode
from Bio import Entrez
from ete3 import NCBITaxa
import logging
import sys
import re
from logging.handlers import RotatingFileHandler

handler = RotatingFileHandler('util.log', maxBytes=5*1024*1024, backupCount=2)
logging.basicConfig(level=logging.INFO, handlers=[handler], 
                    format='%(asctime)s - %(levelname)s - %(message)s')

# Set Entrez email
Entrez.email = "ifeanyi.omah@ed.ac.uk"

# Initialize NCBI Taxonomy
ncbi = NCBITaxa()
# Optionally: ncbi.update_taxonomy_database()

valid_lineages = set()

# Updated column names to match your .m9 file
nr_column_names = [
    'query', 'subject', 'identity', 'alignment_length', 'mismatches', 'gap_opens',
    'q_start', 'q_end', 's_start', 's_end', 'evalue', 'bit_score', 'tax_id',
    'sci_name', 'com_names', 'superkingdom', 'phylum', 'family',
    'lca_taxid', 'lca_name'
]

# Define NR directory
nr_dir = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_Blast_NR/LCA/"

def load_lca(nr_directory):
    """
    Load and process LCA results from NR files in the specified directory.

    Parameters:
        nr_directory (str): Path to the directory containing NR LCA files.

    Returns:
        pd.DataFrame: Processed LCA DataFrame.
    """
    # CHANGED: Search specifically for 'LCA_*.m9' files
    nr_file_paths = glob(f"{nr_directory}/LCA_*.m9")
    logging.info(f"Found {len(nr_file_paths)} NR files matching 'LCA_*.m9' in {nr_directory}.")

    if not nr_file_paths:
        logging.error(f"No files matching LCA_*.m9 found in directory: {nr_directory}")
        raise FileNotFoundError(f"No files matching LCA_*.m9 found in directory: {nr_directory}")

    # Process each BLAST output file
    nr_df = load_blast_output(nr_file_paths, database='nr')

    # Remove rows with missing taxid
    initial_nr = len(nr_df)
    nr_df = nr_df.dropna(subset=['taxid'])
    final_nr = len(nr_df)
    dropped_nr = initial_nr - final_nr
    if dropped_nr > 0:
        logging.info(f"Dropped {dropped_nr} rows with missing taxid from nr_df.")

    # Convert taxid to integer
    nr_df['taxid'] = nr_df['taxid'].astype(int)

    # Get all lineages for valid taxids
    taxids = set(nr_df.taxid)
    taxids = {t for taxid in taxids for t in ncbi.get_lineage(taxid) if t != 0}

    # Build dictionaries for name/rank
    taxid2name = ncbi.get_taxid_translator(taxids)
    taxid2name[-1] = 'Ambiguous'
    taxid2name[0] = 'Unknown'
    taxid2rank = ncbi.get_rank(taxids)

    # Verify valid lineages
    taxids = set(nr_df['taxid'])
    for taxid in taxids:
        try:
            lineage = ncbi.get_lineage(taxid)
            for t in lineage:
                if t != 0:
                    valid_lineages.add(t)
        except ValueError:
            logging.warning(f"Skipping {taxid}: taxid not found in local DB.")

    # Now limit to valid_lineages only
    taxids = valid_lineages
    taxid2name = ncbi.get_taxid_translator(taxids)
    
    def taxid2kingdom(taxid):
        """Retrieve the kingdom for a given taxid."""
        if taxid == 0:
            return 'Unknown'
        lineage_dict = ncbi.get_rank(ncbi.get_lineage(taxid))
        for k, v in lineage_dict.items():
            if v == 'superkingdom':
                kingdom = taxid2name.get(k, 'Unknown')
                if kingdom != 'Unknown' and taxid != k:
                    logging.warning(f"taxid {taxid} was translated into {kingdom}")
                return kingdom
        return 'Unknown'

    # Map taxid -> name, rank, kingdom
    nr_df['name'] = nr_df['taxid'].map(taxid2name)
    nr_df['rank'] = nr_df['taxid'].map(taxid2rank)
    nr_df['kingdom'] = nr_df['taxid'].apply(taxid2kingdom)
    nr_df['db'] = 'nr'  # Since this is the NR dataset

    return nr_df

def load_blast_output(files, database='nr'):
    """
    Load BLAST output files into a single DataFrame.

    Parameters:
        files (list): List of file paths to BLAST output files.
        database (str): Type of database ('nr').

    Returns:
        pd.DataFrame: Concatenated BLAST results.
    """
    def get_sample_from_path(path):
        filename = path.rsplit('/', 1)[-1]  # Extract just the filename
        prefix = "LCA_"
        suffix = "_modified_diamond_blast_nr_annotated.m9"
        # Remove the known prefix
        if filename.startswith(prefix):
            filename = filename[len(prefix):]
        # Remove the known suffix
        if filename.endswith(suffix):
            filename = filename[:-len(suffix)]
        return filename

    if database == 'nr':
        column_names = nr_column_names
    else:
        logging.error(f"Unknown database type: {database}. Expected 'nr'.")
        raise ValueError(f"Unknown database type: {database}. Expected 'nr'.")

    dfs = []
    for file in files:
        try:
            # Quick check of columns in the first line
            with open(file, 'r') as f:
                first_line = f.readline()
                columns_in_file = len(first_line.strip().split('\t'))
                expected_columns = len(column_names)
                if columns_in_file != expected_columns:
                    logging.warning(f"{file} has {columns_in_file} columns; expected {expected_columns}.")

            # Determine if the first line contains headers
            with open(file, 'r') as f:
                first_line = f.readline()
                if 'sci_name' in first_line.lower() or 'lca_taxid' in first_line.lower():
                    skiprows = 1
                else:
                    skiprows = 0

            df = pd.read_csv(
                file,
                sep='\t',
                header=None,
                names=column_names,
                na_values=['', 'NA', 'nan', 'Unknown'],
                skiprows=skiprows,
                on_bad_lines='warn'  # or 'error' if you want it strict
            )

            # Rename 'tax_id' -> 'taxid' for consistency
            if 'tax_id' in df.columns:
                df.rename(columns={'tax_id': 'taxid'}, inplace=True)
            else:
                logging.error(f"'tax_id' column is missing in {file}.")
                raise KeyError(f"'tax_id' column is missing in {file}.")

            # Replace 'Unknown' with '0' in taxid
            df['taxid'] = df['taxid'].replace('Unknown', '0')

            # Convert taxid to numeric
            df['taxid'] = pd.to_numeric(df['taxid'], errors='coerce')

            # Drop rows with NaN in 'taxid'
            before_drop = len(df)
            df = df.dropna(subset=['taxid'])
            after_drop = len(df)
            dropped = before_drop - after_drop
            if dropped > 0:
                logging.info(f"Dropped {dropped} rows with invalid taxid in {file}.")

            # Derive sample name
            df['sample'] = get_sample_from_path(file)

            # Extract contig
            df['contig'] = df.apply(lambda row: extract_contig(row['query'], row['sample']), axis=1)

            # alignment_length -> align_length
            df.rename(columns={'alignment_length': 'align_length'}, inplace=True)

            # Build contig_key
            df['contig_key'] = (df['sample'] + '~' + df['contig']).str.lower()

            # Convert numeric columns
            numeric_columns = ['identity', 'align_length', 'bit_score', 'evalue']
            for col in numeric_columns:
                if col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')

            # Drop rows with NaN in numeric fields (optional)
            before_drop_numeric = len(df)
            df = df.dropna(subset=numeric_columns)
            dropped_numeric = before_drop_numeric - len(df)
            if dropped_numeric > 0:
                logging.info(f"Dropped {dropped_numeric} rows with non-numeric in {file}.")

            dfs.append(df)

        except pd.errors.ParserError as e:
            logging.error(f"ParserError while reading {file}: {e}")
            print(f"ParserError while reading {file}: {e}", file=sys.stderr)
        except KeyError as e:
            logging.error(f"KeyError while processing {file}: {e}")
            print(f"KeyError while processing {file}: {e}", file=sys.stderr)
            raise e
        except Exception as e:
            logging.error(f"Unexpected error while reading {file}: {e}")
            print(f"Unexpected error while reading {file}: {e}", file=sys.stderr)
            raise e

    if dfs:
        combined_df = pd.concat(dfs, ignore_index=True)
    else:
        combined_df = pd.DataFrame()
        logging.warning("No DataFrames were created; returning empty DataFrame.")

    return combined_df

def extract_contig(query, sample):
    """
    Extract contig information by removing the sample prefix from the query.

    Parameters:
        query (str): The query string from BLAST output.
        sample (str): The sample name.

    Returns:
        str: Extracted contig information or 'unknown_contig'.
    """
    prefix = sample + '_'
    if query.startswith(prefix):
        return query[len(prefix):]
    else:
        logging.warning(f"Query '{query}' does not start with sample prefix '{prefix}'. Using 'unknown_contig'.")
        return 'unknown_contig'

# -------------------------------------------
# The rest of your functions remain unchanged
# -------------------------------------------


def load_cdhit_clusters(filename):
    """
    Load CD-HIT cluster information from a cluster file.

    Parameters:
        filename (str): Path to the CD-HIT cluster file.

    Returns:
        defaultdict: Cluster ID mapped to list of members.
    """
    clusters = defaultdict(list)
    cluster_id = None
    with open(filename, 'r') as file:
        for line in file:
            if line.startswith('>Cluster'):
                cluster_id = line.strip().split()[-1]
            else:
                if cluster_id is not None:
                    try:
                        member = parse_cdhit_row(line)
                        clusters[cluster_id].append(member)
                    except ValueError as e:
                        logging.error(f"Error parsing line: {line.strip()} - {e}")
    return clusters

def parse_cdhit_row(row):
    """
    Parse a single line from a CD-HIT cluster file.

    Parameters:
        row (str): Line from CD-HIT cluster file.

    Returns:
        dict: Parsed member information.
    """
    if '*' in row:
        # Reference member
        parts = row.strip().split()
        if len(parts) < 4:
            raise ValueError(f"Unexpected format in row: {row}")
        index, length_str, name, percent_id = parts[:4]
        is_ref = True
    else:
        # Non-reference member
        parts = row.strip().split()
        if len(parts) < 4:
            raise ValueError(f"Unexpected format in row: {row}")
        index, length_str, name, percent_id = parts[:4]
        is_ref = False

    # Clean and extract fields
    # Use regex to extract numeric length
    match = re.match(r'(\d+)', length_str)
    if match:
        length = int(match.group(1))
    else:
        logging.warning(f"Could not extract numeric length from '{length_str}' in row: {row}")
        length = 0  # Assign a default value or handle as needed

    name = name.strip('>').strip('.')

    # Extract sample and contig
    try:
        sample, contig = name.split('~')
    except ValueError:
        # Handle unexpected naming format
        sample, contig = 'unknown_sample', name  # Assign a default sample name

    return {
        'contig': contig,
        'sample': sample,
        'length': length,
        'is_ref': is_ref
    }

def get_cluster_rep(cluster):
    """
    Get the representative member of a cluster (the reference member).

    Parameters:
        cluster (list): List of cluster members.

    Returns:
        dict: Representative member.
    """
    for member in cluster:
        if member['is_ref']:
            return member
    # If no reference member, return the first member
    return cluster[0] if cluster else None

def merge_clusters_lca(clusters, lca_df):
    """
    Merge cluster information with LCA results.

    Parameters:
        clusters (defaultdict): Cluster information.
        lca_df (pd.DataFrame): LCA DataFrame.

    Returns:
        pd.DataFrame: Merged cluster and LCA information.
    """
    # Prepare cluster data
    cluster_ids = [int(id.strip()) for id in clusters.keys()]
    cluster_reps = [get_cluster_rep(clusters[str(id)]) for id in cluster_ids]
    cluster_lengths = [rep['length'] for rep in cluster_reps]
    cluster_sizes = [len(clusters[str(id)]) for id in cluster_ids]
    # Correct contig_key formatting using '~'
    cluster_contig_keys = [f"{rep['sample']}~{rep['contig']}".lower() for rep in cluster_reps]

    # Create DataFrame for clusters
    df = pd.DataFrame({
        'contig_key': cluster_contig_keys,
        'cluster': cluster_ids,
        'cluster_size': cluster_sizes,
        'contig_length': cluster_lengths
    })

    # Explicitly convert to string, then strip and lowercase
    df['contig_key'] = df['contig_key'].astype(str).str.strip().str.lower()

    # Debug: Log sample 'contig_key's
    logging.info("Sample 'contig_key's in clusters DataFrame:")
    logging.info(df['contig_key'].head())

    # Merge with LCA DataFrame
    merged_df = df.merge(lca_df, on='contig_key', how='inner')

    # Debug: Log the number of merged rows
    logging.info(f"Merged DataFrame has {len(merged_df)} rows.")

    if merged_df.empty:
        logging.warning("merged_df is empty after merging. Check 'contig_key' consistency.")
        print("merged_df is empty after merging. Please check 'contig_key' consistency between clusters and LCA DataFrames.", file=sys.stderr)
    else:
        logging.info("Successfully merged clusters with LCA results.")

    # Calculate align_percent
    merged_df['align_percent'] = np.round(100 * merged_df['align_length'] / merged_df['contig_length'], 1)

    # Use 'identity' as 'percent_identity'
    merged_df['percent_identity'] = np.round(merged_df['identity'], 1)

    # Define required columns (including additional columns)
    required_columns = [
        'cluster', 'cluster_size', 'contig_length', 'name', 'rank', 
        'kingdom', 'taxid', 'db', 'align_length', 'align_percent', 
        'percent_identity', 'bit_score', 'evalue', 'subject', 'sci_name',
        'com_names', 'superkingdom', 'phylum', 'family', 'lca_taxid', 'lca_name',
        'contig', 'sample', 'contig_key'
    ]

    # Check for missing columns
    missing_columns = [col for col in required_columns if col not in merged_df.columns]
    if missing_columns:
        logging.error(f"Missing columns in merged DataFrame: {missing_columns}")
        logging.info(f"Available columns: {merged_df.columns.tolist()}")
        raise KeyError(f"{missing_columns} not in index")

    # Subset the DataFrame to include only required columns
    merged_df = merged_df[required_columns]

    # Sort by cluster
    merged_df = merged_df.sort_values('cluster').reset_index(drop=True)

    return merged_df

def load_counts_data(counts_file):
    """
    Load and process the counts DataFrame.

    Parameters:
        counts_file (str): Path to the counts file.

    Returns:
        pd.DataFrame: Processed counts DataFrame with 'contig_key'.
    """
    try:
        # Read counts file with header, assuming first row is header
        counts = pd.read_csv(
            counts_file, 
            sep=',', 
            header=0,  # Assumes first row is header
            usecols=[0, 1, 2],  # Read 'contig', 'count', 'length'
            names=['contig', 'count', 'length'],  # Assign column names
            dtype={'contig': str, 'count': str, 'length': str},  # Read as strings initially
            na_values=['', 'NA', 'nan', 'Unknown'],
            low_memory=False
        )
        
        # Log a sample of the counts DataFrame
        logging.info("Raw Counts DataFrame sample:")
        logging.info(counts.head())
        
        # Convert 'count' to numeric, coercing errors to NaN
        counts['count'] = pd.to_numeric(counts['count'], errors='coerce')
        
        # Log the conversion result
        logging.info("Counts DataFrame after converting 'count' to numeric:")
        logging.info(counts.head())
        
        # Extract sample and contig from 'contig' column assuming 'sample~contig'
        counts[['sample', 'contig']] = counts['contig'].str.split('~', expand=True)
        
        # Handle cases where 'contig' does not contain '~'
        counts['contig'] = counts['contig'].fillna(counts['contig'])
        counts['sample'] = counts['sample'].fillna('unknown_sample')
        
        # Create 'contig_key' as 'sample~contig', lowercase
        counts['contig_key'] = (counts['sample'].astype(str) + '~' + counts['contig'].astype(str)).str.strip().str.lower()
        
        # Debug: Log the processed counts DataFrame
        logging.info("Processed Counts DataFrame:")
        logging.info(counts.head())
        
        return counts
    except pd.errors.ParserError as e:
        logging.error(f"ParserError while reading counts file {counts_file}: {e}")
        print(f"ParserError while reading counts file {counts_file}: {e}", file=sys.stderr)
        raise e
    except Exception as e:
        logging.error(f"Unexpected error while processing counts file {counts_file}: {e}")
        print(f"Unexpected error while processing counts file {counts_file}: {e}", file=sys.stderr)
        raise e

def load_taxid2name(lca_df):
    """
    Create a taxid to name mapping from the LCA DataFrame.

    Parameters:
        lca_df (pd.DataFrame): LCA DataFrame.

    Returns:
        dict: Mapping from taxid to scientific name.
    """
    taxid2name = dict(zip(lca_df['taxid'], lca_df['name']))
    taxid2name[-1] = 'Ambiguous'  # Handle ambiguous taxids
    taxid2name[0] = 'Unknown'     # Handle 'Unknown' taxids
    return taxid2name

# Example usage within util.py (optional)
if __name__ == "__main__":
    try:
        # Load LCA results from NR database
        lca_df = load_lca(nr_dir)
        if lca_df.empty:
            logging.error("LCA DataFrame is empty. Exiting.")
            raise ValueError("LCA DataFrame is empty.")
        logging.info("Loaded LCA DataFrame.")
    except Exception as e:
        logging.error(f"Failed to load LCA DataFrame: {e}")
        raise e  # Let the exception propagate

    try:
        # Load CD-HIT clusters
        clusters = load_cdhit_clusters('/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/Combined_filtered_500_SENZOR_clsuter.clstr')
        logging.info("Loaded CD-HIT clusters.")
    except Exception as e:
        logging.error(f"Failed to load CD-HIT clusters: {e}")
        raise e  # Let the exception propagate

    try:
        # Merge clusters with LCA
        merged_df = merge_clusters_lca(clusters, lca_df)
        logging.info("Successfully merged clusters with LCA results.")
    except KeyError as e:
        logging.error(f"KeyError during merging: {e}")
        print(f"Error during merging: {e}", file=sys.stderr)
        raise e  # Let the exception propagate
    except Exception as e:
        logging.error(f"Unexpected error during merging: {e}")
        print(f"Unexpected error during merging: {e}", file=sys.stderr)
        raise e  # Let the exception propagate

    try:
        # Create clust2kingdom dictionary
        clust2kingdom = dict(zip(merged_df['cluster'], merged_df['kingdom']))
        logging.info("Created clust2kingdom dictionary.")

        # Save clust2kingdom to CSV
        merged_df[['cluster', 'kingdom']].to_csv('/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/cluster_kingdoms.csv', index=False)
        logging.info("Saved cluster_kingdoms.csv successfully.")

        # Print sample entries of clust2kingdom
        print("clust2kingdom Dictionary Sample Entries:")
        for cluster_id, kingdom in list(clust2kingdom.items())[:5]:
            print(f"Cluster {cluster_id}: Kingdom {kingdom}")
    except Exception as e:
        logging.error(f"Error creating clust2kingdom: {e}")
        print(f"Error creating clust2kingdom: {e}", file=sys.stderr)
        raise e  # Let the exception propagate

    try:
        # Load and process the counts DataFrame
        counts_file_path = '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/combined_SENSOR_csvs/1Modified_CSV/merged_contig_stats.csv'
        counts_df = load_counts_data(counts_file_path)
        logging.info("Loaded and processed counts DataFrame.")
    except Exception as e:
        logging.error(f"Failed to load counts DataFrame: {e}")
        print(f"Error loading counts DataFrame: {e}", file=sys.stderr)
        raise e  # Let the exception propagate

    try:
        # Merge counts with merged_df on 'contig_key'
        final_df = merged_df.merge(counts_df[['contig_key', 'count']], on='contig_key', how='left')
        logging.info("Merged counts with merged_df to create final_df.")

        # Convert 'count' to numeric if not already (redundant but safe)
        final_df['count'] = pd.to_numeric(final_df['count'], errors='coerce')

        # Apply filtering: Remove rows where 'count' is NaN
        initial_count = len(final_df)
        final_df = final_df.dropna(subset=['count'])
        final_count = len(final_df)
        removed_rows = initial_count - final_count
        if removed_rows > 0:
            logging.info(f"Filtered out {removed_rows} rows without 'count' value.")
            print(f"Filtered out {removed_rows} rows without 'count' value.")

        # Remove rows where 'count' is zero or negative:
        logging.info("Filtered out rows with 'count' <= 0.")

        # Verify if final_df is empty after filtering
        if final_df.empty:
            logging.warning("final_df is empty after filtering out rows without 'count'.")
            print("final_df is empty after filtering out rows without 'count'.", file=sys.stderr)
        else:
            # Display the final merged DataFrame with counts
            print("Final Merged DataFrame with Counts (after filtering):")
            print(final_df.head())

            # Save the final DataFrame to CSV
            final_df.to_csv('/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/1final_merged_with_counts.csv', index=False)
            logging.info("Saved final_merged_with_counts.csv successfully.")
    except Exception as e:
        logging.error(f"Error merging counts with merged_df: {e}")
        print(f"Error merging counts with merged_df: {e}", file=sys.stderr)
        raise e  # Let the exception propagate



clust2kingdom Dictionary Sample Entries:
Cluster 0: Kingdom Bacteria
Cluster 1: Kingdom Viruses
Cluster 2: Kingdom Bacteria
Cluster 3: Kingdom Bacteria
Cluster 4: Kingdom Bacteria
Filtered out 1735 rows without 'count' value.
Final Merged DataFrame with Counts (after filtering):
   cluster  cluster_size  contig_length                              name  \
0        0             1          64751                Duncaniella sp. C9   
1        1             1          37429  Acinetobacter phage Aristophanes   
2        2             3          36456              Cutibacterium avidum   
3        3             1          35911          uncultured Treponema sp.   
4        4             1          34801          uncultured Treponema sp.   

      rank   kingdom    taxid  db  align_length  align_percent  ...  \
0  species  Bacteria  2530392  nr          2662            4.1  ...   
1  species   Viruses  2759203  nr           784            2.1  ...   
2  species  Bacteria    33010  nr          1

In [1]:
from Bio import SeqIO

def filter_fasta(input_file, output_file, min_length=300):
    """Filters a FASTA file, keeping only sequences longer than the specified minimum length.

    Args:
        input_file (str): Path to the input FASTA file.
        output_file (str): Path to the output FASTA file.
        min_length (int, optional): Minimum sequence length. Defaults to 500.
    """

    with open(output_file, 'w') as outfile:
        for record in SeqIO.parse(input_file, 'fasta'):
            if len(record.seq) >= min_length:
                SeqIO.write(record, outfile, 'fasta')

# Example usage:
input_file = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/All_contigs_combined.fasta"
output_file = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/All_contigs_combined.fa"
filter_fasta(input_file, output_file)

In [31]:
import pandas as pd

# Specify the path to your input and output files
input_file = '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Cluster_contigs/blast_with_clusters.tsv'  # Replace with your input file path
output_file = '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Cluster_contigs/500_contig_stats_filtered_viruses.csv'  # The file where viruses will be saved

try:
    # Load the data into a pandas DataFrame with the correct separator
    df = pd.read_csv(input_file, sep='\t')
    print(f"Successfully loaded data from {input_file}. Total records: {len(df)}")
    
    # Check if 'taxon_group' column exists
    if 'taxon_group' not in df.columns:
        raise KeyError("'taxon_group' column not found in the input file.")
    
    # Check if 'contig_length' column exists
    if 'contig_length' not in df.columns:
        raise KeyError("'contig_length' column not found in the input file.")
    
    # Ensure 'contig_length' is numeric
    if not pd.api.types.is_numeric_dtype(df['contig_length']):
        df['contig_length'] = pd.to_numeric(df['contig_length'], errors='coerce')
        num_non_numeric = df['contig_length'].isna().sum()
        if num_non_numeric > 0:
            print(f"Warning: {num_non_numeric} records have non-numeric 'contig_length' and were set to NaN.")
    
    # Filter rows where 'taxon_group' contains 'Viruses' (case-insensitive)
    virus_df = df[df['taxon_group'].str.contains('Viruses', case=False, na=False)]
    print(f"Filtered records containing 'Viruses'. Total virus records: {len(virus_df)}")
    
    # Further filter rows where 'contig_length' is >= 500
    filtered_virus_df = virus_df[virus_df['contig_length'] >= 500]
    print(f"After applying 'contig_length' >= 500, remaining records: {len(filtered_virus_df)}")
    
    # Save the filtered DataFrame to a new CSV file
    filtered_virus_df.to_csv(output_file, index=False)
    print(f"Filtered data saved to {output_file} successfully.")
    
except FileNotFoundError:
    print(f"Error: The file {input_file} was not found. Please check the file path.")
except pd.errors.ParserError:
    print(f"Error: Could not parse the file {input_file}. Please ensure it's a valid TSV file.")
except KeyError as e:
    print(f"Error: {e}")
except ValueError as e:
    print(f"Value Error: {e}")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

Successfully loaded data from /Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Cluster_contigs/contig_stats_filtered_lca.tsv. Total records: 721594
Filtered records containing 'Viruses'. Total virus records: 38974
After applying 'contig_length' >= 500, remaining records: 3049
Filtered data saved to /Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Cluster_contigs/500_contig_stats_filtered_viruses.csv successfully.


In [5]:
from Bio import SeqIO

def filter_fasta(input_file, output_file, min_length=500):
    """Filters a FASTA file, keeping only sequences longer than the specified minimum length.

    Args:
        input_file (str): Path to the input FASTA file.
        output_file (str): Path to the output FASTA file.
        min_length (int, optional): Minimum sequence length to keep. Defaults to 500.
    """

    with open(output_file, 'w') as output_handle:
        for record in SeqIO.parse(input_file, 'fasta'):
            if len(record.seq) >= min_length:
                SeqIO.write(record, output_handle, 'fasta')

# Example usage:
input_file = '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Cluster_contigs/Non_rodents_500.fasta'
output_file = '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Cluster_contigs/Non_Rodents_500_cluster_filter.fasta'
filter_fasta(input_file, output_file)

In [39]:
import pandas as pd

def merge_txt_files(file1, file2, output_file):
    """Merges two text files with identical headers.

    Args:
        file1 (str): Path to the first text file.
        file2 (str): Path to the second text file.
        output_file (str): Path to the output merged file.
    """

    df1 = pd.read_csv(file1, sep='\t')
    df2 = pd.read_csv(file2, sep='\t')

    # Merge the DataFrames, assuming 'contig_name' is the unique identifier
    merged_df = pd.concat([df1, df2], ignore_index=True)

    # Write the merged DataFrame to a new text file
    merged_df.to_csv(output_file, sep='\t', index=False)

# Example usage:
file1 = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Cluster_contigs/Non_rodents_bowtie_csp_1000.txt"
file2 = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Cluster_contigs/merged_SENSOR_bowtie_csp_1000.txt"
output_file = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Cluster_contigs/1merged_SENSOR_bowtie_csp_1000.txt"
merge_txt_files(file1, file2, output_file)

In [49]:
import pandas as pd
import numpy as np
from glob import glob
from collections import defaultdict
from statistics import mode
from Bio import Entrez
from ete3 import NCBITaxa
import logging
import sys
import re
from logging.handlers import RotatingFileHandler

handler = RotatingFileHandler('util.log', maxBytes=5*1024*1024, backupCount=2)
logging.basicConfig(
    level=logging.INFO,
    handlers=[handler],
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# Set Entrez email
Entrez.email = "ifeanyi.omah@ed.ac.uk"

# Initialize NCBI Taxonomy
ncbi = NCBITaxa()
# Optionally update taxonomy database
# ncbi.update_taxonomy_database()

valid_lineages = set()

# Define column names for NR files (updated to match actual .m9 file columns)
nr_column_names = [
    'query', 'subject', 'identity', 'alignment_length', 'mismatches', 'gap_opens',
    'q_start', 'q_end', 's_start', 's_end', 'evalue', 'bit_score', 'tax_id',
    'sci_name', 'com_names', 'superkingdom', 'phylum', 'family',
    'lca_taxid', 'lca_name'
]

# Define NR directory
nr_dir = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_Blast_NR/LCA/"


def load_lca(nr_directory):
    """
    Load and process LCA results from NR files in the specified directory.

    Parameters:
        nr_directory (str): Path to the directory containing NR LCA files.

    Returns:
        pd.DataFrame: Processed LCA DataFrame.
    """
    # Dynamically list all NR .m9 files that start with 'LCA_'
    nr_file_paths = glob(f"{nr_directory}/*.m9")
    logging.info(f"Found {len(nr_file_paths)} NR files to process.")

    if not nr_file_paths:
        logging.error(f"No NR files found in directory: {nr_directory}")
        raise FileNotFoundError(f"No NR files found in directory: {nr_directory}")

    # Load BLAST outputs for NR database
    nr_df = load_blast_output(nr_file_paths, database='nr')

    # Remove rows with missing taxid
    initial_nr = len(nr_df)
    nr_df = nr_df.dropna(subset=['taxid'])
    final_nr = len(nr_df)
    dropped_nr = initial_nr - final_nr
    if dropped_nr > 0:
        logging.info(f"Dropped {dropped_nr} rows with missing taxid from nr_df.")

    # Convert taxid to integer
    nr_df['taxid'] = nr_df['taxid'].astype(int)

    # Retrieve all relevant taxids (including lineage)
    taxids = set(nr_df['taxid'])
    expanded_lineage = set()
    for taxid in taxids:
        try:
            lineage = ncbi.get_lineage(taxid)
            expanded_lineage.update(lineage)
        except ValueError:
            logging.warning(f"Skipping {taxid}: taxid not found in local DB.")
    # Remove zero/invalid
    expanded_lineage.discard(0)

    # Build dictionaries for name/rank
    taxid2name = ncbi.get_taxid_translator(expanded_lineage)
    taxid2name[-1] = 'Ambiguous'
    taxid2name[0] = 'Unknown'
    taxid2rank = ncbi.get_rank(expanded_lineage)

    def taxid2kingdom(tid):
        """
        Retrieve the kingdom for a given taxid.
        """
        if tid == 0:
            return 'Unknown'
        try:
            lineage = ncbi.get_lineage(tid)
            ranks = ncbi.get_rank(lineage)
            for k, v in ranks.items():
                if v == 'superkingdom':
                    # Map taxid -> name
                    kingdom_name = taxid2name.get(k, 'Unknown')
                    if kingdom_name != 'Unknown' and tid != k:
                        logging.warning(f"taxid {tid} was translated into {kingdom_name}")
                    return kingdom_name
        except ValueError:
            return 'Unknown'
        return 'Unknown'

    # Map taxid -> name, rank, kingdom
    nr_df['name'] = nr_df['taxid'].map(taxid2name)
    nr_df['rank'] = nr_df['taxid'].map(taxid2rank)
    nr_df['kingdom'] = nr_df['taxid'].apply(taxid2kingdom)
    nr_df['db'] = 'nr'  # Tag as NR data

    return nr_df


def load_blast_output(files, database='nr'):
    """
    Load BLAST output files into a single DataFrame.

    Parameters:
        files (list): List of file paths to BLAST output files.
        database (str): Type of database ('nr').

    Returns:
        pd.DataFrame: Concatenated BLAST results.
    """

    def get_sample_from_path(path):
        """
        Extract sample name from filename using regex:
        LCA_SOMESAMPLE_modified_diamond_blast_nr_annotated.m9
        """
        filename = path.split('/')[-1]
        # Regex capturing portion after "LCA_" and before "_modified..."
        match = re.match(r'^LCA_(.+?)_modified_diamond_blast_nr_annotated\.m9$', filename)
        if match:
            return match.group(1)  # Everything in the capturing group
        else:
            logging.warning(f"Could not extract sample name from filename: {filename}")
            return "unknown_sample"

    # Define column names based on database
    if database == 'nr':
        column_names = nr_column_names
    else:
        logging.error(f"Unknown database type: {database}. Expected 'nr'.")
        raise ValueError(f"Unknown database type: {database}. Expected 'nr'.")

    dfs = []

    for file in files:
        try:
            # Quick check of columns in the first line
            with open(file, 'r') as f:
                first_line = f.readline()
                columns_in_file = len(first_line.strip().split('\t'))
                expected_columns = len(column_names)
                if columns_in_file != expected_columns:
                    logging.warning(
                        f"{file} has {columns_in_file} columns; "
                        f"expected {expected_columns}."
                    )

            # Determine if the file has a header row
            with open(file, 'r') as f:
                first_line = f.readline()
                if 'sci_name' in first_line.lower() or 'lca_taxid' in first_line.lower():
                    skiprows = 1
                else:
                    skiprows = 0

            df = pd.read_csv(
                file,
                sep='\t',
                header=None,
                names=column_names,
                na_values=['', 'NA', 'nan', 'Unknown'],
                skiprows=skiprows,
                on_bad_lines='warn'  # or 'error' if you want strict
            )

            # Debug: Log the first few rows
            logging.info(f"First few rows of {file} after reading:")
            logging.info(df.head())

            # Rename 'tax_id' to 'taxid' if present
            if 'tax_id' in df.columns:
                df.rename(columns={'tax_id': 'taxid'}, inplace=True)
            elif 'taxid' not in df.columns:
                logging.error(f"'tax_id' or 'taxid' column not found in {file}")
                raise KeyError(f"'tax_id' or 'taxid' column not found in {file}.")

            # Replace 'Unknown' with '0'
            df['taxid'] = df['taxid'].replace('Unknown', '0')

            # Convert 'taxid' to numeric
            df['taxid'] = pd.to_numeric(df['taxid'], errors='coerce')

            # Drop rows with NaN in 'taxid'
            before_drop = len(df)
            df = df.dropna(subset=['taxid'])
            after_drop = len(df)
            dropped = before_drop - after_drop
            if dropped > 0:
                logging.info(f"Dropped {dropped} rows with invalid taxid in {file}.")

            # Derive sample name from the filename
            df['sample'] = get_sample_from_path(file)

            # Extract contig from 'query' by removing sample prefix
            df['contig'] = df.apply(
                lambda row: extract_contig(row['query'], row['sample']),
                axis=1
            )

            # alignment_length -> align_length
            df.rename(columns={'alignment_length': 'align_length'}, inplace=True)

            # Build contig_key
            df['contig_key'] = (df['sample'] + '~' + df['contig']).str.lower()

            # Debug: Log the created 'contig_key'
            logging.info(f"Created 'contig_key' for {file}:")
            logging.info(df[['contig_key']].head())

            # Convert numeric columns
            numeric_columns = ['identity', 'align_length', 'bit_score', 'evalue']
            for col in numeric_columns:
                if col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                else:
                    logging.warning(f"Column '{col}' not found in {file}.")

            # Optionally drop rows where numeric columns are NaN
            before_num_drop = len(df)
            df = df.dropna(subset=numeric_columns)
            after_num_drop = len(df)
            dropped_num = before_num_drop - after_num_drop
            if dropped_num > 0:
                logging.info(f"Dropped {dropped_num} rows with non-numeric values in {file}.")

            dfs.append(df)
        except pd.errors.ParserError as e:
            logging.error(f"ParserError while reading {file}: {e}")
            print(f"ParserError while reading {file}: {e}", file=sys.stderr)
        except KeyError as e:
            logging.error(f"KeyError while processing {file}: {e}")
            print(f"KeyError while processing {file}: {e}", file=sys.stderr)
            raise e
        except Exception as e:
            logging.error(f"Unexpected error while reading {file}: {e}")
            print(f"Unexpected error while reading {file}: {e}", file=sys.stderr)
            raise e

    if dfs:
        combined_df = pd.concat(dfs, ignore_index=True)
    else:
        combined_df = pd.DataFrame()
        logging.warning("No dataframes to concatenate. Returning empty DataFrame.")

    return combined_df


def extract_contig(query, sample):
    """
    Extract contig information by removing the sample prefix from the query.
    If the query doesn't start with 'sample_', label it 'unknown_contig'.
    """
    prefix = sample + '_'
    if query.startswith(prefix):
        return query[len(prefix):]
    else:
        logging.warning(
            f"Query '{query}' does not start with sample prefix '{prefix}'. "
            f"Assigning 'unknown_contig'."
        )
        return 'unknown_contig'


def load_cdhit_clusters(filename):
    """
    Load CD-HIT cluster information from a cluster file.
    """
    clusters = defaultdict(list)
    cluster_id = None

    with open(filename, 'r') as file:
        for line in file:
            if line.startswith('>Cluster'):
                cluster_id = line.strip().split()[-1]
            else:
                if cluster_id is not None:
                    try:
                        member = parse_cdhit_row(line)
                        clusters[cluster_id].append(member)
                    except ValueError as e:
                        logging.error(f"Error parsing line: {line.strip()} - {e}")
    return clusters


def parse_cdhit_row(row):
    """
    Parse a single line from a CD-HIT cluster file.
    """
    if '*' in row:
        # Reference member
        parts = row.strip().split()
        if len(parts) < 4:
            raise ValueError(f"Unexpected format in row: {row}")
        index, length_str, name, percent_id = parts[:4]
        is_ref = True
    else:
        # Non-reference member
        parts = row.strip().split()
        if len(parts) < 4:
            raise ValueError(f"Unexpected format in row: {row}")
        index, length_str, name, percent_id = parts[:4]
        is_ref = False

    # Extract numeric length
    match_len = re.match(r'(\d+)', length_str)
    if match_len:
        length = int(match_len.group(1))
    else:
        logging.warning(f"Could not extract numeric length from '{length_str}' in row: {row}")
        length = 0

    # Cleanup name
    name = name.strip('>').strip('.')

    # Extract sample, contig
    try:
        sample, contig = name.split('~')
    except ValueError:
        # Unexpected naming
        sample, contig = 'unknown_sample', name

    return {
        'contig': contig,
        'sample': sample,
        'length': length,
        'is_ref': is_ref
    }


def get_cluster_rep(cluster):
    """
    Get the representative member of a cluster (the reference member).
    """
    for member in cluster:
        if member['is_ref']:
            return member
    # If no reference member, just return the first
    return cluster[0] if cluster else None


def merge_clusters_lca(clusters, lca_df):
    """
    Merge cluster information with LCA results.
    """
    cluster_ids = [int(cid.strip()) for cid in clusters.keys()]
    cluster_reps = [get_cluster_rep(clusters[str(cid)]) for cid in cluster_ids]
    cluster_lengths = [rep['length'] for rep in cluster_reps]
    cluster_sizes = [len(clusters[str(cid)]) for cid in cluster_ids]
    # Build contig_key
    cluster_contig_keys = [f"{rep['sample']}~{rep['contig']}".lower() for rep in cluster_reps]

    df_clusters = pd.DataFrame({
        'contig_key': cluster_contig_keys,
        'cluster': cluster_ids,
        'cluster_size': cluster_sizes,
        'contig_length': cluster_lengths
    })

    # Lowercase & strip
    df_clusters['contig_key'] = df_clusters['contig_key'].str.strip().str.lower()

    # Debug
    logging.info("Sample 'contig_key's in clusters DataFrame:")
    logging.info(df_clusters['contig_key'].head())

    # Merge
    merged_df = df_clusters.merge(lca_df, on='contig_key', how='inner')
    logging.info(f"Merged DataFrame has {len(merged_df)} rows.")

    if merged_df.empty:
        logging.warning("merged_df is empty after merging. Check 'contig_key' consistency.")
        print(
            "merged_df is empty after merging. "
            "Please check 'contig_key' consistency between clusters and LCA DataFrames.",
            file=sys.stderr
        )
    else:
        logging.info("Successfully merged clusters with LCA results.")

    # Add align_percent
    merged_df['align_percent'] = np.round(
        100 * merged_df['align_length'] / merged_df['contig_length'], 1
    )
    # Use 'identity' as 'percent_identity'
    merged_df['percent_identity'] = np.round(merged_df['identity'], 1)

    # Required columns
    required_columns = [
        'cluster', 'cluster_size', 'contig_length', 'name', 'rank',
        'kingdom', 'taxid', 'db', 'align_length', 'align_percent',
        'percent_identity', 'bit_score', 'evalue', 'subject', 'sci_name',
        'com_names', 'superkingdom', 'phylum', 'family', 'lca_taxid', 'lca_name',
        'contig', 'sample', 'contig_key'
    ]

    missing_columns = [c for c in required_columns if c not in merged_df.columns]
    if missing_columns:
        logging.error(f"Missing columns in merged DataFrame: {missing_columns}")
        logging.info(f"Available columns: {merged_df.columns.tolist()}")
        raise KeyError(f"{missing_columns} not in index")

    # Subset only required columns
    merged_df = merged_df[required_columns]
    merged_df = merged_df.sort_values('cluster').reset_index(drop=True)

    return merged_df


def load_counts_data(counts_file):
    """
    Load and process the counts DataFrame.
    """
    try:
        counts = pd.read_csv(
            counts_file,
            sep=',',
            header=0,
            usecols=[0, 1, 2],
            names=['contig', 'count', 'length'],
            dtype={'contig': str, 'count': str, 'length': str},
            na_values=['', 'NA', 'nan', 'Unknown'],
            low_memory=False
        )

        # Debug
        logging.info("Raw Counts DataFrame sample:")
        logging.info(counts.head())

        counts['count'] = pd.to_numeric(counts['count'], errors='coerce')

        logging.info("Counts DataFrame after converting 'count' to numeric:")
        logging.info(counts.head())

        # Extract sample, contig from 'contig'
        counts[['sample', 'contig']] = counts['contig'].str.split('~', expand=True)
        counts['contig'] = counts['contig'].fillna(counts['contig'])
        counts['sample'] = counts['sample'].fillna('unknown_sample')

        # Build contig_key
        counts['contig_key'] = (
            counts['sample'] + '~' + counts['contig']
        ).str.strip().str.lower()

        # Debug
        logging.info("Processed Counts DataFrame:")
        logging.info(counts.head())

        return counts
    except pd.errors.ParserError as e:
        logging.error(f"ParserError while reading counts file {counts_file}: {e}")
        print(f"ParserError while reading counts file {counts_file}: {e}", file=sys.stderr)
        raise e
    except Exception as e:
        logging.error(f"Unexpected error while processing counts file {counts_file}: {e}")
        print(f"Unexpected error while processing counts file {counts_file}: {e}", file=sys.stderr)
        raise e


def load_taxid2name(lca_df):
    """
    Create a taxid to name mapping from the LCA DataFrame.
    """
    taxid2name = dict(zip(lca_df['taxid'], lca_df['name']))
    taxid2name[-1] = 'Ambiguous'
    taxid2name[0] = 'Unknown'
    return taxid2name


#
# Main execution (Example usage)
#
if __name__ == "__main__":
    try:
        # 1) Load LCA results
        lca_df = load_lca(nr_dir)
        if lca_df.empty:
            logging.error("LCA DataFrame is empty. Exiting.")
            raise ValueError("LCA DataFrame is empty.")
        logging.info("Loaded LCA DataFrame.")

        # Debugging step: check LCA keys
        unique_lca_keys = set(lca_df['contig_key'])
        print("Number of unique LCA contig_keys:", len(unique_lca_keys))

    except Exception as e:
        logging.error(f"Failed to load LCA DataFrame: {e}")
        print(f"Error loading LCA DataFrame: {e}")
        sys.exit(1)

    try:
        # 2) Load CD-HIT clusters
        clusters = load_cdhit_clusters('/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/Combined_filtered_500_SENZOR_clsuter.clstr'
        )
        logging.info("Loaded CD-HIT clusters.")

        # Debug: convert clusters -> DataFrame just to see how many contig_keys it has
        cluster_ids = clusters.keys()
        cluster_reps = [get_cluster_rep(clusters[cid]) for cid in cluster_ids]
        df_clusters = pd.DataFrame({
            'sample': [rep['sample'] for rep in cluster_reps],
            'contig': [rep['contig'] for rep in cluster_reps]
        })
        df_clusters['contig_key'] = (
            df_clusters['sample'] + '~' + df_clusters['contig']
        ).str.lower()

        unique_cluster_keys = set(df_clusters['contig_key'])
        print("Number of unique cluster contig_keys:", len(unique_cluster_keys))

        # Check what's missing from LCA but not in cluster
        missing_in_clusters = unique_lca_keys.difference(unique_cluster_keys)
        print(f"Number of LCA contig_keys not found in clusters: {len(missing_in_clusters)}")
        if missing_in_clusters:
            print("Sample missing contig_keys (cluster):", list(missing_in_clusters)[:10], "...")

    except Exception as e:
        logging.error(f"Failed to load CD-HIT clusters: {e}")
        print(f"Error loading CD-HIT clusters: {e}")
        sys.exit(1)

    try:
        # 3) Merge clusters with LCA
        merged_df = merge_clusters_lca(clusters, lca_df)
        logging.info("Successfully merged clusters with LCA results.")
    except KeyError as e:
        logging.error(f"KeyError during merging: {e}")
        print(f"Error during merging: {e}", file=sys.stderr)
        sys.exit(1)
    except Exception as e:
        logging.error(f"Unexpected error during merging: {e}")
        print(f"Unexpected error during merging: {e}", file=sys.stderr)
        sys.exit(1)

    try:
        # 4) Create clust2kingdom dictionary
        clust2kingdom = dict(zip(merged_df['cluster'], merged_df['kingdom']))
        logging.info("Created clust2kingdom dictionary.")

        # Save clust2kingdom to CSV
        merged_df[['cluster', 'kingdom']].to_csv(
            '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/cluster_kingdoms.csv',
            index=False
        )
        logging.info("Saved cluster_kingdoms.csv successfully.")

        # Print sample entries of clust2kingdom
        print("clust2kingdom Dictionary Sample Entries:")
        for cluster_id, kingdom in list(clust2kingdom.items())[:5]:
            print(f"Cluster {cluster_id}: Kingdom {kingdom}")
    except Exception as e:
        logging.error(f"Error creating clust2kingdom: {e}")
        print(f"Error creating clust2kingdom: {e}", file=sys.stderr)
        sys.exit(1)

    try:
        # 5) Load and process the counts DataFrame
        counts_file_path = '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/combined_SENSOR_csvs/1Modified_CSV/merged_contig_stats.csv'
        counts_df = load_counts_data(counts_file_path)
        logging.info("Loaded and processed counts DataFrame.")

        # Debug: check how many contig_keys in counts
        unique_counts_keys = set(counts_df['contig_key'])
        missing_in_counts = unique_lca_keys.difference(unique_counts_keys)
        print(f"Number of LCA contig_keys not found in counts: {len(missing_in_counts)}")
        if missing_in_counts:
            print("Sample missing contig_keys (counts):", list(missing_in_counts)[:10], "...")

    except Exception as e:
        logging.error(f"Failed to load counts DataFrame: {e}")
        print(f"Error loading counts DataFrame: {e}", file=sys.stderr)
        sys.exit(1)

    try:
        # 6) Merge counts with merged_df on 'contig_key'
        final_df = merged_df.merge(counts_df[['contig_key', 'count']], on='contig_key', how='left')
        logging.info("Merged counts with merged_df to create final_df.")

        # Convert 'count' to numeric if not already
        final_df['count'] = pd.to_numeric(final_df['count'], errors='coerce')

        # Filter: remove rows where 'count' is NaN
        initial_count = len(final_df)
        final_df = final_df.dropna(subset=['count'])
        final_count = len(final_df)
        removed_rows = initial_count - final_count
        if removed_rows > 0:
            logging.info(f"Filtered out {removed_rows} rows without 'count' value.")
            print(f"Filtered out {removed_rows} rows without 'count' value.")

        # Filter out rows where 'count' <= 0
        final_df = final_df[final_df['count'] > 0]
        logging.info("Filtered out rows with 'count' <= 0.")

        if final_df.empty:
            logging.warning("final_df is empty after filtering out rows with no/low 'count'.")
            print("final_df is empty after filtering out rows without a valid 'count'.", file=sys.stderr)
        else:
            # Show final DataFrame
            print("Final Merged DataFrame with Counts (after filtering):")
            print(final_df.head())

            # Save the final DataFrame
            final_df.to_csv(
                '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/Complete_contig_counts.csv',
                index=False
            )
            logging.info("Saved final_merged_with_counts.csv successfully.")

    except Exception as e:
        logging.error(f"Error merging counts with merged_df: {e}")
        print(f"Error merging counts with merged_df: {e}", file=sys.stderr)
        sys.exit(1)



Number of unique LCA contig_keys: 1616406
Number of unique cluster contig_keys: 320917
Number of LCA contig_keys not found in clusters: 1517383
Sample missing contig_keys (cluster): ['aiamagoat002_s49_656110~node_5986_length_472_cov_1.179747', 'aiamagoat003_s50_657720~node_66920_length_250_cov_0.855491', 'ainwzchicken001_s56_655747~node_5002_length_340_cov_0.855513', 'sezoowogdog006_s16_1_657760~node_12378_length_265_cov_0.787234', 'ainwzdog003_s39_657736~node_962_length_439_cov_2.287293', 'ainwzgoat001_s41_656115~node_3119_length_306_cov_1.039301', 'sezonowgoat010_s65_2_656741~node_38131_length_111_cov_34.264706', 'sezonowifgoat011_s29_2_656747~node_25293_length_429_cov_1.511364', 'sezonowowisr086_liv_s67_666422~node_4828_length_377_cov_0.666667', 'ainwzchicken001_s56_655747~node_21540_length_247_cov_0.864706'] ...
clust2kingdom Dictionary Sample Entries:
Cluster 0: Kingdom Bacteria
Cluster 1: Kingdom Viruses
Cluster 2: Kingdom Bacteria
Cluster 3: Kingdom Bacteria
Cluster 4: Kingdom B

In [35]:
import pandas as pd

# File path
file_path = '/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_results/final_merged_with_counts.csv'

# Load the CSV file
data = pd.read_csv(file_path)

# Get unique values for 'sample' column
unique_samples = data['sample'].unique()

# Display the unique values
unique_samples

array(['SEZONOWIFGOAT011_S29_2_656747', 'AIAMAGOAT001_S48_656109',
       'AINWZGOAT001_S41_656115', 'AIAMAGOAT002_S49_656110',
       'SEZOOWOGCHICKEN001_S35_655749', 'SEZONOWOGDOG014_S21_657751',
       'SEZONOWGOAT009_S67_2_656740', 'AINWZGOAT004_S44_1_656183',
       'SEZONOWIFRR033_LIV_S73_666402', 'AINWZGOAT007_S47_2_656739',
       'AIMUZSHEEP001_S52_656081', 'SEZONOWOGR004_LIV_S50_666375',
       'AIAMAR011_LIV_S31_663039', 'AINWZGOAT005_S45_1_656184',
       'SEZONOWDOG01_S7_657737', 'AINWZLIZARD001_S58_657718',
       'AINWZGOAT006_S46_1_656185', 'SEZONOWISSHEEP001_S32_656084',
       'SEZONOWISDOG002_S4_657741', 'AINWZDOG001_S37_657734',
       'SEZONOWISDOG015_S13_657743', 'AIMGBR023_LIV_S43_662128',
       'SEZONOWISLIZARD001_S36_657719', 'SEZONOWOGR006_LIV_S52_666377',
       'AIMGBR021_LIV_S41_663052', 'AINWZR015_LIV_S15_663078',
       'AINWZR008_LIV_S8_663067', 'AINWZPIG001_S62_657715',
       'SEZONOWOGR002_LIV_S48_666373', 'AINWZR009_LIV_S9_663068',
       'SEZONOWDO

In [33]:
unique_samples_count = data['sample'].value_counts()

# Display the counts of unique values
unique_samples_count

sample
SEZONOWIFGOAT011_S29_2_656747     18230
AIAMAGOAT003_S50_657720            7213
AIAMAGOAT002_S49_656110            4227
AINWZLIZARD001_S58_657718          2717
AIMGLDOG001_S59_657733             2609
                                  ...  
SEZONOWOWISR001_LIV_S62_666415       23
NTC_S71_670947                       19
AINWZR019_LIV_S19_663082             12
SEZONOWOGR001_LIV_S47_666398          5
NE_S82_663084                         3
Name: count, Length: 152, dtype: int64

In [43]:
import pandas as pd
import os
import json
import re
from ete3 import NCBITaxa
from tqdm import tqdm  # For progress bars
import logging
import glob

# ===========================
# Configuration
# ===========================

# Directories
json_dir = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_Contig_quality/"
lca_dir = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_SENZOR_Blast_NR/LCA"

# Output TSV file paths
contig_stats_all_tsv = os.path.join(json_dir, "contig_stats.tsv")
complete_summary_tsv = os.path.join(lca_dir, "contig_stats_lca.tsv")

# ===========================
# Logging Configuration
# ===========================

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s:%(message)s',
    handlers=[
        logging.StreamHandler()
    ]
)

# ===========================
# Function Definitions
# ===========================

def load_json_data(json_path):
    """
    Load JSON data from the given path and process contig names.

    Args:
        json_path (str): Path to the JSON file.

    Returns:
        dict: Dictionary with contig_name as keys and read_count as values.
    """
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    # Replace "~" with "_" in contig names and exclude any key named "*"
    processed_data = {key.replace("~", "_"): value for key, value in data.items() if key != "*"}
    return processed_data

def load_lca_data(lca_path):
    """
    Load LCA data from the given path with specified data types.

    Args:
        lca_path (str): Path to the LCA file.

    Returns:
        pd.DataFrame: DataFrame containing LCA data.
    """
    dtype_dict = {
        "query": str,
        "subject": str,
        "identity": float,
        "alignment_length": int,
        "mismatches": int,
        "gap_opens": int,
        "q_start": int,
        "q_end": int,
        "s_start": int,
        "s_end": int,
        "evalue": float,
        "bit_score": float,
        "tax_id": str,
        "sci_name": str,
        "com_names": str,
        "superkingdom": str,
        "phylum": str,
        "family": str,
        "lca_taxid": str,
        "lca_name": str
    }
    return pd.read_csv(
        lca_path, 
        sep="\t", 
        header=0,  # Use the first row as header
        dtype=dtype_dict,
        na_values=['Unknown'],  # Treat 'Unknown' as NaN if desired
        keep_default_na=False  # Prevent 'Unknown' from being treated as NaN
    )

def extract_contig_length(contig_name):
    """
    Extract contig length from contig name using regex.
    Assumes contig_name contains 'length_<number>'.

    Args:
        contig_name (str): The contig name string.

    Returns:
        int or float: Extracted contig length or NaN if not found.
    """
    match = re.search(r"length_(\d+)", contig_name)
    if match:
        return int(match.group(1))
    else:
        return pd.NA  # Use pandas NA for missing values

def get_tax_group(taxid, ncbi, taxon_groups, taxon_groups_id):
    """
    Determine the taxon group for a given taxid.

    Args:
        taxid (int or str): NCBI taxid.
        ncbi (NCBITaxa): Initialized NCBITaxa object.
        taxon_groups (list): List of taxon group names.
        taxon_groups_id (list): List of taxon group taxids.

    Returns:
        str: Taxon group name or "Ambiguous" if not found.
    """
    try:
        taxid = int(taxid)
        lineage = ncbi.get_lineage(taxid)
        for i, tax in enumerate(taxon_groups_id):
            if tax is not None and tax in lineage:
                return taxon_groups[i]
        return "Ambiguous"
    except Exception as e:
        logging.error(f"Error processing taxid {taxid}: {e}")
        return "Ambiguous"

# ===========================
# Initialize NCBITaxa
# ===========================

logging.info("Initializing NCBITaxa...")
ncbi = NCBITaxa()

# Define taxon groups and retrieve their taxids
taxon_groups = ["Viruses", "Bacteria", "Archaea", "Metazoa", "Eukaryota"]
taxon_groups_id = []

for group in taxon_groups:
    name_translator = ncbi.get_name_translator([group])
    if group in name_translator:
        taxid = name_translator[group][0]
        taxon_groups_id.append(taxid)
    else:
        logging.warning(f"Taxon group '{group}' not found in NCBITaxa database.")
        taxon_groups_id.append(None)  # Append None if taxon group not found

# ===========================
# Initialize Data Containers
# ===========================

# List to hold DataFrames for contig_stats_all.tsv
contig_stats_all_data = []

# List to hold DataFrames for contig_stats_lca.tsv
complete_summary_data = []

# Initialize a cache for taxid to taxon_group mapping to optimize performance
taxid_cache = {}

def get_tax_group_cached(taxid):
    """
    Determine the taxon group for a given taxid with caching.

    Args:
        taxid (int or str): NCBI taxid.

    Returns:
        str: Taxon group name or "Ambiguous" if not found.
    """
    if taxid in taxid_cache:
        return taxid_cache[taxid]
    else:
        group = get_tax_group(taxid, ncbi, taxon_groups, taxon_groups_id)
        taxid_cache[taxid] = group
        return group

# ===========================
# Dynamic File Listing
# ===========================

# Dynamically list all JSON files in json_dir
json_pattern = os.path.join(json_dir, "*_contigs_stats.json")
json_files = glob.glob(json_pattern)

if not json_files:
    logging.error(f"No JSON files found in directory: {json_dir}")
    exit(1)

# Dynamically list all NR BLAST result files in lca_dir
# Assuming NR blast files have 'blast_nr' in their filename
nr_lca_pattern = os.path.join(lca_dir, "*_modified_diamond_blast_nr_annotated.m9")
nr_lca_files = glob.glob(nr_lca_pattern)

if not nr_lca_files:
    logging.error(f"No NR BLAST result files found in directory: {lca_dir}")
    exit(1)

# Create a mapping from sample name to NR BLAST file
# Assuming sample name can be extracted from JSON and NR BLAST filenames
# For example:
# JSON: AIAMACAT001_S69_657769_contig_stats.json
# NR BLAST: AIAMACAT001_S69_657769_modified_diamond_blast_nr_annotated.m9

nr_blast_mapping = {}
for nr_file in nr_lca_files:
    basename = os.path.basename(nr_file)
    # Extract sample name by removing suffix '_modified_diamond_blast_nr_annotated.m9'
    match = re.match(r"(.+?)_modified_diamond_blast_nr_annotated\.m9", basename)
    if match:
        sample_name = match.group(1)  # "LCA_AIAMA_GOT005_OS_S43_673459"
        # Remove the "LCA_" prefix if present
        sample_name = re.sub(r"^LCA_", "", sample_name)
        nr_blast_mapping[sample_name] = nr_file
    else:
        logging.warning(f"NR BLAST file name does not match expected pattern: {basename}")

# Log the mapping for verification
logging.info("Mapping of samples to NR BLAST files:")
for sample, nr_file in nr_blast_mapping.items():
    logging.info(f"Sample: {sample} => NR BLAST File: {os.path.basename(nr_file)}")

# ===========================
# Processing JSON and NR BLAST Files
# ===========================

for json_file in json_files:
    basename = os.path.basename(json_file)
    # Extract sample name by removing suffix '_contig_stats.json'
    sample_name = basename.replace("_contigs_stats.json", "")
    
    logging.info(f"\nProcessing sample: {sample_name}")
    
    # Find corresponding NR BLAST file
    if sample_name not in nr_blast_mapping:
        logging.error(f"No corresponding NR BLAST file found for sample: {sample_name}")
        continue
    
    nr_lca_file = nr_blast_mapping[sample_name]
    
    # ---------------------------
    # Process JSON File
    # ---------------------------
    
    # Full path to the JSON file
    json_path = os.path.join(json_dir, json_file)
    
    # Check if JSON file exists
    if not os.path.exists(json_path):
        logging.error(f"JSON file not found: {json_path}")
        continue
    
    # Load and process JSON data
    json_data = load_json_data(json_path)
    num_contigs = len(json_data)
    logging.info(f"Loaded {num_contigs} contigs from {basename}")
    
    # Convert to DataFrame for contig_stats_all.tsv
    contig_stats = pd.DataFrame.from_dict(json_data, orient='index', columns=['read_count']).reset_index()
    contig_stats.rename(columns={'index': 'contig_name'}, inplace=True)
    
    # Extract contig_length
    contig_stats['contig_length'] = contig_stats['contig_name'].apply(extract_contig_length)
    
    # Handle contig_length extraction failures
    missing_lengths = contig_stats['contig_length'].isna().sum()
    if missing_lengths > 0:
        logging.warning(f"{missing_lengths} contigs in {basename} did not have a 'length_' pattern.")
    
    # Add sample name
    contig_stats['sample'] = sample_name
    
    # Reorder columns for contig_stats_all.tsv
    contig_stats = contig_stats[['sample', 'contig_name', 'contig_length', 'read_count']]
    
    # Append to contig_stats_all_data list
    contig_stats_all_data.append(contig_stats)
    
    # ---------------------------
    # Process NR LCA File
    # ---------------------------
    
    # Full path to NR LCA file
    nr_lca_path = nr_lca_file
    
    # Check if NR LCA file exists
    if not os.path.exists(nr_lca_path):
        logging.error(f"NR LCA file not found: {nr_lca_path}")
        continue
    
    # Load LCA data
    nr_lca = load_lca_data(nr_lca_path)
    logging.info(f"Loaded NR LCA data for {sample_name}")
    
    # Standardize contig names in LCA files by replacing "~" with "_"
    nr_lca["query"] = nr_lca["query"].str.replace("~", "_")
    
    # Check for matching contigs
    contig_names = set(contig_stats["contig_name"])
    nr_queries = set(nr_lca["query"])
    common_nr = contig_names.intersection(nr_queries)
    logging.info(f"Common contigs in NR LCA for {sample_name}: {len(common_nr)}")
    
    # Merge NR LCA data with contig stats
    nr_merged = pd.merge(
        contig_stats, 
        nr_lca, 
        left_on="contig_name", 
        right_on="query", 
        how="inner",
        suffixes=('_stats', '_nr')
    )
    nr_merged["nr"] = True
    nr_merged["nt_or_nr"] = "nr"
    logging.info(f"NR merged rows for {sample_name}: {len(nr_merged)}")
    
    # Add 'common_nr' column
    # This column will be True if the contig is present in NR, else False
    nr_merged["common_nr"] = nr_merged["contig_name"].isin(common_nr)
    logging.info(f"Added 'common_nr' column for {sample_name}")
    
    # Handle duplicate 'contig_length' columns
    # After merging, there might be 'contig_length_stats' and 'contig_length_nr'
    # We'll retain one 'contig_length' column and drop the others
    if 'contig_length_stats' in nr_merged.columns:
        nr_merged.rename(columns={'contig_length_stats': 'contig_length'}, inplace=True)
        # Drop the other contig_length columns
        nr_merged.drop(['contig_length_nr'], axis=1, inplace=True, errors='ignore')
    elif 'contig_length_nr' in nr_merged.columns:
        nr_merged.rename(columns={'contig_length_nr': 'contig_length'}, inplace=True)
        nr_merged.drop(['contig_length_stats'], axis=1, inplace=True, errors='ignore')
    
    # Define selected_columns with the updated 'contig_length'
    selected_columns = [
        "sample", 
        "contig_name", 
        "contig_length",  # Updated column
        "read_count", 
        "nr", 
        "nt_or_nr",
        "lca_taxid", 
        "bit_score", 
        "alignment_length", 
        "superkingdom", 
        "tax_id",
        "identity",
        "sci_name",
        "com_names",
        "phylum",
        "family",
        "lca_taxid",
        "lca_name",
        "common_nr"  # Include the new column
    ]
    
    # Select only columns that are present
    available_columns = nr_merged.columns.tolist()
    selected_columns = [col for col in selected_columns if col in available_columns]
    
    # Now, select the relevant columns
    combined = nr_merged[selected_columns]
    
    # Rename columns for clarity
    combined.rename(columns={
        "bit_score": "bitscore",
        "alignment_length": "align_length"
    }, inplace=True)
    
    # ===========================
    # Update taxon_group Using taxid
    # ===========================
    
    logging.info("\nUpdating 'taxon_group' based on 'taxid' using ETE3's NCBITaxa...")
    
    # Apply the cached get_tax_group function with progress bar
    tqdm.pandas(desc="Updating taxon_group")
    combined['taxon_group'] = combined['tax_id'].progress_apply(get_tax_group_cached)
    
    logging.info("Updated 'taxon_group' successfully.")
    
    # ---------------------------
    # Reorder columns after updating taxon_group
    # ---------------------------
    
    # Reorder columns if desired
    # Ensure that 'common_nr' is at the end
    desired_order = [
        "sample", 
        "contig_name", 
        "contig_length", 
        "read_count", 
        "nr", 
        "nt_or_nr",
        "taxid", 
        "bitscore", 
        "align_length", 
        "taxon_group", 
        "tax_id",
        "identity",
        "sci_name",
        "com_names",
        "phylum",
        "family",
        "lca_taxid",
        "lca_name",
        "common_nr"  # Ensure the new column is included
    ]
    
    # Adjust desired_order based on available columns
    final_columns = [col for col in desired_order if col in combined.columns]
    combined = combined[final_columns]
    
    logging.info(f"Final selected columns after updating 'taxon_group': {combined.columns.tolist()}")
    
    logging.info(f"Total merged rows for {sample_name}: {len(combined)}")
    
    # Append to complete_summary_data list
    complete_summary_data.append(combined)

# ===========================
# Combine All Data and Save Outputs
# ===========================

# ---------------------------
# Save contig_stats_all.tsv
# ---------------------------

if contig_stats_all_data:
    # Concatenate all contig_stats DataFrames
    contig_stats_all_df = pd.concat(contig_stats_all_data, ignore_index=True)
    
    # Optional: Sort the DataFrame for better readability
    contig_stats_all_df.sort_values(by=['sample', 'contig_name'], inplace=True)
    
    # Optional: Reset index after sorting
    contig_stats_all_df.reset_index(drop=True, inplace=True)
    
    # Save to TSV
    contig_stats_all_df.to_csv(contig_stats_all_tsv, sep='\t', index=False)
    logging.info(f"\nSuccessfully saved combined contig stats to {contig_stats_all_tsv}")
else:
    logging.warning("No contig stats data available to save.")

# ---------------------------
# Save contig_stats_lca.tsv
# ---------------------------

if complete_summary_data:
    # Concatenate all complete_summary DataFrames
    complete_summary_df = pd.concat(complete_summary_data, ignore_index=True)
    
    # Optional: Sort the DataFrame for better readability
    complete_summary_df.sort_values(by=['sample', 'contig_name'], inplace=True)
    
    # Optional: Reset index after sorting
    complete_summary_df.reset_index(drop=True, inplace=True)
    
    # Save to TSV
    complete_summary_df.to_csv(complete_summary_tsv, sep='\t', index=False)
    logging.info(f"Successfully saved complete summary to {complete_summary_tsv}")
else:
    logging.warning("No data available to save for contig_stats_lca.tsv.")

# ===========================
# Display Previews
# ===========================

# Preview contig_stats_all.tsv
if contig_stats_all_data:
    logging.info("\nPreview of contig_stats_all.tsv:")
    print(contig_stats_all_df.head())
else:
    logging.warning("No contig stats data available to preview.")

# Preview contig_stats_lca.tsv
if complete_summary_data:
    logging.info("\nPreview of contig_stats_lca.tsv:")
    print(complete_summary_df.head())
else:
    logging.warning("No data available to preview for contig_stats_lca.tsv.")

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  combined.rename(columns={
Updating taxon_group: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 4469/4469 [00:00<00:00, 139303.09it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  combined['taxon_group'] = combined['tax_id'].progress_apply(get_tax_group_cached)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  combined.rename(columns={
Updating taxon

                       sample  \
0  AIAMAAGR001_LIV_S46_662105   
1  AIAMAAGR001_LIV_S46_662105   
2  AIAMAAGR001_LIV_S46_662105   
3  AIAMAAGR001_LIV_S46_662105   
4  AIAMAAGR001_LIV_S46_662105   

                                         contig_name  contig_length  \
0  AIAMAAGR001_LIV_S46_662105_NODE_1000_length_48...            480   
1  AIAMAAGR001_LIV_S46_662105_NODE_1001_length_48...            480   
2  AIAMAAGR001_LIV_S46_662105_NODE_1002_length_48...            480   
3  AIAMAAGR001_LIV_S46_662105_NODE_1003_length_48...            480   
4  AIAMAAGR001_LIV_S46_662105_NODE_1004_length_47...            479   

   read_count  
0          12  
1          12  
2           8  
3           4  
4          13  
                       sample  \
0  AIAMAAGR001_LIV_S46_662105   
1  AIAMAAGR001_LIV_S46_662105   
2  AIAMAAGR001_LIV_S46_662105   
3  AIAMAAGR001_LIV_S46_662105   
4  AIAMAAGR001_LIV_S46_662105   

                                         contig_name  contig_length  \
0  AIAMA

In [16]:
import json
import os

def remove_substring_from_keys(obj, substring):
    if isinstance(obj, dict):
        new_obj = {}
        for key, value in obj.items():
            new_key = key.replace(substring, "")
            new_obj[new_key] = remove_substring_from_keys(value, substring)
        return new_obj
    elif isinstance(obj, list):
        return [remove_substring_from_keys(item, substring) for item in obj]
    else:
        return obj

directory = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Combined_Contig_quality"

for filename in os.listdir(directory):
    if filename.endswith(".json"):
        filepath = os.path.join(directory, filename)
        with open(filepath, 'r') as f:
            data = json.load(f)
        
        updated_data = remove_substring_from_keys(data, "extracted_csv")
        
        with open(filepath, 'w') as f:
            json.dump(updated_data, f, indent=4)
        
        print(f"Processed {filename}")

Processed AIAMA_GOT005_OS_S43_673459_contigs_stats.json
Processed AIAMA_GOT006_RS_S44_673460_contigs_stats.json
Processed AIAMA_GOT007_OS_S45_673461_contigs_stats.json
Processed AIAMA_LIZ002_OS_S50_673462_contigs_stats.json
Processed AIAMA_PIG003_RS_S51_673463_contigs_stats.json
Processed AIAMA_SHP004_OS_S46_673464_contigs_stats.json
Processed AIAMA_SHP005_RS_S47_673465_contigs_stats.json
Processed AIAMA_SHP006_OS_S48_673466_contigs_stats.json
Processed AIAMAAGR001_LIV_S46_662105_contigs_stats.json
Processed AIAMACAT001_S69_657769_contigs_stats.json
Processed AIAMADOG001_S40_657732_contigs_stats.json
Processed AIAMADOG002_OS_S64_673467_contigs_stats.json
Processed AIAMAGOAT001_S48_656109_contigs_stats.json
Processed AIAMAGOAT002_S49_656110_contigs_stats.json
Processed AIAMAGOAT003_S50_657720_contigs_stats.json
Processed AIAMAGOAT004_S51_656112_contigs_stats.json
Processed AIAMAPIG001_S63_657713_contigs_stats.json
Processed AIAMAPIG002_S64_657714_contigs_stats.json
Processed AIAMAR001_L