In [12]:
import os
import pandas as pd
import numpy as np
import time
import logging
from ete3 import NCBITaxa
from Bio import Entrez

# ----------------------------
# Configuration and Setup
# ----------------------------
logging.basicConfig(
    filename='blast_lca_processing.log',
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

Entrez.email = "Ifeanyi.omah@ed.ac.uk"
Entrez.api_key = "d155c4478aa27128073f178361b921d2e407"
ncbi = NCBITaxa()

# Column definitions for NT and NR
NT_COLUMNS = [
    "qseqid", "sseqid", "pident", "alignment_length", "mismatches", "gap_opens",
    "qstart", "qend", "sstart", "send", "evalue", "bitscore",
    "taxid", "sci_name", "com_names", "subject_title"
]
NR_COLUMNS = [
    "qseqid", "sseqid", "identity", "alignment_length", "mismatches", "gap_opens",
    "qstart", "qend", "sstart", "send", "evalue", "bitscore",
    "taxid", "sci_name", "subject_title"
]

# Helper functions

def parse_blast_chunk(chunk: pd.DataFrame, columns: list, blast_type: str) -> pd.DataFrame:
    # Slice to expected number of columns and assign names
    if chunk.shape[1] < len(columns):
        logging.error(f"Column mismatch for {blast_type}: expected >= {len(columns)}, got {chunk.shape[1]}")
        return pd.DataFrame()
    chunk = chunk.iloc[:, :len(columns)].copy()
    chunk.columns = columns
    chunk['blast_type'] = blast_type
    return chunk


def get_taxid(acc: str, db: str) -> int:
    try:
        handle = Entrez.esummary(id=acc, db=db)
        records = Entrez.read(handle)
        handle.close()
        time.sleep(0.1)
        return int(records[0]['TaxId'])
    except Exception as e:
        logging.warning(f"Failed retrieving taxid for {acc}: {e}")
        return None


def get_lca(taxids: list) -> int or str:
    unique = set([t for t in taxids if t is not None])
    if not unique:
        return 'Unknown'
    if len(unique) == 1:
        return unique.pop()
    try:
        return ncbi.get_common_ancestor(list(unique))
    except Exception as e:
        logging.warning(f"LCA failure for {unique}: {e}")
        return 'Unknown'


def select_taxids_for_lca(df: pd.DataFrame) -> list:
    df = df.copy()
    df['aligned_bases'] = df['alignment_length'] * df.get('identity', df.get('pident')) / 100
    best = df.loc[df['bitscore'].idxmax()]
    mismatches = best.get('mismatches', 0)
    threshold = best['aligned_bases'] - mismatches
    hits = df[df['aligned_bases'] >= threshold]
    return hits['taxid'].dropna().unique().tolist()


def calculate_lca(df: pd.DataFrame) -> pd.DataFrame:
    results = []
    for qseqid, group in df.groupby('qseqid'):
        # Ensure numeric conversion for relevant fields
        for col in ['alignment_length', 'identity', 'mismatches', 'bitscore', 'pident']:
            if col in group:
                group[col] = pd.to_numeric(group[col], errors='coerce')
        group = group.dropna(subset=['bitscore'])
        taxids = select_taxids_for_lca(group)
        lca = get_lca(taxids)
        best = group.loc[group['bitscore'].idxmax()]
        # Build output record including extra fields
        results.append({
            'qseqid': qseqid,
            'lca_taxid': lca,
            'best_hit_sciname': best.get('sci_name', best.get('sseqid')),
            'alignment_length': best.get('alignment_length'),
            'identity': best.get('identity', np.nan),
            'pident': best.get('pident', np.nan),
            'mismatches': best.get('mismatches'),
            'bitscore': best['bitscore'],
            'evalue': best.get('evalue'),
            'com_names': best.get('com_names', ''),
            'aligned_bases': best.get('aligned_bases'),
            'blast_type': best['blast_type']
        })
    return pd.DataFrame(results)


def process_files(directory: str, blast_type: str, columns: list, db_name: str):
    files = [f for f in os.listdir(directory) if f.endswith('.m9')]
    print(f"Found {len(files)} {blast_type.upper()} files in {directory}")

    for idx, file in enumerate(files, start=1):
        in_path = os.path.join(directory, file)
        out_path = os.path.join(directory, f"LCA_{file}")
        print(f"[{idx}/{len(files)}] Processing {blast_type.upper()}: {file}")
        logging.info(f"Start {blast_type} processing: {file}")

        try:
            reader = pd.read_csv(
                in_path,
                sep='\t',
                header=None,
                comment='#',
                chunksize=500000,
                engine='python',
                on_bad_lines='warn'
            )
        except pd.errors.EmptyDataError:
            print(f"  - Skipped empty file: {file}")
            logging.warning(f"Empty file skipped: {file}")
            continue

        parsed_chunks = []
        for cidx, chunk in enumerate(reader, start=1):
            print(f"    Chunk {cidx}...")
            parsed = parse_blast_chunk(chunk, columns, blast_type)
            if parsed.empty:
                continue
            if blast_type == 'nr':
                parsed['taxid'] = parsed.apply(
                    lambda r: r['taxid'] if pd.notna(r['taxid']) else get_taxid(r['sseqid'], db_name),
                    axis=1
                )
            parsed_chunks.append(parsed)

        if not parsed_chunks:
            print(f"  - No valid data in {file}, skipping.")
            logging.warning(f"No valid data in {file}")
            continue

        df_all = pd.concat(parsed_chunks, ignore_index=True)
        print(f"  - Calculating LCA on {len(df_all)} rows...")
        lca_df = calculate_lca(df_all)
        lca_df.to_csv(out_path, sep='\t', index=False)
        print(f"  - Saved LCA to {out_path}\n")
        logging.info(f"Saved LCA to {out_path}")


if __name__ == '__main__':
    base = '/Volumes/aine_store/Blast_nr'
    process_files(
        os.path.join(base, 'virus_NT'),
        blast_type='nt',
        columns=NT_COLUMNS,
        db_name='nuccore'
    )
    process_files(
        os.path.join(base, 'Virus_Blast_nr'),
        blast_type='nr',
        columns=NR_COLUMNS,
        db_name='protein'
    )

Found 78 NT files in /Volumes/aine_store/Blast_nr/virus_NT
[1/78] Processing NT: 1_01_23_0574_S41_virus_contigs_blast_nt.m9
    Chunk 1...
  - Calculating LCA on 486302 rows...
  - Saved LCA to /Volumes/aine_store/Blast_nr/virus_NT/LCA_1_01_23_0574_S41_virus_contigs_blast_nt.m9

[2/78] Processing NT: 1_01_23_0590_S42_virus_contigs_blast_nt.m9
    Chunk 1...
  - Calculating LCA on 339589 rows...
  - Saved LCA to /Volumes/aine_store/Blast_nr/virus_NT/LCA_1_01_23_0590_S42_virus_contigs_blast_nt.m9

[3/78] Processing NT: 1_01_24_0050_S45_virus_contigs_blast_nt.m9
    Chunk 1...
    Chunk 2...
    Chunk 3...
    Chunk 4...
  - Calculating LCA on 1526342 rows...
  - Saved LCA to /Volumes/aine_store/Blast_nr/virus_NT/LCA_1_01_24_0050_S45_virus_contigs_blast_nt.m9

[4/78] Processing NT: 1_01_24_0188_S38_virus_contigs_blast_nt.m9
    Chunk 1...
  - Calculating LCA on 327560 rows...
  - Saved LCA to /Volumes/aine_store/Blast_nr/virus_NT/LCA_1_01_24_0188_S38_virus_contigs_blast_nt.m9

[5/78] Proc

In [14]:
import os
import re
import logging
import pandas as pd
from glob import glob
from tqdm import tqdm

# ====== Config ======
contig_dir = "/Volumes/aine_store/Blast_nr/contigs_reads"       # TSVs like 1_01_23_0574_S41_contig_idxstats_with_header.tsv
nt_dir     = "/Volumes/aine_store/Blast_nr/LCA_virusblas_nt"    # NT LCA TSVs
nr_dir     = "/Volumes/aine_store/Blast_nr/LCA_virublast_nr"    # NR LCA TSVs

out_contig_stats_all = os.path.join(contig_dir, "contig_stats_all.tsv")
out_complete_summary = os.path.join(contig_dir, "contig_stats_lca.tsv")

# ====== Logging ======
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s: %(message)s'
)

# ====== Helpers ======

def list_contig_tsvs(contig_dir):
    return sorted(glob(os.path.join(contig_dir, "*_contig_idxstats_with_header.tsv")))

# Normalize any qseqid into canonical "<sample>~<contig>"
_QSEQ_PATTERNS = [
    # 1) LCA_<sample>_virus_contigs_blast_(nt|nr).m9~k141_XXXX
    re.compile(r"^LCA_(?P<sample>.+?)_virus_contigs_blast_(?:nt|nr)\.m9~(?P<contig>k141_\d+)$", re.I),
    # 2) LCA_<sample>_virus_hits.m9~k141_XXXX
    re.compile(r"^LCA_(?P<sample>.+?)_virus_hits\.m9~(?P<contig>k141_\d+)$", re.I),
    # 3) LCA_<sample>~k141_XXXX
    re.compile(r"^LCA_(?P<sample>.+?)~(?P<contig>k141_\d+)$", re.I),
    # 4) <sample>~k141_XXXX  (already canonical)
    re.compile(r"^(?P<sample>.+?)~(?P<contig>k141_\d+)$", re.I),
]

def normalize_qseqid(qseqid: str) -> str:
    if pd.isna(qseqid):
        return qseqid
    s = str(qseqid)
    for pat in _QSEQ_PATTERNS:
        m = pat.match(s)
        if m:
            return f"{m.group('sample')}~{m.group('contig')}"
    # If we cannot parse it, return unchanged but log once
    logging.warning(f"Could not normalize qseqid: {s}")
    return s

def sample_from_contig_filename(fn: str) -> str:
    """
    1_01_23_0574_S41_contig_idxstats_with_header.tsv -> 1_01_23_0574_S41
    """
    base = os.path.basename(fn)
    return re.sub(r"_contig_idxstats_with_header\.tsv$", "", base, flags=re.I)

def load_contig_stats_tsv(path: str) -> pd.DataFrame:
    # Expect columns: contig_id, contig_len_bp, mapped_reads, unmapped_mate_reads
    df = pd.read_csv(path, sep="\t", dtype=str)
    # Force expected columns and types
    for col in ["contig_id", "contig_len_bp", "mapped_reads", "unmapped_mate_reads"]:
        if col not in df.columns:
            raise ValueError(f"{path} missing expected column: {col}")
    df["contig_len_bp"] = pd.to_numeric(df["contig_len_bp"], errors="coerce")
    df["mapped_reads"] = pd.to_numeric(df["mapped_reads"], errors="coerce").fillna(0).astype(int)
    df["unmapped_mate_reads"] = pd.to_numeric(df["unmapped_mate_reads"], errors="coerce").fillna(0).astype(int)

    # contig_name is the canonical id; already looks like "<sample>~k141_XXXX"
    df = df.rename(columns={"contig_id": "contig_name"})
    # read_count policy: mapped_reads (change to +unmapped if desired)
    df["read_count"] = df["mapped_reads"]

    # Add sample (left side of contig_name)
    df["sample"] = df["contig_name"].str.extract(r"^(.*?)~k141_\d+$", expand=False)
    df["contig_length"] = df["contig_len_bp"]

    return df[["sample", "contig_name", "contig_length", "read_count"]]

def load_lca_tsv(path: str) -> pd.DataFrame:
    # Your LCA TSV header examples:
    # qseqid lca_taxid best_hit_sciname alignment_length identity pident mismatches bitscore evalue com_names aligned_bases blast_type
    df = pd.read_csv(path, sep="\t", dtype=str, keep_default_na=False, na_values=[])
    if "qseqid" not in df.columns:
        raise ValueError(f"{path} missing qseqid")
    # Normalize qseqid to canonical key
    df["qseqid"] = df["qseqid"].apply(normalize_qseqid)
    # Harmonize some column names if present
    if "aligned_bases" in df.columns and "alignment_length" not in df.columns:
        df = df.rename(columns={"aligned_bases": "alignment_length"})
    # Best-effort numeric coercions
    for col in ["bitscore", "identity", "pident", "alignment_length", "mismatches"]:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce")
    return df

# Optional: NCBITaxa grouping (kept as in your original)
try:
    from ete3 import NCBITaxa
    logging.info("Initializing NCBITaxa...")
    ncbi = NCBITaxa()
    _name2id = ncbi.get_name_translator(["Viruses", "Bacteria", "Archaea", "Metazoa", "Eukaryota"])
    taxon_groups = ["Viruses", "Bacteria", "Archaea", "Metazoa", "Eukaryota"]
    taxon_groups_id = [(_name2id[g][0] if g in _name2id else None) for g in taxon_groups]
except Exception as e:
    logging.warning(f"ETE3/NCBITaxa not available: {e}")
    ncbi = None
    taxon_groups = []
    taxon_groups_id = []

_taxid_cache = {}

def _tax_group_from_taxid(taxid: str) -> str:
    if not taxid or taxid in ("Unknown", "NA"):
        return "Ambiguous"
    if taxid in _taxid_cache:
        return _taxid_cache[taxid]
    if ncbi is None:
        _taxid_cache[taxid] = "Ambiguous"
        return "Ambiguous"
    try:
        lineage = ncbi.get_lineage(int(taxid))
        for name, tid in zip(taxon_groups, taxon_groups_id):
            if tid and tid in lineage:
                _taxid_cache[taxid] = name
                return name
        _taxid_cache[taxid] = "Ambiguous"
        return "Ambiguous"
    except Exception:
        _taxid_cache[taxid] = "Ambiguous"
        return "Ambiguous"

# ====== Main ======

contig_files = list_contig_tsvs(contig_dir)
if not contig_files:
    raise SystemExit(f"No contig TSVs found in {contig_dir}")

# Build maps from sample -> NT/NR LCA file (best guess: match by sample token in filename)
def sample_key_from_lca_filename(fn: str) -> str:
    """
    Examples:
      LCA_1_01_23_0574_S41_virus_contigs_blast_nt.m9
      LCA_1_01_23_0574_S41_virus_hits.m9
    -> 1_01_23_0574_S41
    """
    base = os.path.basename(fn)
    s = re.sub(r"^LCA_", "", base, flags=re.I)
    s = re.sub(r"_virus_contigs_blast_(?:nt|nr)\.m9$", "", s, flags=re.I)
    s = re.sub(r"_virus_hits\.m9$", "", s, flags=re.I)
    return s

nt_files = sorted(glob(os.path.join(nt_dir, "LCA_*")))
nr_files = sorted(glob(os.path.join(nr_dir, "LCA_*")))

nt_by_sample = {sample_key_from_lca_filename(x): x for x in nt_files}
nr_by_sample = {sample_key_from_lca_filename(x): x for x in nr_files}

contig_stats_all = []
complete_summary = []

for contig_path in contig_files:
    sample = sample_from_contig_filename(contig_path)
    logging.info(f"Processing sample: {sample}")

    # contig stats
    contig_stats = load_contig_stats_tsv(contig_path)
    contig_stats_all.append(contig_stats)

    # find matching NT/NR
    nt_path = nt_by_sample.get(sample)
    nr_path = nr_by_sample.get(sample)

    # Load LCA (if present)
    nt_df = load_lca_tsv(nt_path) if nt_path and os.path.exists(nt_path) else None
    nr_df = load_lca_tsv(nr_path) if nr_path and os.path.exists(nr_path) else None

    # Merge helper
    def merge_lca(side_df: pd.DataFrame, label: str):
        if side_df is None:
            return None
        # Ensure expected keys exist
        df = side_df.copy()
        # Canonical join key
        key_left = "contig_name"
        key_right = "qseqid"
        merged = pd.merge(
            contig_stats, df,
            left_on=key_left, right_on=key_right,
            how="inner", suffixes=("_stats", f"_{label}")
        )
        merged["sample"] = sample
        merged["nt"] = (label == "nt")
        merged["nr"] = (label == "nr")
        merged["nt_or_nr"] = label
        return merged

    nt_m = merge_lca(nt_df, "nt") if nt_df is not None else None
    nr_m = merge_lca(nr_df, "nr") if nr_df is not None else None

    if nt_m is None and nr_m is None:
        logging.warning(f"No LCA files found for {sample}")
        continue

    combined = pd.concat([x for x in [nt_m, nr_m] if x is not None], ignore_index=True)

    # Columns harmonization
    rename_map = {
        "lca_taxid": "taxid",
        "alignment_length": "align_length",
        "best_hit_sciname": "best_hit_sciname",
    }
    for k, v in rename_map.items():
        if k in combined.columns:
            combined.rename(columns={k: v}, inplace=True)

    # taxon group via taxid
    if "taxid" in combined.columns:
        tqdm.pandas(desc=f"Tax grouping {sample}")
        combined["taxon_group"] = combined["taxid"].progress_apply(_tax_group_from_taxid)
    else:
        combined["taxon_group"] = "Ambiguous"

    # Keep a clean set of columns if present
    desired = [
        "sample",
        "contig_name",
        "contig_length",
        "read_count",
        "nt",
        "nr",
        "nt_or_nr",
        "taxid",
        "best_hit_sciname",
        "bitscore",
        "align_length",
        "identity",
        "pident",
        "mismatches",
        "evalue",
        "taxon_group",
        "blast_type",
    ]
    cols = [c for c in desired if c in combined.columns]
    combined = combined[cols].copy()

    # Flag if contig appears in both NT and NR
    combined["common_nt_nr"] = combined["contig_name"].isin(
        combined["contig_name"][combined["nt_or_nr"] == "nt"]
    ) & combined["contig_name"].isin(
        combined["contig_name"][combined["nt_or_nr"] == "nr"]
    )

    complete_summary.append(combined)

# ====== Write outputs ======
contig_stats_all_df = pd.concat(contig_stats_all, ignore_index=True).sort_values(["sample", "contig_name"])
contig_stats_all_df.to_csv(out_contig_stats_all, sep="\t", index=False)
logging.info(f"Wrote contig stats: {out_contig_stats_all}")

if complete_summary:
    complete_df = pd.concat(complete_summary, ignore_index=True).sort_values(["sample", "contig_name", "nt_or_nr"])
    complete_df.to_csv(out_complete_summary, sep="\t", index=False)
    logging.info(f"Wrote LCA summary: {out_complete_summary}")
else:
    logging.warning("No LCA merges produced output.")

Tax grouping 1_01_23_0574_S41: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 533/533 [00:00<00:00, 17959.37it/s]
Tax grouping 1_01_23_0590_S42: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 596/596 [00:00<00:00, 24843.77it/s]
Tax grouping 1_01_24_0050_S45: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [None]:
modify this code to handle this data 


import pandas as pd
import os
import json
import re
from ete3 import NCBITaxa
from tqdm import tqdm  # For progress bars
import logging

# ===========================
# Configuration
# ===========================

# Directories
json_dir = "/Volumes/aine_store/Blast_nr/contigs_reads/"
lca_dir = "/Volumes/aine_store/SENZOR_project/SPlited_SENSOR_porject/IDseq/Test_blast_SENZOR/LCA_results/"

# File Lists
json_files = [
    the file here is now tsv files 

contig_id	contig_len_bp	mapped_reads	unmapped_mate_reads
1_01_23_0574_S41~k141_19268	308	4	2
1_01_23_0574_S41~k141_0	309	20	2
1_01_23_0574_S41~k141_14451	324	4	0
1_01_23_0574_S41~k141_28900	285	4	0
1_01_23_0574_S41~k141_24084	282	6	0
1_01_23_0574_S41~k141_9634	366	5	1
1_01_23_0574_S41~k141_33716	295	2	0
1_01_23_0574_S41~k141_4817	494	18	0

]

1_01_23_0574_S41_contig_idxstats_with_header.tsv
1_01_23_0590_S42_contig_idxstats_with_header.tsv
1_01_24_0050_S45_contig_idxstats_with_header.tsv
1_01_24_0188_S38_contig_idxstats_with_header.tsv
1_01_24_0191_S43_contig_idxstats_with_header.tsv
1_01_24_0192_S44_contig_idxstats_with_header.tsv
1_01_24_0233_S47_contig_idxstats_with_header.tsv
1_01_24_0235_S48_contig_idxstats_with_header.tsv
1_01_24_0236_S49_contig_idxstats_with_header.tsv

Volumes/aine_store/Blast_nr/LCA_virublast_nr
LCA_1_01_23_0574_S41_virus_hits.m9
LCA_1_01_23_0590_S42_virus_hits.m9
LCA_1_01_24_0050_S45_virus_hits.m9
LCA_1_01_24_0188_S38_virus_hits.m9
LCA_1_01_24_0191_S43_virus_hits.m9
LCA_1_01_24_0192_S44_virus_hits.m9

snippet of the nr blast files 

qseqid	lca_taxid	best_hit_sciname	alignment_length	identity	pident	mismatches	bitscore	evalue	com_names	aligned_bases	blast_type
LCA_1_01_23_0574_S41_virus_hits.m9~k141_10155	2832643	Caudoviricetes sp.	51	96.1		2	106.0	9.3e-27			nr
LCA_1_01_23_0574_S41_virus_hits.m9~k141_10216	38018	Bacteriophage sp.	51	62.7		19	57.4	9.53e-08			nr
LCA_1_01_23_0574_S41_virus_hits.m9~k141_10476	1907787	Picobirnavirus sp.	45	57.8		17	51.6	6.85e-06			nr
LCA_1_01_23_0574_S41_virus_hits.m9~k141_10554	2832643	Caudoviricetes sp.	230	63.0		85	320.0	9.88e-102			nr
LCA_1_01_23_0574_S41_virus_hits.m9~k141_10564	Unknown	Aspergillus nidulans partitivirus 1	116	56.0		51	137.0	1.13e-34			nr
LCA_1_01_23_0574_S41_virus_hits.m9~k141_10614	2832643	Caudoviricetes sp.	197	83.2		28	328.0	1.64e-111			nr
LCA_1_01_23_0574_S41_virus_hits.m9~k141_1071	Unknown	Corynebacterium phage HS01	110	62.7		38	127.0	7.73e-30			nr
LCA_1_01_23_0574_S41_virus_hits.m9~k141_10896	2832643	Caudoviricetes sp.	87	50.6		43	92.8	1.53e-19			nr
LCA_1_01_23_0574_S41_virus_hits.m9~k141_1098	2831613	Herelleviridae sp.	55	47.3		29	54.7	6.03e-06			nr

Volumes/aine_store/Blast_nr/LCA_virusblas_nt
LCA_1_01_23_0574_S41_virus_contigs_blast_nt.m9
LCA_1_01_23_0590_S42_virus_contigs_blast_nt.m9
LCA_1_01_24_0050_S45_virus_contigs_blast_nt.m9
LCA_1_01_24_0188_S38_virus_contigs_blast_nt.m9

snippet of the blast nt results 
qseqid	lca_taxid	best_hit_sciname	alignment_length	identity	pident	mismatches	bitscore	evalue	com_names	aligned_bases	blast_type
LCA_1_01_23_0574_S41~k141_10155	Unknown	Escherichia coli	153		100.0	0	277.0	2.42e-69	Escherichia coli		nt
LCA_1_01_23_0574_S41~k141_10216	2754726	Corynebacterium haemomassiliense	237		100.0	0	428.0	8.83e-115	Corynebacterium haemomassiliense		nt
LCA_1_01_23_0574_S41~k141_10476	Unknown	Salmonella enterica	224		100.0	0	405.0	8.19e-108	Salmonella enterica		nt
LCA_1_01_23_0574_S41~k141_10554	Unknown	Burkholderia gladioli	565		68.673	173	207.0	9.459999999999999e-48	Burkholderia gladioli		nt
LCA_1_01_23_0574_S41~k141_10564	2806428	Aspergillus nidulans partitivirus 1	197		75.127	37	128.0	4.0799999999999995e-24	Aspergillus nidulans partitivirus 1		nt
LCA_1_01_23_0574_S41~k141_10614	Unknown	Caudoviricetes sp.	595		81.176	97	564.0	2.6300000000000002e-155	Caudoviricetes sp.		nt
LCA_1_01_23_0574_S41~k141_1071	Unknown	Corynebacterium lizhenjunii	270		74.444	69	177.0	1.1199999999999999e-38	Corynebacterium lizhenjunii		nt
LCA_1_01_23_0574_S41~k141_11053	2832643	Caudoviricetes sp.	265		69.434	71	91.5	4.53e-13	Caudoviricetes sp.		nt
LCA_1_01_23_0574_S41~k141_11085	Unknown	Corynebacterium lujinxingii	386		81.347	72	372.0	7.03e-98	Corynebacterium lujinxingii		nt
LCA_1_01_23_0574_S41~k141_11271	Unknown	Bacillus anthracis	532		99.812	0	952.0	0.0	Bacillus anthracis		nt

# Output TSV file paths
contig_stats_all_tsv = os.path.join(json_dir, "contig_stats_all.tsv")
complete_summary_tsv = os.path.join(lca_dir, "contig_stats_lca.tsv")

# ===========================
# Logging Configuration
# ===========================

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s:%(message)s',
    handlers=[
        logging.StreamHandler()
    ]
)

# ===========================
# Function Definitions
# ===========================

def load_json_data(json_path):
    """
    Load JSON data from the given path and process contig names.
    
    Args:
        json_path (str): Path to the JSON file.
    
    Returns:
        dict: Dictionary with contig_name as keys and read_count as values.
    """
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    # Replace "~" with "_" in contig names and exclude any key named "*"
    processed_data = {key.replace("~", "_"): value for key, value in data.items() if key != "*"}
    return processed_data

def load_lca_data(lca_path):
    """
    Load LCA data from the given path with specified data types.
    
    Args:
        lca_path (str): Path to the LCA file.
    
    Returns:
        pd.DataFrame: DataFrame containing LCA data.
    """
    dtype_dict = {
        "qseqid": str,
        "lca_taxid": str,
        "best_hit_sciname": str,
        "aligned_bases": float,
        "contig_length": int,
        "superkingdom": str,
        "bitscore": float,
        "tax_id": str,
        'identity': float,
        "subject_title": str
    }
    return pd.read_csv(
        lca_path, 
        sep="\t", 
        header=0,  # Use the first row as header
        dtype=dtype_dict,
        na_values=['Unknown'],  # Treat 'Unknown' as NaN if desired
        keep_default_na=False  # Prevent 'Unknown' from being treated as NaN
    )

def extract_contig_length(contig_name):
    """
    Extract contig length from contig name using regex.
    Assumes contig_name contains 'length_<number>'.
    
    Args:
        contig_name (str): The contig name string.
    
    Returns:
        int or float: Extracted contig length or NaN if not found.
    """
    match = re.search(r"length_(\d+)", contig_name)
    if match:
        return int(match.group(1))
    else:
        return pd.NA  # Use pandas NA for missing values

def get_tax_group(taxid, ncbi, taxon_groups, taxon_groups_id):
    """
    Determine the taxon group for a given taxid.
    
    Args:
        taxid (int or str): NCBI taxid.
        ncbi (NCBITaxa): Initialized NCBITaxa object.
        taxon_groups (list): List of taxon group names.
        taxon_groups_id (list): List of taxon group taxids.
    
    Returns:
        str: Taxon group name or "Ambiguous" if not found.
    """
    try:
        taxid = int(taxid)
        lineage = ncbi.get_lineage(taxid)
        for i, tax in enumerate(taxon_groups_id):
            if tax in lineage:
                return taxon_groups[i]
        return "Ambiguous"
    except Exception as e:
        logging.error(f"Error processing taxid {taxid}: {e}")
        return "Ambiguous"

# ===========================
# Initialize NCBITaxa
# ===========================

logging.info("Initializing NCBITaxa...")
ncbi = NCBITaxa()

# Define taxon groups and retrieve their taxids
taxon_groups = ["Viruses", "Bacteria", "Archaea", "Metazoa", "Eukaryota"]
taxon_groups_id = []

for group in taxon_groups:
    name_translator = ncbi.get_name_translator([group])
    if group in name_translator:
        taxid = name_translator[group][0]
        taxon_groups_id.append(taxid)
    else:
        logging.warning(f"Taxon group '{group}' not found in NCBITaxa database.")
        taxon_groups_id.append(None)  # Append None if taxon group not found

# ===========================
# Initialize Data Containers
# ===========================

# List to hold DataFrames for contig_stats_all.tsv
contig_stats_all_data = []

# List to hold DataFrames for Complete_final_summary_nt_nr.tsv
complete_summary_data = []

# Initialize a cache for taxid to taxon_group mapping to optimize performance
taxid_cache = {}

def get_tax_group_cached(taxid):
    """
    Determine the taxon group for a given taxid with caching.
    
    Args:
        taxid (int or str): NCBI taxid.
    
    Returns:
        str: Taxon group name or "Ambiguous" if not found.
    """
    if taxid in taxid_cache:
        return taxid_cache[taxid]
    else:
        group = get_tax_group(taxid, ncbi, taxon_groups, taxon_groups_id)
        taxid_cache[taxid] = group
        return group

# ===========================
# Processing JSON and LCA Files
# ===========================

for json_file, nt_lca_file, nr_lca_file in zip(json_files, nt_lca_files, nr_lca_files):
    sample_name = json_file.replace("_contig_stats.json", "")
    logging.info(f"\nProcessing sample: {sample_name}")
    
    # ---------------------------
    # Process JSON File
    # ---------------------------
    
    # Full path to the JSON file
    json_path = os.path.join(json_dir, json_file)
    
    # Check if JSON file exists
    if not os.path.exists(json_path):
        logging.error(f"JSON file not found: {json_path}")
        continue
    
    # Load and process JSON data
    json_data = load_json_data(json_path)
    num_contigs = len(json_data)
    logging.info(f"Loaded {num_contigs} contigs from {json_file}")
    
    # Convert to DataFrame for contig_stats_all.tsv
    contig_stats = pd.DataFrame.from_dict(json_data, orient='index', columns=['read_count']).reset_index()
    contig_stats.rename(columns={'index': 'contig_name'}, inplace=True)
    
    # Extract contig_length
    contig_stats['contig_length'] = contig_stats['contig_name'].apply(extract_contig_length)
    
    # Handle contig_length extraction failures
    missing_lengths = contig_stats['contig_length'].isna().sum()
    if missing_lengths > 0:
        logging.warning(f"{missing_lengths} contigs in {json_file} did not have a 'length_' pattern.")
    
    # Add sample name
    contig_stats['sample'] = sample_name
    
    # Reorder columns for contig_stats_all.tsv
    contig_stats = contig_stats[['sample', 'contig_name', 'contig_length', 'read_count']]
    
    # Append to contig_stats_all_data list
    contig_stats_all_data.append(contig_stats)
    
    # ---------------------------
    # Process LCA Files
    # ---------------------------
    
    # Full paths to LCA files
    nt_lca_path = os.path.join(lca_dir, nt_lca_file)
    nr_lca_path = os.path.join(lca_dir, nr_lca_file)
    
    # Check if LCA files exist
    if not os.path.exists(nt_lca_path):
        logging.error(f"NT LCA file not found: {nt_lca_path}")
        continue
    if not os.path.exists(nr_lca_path):
        logging.error(f"NR LCA file not found: {nr_lca_path}")
        continue
    
    # Load LCA data
    nt_lca = load_lca_data(nt_lca_path)
    nr_lca = load_lca_data(nr_lca_path)
    logging.info(f"Loaded NT and NR LCA data for {sample_name}")
    
    # Standardize contig names in LCA files by replacing "~" with "_"
    nt_lca["qseqid"] = nt_lca["qseqid"].str.replace("~", "_")
    nr_lca["qseqid"] = nr_lca["qseqid"].str.replace("~", "_")
    
    # Check for matching contigs
    contig_names = set(contig_stats["contig_name"])
    nt_queries = set(nt_lca["qseqid"])
    nr_queries = set(nr_lca["qseqid"])
    common_nt = contig_names.intersection(nt_queries)
    common_nr = contig_names.intersection(nr_queries)
    logging.info(f"Common contigs in NT LCA for {sample_name}: {len(common_nt)}")
    logging.info(f"Common contigs in NR LCA for {sample_name}: {len(common_nr)}")
    
    # Identify contigs common to both NT and NR
    common_contigs = common_nt.intersection(common_nr)
    logging.info(f"Contigs common to both NT and NR for {sample_name}: {len(common_contigs)}")
    
    # Merge NT LCA data with contig stats
    nt_merged = pd.merge(
        contig_stats, 
        nt_lca, 
        left_on="contig_name", 
        right_on="qseqid", 
        how="inner",
        suffixes=('_stats', '_nt')
    )
    nt_merged["nt"] = True
    nt_merged["nr"] = False
    nt_merged["nt_or_nr"] = "nt"
    logging.info(f"NT merged rows for {sample_name}: {len(nt_merged)}")
    
    # Merge NR LCA data with contig stats
    nr_merged = pd.merge(
        contig_stats, 
        nr_lca, 
        left_on="contig_name", 
        right_on="qseqid", 
        how="inner",
        suffixes=('_stats', '_nr')
    )
    nr_merged["nt"] = False
    nr_merged["nr"] = True
    nr_merged["nt_or_nr"] = "nr"
    logging.info(f"NR merged rows for {sample_name}: {len(nr_merged)}")
    
    # Combine NT and NR merged DataFrames
    combined = pd.concat([nt_merged, nr_merged], ignore_index=True)
    
    # Add sample name (already present in contig_stats, but ensure it's included)
    combined["sample"] = sample_name
    
    # Add 'common_nt_nr' column
    # This column will be True if the contig is present in both NT and NR, else False
    combined["common_nt_nr"] = combined["contig_name"].isin(common_contigs)
    logging.info(f"Added 'common_nt_nr' column for {sample_name}")
    
    # Inspect the columns to handle duplicates
    logging.info(f"Columns before handling duplicates: {combined.columns.tolist()}")
    
    # Handle duplicate 'contig_length' columns
    # After merging, there might be 'contig_length_stats', 'contig_length_nt', 'contig_length_nr'
    # We'll retain one 'contig_length' column and drop the others
    if 'contig_length_stats' in combined.columns:
        combined.rename(columns={'contig_length_stats': 'contig_length'}, inplace=True)
        # Drop the other contig_length columns
        combined.drop(['contig_length_nt', 'contig_length_nr'], axis=1, inplace=True, errors='ignore')
    elif 'contig_length_nt' in combined.columns:
        combined.rename(columns={'contig_length_nt': 'contig_length'}, inplace=True)
        combined.drop(['contig_length_stats', 'contig_length_nr'], axis=1, inplace=True, errors='ignore')
    elif 'contig_length_nr' in combined.columns:
        combined.rename(columns={'contig_length_nr': 'contig_length'}, inplace=True)
        combined.drop(['contig_length_stats', 'contig_length_nt'], axis=1, inplace=True, errors='ignore')
    
    # Define selected_columns with the updated 'contig_length'
    selected_columns = [
        "sample", 
        "contig_name", 
        "contig_length",  # Updated column
        "read_count", 
        "nt", 
        "nr", 
        "nt_or_nr",
        "lca_taxid", 
        "bitscore", 
        "aligned_bases", 
        "superkingdom", 
        "tax_id",
        'identity',
        "subject_title",
        "common_nt_nr"  # Include the new column
    ]
    
    # Select only columns that are present
    available_columns = combined.columns.tolist()
    selected_columns = [col for col in selected_columns if col in available_columns]
    
    # Now, select the relevant columns
    combined = combined[selected_columns]
    
    # Rename columns for clarity
    combined.rename(columns={
        "lca_taxid": "taxid",
        "aligned_bases": "align_length",
        "superkingdom": "taxon_group"
    }, inplace=True)
    
    # ===========================
    # Update taxon_group Using taxid
    # ===========================
    
    logging.info("\nUpdating 'taxon_group' based on 'taxid' using ETE3's NCBITaxa...")
    
    # Apply the cached get_tax_group function with progress bar
    tqdm.pandas(desc="Updating taxon_group")
    combined['taxon_group'] = combined['taxid'].progress_apply(get_tax_group_cached)
    
    logging.info("Updated 'taxon_group' successfully.")
    
    # ---------------------------
    # Reorder columns after updating taxon_group
    # ---------------------------
    
    # Reorder columns if desired
    # Ensure that 'common_nt_nr' is at the end
    desired_order = [
        "sample", 
        "contig_name", 
        "contig_length", 
        "read_count", 
        "nt", 
        "nr", 
        "nt_or_nr",
        "taxid", 
        "bitscore", 
        "align_length", 
        "taxon_group", 
        "tax_id",
        'identity',
        "subject_title",
        "common_nt_nr"  # Ensure the new column is included
    ]
    
    # Adjust desired_order based on available columns
    final_columns = [col for col in desired_order if col in combined.columns]
    combined = combined[final_columns]
    
    logging.info(f"Final selected columns after updating 'taxon_group': {combined.columns.tolist()}")
    
    logging.info(f"Total merged rows for {sample_name}: {len(combined)}")
    
    # Append to complete_summary_data list
    complete_summary_data.append(combined)

# ===========================
# Combine All Data and Save Outputs
# ===========================

# ---------------------------
# Save contig_stats_all.tsv
# ---------------------------

# Concatenate all contig_stats DataFrames
contig_stats_all_df = pd.concat(contig_stats_all_data, ignore_index=True)

# Optional: Sort the DataFrame for better readability
contig_stats_all_df.sort_values(by=['sample', 'contig_name'], inplace=True)

# Optional: Reset index after sorting
contig_stats_all_df.reset_index(drop=True, inplace=True)

# Save to TSV
contig_stats_all_df.to_csv(contig_stats_all_tsv, sep='\t', index=False)
logging.info(f"\nSuccessfully saved combined contig stats to {contig_stats_all_tsv}")

# ---------------------------
# Save Complete_final_summary_nt_nr.tsv
# ---------------------------

# Concatenate all complete_summary DataFrames
if complete_summary_data:
    complete_summary_df = pd.concat(complete_summary_data, ignore_index=True)
    
    # Optional: Sort the DataFrame for better readability
    complete_summary_df.sort_values(by=['sample', 'contig_name'], inplace=True)
    
    # Optional: Reset index after sorting
    complete_summary_df.reset_index(drop=True, inplace=True)
    
    # Save to TSV
    complete_summary_df.to_csv(complete_summary_tsv, sep='\t', index=False)
    logging.info(f"Successfully saved complete summary to {complete_summary_tsv}")
else:
    logging.warning("No data available to save for Complete_final_summary_nt_nr.tsv.")

# ===========================
# Display Previews
# ===========================

# Preview contig_stats_all.tsv
logging.info("\nPreview of contig_stats_all.tsv:")
print(contig_stats_all_df.head())

# Preview Complete_final_summary_nt_nr.tsv
if complete_summary_data:
    logging.info("\nPreview of Complete_final_summary_nt_nr.tsv:")
    print(complete_summary_df.head())
else:
    logging.warning("No data available to preview for Complete_final_summary_nt_nr.tsv.")