In [1]:
import numpy as np, tensorflow as tf
print("NumPy:", np.__version__)
print("TensorFlow:", tf.__version__)


NumPy: 1.26.4
TensorFlow: 2.20.0


In [2]:
# Install all detected packages (CPU/default wheels) from PyPI
%pip install biopython gensim hdbscan matplotlib numpy pandas scikit-learn tensorflow torch ipykernel


Note: you may need to restart the kernel to use updated packages.


In [3]:
# Cell 1: imports, constants, reproducibility
import os
import sys
import json
import tarfile
import hashlib
import random
from pathlib import Path
from collections import Counter, defaultdict

# Basic ML / bio libs
import numpy as np
import pandas as pd

# PyTorch + sklearn
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# NLP / k-mer embeddings
import gensim

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Paths and chosen DB files (from NCBI ftp listing)
DOWNLOAD_DIR = Path("./ncbi_blast_db")
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)

# Filenames observed on the BLAST DB FTP listing (we will download these exact files).
DB_FILES = {
    "ssu": "SSU_eukaryote_rRNA.tar.gz",
    "lsu": "LSU_eukaryote_rRNA.tar.gz",
    "its": "ITS_eukaryote_sequences.tar.gz",
    "ssu_meta": "SSU_eukaryote_rRNA-nucl-metadata.json",
    "lsu_meta": "LSU_eukaryote_rRNA-nucl-metadata.json",
    "its_meta": "ITS_eukaryote_sequences-nucl-metadata.json",
}

print("Working dir:", DOWNLOAD_DIR.resolve())
print("DB files to fetch:", DB_FILES)


Working dir: C:\Users\Srijit\sih\ncbi_blast_db
DB files to fetch: {'ssu': 'SSU_eukaryote_rRNA.tar.gz', 'lsu': 'LSU_eukaryote_rRNA.tar.gz', 'its': 'ITS_eukaryote_sequences.tar.gz', 'ssu_meta': 'SSU_eukaryote_rRNA-nucl-metadata.json', 'lsu_meta': 'LSU_eukaryote_rRNA-nucl-metadata.json', 'its_meta': 'ITS_eukaryote_sequences-nucl-metadata.json'}


In [4]:
# Cell 2 (fixed): Download the selected tarballs and metadata JSONs using Python (works in a Python kernel)
import urllib.request
import shutil
from pathlib import Path
import time

# Use variables from Cell 1: DOWNLOAD_DIR (Path) and DB_FILES (dict)
# If you re-ran the kernel and haven't executed Cell 1, re-declare DOWNLOAD_DIR and DB_FILES accordingly.
try:
    DOWNLOAD_DIR
    DB_FILES
except NameError:
    DOWNLOAD_DIR = Path("./ncbi_blast_db")
    DB_FILES = {
        "ssu": "SSU_eukaryote_rRNA.tar.gz",
        "lsu": "LSU_eukaryote_rRNA.tar.gz",
        "its": "ITS_eukaryote_sequences.tar.gz",
        "ssu_meta": "SSU_eukaryote_rRNA-nucl-metadata.json",
        "lsu_meta": "LSU_eukaryote_rRNA-nucl-metadata.json",
        "its_meta": "ITS_eukaryote_sequences-nucl-metadata.json",
    }
DOWNLOAD_DIR = Path(DOWNLOAD_DIR)
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)

BASE_URL = "https://ftp.ncbi.nlm.nih.gov/blast/db"
files_to_fetch = [
    DB_FILES["ssu"],
    DB_FILES["ssu_meta"],
    DB_FILES["lsu"],
    DB_FILES["lsu_meta"],
    DB_FILES["its"],
    DB_FILES["its_meta"],
]

def download_url_to_path(url: str, out_path: Path, chunk_size=1024*1024):
    """Download URL to out_path with chunking and basic retries."""
    out_path_tmp = out_path.with_suffix(out_path.suffix + ".part")
    headers = {"User-Agent": "python-urllib/3 - ncbi-download-script"}
    req = urllib.request.Request(url, headers=headers)
    attempts = 3
    for attempt in range(1, attempts+1):
        try:
            with urllib.request.urlopen(req, timeout=60) as resp, open(out_path_tmp, "wb") as outfh:
                shutil.copyfileobj(resp, outfh)
            out_path_tmp.replace(out_path)
            return True
        except Exception as e:
            print(f"Attempt {attempt} failed for {url}: {e}")
            time.sleep(2 * attempt)
    return False

for fname in files_to_fetch:
    url = f"{BASE_URL}/{fname}"
    out_path = DOWNLOAD_DIR / fname
    if out_path.exists():
        print(f"[SKIP] Already present: {fname}  ({out_path.stat().st_size / (1024**2):.2f} MB)")
        continue
    print(f"[DOWNLOAD] {fname} from {url}")
    ok = download_url_to_path(url, out_path)
    if not ok:
        print(f"[ERROR] Failed to download {fname}. Check network or try again manually.")
    else:
        size_mb = out_path.stat().st_size / (1024**2)
        print(f"[OK] Saved {fname} — {size_mb:.2f} MB")

# Final listing
print("\nFiles in download directory:")
for p in sorted(DOWNLOAD_DIR.iterdir()):
    if p.is_file():
        print(f" - {p.name}  {p.stat().st_size / (1024**2):.2f} MB")

[SKIP] Already present: SSU_eukaryote_rRNA.tar.gz  (57.01 MB)
[SKIP] Already present: SSU_eukaryote_rRNA-nucl-metadata.json  (0.00 MB)
[SKIP] Already present: LSU_eukaryote_rRNA.tar.gz  (56.64 MB)
[SKIP] Already present: LSU_eukaryote_rRNA-nucl-metadata.json  (0.00 MB)
[SKIP] Already present: ITS_eukaryote_sequences.tar.gz  (71.03 MB)
[SKIP] Already present: ITS_eukaryote_sequences-nucl-metadata.json  (0.00 MB)

Files in download directory:
 - ITS_eukaryote_sequences-nucl-metadata.json  0.00 MB
 - ITS_eukaryote_sequences.tar.gz  71.03 MB
 - kmer_w2v_k6.model  4.90 MB
 - LSU_eukaryote_rRNA-nucl-metadata.json  0.00 MB
 - LSU_eukaryote_rRNA.tar.gz  56.64 MB
 - SSU_eukaryote_rRNA-nucl-metadata.json  0.00 MB
 - SSU_eukaryote_rRNA.tar.gz  57.01 MB


In [5]:
# Cell 3: Extract FASTA-like content from downloaded tarballs into an 'extracted' directory.
# This cell is defensive: it re-creates DOWNLOAD_DIR if missing and checks tarball presence.
import io
import tarfile
from pathlib import Path

# Re-define or sanity-check DOWNLOAD_DIR and DB_FILES if they are missing
try:
    DOWNLOAD_DIR
except NameError:
    DOWNLOAD_DIR = Path("./ncbi_blast_db")
DOWNLOAD_DIR = Path(DOWNLOAD_DIR)
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)

# Tarball names from earlier (safe fallback)
try:
    DB_FILES
except NameError:
    DB_FILES = {
        "ssu": "SSU_eukaryote_rRNA.tar.gz",
        "lsu": "LSU_eukaryote_rRNA.tar.gz",
        "its": "ITS_eukaryote_sequences.tar.gz",
        "ssu_meta": "SSU_eukaryote_rRNA-nucl-metadata.json",
        "lsu_meta": "LSU_eukaryote_rRNA-nucl-metadata.json",
        "its_meta": "ITS_eukaryote_sequences-nucl-metadata.json",
    }

# Limits (safe default)
MAX_RECORDS_PER_MARKER = 5000  # change to None to extract everything

def extract_and_concatenate_fastas(tarball_path: Path, out_fasta: Path, max_records=None):
    """
    Extract any member file in tarball that appears to contain FASTA (detect '>' lines).
    Concatenate only sequence entries found into out_fasta.
    """
    count_sequences = 0
    if not tarball_path.exists():
        print(f"[WARN] Tarball not found: {tarball_path}")
        return False, 0
    with tarfile.open(tarball_path, "r:gz") as tar, open(out_fasta, "w") as outfh:
        for member in tar:
            if not member.isfile():
                continue
            # Read member in streaming mode
            try:
                f = tar.extractfile(member)
                if f is None:
                    continue
                # We'll scan for lines starting with '>' and write contiguous blocks
                writing = False
                seq_lines_written = 0
                for raw in f:
                    try:
                        line = raw.decode('utf-8', errors='replace')
                    except Exception:
                        line = raw.decode('latin-1', errors='replace')
                    if not line:
                        continue
                    if line.startswith(">"):
                        # New sequence header
                        if max_records and count_sequences >= max_records:
                            f.close()
                            print(f"[INFO] Reached max_records ({max_records}) for {tarball_path.name}")
                            return True, count_sequences
                        outfh.write(line)
                        writing = True
                        count_sequences += 1
                        seq_lines_written = 0
                    else:
                        if writing:
                            outfh.write(line)
                            seq_lines_written += 1
                f.close()
            except Exception as e:
                # skip unreadable members
                print(f"[DEBUG] skipped member {member.name} due to {e}")
                continue
    return True, count_sequences

# Run extraction for each marker
combined_fastas = {}
for marker in ("ssu", "lsu", "its"):
    tar_path = DOWNLOAD_DIR / DB_FILES[marker]
    out_fasta = EXTRACT_DIR / f"{marker}_combined.fasta"
    print(f"[EXTRACT] {marker} from {tar_path} -> {out_fasta}")
    ok, nseq = extract_and_concatenate_fastas(tar_path, out_fasta, max_records=MAX_RECORDS_PER_MARKER)
    if ok:
        print(f"  -> wrote approx {nseq} sequences to {out_fasta}")
        combined_fastas[marker] = out_fasta
    else:
        print(f"  -> extraction failed for {marker}")
        combined_fastas[marker] = None

# Quick summary: counts by marker
for k, p in combined_fastas.items():
    if p and p.exists():
        with open(p) as fh:
            seq_count = sum(1 for line in fh if line.startswith(">"))
        print(f"{k} combined FASTA: {p} (sequences: {seq_count})")
    else:
        print(f"{k} combined FASTA: not available")

[EXTRACT] ssu from ncbi_blast_db\SSU_eukaryote_rRNA.tar.gz -> ncbi_blast_db\extracted\ssu_combined.fasta
[DEBUG] skipped member taxdb.bti due to 'charmap' codec can't encode character '\ufffd' in position 4: character maps to <undefined>
[DEBUG] skipped member taxonomy4blast.sqlite3 due to 'charmap' codec can't encode character '\ufffd' in position 3: character maps to <undefined>
[DEBUG] skipped member SSU_eukaryote_rRNA.nin due to 'charmap' codec can't encode character '\ufffd' in position 1: character maps to <undefined>
[DEBUG] skipped member SSU_eukaryote_rRNA.nsq due to 'charmap' codec can't encode characters in position 5-9: character maps to <undefined>
  -> wrote approx 2 sequences to ncbi_blast_db\extracted\ssu_combined.fasta
[EXTRACT] lsu from ncbi_blast_db\LSU_eukaryote_rRNA.tar.gz -> ncbi_blast_db\extracted\lsu_combined.fasta
[DEBUG] skipped member taxdb.bti due to 'charmap' codec can't encode character '\ufffd' in position 4: character maps to <undefined>
[DEBUG] skipped 

In [6]:
# Cell 4: Load NCBI-provided metadata JSONs and construct a flexible DataFrame mapping accession -> taxonomy info
import json
import pandas as pd

META_DIR = DOWNLOAD_DIR
meta_paths = {
    "ssu": META_DIR / DB_FILES["ssu_meta"],
    "lsu": META_DIR / DB_FILES["lsu_meta"],
    "its": META_DIR / DB_FILES["its_meta"]
}

def load_json_if_exists(p: Path):
    if p.exists():
        try:
            with open(p, "r") as fh:
                return json.load(fh)
        except Exception as e:
            print(f"[ERROR] failed to read {p}: {e}")
            return None
    else:
        print(f"[WARN] metadata file not found: {p}")
        return None

# Load
metadata = {k: load_json_if_exists(path) for k, path in meta_paths.items()}

# The metadata format varies; we'll attempt to extract accession and any taxonomic lineage fields robustly.
def parse_metadata_list(meta_list):
    """Return a list of dicts with keys: accession, organism, tax_lineage (list/str), taxid, extra."""
    out = []
    if not meta_list:
        return out
    # meta_list could be dict or list depending on file - handle both
    if isinstance(meta_list, dict):
        # sometimes these JSONs are dict-of-accession->metadata
        iterable = meta_list.items()
    else:
        iterable = enumerate(meta_list)
    for key, entry in iterable:
        e = {}
        # entry may be dict with various structures
        if isinstance(entry, dict):
            # common fields (try many)
            e['accession'] = entry.get('accession_version') or entry.get('accession') or entry.get('seqid') or entry.get('id') or key
            e['organism'] = entry.get('organism') or entry.get('scientific_name') or entry.get('species') or entry.get('organism_name')
            e['taxid'] = entry.get('taxid') or entry.get('tax_id') or entry.get('taxonomic_id')
            # lineage may be a string or list
            lineage = entry.get('lineage') or entry.get('taxonomy') or entry.get('taxonomic_lineage') or entry.get('taxonomic_lineage_names')
            e['lineage_raw'] = lineage
            # if lineage is a string, attempt split by ';' or '|'
            if isinstance(lineage, str):
                if ';' in lineage:
                    e['lineage'] = [x.strip() for x in lineage.split(';') if x.strip()]
                elif '|' in lineage:
                    e['lineage'] = [x.strip() for x in lineage.split('|') if x.strip()]
                else:
                    e['lineage'] = [lineage.strip()]
            elif isinstance(lineage, list):
                e['lineage'] = lineage
            else:
                e['lineage'] = None
            # store raw entry for provenance
            e['raw'] = entry
        else:
            e['accession'] = key
            e['organism'] = None
            e['taxid'] = None
            e['lineage'] = None
            e['raw'] = entry
        out.append(e)
    return out

# Parse each metadata
parsed_meta = {}
for k, meta in metadata.items():
    parsed = parse_metadata_list(meta)
    parsed_meta[k] = pd.DataFrame(parsed)
    print(f"[META] {k}: parsed {len(parsed)} entries -> DataFrame columns: {parsed_meta[k].columns.tolist()}")

# Example preview for one (if available)
for k in parsed_meta:
    if len(parsed_meta[k])>0:
        display(parsed_meta[k].head(2))
        break

[META] ssu: parsed 12 entries -> DataFrame columns: ['accession', 'organism', 'taxid', 'lineage', 'raw']
[META] lsu: parsed 12 entries -> DataFrame columns: ['accession', 'organism', 'taxid', 'lineage', 'raw']
[META] its: parsed 12 entries -> DataFrame columns: ['accession', 'organism', 'taxid', 'lineage', 'raw']


Unnamed: 0,accession,organism,taxid,lineage,raw
0,dbname,,,,SSU_eukaryote_rRNA
1,version,,,,1.1


In [7]:
# Cell A — Diagnostic
from pathlib import Path
import os, sys, textwrap

# 1) Check whether DOWNLOAD_DIR exists in the current notebook state
try:
    download_dir = DOWNLOAD_DIR  # use your notebook's DOWNLOAD_DIR variable
    print("DOWNLOAD_DIR found from notebook variable.")
except NameError:
    download_dir = None
    print("DOWNLOAD_DIR is NOT defined in the current notebook namespace.")

# If not found, try common fallback locations (only for diagnostic; won't overwrite)
if download_dir is None:
    print("\nAttempting to auto-detect possible download directories (diagnostic only):")
    candidates = []
    cwd = Path.cwd()
    candidates.extend([cwd, Path("/mnt/data")])
    for c in candidates:
        if c.exists():
            print("  candidate:", c)
    # do not set DOWNLOAD_DIR automatically; just show candidates
else:
    download_dir = Path(download_dir)
    print("\nDOWNLOAD_DIR:", download_dir)
    print("Exists:", download_dir.exists())
    # show 'extracted' subdir
    extracted = download_dir / "extracted"
    print("\n'extracted' folder:", extracted)
    print("exists:", extracted.exists())
    if extracted.exists():
        print("\nFiles inside extracted/:")
        for p in sorted(extracted.iterdir()):
            print(" ", p.name, "(size:", p.stat().st_size, "bytes)")
    else:
        print("\nNo extracted/ folder found at that path.")

    # show whether the combined fasta files exist
    for m in ("ssu","lsu","its"):
        p = extracted / f"{m}_combined.fasta"
        print(f"  {p.name:25} exists: {p.exists()}  size:", p.stat().st_size if p.exists() else "N/A")


DOWNLOAD_DIR found from notebook variable.

DOWNLOAD_DIR: ncbi_blast_db
Exists: True

'extracted' folder: ncbi_blast_db\extracted
exists: True

Files inside extracted/:
  best_shared_heads.pt (size: 1014743 bytes)
  best_shared_heads_defensive.pt (size: 452813 bytes)
  best_shared_heads_defensive_state_dict.pt (size: 452693 bytes)
  best_shared_heads_labeled.pt (size: 1212295 bytes)
  best_shared_heads_pseudo_tensordataset_fix.pt (size: 1557749 bytes)
  best_shared_heads_resumed.pt (size: 1359303 bytes)
  best_shared_heads_resumed_state_dict.pt (size: 452645 bytes)
  best_shared_heads_retrain.pt (size: 1015687 bytes)
  calibration_bins_class.csv (size: 667 bytes)
  calibration_bins_family.csv (size: 667 bytes)
  calibration_bins_genus.csv (size: 668 bytes)
  calibration_bins_kingdom.csv (size: 665 bytes)
  calibration_bins_order.csv (size: 667 bytes)
  calibration_bins_phylum.csv (size: 667 bytes)
  calibration_bins_species.csv (size: 651 bytes)
  calibration_metrics_by_rank.csv (size:

In [8]:
# Cell B — Auto-combine FASTA sources into the expected combined files
from pathlib import Path
import gzip, shutil, sys, os

# Use the notebook's DOWNLOAD_DIR — abort if missing to avoid creating paths unexpectedly
try:
    download_dir = Path(DOWNLOAD_DIR)
except NameError:
    raise RuntimeError("DOWNLOAD_DIR is not defined in this notebook. Run the earlier cells that set DOWNLOAD_DIR, then re-run this cell.")

if not download_dir.exists():
    raise RuntimeError(f"DOWNLOAD_DIR path does not exist: {download_dir}")

extracted_dir = download_dir / "extracted"
extracted_dir.mkdir(parents=True, exist_ok=True)

# search for FASTA-like files recursively under DOWNLOAD_DIR
fasta_patterns = ("*.fasta","*.fa","*.fna","*.fasta.gz","*.fa.gz","*.fna.gz")
all_fasta_files = []
for pat in fasta_patterns:
    all_fasta_files.extend([p for p in download_dir.rglob(pat) if p.is_file()])

# remove any combined files that already exist in extracted (we will consider them as sources if needed)
# but prefer source files outside the combined names
all_fasta_files = [p for p in sorted(set(all_fasta_files)) if not p.name.endswith(("_combined.fasta","_combined.fa","_combined.fna"))]

print(f"Found {len(all_fasta_files)} candidate FASTA files under {download_dir}:")
for p in all_fasta_files:
    print(" ", p.relative_to(download_dir))

if len(all_fasta_files) == 0:
    print("\nNo FASTA files found under DOWNLOAD_DIR. Nothing to combine. You must run the earlier extraction/download step or place FASTA files under DOWNLOAD_DIR.")
else:
    # group files by marker name if their filename contains the marker, else list them as 'unassigned'
    markers = ("ssu","lsu","its")
    grouped = {m: [] for m in markers}
    unassigned = []
    for p in all_fasta_files:
        name = p.name.lower()
        found = False
        for m in markers:
            if m in name:
                grouped[m].append(p)
                found = True
                break
        if not found:
            unassigned.append(p)

    # If some markers received no files but there are unassigned files, we will NOT forcefully assign them.
    # Instead, present the situation and combine only where we have explicit matches.
    for m in markers:
        files_for_m = grouped[m]
        if not files_for_m:
            print(f"\nNo source FASTA files detected for marker '{m}'. Will NOT create {m}_combined.fasta.")
            continue

        dest = extracted_dir / f"{m}_combined.fasta"
        print(f"\nCreating combined file for marker '{m}': {dest}")
        with open(dest, "wb") as outfh:
            total_seqs = 0
            for src in files_for_m:
                print("  adding", src.relative_to(download_dir))
                # open gzipped or plain
                if src.suffix == ".gz" or src.name.endswith(".fasta.gz") or src.name.endswith(".fa.gz") or src.name.endswith(".fna.gz"):
                    opener = gzip.open
                    mode = "rt"
                else:
                    opener = open
                    mode = "r"
                with opener(src, mode) as inf:
                    # write text as bytes
                    for line in inf:
                        # count headers as sequences (approx)
                        if line.startswith(">"):
                            total_seqs += 1
                        outfh.write(line.encode() if isinstance(line, str) else line)
            print(f"  written {total_seqs} sequence headers to {dest} (approx).")

    if unassigned:
        print("\nThe following FASTA files did NOT match any marker name (ssu/lsu/its).")
        for p in unassigned:
            print("  ", p.relative_to(download_dir))
        print("If those should be part of a marker, rename them to include 'ssu','lsu' or 'its' in the filename, or move them into a folder whose name contains the marker.")

print("\nDone. Re-run your Word2Vec training cell now (the cell that previously raised 'Empty k-mer corpus').")


Found 3 candidate FASTA files under ncbi_blast_db:
  extracted\its_fetched.fasta
  extracted\lsu_fetched.fasta
  extracted\ssu_fetched.fasta

Creating combined file for marker 'ssu': ncbi_blast_db\extracted\ssu_combined.fasta
  adding extracted\ssu_fetched.fasta
  written 468 sequence headers to ncbi_blast_db\extracted\ssu_combined.fasta (approx).

Creating combined file for marker 'lsu': ncbi_blast_db\extracted\lsu_combined.fasta
  adding extracted\lsu_fetched.fasta
  written 406 sequence headers to ncbi_blast_db\extracted\lsu_combined.fasta (approx).

Creating combined file for marker 'its': ncbi_blast_db\extracted\its_combined.fasta
  adding extracted\its_fetched.fasta
  written 699 sequence headers to ncbi_blast_db\extracted\its_combined.fasta (approx).

Done. Re-run your Word2Vec training cell now (the cell that previously raised 'Empty k-mer corpus').


In [9]:
# Cell C — Sanity check: list fasta paths and count sequences found by the fasta_kmer_generator
from pathlib import Path
import os

# reuse the exact same variables and paths as your original cell
download_dir = Path(DOWNLOAD_DIR)
extracted_dir = download_dir / "extracted"

fasta_paths = []
for m in ("ssu", "lsu", "its"):
    p = extracted_dir / f"{m}_combined.fasta"
    if p.exists():
        fasta_paths.append(p)
print("[CORPUS] fasta paths used:", fasta_paths)

# quick k-mer generator (same as your code)
K = 6
def seq_to_kmers(seq, k=K):
    s = seq.strip().upper()
    return [s[i:i+k] for i in range(len(s)-k+1) if 'N' not in s[i:i+k]]

def fasta_kmer_generator(fasta_paths, k=K):
    for p in fasta_paths:
        if not p or not Path(p).exists():
            continue
        with open(p) as fh:
            seq = ""
            for line in fh:
                if line.startswith(">"):
                    if seq:
                        yield seq_to_kmers(seq, k=k)
                    seq = ""
                else:
                    seq += line.strip()
            if seq:
                yield seq_to_kmers(seq, k=k)

# count sequences found (cap to show first 5)
corpus_preview = []
for i, kmers in enumerate(fasta_kmer_generator(fasta_paths, k=K)):
    if kmers:
        corpus_preview.append(kmers)
    if i >= 9999:
        break
print("[CORPUS] total sequences (kmers) discovered (preview cap 10000):", len(corpus_preview))
if len(corpus_preview):
    print("Example k-mers from first sequence:", corpus_preview[0][:10])
else:
    print("No sequences found. If this prints, the combined FASTA files are still missing or empty.")


[CORPUS] fasta paths used: [WindowsPath('ncbi_blast_db/extracted/ssu_combined.fasta'), WindowsPath('ncbi_blast_db/extracted/lsu_combined.fasta'), WindowsPath('ncbi_blast_db/extracted/its_combined.fasta')]
[CORPUS] total sequences (kmers) discovered (preview cap 10000): 1573
Example k-mers from first sequence: ['TTATAC', 'TATACC', 'ATACCG', 'TACCGT', 'ACCGTG', 'CCGTGA', 'CGTGAA', 'GTGAAA', 'TGAAAC', 'GAAACT']


In [10]:
# Cell 5 (robust streaming, avoids OOM): Build k-mer corpus from combined FASTAs and train Word2Vec embeddings (k=6 default).
# Installs gensim automatically if missing.
import subprocess
import sys
import os
import gc
from collections import deque

# Ensure gensim present
try:
    import gensim
    from gensim.models import Word2Vec
except Exception:
    print("[INSTALL] gensim not found; installing gensim...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "gensim==4.3.1"])
    import gensim
    from gensim.models import Word2Vec

# Parameters (kept identical to your original names)
K = 6
W2V_VECTOR_SIZE = 128
W2V_WINDOW = 4
W2V_MIN_COUNT = 2
W2V_EPOCHS = 8

# New small constant to limit per-yield memory (buffer size of kmers yielded at once).
# Lower this to reduce peak memory further (500 is very conservative); keep at 1000 by default.
CHUNK_SIZE = 1000

def fasta_kmer_generator(fasta_paths, k=K, chunk_size=CHUNK_SIZE):
    """
    Stream k-mers from fasta files without building entire sequences.
    Yields small lists (length <= chunk_size) of k-mers.
    """
    for p in fasta_paths:
        if not p or not p.exists():
            continue
        # open and stream file; maintain a sliding window per sequence
        with open(p, 'r', errors='replace') as fh:
            window = deque()  # holds up to k characters
            buffer = []       # holds up to chunk_size k-mers to yield
            for raw in fh:
                if raw.startswith(">"):
                    # new header -> flush buffer and reset window
                    if buffer:
                        yield buffer
                        buffer = []
                    window.clear()
                    continue
                line = raw.strip().upper()
                if not line:
                    continue
                for ch in line:
                    if ch == 'N':
                        # ambiguous base -> reset window (k-mers cannot cross this)
                        window.clear()
                        continue
                    window.append(ch)
                    if len(window) == k:
                        # form kmer
                        kmer = ''.join(window)
                        buffer.append(kmer)
                        # slide window by 1
                        window.popleft()
                        # flush buffer if it reached chunk_size
                        if len(buffer) >= chunk_size:
                            yield buffer
                            buffer = []
            # End of file: flush any remaining buffer
            if buffer:
                yield buffer
            # ensure window cleared at sequence boundary (header logic above handled)
            # proceed to next file

class CorpusIterable:
    """Re-iterable, low-memory corpus wrapper that yields small k-mer chunks per iteration.
       __len__ scans files incrementally and counts sequences that would produce at least one valid k-mer.
    """
    def __init__(self, fasta_paths, k=K, chunk_size=CHUNK_SIZE):
        self.fasta_paths = list(fasta_paths)
        self.k = k
        self.chunk_size = chunk_size

    def __iter__(self):
        return fasta_kmer_generator(self.fasta_paths, k=self.k, chunk_size=self.chunk_size)

    def __len__(self):
        # Count sequences that yield at least one valid k-mer by scanning files without accumulating sequences
        cnt = 0
        for p in self.fasta_paths:
            if not p or not p.exists():
                continue
            with open(p, 'r', errors='replace') as fh:
                window = deque()
                seen_for_seq = False
                for raw in fh:
                    if raw.startswith(">"):
                        if seen_for_seq:
                            cnt += 1
                        seen_for_seq = False
                        window.clear()
                        continue
                    line = raw.strip().upper()
                    if not line:
                        continue
                    for ch in line:
                        if ch == 'N':
                            window.clear()
                            continue
                        window.append(ch)
                        if len(window) == self.k:
                            # we found at least one valid k-mer in this sequence
                            seen_for_seq = True
                            # slide to continue scanning without building strings
                            window.popleft()
                # EOF for this file: count last sequence if we saw any valid k-mer
                if seen_for_seq:
                    cnt += 1
        return cnt

# Collect list of combined FASTA paths from previous cells (kept identical)
fasta_paths = []
for m in ("ssu", "lsu", "its"):
    p = DOWNLOAD_DIR / "extracted" / f"{m}_combined.fasta"
    if p.exists():
        fasta_paths.append(p)
print("[CORPUS] fasta paths used:", fasta_paths)

# streaming corpus object (no materialized list)
corpus = CorpusIterable(fasta_paths, k=K, chunk_size=CHUNK_SIZE)

# keep max_seqs_to_collect for compatibility (not used to avoid materialization)
max_seqs_to_collect = 10000

print("[CORPUS] total sequences (kmers) in corpus (scanned):", len(corpus))

# Train or load Word2Vec model
w2v_model_path = DOWNLOAD_DIR / f"kmer_w2v_k{K}.model"
if w2v_model_path.exists():
    print("[W2V] loading existing model:", w2v_model_path)
    w2v = Word2Vec.load(str(w2v_model_path))
else:
    if len(corpus) == 0:
        raise RuntimeError("Empty k-mer corpus. Ensure combined FASTAs were extracted and not empty.")
    print("[W2V] training Word2Vec (streaming, memory-safe)...")
    # Use the streaming corpus directly (no materialized list)
    # Note: Word2Vec will iterate corpus multiple times internally; our corpus is re-iterable.
    w2v = Word2Vec(sentences=corpus, vector_size=W2V_VECTOR_SIZE, window=W2V_WINDOW,
                   min_count=W2V_MIN_COUNT, workers=max(1, (os.cpu_count() or 1)-1),
                   seed=42, epochs=W2V_EPOCHS)
    w2v.save(str(w2v_model_path))
    print("[W2V] saved model to", w2v_model_path)

# Quick sanity check: vector for a sample kmer
sample_kmer = None
for seq_kmers in corpus:
    if len(seq_kmers) > 0:
        sample_kmer = seq_kmers[0]
        break

if sample_kmer and sample_kmer in w2v.wv:
    print("[W2V] example k-mer:", sample_kmer, "vector_len:", w2v.vector_size)
else:
    print("[W2V] example k-mer not in vocabulary (common if min_count> observed).")

# final cleanup
gc.collect()


[CORPUS] fasta paths used: [WindowsPath('ncbi_blast_db/extracted/ssu_combined.fasta'), WindowsPath('ncbi_blast_db/extracted/lsu_combined.fasta'), WindowsPath('ncbi_blast_db/extracted/its_combined.fasta')]
[CORPUS] total sequences (kmers) in corpus (scanned): 1573
[W2V] loading existing model: ncbi_blast_db\kmer_w2v_k6.model
[W2V] example k-mer: TTATAC vector_len: 128


0

In [11]:
# Cell 6: Fetch representative sequences and taxonomy from NCBI Entrez if extracted FASTAs are small.
# Defensive: auto-install Biopython if missing, robust retries, saves outputs to DOWNLOAD_DIR/extracted/
import os
import time
from pathlib import Path
import json
import math

# Ensure Biopython installed
import subprocess, sys
try:
    from Bio import Entrez, SeqIO
except Exception:
    print("[INSTALL] Installing biopython...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "biopython"])
    from Bio import Entrez, SeqIO

# Config - change ENTREZ_EMAIL to a real email if you plan to fetch many records.
ENTREZ_EMAIL = "demoservice654@gmail.com"   # <-- replace with your email if you like
ENTREZ_API_KEY ="b39eb52d13753525a46134bc1ed3ac2fa708"  # set to a string if you have one; keeps requests higher-rate

Entrez.email = ENTREZ_EMAIL
if ENTREZ_API_KEY:
    Entrez.api_key = ENTREZ_API_KEY

EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)

# Read the combined FASTA counts from previous run
def count_fasta_records(p: Path):
    if not p.exists():
        return 0
    with open(p) as fh:
        return sum(1 for line in fh if line.startswith(">"))

combined_paths = {m: EXTRACT_DIR / f"{m}_combined.fasta" for m in ("ssu", "lsu", "its")}
existing_counts = {m: count_fasta_records(p) for m, p in combined_paths.items()}
print("[STATUS] existing extracted fasta counts:", existing_counts)

# If counts are small, we'll fetch sequences from NCBI (safe defaults)
FETCH_IF_LESS_THAN = 200  # if existing sequences < this, fetch from Entrez
FETCH_PER_MARKER = 700    # target number to fetch per marker (you can reduce to 200-500 if rate-limited)
FETCH_BATCH = 200         # fetch in batches to avoid overloading responses

# Search queries for markers
search_queries = {
    "ssu": '18S[All Fields] AND "rRNA"[All Fields] AND Eukaryota[Organism]',
    "lsu": '28S[All Fields] AND "rRNA"[All Fields] AND Eukaryota[Organism]',
    "its": 'internal transcribed spacer[All Fields] AND Eukaryota[Organism]'
}

fetched_metadata = {}
for marker in ("ssu", "lsu", "its"):
    if existing_counts.get(marker, 0) >= FETCH_IF_LESS_THAN:
        print(f"[SKIP FETCH] {marker} has {existing_counts[marker]} records (>= {FETCH_IF_LESS_THAN})")
        continue

    query = search_queries[marker]
    print(f"[ENTREZ SEARCH] marker={marker} query={query} (fetch up to {FETCH_PER_MARKER})")
    # esearch
    try:
        handle = Entrez.esearch(db="nuccore", term=query, retmax=FETCH_PER_MARKER)
        result = Entrez.read(handle)
        handle.close()
    except Exception as e:
        print(f"[ERROR] Entrez esearch failed for {marker}: {e}")
        result = {"IdList": []}

    id_list = result.get("IdList", [])
    print(f"  -> found {len(id_list)} ids (limiting to {FETCH_PER_MARKER})")

    if not id_list:
        fetched_metadata[marker] = []
        continue

    # Fetch GenBank records in batches and save FASTA + structured metadata
    out_fasta = EXTRACT_DIR / f"{marker}_fetched.fasta"
    out_meta = EXTRACT_DIR / f"{marker}_fetched_metadata.json"
    saved = 0
    meta_list = []

    with open(out_fasta, "w") as fasta_fh:
        for start in range(0, len(id_list), FETCH_BATCH):
            batch_ids = id_list[start:start+FETCH_BATCH]
            ids_str = ",".join(batch_ids)
            print(f"  [EFETCH] fetching {len(batch_ids)} ids (start={start})")
            try:
                fh = Entrez.efetch(db="nuccore", id=ids_str, rettype="gb", retmode="text")
                records = SeqIO.parse(fh, "gb")
                for rec in records:
                    # write fasta
                    seq_str = str(rec.seq)
                    if len(seq_str) < 100:  # skip tiny sequences
                        continue
                    header = f">{rec.id} {rec.annotations.get('organism','')}\n"
                    fasta_fh.write(header)
                    # wrap sequence lines 80 chars
                    for i in range(0, len(seq_str), 80):
                        fasta_fh.write(seq_str[i:i+80] + "\n")
                    saved += 1
                    # extract taxonomy info from annotations
                    tax = rec.annotations.get("taxonomy", [])
                    organism = rec.annotations.get("organism", None)
                    meta = {
                        "id": rec.id,
                        "accession": getattr(rec, "name", None) or rec.id,
                        "organism": organism,
                        "taxonomy": tax,
                        "description": rec.description,
                        "seq_len": len(seq_str)
                    }
                    meta_list.append(meta)
                fh.close()
            except Exception as e:
                print(f"    [WARN] efetch batch failed: {e}. Sleeping and retrying once.")
                time.sleep(3)
                try:
                    fh = Entrez.efetch(db="nuccore", id=ids_str, rettype="gb", retmode="text")
                    records = SeqIO.parse(fh, "gb")
                    for rec in records:
                        seq_str = str(rec.seq)
                        if len(seq_str) < 100:
                            continue
                        header = f">{rec.id} {rec.annotations.get('organism','')}\n"
                        fasta_fh.write(header)
                        for i in range(0, len(seq_str), 80):
                            fasta_fh.write(seq_str[i:i+80] + "\n")
                        saved += 1
                        tax = rec.annotations.get("taxonomy", [])
                        organism = rec.annotations.get("organism", None)
                        meta = {
                            "id": rec.id,
                            "accession": getattr(rec, "name", None) or rec.id,
                            "organism": organism,
                            "taxonomy": tax,
                            "description": rec.description,
                            "seq_len": len(seq_str)
                        }
                        meta_list.append(meta)
                    fh.close()
                except Exception as e2:
                    print(f"    [ERROR] efetch retry failed: {e2}. Skipping batch.")
                    continue
            # be polite to NCBI
            time.sleep(0.34)  # keep requests moderate (or longer if no API key)
    print(f"[SAVED] {saved} sequences written to {out_fasta}")
    # save metadata
    with open(out_meta, "w") as mh:
        json.dump(meta_list, mh, indent=2)
    fetched_metadata[marker] = meta_list

# Save a summary file
with open(EXTRACT_DIR / "fetched_summary.json", "w") as fh:
    json.dump({k: len(v) for k, v in fetched_metadata.items()}, fh, indent=2)

print("[CELL6 COMPLETE] fetched counts:", {k: len(v) for k, v in fetched_metadata.items()})

[STATUS] existing extracted fasta counts: {'ssu': 468, 'lsu': 406, 'its': 699}
[SKIP FETCH] ssu has 468 records (>= 200)
[SKIP FETCH] lsu has 406 records (>= 200)
[SKIP FETCH] its has 699 records (>= 200)
[CELL6 COMPLETE] fetched counts: {}


In [12]:
# Fixed Cell 7: Build a re-iterable k-mer corpus (list, with cap) and update / train Word2Vec safely.
import os
import sys
import subprocess
from pathlib import Path

# Ensure gensim installed and import
try:
    import gensim
    from gensim.models import Word2Vec
except Exception:
    print("[INSTALL] gensim not found; installing gensim...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "gensim==4.3.1"])
    import gensim
    from gensim.models import Word2Vec

# Ensure required variables exist (fall back to defaults if not)
try:
    DOWNLOAD_DIR
except NameError:
    DOWNLOAD_DIR = Path("./ncbi_blast_db")
DOWNLOAD_DIR = Path(DOWNLOAD_DIR)
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"

# k-mer / w2v params (safe defaults)
K = globals().get("K", 6)
W2V_VECTOR_SIZE = globals().get("W2V_VECTOR_SIZE", 128)
W2V_WINDOW = globals().get("W2V_WINDOW", 4)
W2V_MIN_COUNT = globals().get("W2V_MIN_COUNT", 2)
W2V_EPOCHS = globals().get("W2V_EPOCHS", 10)
MAX_CORPUS_SEQS = 20000  # safety cap for in-memory corpus

# Helper functions
def seq_to_kmers(seq, k=K):
    s = seq.strip().upper()
    return [s[i:i+k] for i in range(len(s)-k+1) if 'N' not in s[i:i+k]]

def fasta_kmer_gen(paths, k=K):
    """Generator: yields list of kmers for each sequence in given fasta paths"""
    for p in paths:
        if not p.exists():
            continue
        with open(p) as fh:
            seq = ""
            for line in fh:
                if line.startswith(">"):
                    if seq:
                        yield seq_to_kmers(seq, k=k)
                    seq = ""
                else:
                    seq += line.strip()
            if seq:
                yield seq_to_kmers(seq, k=k)

# Collect fasta paths (include combined + fetched if present)
fasta_paths = []
for m in ("ssu","lsu","its"):
    p_combined = EXTRACT_DIR / f"{m}_combined.fasta"
    p_fetched = EXTRACT_DIR / f"{m}_fetched.fasta"
    if p_combined.exists() and p_combined.stat().st_size > 0:
        fasta_paths.append(p_combined)
    if p_fetched.exists() and p_fetched.stat().st_size > 0:
        fasta_paths.append(p_fetched)
if not fasta_paths:
    raise RuntimeError("No FASTA files found for corpus. Ensure previous extraction/fetch cells ran correctly.")
print("[FASTA PATHS FOR CORPUS]", fasta_paths)

# Materialize corpus as a list up to MAX_CORPUS_SEQS
corpus_list = []
for i, kmers in enumerate(fasta_kmer_gen(fasta_paths, k=K)):
    if kmers:
        corpus_list.append(kmers)
    if (i+1) % 1000 == 0:
        print(f"  collected {i+1} sequences for corpus")
    if len(corpus_list) >= MAX_CORPUS_SEQS:
        print(f"[CAP] reached MAX_CORPUS_SEQS={MAX_CORPUS_SEQS}; stopping corpus collection")
        break

print(f"[CORPUS] total sequences collected for training: {len(corpus_list)}")
if len(corpus_list) == 0:
    raise RuntimeError("Empty k-mer corpus after scanning FASTA files. Cannot train Word2Vec.")

# Train or update Word2Vec using the materialized corpus_list (re-iterable)
w2v_model_path = DOWNLOAD_DIR / f"kmer_w2v_k{K}.model"
workers = max(1, (os.cpu_count() or 1) - 1)

if w2v_model_path.exists():
    print("[W2V] loading existing model for update:", w2v_model_path)
    w2v = Word2Vec.load(str(w2v_model_path))
    print("[W2V] building vocab (update=True) with new corpus")
    w2v.build_vocab(corpus_list, update=True)  # corpus_list is re-iterable (list)
    print("[W2V] training updated model...")
    w2v.train(corpus_list, total_examples=len(corpus_list), epochs=W2V_EPOCHS)
    w2v.save(str(w2v_model_path))
    print("[W2V] updated model saved:", w2v_model_path)
else:
    print("[W2V] training new Word2Vec model from scratch...")
    w2v = Word2Vec(
        sentences=corpus_list,
        vector_size=W2V_VECTOR_SIZE,
        window=W2V_WINDOW,
        min_count=W2V_MIN_COUNT,
        workers=workers,
        seed=42,
        epochs=W2V_EPOCHS
    )
    w2v.save(str(w2v_model_path))
    print("[W2V] new model saved:", w2v_model_path)

# Sanity & summary
vocab_size = len(w2v.wv.index_to_key)
print(f"[W2V] vocab size: {vocab_size}; vector_size: {w2v.vector_size}; total_examples used: {len(corpus_list)}")

[FASTA PATHS FOR CORPUS] [WindowsPath('ncbi_blast_db/extracted/ssu_combined.fasta'), WindowsPath('ncbi_blast_db/extracted/ssu_fetched.fasta'), WindowsPath('ncbi_blast_db/extracted/lsu_combined.fasta'), WindowsPath('ncbi_blast_db/extracted/lsu_fetched.fasta'), WindowsPath('ncbi_blast_db/extracted/its_combined.fasta'), WindowsPath('ncbi_blast_db/extracted/its_fetched.fasta')]
  collected 1000 sequences for corpus
  collected 2000 sequences for corpus
  collected 3000 sequences for corpus
[CORPUS] total sequences collected for training: 3146
[W2V] loading existing model for update: ncbi_blast_db\kmer_w2v_k6.model
[W2V] building vocab (update=True) with new corpus
[W2V] training updated model...
[W2V] updated model saved: ncbi_blast_db\kmer_w2v_k6.model
[W2V] vocab size: 4874; vector_size: 128; total_examples used: 3146


In [13]:
# Fixed Cell 9 (retry): Robust metadata rebuild (no fh.tell), PCA, clustering, novelty scoring, safe save
import os, sys, csv, subprocess
from pathlib import Path
import numpy as np
import pandas as pd

# Paths
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
emb_npy = EXTRACT_DIR / "embeddings.npy"
meta_csv = EXTRACT_DIR / "embeddings_meta.csv"
pca_npy = EXTRACT_DIR / "embeddings_pca.npy"
meta_pca_csv = EXTRACT_DIR / "embeddings_meta_pca.csv"
clustered_csv = EXTRACT_DIR / "embeddings_meta_clustered.csv"
clustered_json = EXTRACT_DIR / "embeddings_meta_clustered.json"

# Ensure embeddings exist
if not emb_npy.exists():
    raise RuntimeError(f"embeddings.npy not found at {emb_npy}. Run the embedding cell first.")

# Load embeddings
X_full = np.load(emb_npy)
n_samples, vec_size = X_full.shape
print(f"[START] Loaded embeddings.npy: samples={n_samples}, vec_size={vec_size}")

# Determine FASTA paths to read headers from
fasta_paths = []
for m in ("ssu", "lsu", "its"):
    p_combined = EXTRACT_DIR / f"{m}_combined.fasta"
    p_fetched  = EXTRACT_DIR / f"{m}_fetched.fasta"
    if p_combined.exists() and p_combined.stat().st_size > 0:
        fasta_paths.append(p_combined)
    if p_fetched.exists() and p_fetched.stat().st_size > 0:
        fasta_paths.append(p_fetched)
print("[INFO] fasta paths to scan for headers:", fasta_paths)

# Robust header reader without using fh.tell()
def robust_fasta_headers(paths):
    seq_counter = 0
    for p in paths:
        try:
            with open(p, "r", encoding="utf-8", errors="replace") as fh:
                for line in fh:
                    line = line.rstrip("\n\r")
                    if not line:
                        continue
                    if line.startswith(">"):
                        seq_counter += 1
                        header_text = line[1:].strip()
                        if header_text == "":
                            header = f"{p.stem}generated{seq_counter}"
                        else:
                            # take first token if exists
                            parts = header_text.split()
                            header = parts[0] if len(parts) > 0 and parts[0] != "" else f"{p.stem}generated{seq_counter}"
                        yield header
        except Exception as e:
            # Skip unreadable files but warn
            print(f"[WARN] failed to read {p}: {e}")
            continue

# Decide whether to rebuild metadata
rebuild_meta = False
if not meta_csv.exists():
    print("[INFO] metadata CSV not found; will rebuild from FASTA headers.")
    rebuild_meta = True
else:
    try:
        df_meta = pd.read_csv(meta_csv, dtype=str, keep_default_na=False, na_filter=False)
        if len(df_meta) != n_samples:
            print(f"[INFO] metadata rows ({len(df_meta)}) != embeddings ({n_samples}) -> will rebuild meta.")
            rebuild_meta = True
    except Exception as e:
        print("[WARN] failed to read existing metadata CSV:", e)
        rebuild_meta = True

# Rebuild meta if needed
if rebuild_meta:
    headers = list(robust_fasta_headers(fasta_paths))
    print(f"[INFO] headers found from FASTA files: {len(headers)}")
    # Pad headers if fewer than embeddings
    if len(headers) < n_samples:
        print(f"[WARN] found {len(headers)} headers but have {n_samples} embeddings. Will pad generated ids.")
    rows = []
    for i in range(n_samples):
        hdr = headers[i] if i < len(headers) else f"generated_seq_{i+1}"
        rows.append({"id": hdr, "seq_len": "", "no_kmer": ""})
    df_meta = pd.DataFrame(rows)
    # sanitize strings
    def sanitize_str(s):
        try:
            s2 = str(s)
        except Exception:
            s2 = ""
        s2 = s2.replace("\r", " ").replace("\n", " ").replace("\t", " ").replace('"', "''").strip()
        return s2
    for col in df_meta.columns:
        if df_meta[col].dtype == object:
            df_meta[col] = df_meta[col].apply(sanitize_str)
    # Save sanitized meta CSV
    df_meta.to_csv(meta_csv, index=False, encoding='utf-8', escapechar='\\', quoting=csv.QUOTE_MINIMAL)
    print(f"[REBUILT] saved metadata CSV to {meta_csv} with {len(df_meta)} rows.")
else:
    print("[INFO] using existing metadata CSV.")
    df_meta = pd.read_csv(meta_csv, dtype=str, keep_default_na=False, na_filter=False)

# Compute PCA
from sklearn.decomposition import PCA
n_components = min(64, vec_size, max(1, n_samples - 1))
print(f"[PCA] computing PCA with n_components={n_components} ...")
pca = PCA(n_components=n_components, random_state=42)
X_pca = pca.fit_transform(X_full)
np.save(pca_npy, X_pca)
print(f"[PCA] saved embeddings_pca.npy shape: {X_pca.shape}")

# Build df_meta + PCA columns, align lengths
pc_cols = [f"PC{i+1}" for i in range(X_pca.shape[1])]
df_pcas = pd.DataFrame(X_pca, columns=pc_cols)
min_n = min(len(df_meta), df_pcas.shape[0])
if len(df_meta) != df_pcas.shape[0]:
    print(f"[ALIGN] aligning meta ({len(df_meta)}) and PCA ({df_pcas.shape[0]}) to min_n={min_n}")
df_meta_full = pd.concat([df_meta.iloc[:min_n].reset_index(drop=True), df_pcas.iloc[:min_n].reset_index(drop=True)], axis=1)

# sanitize df_meta_full for safe CSV
def sanitize_df_for_csv(df):
    for col in df.columns:
        if df[col].dtype == object or df[col].dtype.name == 'string':
            df[col] = df[col].astype(str).apply(lambda s: s.replace("\r"," ").replace("\n"," ").replace("\t"," ").replace('"', "''").strip())
    return df

df_meta_full = sanitize_df_for_csv(df_meta_full)
df_meta_full.to_csv(meta_pca_csv, index=False, encoding='utf-8', escapechar='\\', quoting=csv.QUOTE_MINIMAL)
print(f"[SAVE] saved meta+PCA to {meta_pca_csv} rows={len(df_meta_full)} cols={len(df_meta_full.columns)}")

# Standardize PCA features
from sklearn.preprocessing import StandardScaler
Xs = StandardScaler().fit_transform(X_pca[:min_n])

# Clustering: prefer HDBSCAN; fallback to DBSCAN
labels = None
try:
    import hdbscan
    clusterer = hdbscan.HDBSCAN(min_cluster_size=5, min_samples=3, metric='euclidean')
    labels = clusterer.fit_predict(Xs)
    print(f"[HDBSCAN] found clusters: {len(set(labels)) - (1 if -1 in labels else 0)}, noise: {(labels==-1).sum()}")
except Exception as e:
    print(f"[HDBSCAN] not available or failed ({e}). Falling back to DBSCAN.")
    from sklearn.cluster import DBSCAN
    db = DBSCAN(eps=0.8, min_samples=3, metric='euclidean', n_jobs=-1)
    labels = db.fit_predict(Xs)
    print(f"[DBSCAN] found clusters: {len(set(labels)) - (1 if -1 in labels else 0)}, noise: {(labels==-1).sum()}")

# Novelty scoring using IsolationForest (or fallback centroid distance if small N)
from sklearn.ensemble import IsolationForest
if Xs.shape[0] < 10:
    centroid = Xs.mean(axis=0, keepdims=True)
    dists = np.linalg.norm(Xs - centroid, axis=1)
    novelty = (dists - dists.min()) / (dists.max() - dists.min() + 1e-12)
else:
    iso = IsolationForest(n_estimators=200, contamination=0.05, random_state=42)
    iso.fit(Xs)
    scores = iso.decision_function(Xs)  # higher -> more normal
    smin, smax = float(scores.min()), float(scores.max())
    if smax - smin == 0:
        novelty = np.zeros_like(scores)
    else:
        novelty = 1.0 - ((scores - smin) / (smax - smin))

# Attach and save clustered results (sanitize)
df_meta_full = df_meta_full.iloc[:len(labels)].reset_index(drop=True)
df_meta_full["cluster_label"] = labels.astype(int).astype(str)
df_meta_full["novelty_score"] = novelty.astype(float)
df_meta_full = sanitize_df_for_csv(df_meta_full)

df_meta_full.to_csv(clustered_csv, index=False, encoding='utf-8', escapechar='\\', quoting=csv.QUOTE_MINIMAL)
df_meta_full.to_json(clustered_json, orient='records', lines=False)
print(f"[DONE] saved clustered metadata CSV: {clustered_csv} and JSON: {clustered_json}")
print(f"  clusters example labels (unique, truncated): {sorted(set(labels))[:20]}")
print(f"  novelty min/max/mean: {float(np.min(novelty)):.4f}/{float(np.max(novelty)):.4f}/{float(np.mean(novelty)):.4f}")

[START] Loaded embeddings.npy: samples=2555, vec_size=128
[INFO] fasta paths to scan for headers: [WindowsPath('ncbi_blast_db/extracted/ssu_combined.fasta'), WindowsPath('ncbi_blast_db/extracted/ssu_fetched.fasta'), WindowsPath('ncbi_blast_db/extracted/lsu_combined.fasta'), WindowsPath('ncbi_blast_db/extracted/lsu_fetched.fasta'), WindowsPath('ncbi_blast_db/extracted/its_combined.fasta'), WindowsPath('ncbi_blast_db/extracted/its_fetched.fasta')]
[INFO] using existing metadata CSV.
[PCA] computing PCA with n_components=64 ...
[PCA] saved embeddings_pca.npy shape: (2555, 64)
[SAVE] saved meta+PCA to ncbi_blast_db\extracted\embeddings_meta_pca.csv rows=2555 cols=70




[HDBSCAN] found clusters: 99, noise: 531
[DONE] saved clustered metadata CSV: ncbi_blast_db\extracted\embeddings_meta_clustered.csv and JSON: ncbi_blast_db\extracted\embeddings_meta_clustered.json
  clusters example labels (unique, truncated): [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
  novelty min/max/mean: 0.0000/1.0000/0.2123


In [14]:
# Replacement Cell 10 fix (robust): Build shared feature extractor + ModuleDict heads (no custom subclass)
# and provide training/evaluation functions that use them directly (avoids subclass/instantiation issues).

import os, pickle, json
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Paths
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
EMB_PCA_NPY = EXTRACT_DIR / "embeddings_pca.npy"
META_CLUSTERED_CSV = EXTRACT_DIR / "embeddings_meta_clustered.csv"
LABEL_ENCODERS_CANDIDATES = [
    EXTRACT_DIR / "label_encoders_safe.pkl",
    EXTRACT_DIR / "label_encoders_v2.pkl",
    EXTRACT_DIR / "label_encoders_rebuilt.pkl",
    EXTRACT_DIR / "label_encoders_final.pkl",
    EXTRACT_DIR / "label_encoders.pkl"
]

RANKS = ["kingdom","phylum","class","order","family","genus","species"]
BATCH_SIZE = 64
SEED = 42

# sanity
if not Path(EMB_PCA_NPY).exists():
    raise RuntimeError(f"Missing {EMB_PCA_NPY}. Run embedding PCA cell first.")
if not Path(META_CLUSTERED_CSV).exists():
    raise RuntimeError(f"Missing {META_CLUSTERED_CSV}. Run previous clustering cell first.")

# load features and meta
X_pca = np.load(EMB_PCA_NPY)
df_meta = pd.read_csv(META_CLUSTERED_CSV, dtype=str, keep_default_na=False, na_filter=False)
n = X_pca.shape[0]
if len(df_meta) != n:
    mn = min(len(df_meta), n)
    print(f"[ALIGN] trimming to {mn}")
    df_meta = df_meta.iloc[:mn].reset_index(drop=True)
    X_pca = X_pca[:mn]
    n = mn

print(f"[LOAD] X_pca shape: {X_pca.shape}, meta rows: {len(df_meta)}")

# load or rebuild label_encoders
label_encoders = None
for p in LABEL_ENCODERS_CANDIDATES:
    if p.exists():
        try:
            with open(p, "rb") as fh:
                label_encoders = pickle.load(fh)
            print(f"[LOAD] label_encoders loaded from: {p.name}")
            break
        except Exception:
            label_encoders = None

# fallback: use in-memory label_encoders if present
if label_encoders is None and "label_encoders" in globals():
    label_encoders = globals()["label_encoders"]
    print("[LOAD] label_encoders loaded from globals")

# final fallback: reconstruct light-weight encoders (best-effort)
if label_encoders is None:
    print("[INFO] Reconstructing simple label_encoders from fetched metadata (best-effort).")
    id_to_tax = {}
    for marker in ("ssu","lsu","its"):
        meta_json = EXTRACT_DIR / f"{marker}_fetched_metadata.json"
        if not meta_json.exists():
            continue
        try:
            recs = json.load(open(meta_json))
        except Exception:
            continue
        for rec in recs:
            rid = str(rec.get("id") or rec.get("accession") or rec.get("accession_version") or "")
            if not rid:
                continue
            taxonomy = rec.get("taxonomy") or []
            tax_map = {}
            for i, rank in enumerate(["kingdom","phylum","class","order","family","genus"]):
                if i < len(taxonomy) and taxonomy[i]:
                    tax_map[rank] = taxonomy[i]
            organism = rec.get("organism") or rec.get("description") or ""
            parts = organism.split()
            if len(parts) >= 2:
                tax_map["genus"] = tax_map.get("genus") or parts[0]
                tax_map["species"] = " ".join(parts[:2])
            id_to_tax[rid] = tax_map
    # map df_meta ids
    labels_by_rank = {r: [] for r in RANKS}
    for rid in df_meta["id"].astype(str).tolist():
        t = id_to_tax.get(rid) or id_to_tax.get(rid.split(".")[0]) or {}
        for r in RANKS:
            labels_by_rank[r].append(t.get(r) if t.get(r) is not None else "UNASSIGNED")
    from sklearn.preprocessing import LabelEncoder
    label_encoders = {}
    for r in RANKS:
        le = LabelEncoder()
        le.fit(labels_by_rank[r])
        label_encoders[r] = le
    # save
    with open(EXTRACT_DIR / "label_encoders_rebuilt.pkl", "wb") as fh:
        pickle.dump(label_encoders, fh)
    print("[SAVE] saved label_encoders_rebuilt.pkl")

# Prepare y_encoded arrays aligned with df_meta
y_encoded = {}
for r in RANKS:
    le = label_encoders[r]
    # Attempt mapping from df_meta columns if present (e.g., a column with taxon names) — otherwise fallback to UNASSIGNED
    # We don't assume any taxon column exists; use best-effort mapping via fetched metadata if available
    mapped_list = []
    # Try to read column named r in df_meta (if earlier pipeline placed taxon names there)
    if r in df_meta.columns:
        for val in df_meta[r].astype(str).tolist():
            mapped_list.append(val if val != "" else "UNASSIGNED")
    else:
        # fallback: create UNASSIGNED for all; label_encoder will still have classes (e.g., UNASSIGNED)
        mapped_list = ["UNASSIGNED"] * len(df_meta)
    # now encode, mapping unseen labels to UNASSIGNED or index 0
    encoded = []
    for lab in mapped_list:
        if lab in le.classes_:
            encoded.append(int(np.where(le.classes_ == lab)[0][0]))
        else:
            if "UNASSIGNED" in le.classes_:
                encoded.append(int(np.where(le.classes_ == "UNASSIGNED")[0][0]))
            else:
                encoded.append(0)
    y_encoded[r] = np.array(encoded, dtype=int)

# Build TensorDatasets (assuming train_idx/val_idx were created earlier)
if "train_idx" in globals() and "val_idx" in globals():
    train_idx = globals()["train_idx"]
    val_idx = globals()["val_idx"]
else:
    from sklearn.model_selection import train_test_split
    idx = np.arange(n)
    train_idx, val_idx = train_test_split(idx, test_size=0.15, random_state=SEED, shuffle=True)
    globals()["train_idx"], globals()["val_idx"] = train_idx, val_idx

X_tensor = torch.tensor(X_pca, dtype=torch.float32)
y_tensors = [torch.tensor(y_encoded[r], dtype=torch.long) for r in RANKS]

X_train = X_tensor[train_idx]
X_val   = X_tensor[val_idx]
y_train_list = [yt[train_idx] for yt in y_tensors]
y_val_list   = [yt[val_idx] for yt in y_tensors]

from torch.utils.data import TensorDataset
train_ds = TensorDataset(X_train, *y_train_list)
val_ds   = TensorDataset(X_val,   *y_val_list)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False)
print(f"[DATA] train={len(train_ds)}, val={len(val_ds)}, batch_size={BATCH_SIZE}")

# Build shared extractor + ModuleDict heads (no subclass)
input_dim = X_pca.shape[1]
hidden_dim = 256
shared = nn.Sequential(
    nn.Linear(input_dim, hidden_dim),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(hidden_dim, hidden_dim//2),
    nn.ReLU()
)
heads = nn.ModuleDict()
for r in RANKS:
    ncls = len(label_encoders[r].classes_)
    heads[r] = nn.Linear(hidden_dim//2, ncls)

# Move modules to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
shared.to(device)
heads.to(device)
print(f"[MODEL PARTS] shared on {device}; heads: {[ (r, heads[r].out_features) for r in RANKS ]}")

# Collect parameters and optimizer
params = list(shared.parameters()) + list(heads.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3)

# Build criterions: try to use any class_weights in globals, otherwise uniform
criterions = {}
for r in RANKS:
    if "class_weights" in globals() and isinstance(globals()["class_weights"], dict) and r in globals()["class_weights"]:
        w = globals()["class_weights"][r]
        if isinstance(w, (list, tuple, np.ndarray)):
            w = torch.tensor(w, dtype=torch.float32)
        if isinstance(w, torch.Tensor):
            w = w.to(device)
        else:
            w = torch.tensor(np.asarray(w), dtype=torch.float32).to(device)
    else:
        w = torch.ones(len(label_encoders[r].classes_), dtype=torch.float32).to(device)
    criterions[r] = nn.CrossEntropyLoss(weight=w)

# Training & evaluation functions that use shared + heads directly
def train_one_epoch_shared(shared, heads, loader, optimizer, criterions, device=device):
    shared.train()
    heads.train()
    total_loss = 0.0
    n_batches = 0
    for batch in loader:
        x = batch[0].to(device)
        h = shared(x)
        outputs = {r: heads[r](h) for r in RANKS}
        loss = 0.0
        for i, r in enumerate(RANKS):
            target = batch[1 + i].to(device)
            loss = loss + criterions[r](outputs[r], target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        n_batches += 1
    return total_loss / max(1, n_batches)

def evaluate_shared(shared, heads, loader, device=device):
    shared.eval()
    heads.eval()
    from sklearn.metrics import accuracy_score, f1_score
    preds = {r: [] for r in RANKS}
    trues = {r: [] for r in RANKS}
    with torch.no_grad():
        for batch in loader:
            x = batch[0].to(device)
            h = shared(x)
            outputs = {r: heads[r](h) for r in RANKS}
            for i, r in enumerate(RANKS):
                logits = outputs[r]
                pred = torch.argmax(torch.softmax(logits, dim=1), dim=1).cpu().numpy()
                true = batch[1 + i].cpu().numpy()
                preds[r].extend(pred.tolist())
                trues[r].extend(true.tolist())
    metrics = {}
    for r in RANKS:
        try:
            acc = accuracy_score(trues[r], preds[r])
            f1m = f1_score(trues[r], preds[r], average='macro', zero_division=0)
        except Exception:
            acc, f1m = None, None
        metrics[r] = {"accuracy": acc, "f1_macro": f1m, "n_classes": len(label_encoders[r].classes_)}
    return metrics

print("[READY] shared + heads built. Use train_one_epoch_shared(shared, heads, train_loader, optimizer, criterions) to train and evaluate_shared(shared, heads, val_loader) to evaluate.")

# Save the shared+heads state dicts and label encoders for later inference
torch.save({"shared_state": shared.state_dict(), "heads_state": {r: heads[r].state_dict() for r in RANKS}}, EXTRACT_DIR / "shared_heads_initial.pt")
with open(EXTRACT_DIR / "label_encoders_used.pkl", "wb") as fh:
    pickle.dump(label_encoders, fh)
print("[SAVE] saved shared_heads_initial.pt and label_encoders_used.pkl")

[LOAD] X_pca shape: (2555, 64), meta rows: 2555
[LOAD] label_encoders loaded from: label_encoders_rebuilt.pkl
[DATA] train=2171, val=384, batch_size=64
[MODEL PARTS] shared on cpu; heads: [('kingdom', 2), ('phylum', 5), ('class', 10), ('order', 13), ('family', 19), ('genus', 27), ('species', 182)]
[READY] shared + heads built. Use train_one_epoch_shared(shared, heads, train_loader, optimizer, criterions) to train and evaluate_shared(shared, heads, val_loader) to evaluate.
[SAVE] saved shared_heads_initial.pt and label_encoders_used.pkl


In [15]:
# Cell 11 (robust): Training loop with early stopping and robust ReduceLROnPlateau creation
import time, json, inspect
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score

# --- Config ---
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
HISTORY_CSV = EXTRACT_DIR / "training_history.csv"
HISTORY_JSON = EXTRACT_DIR / "training_history.json"
BEST_CHECKPOINT = EXTRACT_DIR / "best_shared_heads.pt"

MAX_EPOCHS = 50
MIN_EPOCHS = 5
PATIENCE = 8            # early stopping patience (no improvement in val score)
LR_FACTOR = 0.5
LR_PATIENCE = 3
MIN_LR = 1e-6

# --- Sanity checks: required objects created by previous cells ---
reqs = ["shared", "heads", "train_loader", "val_loader", "criterions", "optimizer", "device", "RANKS", "label_encoders"]
missing = [n for n in reqs if n not in globals()]
if missing:
    raise RuntimeError(f"Cannot start training: missing objects in notebook globals: {missing}")

shared = globals()["shared"]
heads = globals()["heads"]
train_loader = globals()["train_loader"]
val_loader = globals()["val_loader"]
criterions = globals()["criterions"]
optimizer = globals()["optimizer"]
device = globals()["device"]
RANKS = globals()["RANKS"]
label_encoders = globals()["label_encoders"]

# ensure modules on device
shared.to(device)
heads.to(device)
for r in RANKS:
    heads[r].to(device)

# deterministic-ish
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Robust scheduler creation: inspect allowed kwargs and instantiate accordingly.
def create_reduce_on_plateau_scheduler(opt, mode="max", factor=0.5, patience=3, min_lr=1e-6):
    try:
        ctor = torch.optim.lr_scheduler.ReduceLROnPlateau
        sig = inspect.signature(ctor._init_)
        allowed = set(sig.parameters.keys()) - {"self", "args", "kwargs"}
        kwargs = {}
        cand = {"mode": mode, "factor": factor, "patience": patience, "min_lr": min_lr}
        # Some old/new torch versions name min_lr differently; try to map common names
        renames = {"min_lr": "min_lr", "min_lr_alt": "min_lr"}
        for k, v in cand.items():
            if k in allowed:
                kwargs[k] = v
        # if verbose allowed, set it to False (safe)
        if "verbose" in allowed:
            kwargs["verbose"] = False
        # instantiate
        scheduler = ctor(opt, **kwargs)
        print("[SCHED] ReduceLROnPlateau created with kwargs:", kwargs)
        return scheduler
    except Exception as e:
        print("[SCHED WARN] Could not create ReduceLROnPlateau scheduler due to:", repr(e))
        # fallback: dummy scheduler with same interface
        class DummyScheduler:
            def step(self, metric=None):
                return None
        return DummyScheduler()

scheduler = create_reduce_on_plateau_scheduler(optimizer, mode="max", factor=LR_FACTOR, patience=LR_PATIENCE, min_lr=MIN_LR)

# helper: compute validation loss and other metrics
def compute_val_loss_and_metrics(shared, heads, loader, criterions, device):
    shared.eval()
    heads.eval()
    total_loss = 0.0
    n_batches = 0
    preds = {r: [] for r in RANKS}
    trues = {r: [] for r in RANKS}
    confidences = {r: [] for r in RANKS}   # top1 prob
    with torch.no_grad():
        for batch in loader:
            x = batch[0].to(device)
            targets = [batch[i+1].to(device) for i in range(len(RANKS))]
            h = shared(x)
            outputs = {r: heads[r](h) for r in RANKS}
            loss = 0.0
            for i, r in enumerate(RANKS):
                logits = outputs[r]
                tgt = targets[i]
                loss += criterions[r](logits, tgt)
                probs = F.softmax(logits, dim=1)
                top1 = torch.argmax(probs, dim=1).cpu().numpy()
                top1_conf = probs.max(dim=1).values.cpu().numpy()
                preds[r].extend(top1.tolist())
                trues[r].extend(tgt.cpu().numpy().tolist())
                confidences[r].extend(top1_conf.tolist())
            total_loss += float(loss.item())
            n_batches += 1
    avg_loss = total_loss / max(1, n_batches)
    metrics = {}
    f1s = []
    for r in RANKS:
        try:
            acc = accuracy_score(trues[r], preds[r])
            f1m = f1_score(trues[r], preds[r], average="macro", zero_division=0)
            mean_conf = float(np.mean(confidences[r])) if len(confidences[r])>0 else None
        except Exception:
            acc, f1m, mean_conf = None, None, None
        metrics[r] = {"accuracy": acc, "f1_macro": f1m, "mean_confidence": mean_conf,
                      "n_classes": len(label_encoders[r].classes_)}
        if f1m is not None:
            f1s.append(f1m)
    agg_f1 = float(np.mean(f1s)) if len(f1s)>0 else 0.0
    return avg_loss, metrics, agg_f1

# training loop with early stopping
best_score = -np.inf
epochs_no_improve = 0
history = []

start_time = time.time()
for epoch in range(1, MAX_EPOCHS + 1):
    t0 = time.time()
    # Train one epoch
    shared.train()
    heads.train()
    train_loss = 0.0
    n_batches = 0
    for batch in train_loader:
        x = batch[0].to(device)
        targets = [batch[i+1].to(device) for i in range(len(RANKS))]
        h = shared(x)
        outputs = {r: heads[r](h) for r in RANKS}
        loss = 0.0
        for i, r in enumerate(RANKS):
            loss = loss + criterions[r](outputs[r], targets[i])
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(shared.parameters()) + list(heads.parameters()), max_norm=5.0)
        optimizer.step()
        train_loss += float(loss.item())
        n_batches += 1
    train_loss = train_loss / max(1, n_batches)

    # Validation
    val_loss, val_metrics, val_agg_f1 = compute_val_loss_and_metrics(shared, heads, val_loader, criterions, device)

    # Scheduler step (on aggregated val score) - handle both schedulers with and w/o metric arg
    try:
        # ReduceLROnPlateau expects a metric
        scheduler.step(val_agg_f1)
    except TypeError:
        try:
            # some schedulers expect no args
            scheduler.step()
        except Exception:
            pass

    # Save metrics & history
    epoch_record = {
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "val_agg_f1": val_agg_f1,
        "time_sec": time.time() - t0
    }
    for r in RANKS:
        m = val_metrics.get(r, {})
        epoch_record.update({f"{r}_acc": m.get("accuracy"), f"{r}_f1_macro": m.get("f1_macro"), f"{r}_mean_conf": m.get("mean_confidence")})
    history.append(epoch_record)

    # Print progress
    print(f"Epoch {epoch:03d} | train_loss: {train_loss:.4f} | val_loss: {val_loss:.4f} | val_agg_f1: {val_agg_f1:.4f} | time: {epoch_record['time_sec']:.1f}s")

    # Check for improvement
    if val_agg_f1 > best_score + 1e-6:
        best_score = val_agg_f1
        epochs_no_improve = 0
        checkpoint = {
            "shared_state": shared.state_dict(),
            "heads_state": {r: heads[r].state_dict() for r in RANKS},
            "epoch": epoch,
            "val_agg_f1": val_agg_f1,
            "optimizer_state": optimizer.state_dict()
        }
        torch.save(checkpoint, BEST_CHECKPOINT)
        print(f"  [CHECKPOINT] saved new best model at epoch {epoch}, val_agg_f1={val_agg_f1:.4f}")
    else:
        epochs_no_improve += 1

    # Save history
    try:
        pd.DataFrame(history).to_csv(HISTORY_CSV, index=False)
        with open(HISTORY_JSON, "w") as fh:
            json.dump(history, fh, indent=2)
    except Exception as e:
        print("[WARN] could not save history:", e)

    # Early stopping
    if epoch >= MIN_EPOCHS and epochs_no_improve >= PATIENCE:
        print(f"[EARLY STOP] No improvement for {epochs_no_improve} epochs (patience={PATIENCE}). Stopping.")
        break

total_time = time.time() - start_time
print(f"[TRAINING COMPLETE] epochs_run={epoch} best_val_agg_f1={best_score:.4f} total_time_sec={total_time:.1f}")

# Load best checkpoint into memory (for immediate inference)
if Path(BEST_CHECKPOINT).exists():
    ckpt = torch.load(BEST_CHECKPOINT, map_location=device)
    shared.load_state_dict(ckpt["shared_state"])
    for r in RANKS:
        heads[r].load_state_dict(ckpt["heads_state"][r])
    print(f"[LOAD] loaded best checkpoint from {BEST_CHECKPOINT} (epoch {ckpt.get('epoch')}, val_agg_f1={ckpt.get('val_agg_f1'):.4f})")

# Save final history once more
try:
    pd.DataFrame(history).to_csv(HISTORY_CSV, index=False)
    with open(HISTORY_JSON, "w") as fh:
        json.dump(history, fh, indent=2)
    print(f"[SAVE] training history saved to {HISTORY_CSV} and {HISTORY_JSON}")
except Exception as e:
    print("[WARN] could not save final history:", e)

# Final evaluation on validation set (print nicely)
final_val_loss, final_val_metrics, final_val_agg_f1 = compute_val_loss_and_metrics(shared, heads, val_loader, criterions, device)
print("=== Final validation metrics (best model) ===")
for r in RANKS:
    m = final_val_metrics.get(r, {})
    print(f"{r:8s} | n_classes={m.get('n_classes', '?'):<3} | acc={m.get('accuracy'):.4f} | f1_macro={m.get('f1_macro'):.4f} | mean_conf={m.get('mean_confidence'):.4f}")
print(f"Aggregated mean-macro-F1: {final_val_agg_f1:.4f} | val_loss: {final_val_loss:.4f}")

[SCHED WARN] Could not create ReduceLROnPlateau scheduler due to: AttributeError("type object 'ReduceLROnPlateau' has no attribute '_init_'")
Epoch 001 | train_loss: 11.7235 | val_loss: 1.4548 | val_agg_f1: 1.0000 | time: 1.1s
  [CHECKPOINT] saved new best model at epoch 1, val_agg_f1=1.0000
Epoch 002 | train_loss: 0.1883 | val_loss: 0.0025 | val_agg_f1: 1.0000 | time: 0.7s
Epoch 003 | train_loss: 0.0021 | val_loss: 0.0013 | val_agg_f1: 1.0000 | time: 0.8s
Epoch 004 | train_loss: 0.0014 | val_loss: 0.0009 | val_agg_f1: 1.0000 | time: 1.0s
Epoch 005 | train_loss: 0.0010 | val_loss: 0.0006 | val_agg_f1: 1.0000 | time: 1.6s
Epoch 006 | train_loss: 0.0006 | val_loss: 0.0004 | val_agg_f1: 1.0000 | time: 1.1s
Epoch 007 | train_loss: 0.0004 | val_loss: 0.0002 | val_agg_f1: 1.0000 | time: 1.1s
Epoch 008 | train_loss: 0.0002 | val_loss: 0.0001 | val_agg_f1: 1.0000 | time: 1.2s
Epoch 009 | train_loss: 0.0001 | val_loss: 0.0000 | val_agg_f1: 1.0000 | time: 1.3s
[EARLY STOP] No improvement for 8 e

In [16]:
# Cell 12: Inference + per-sequence JSON export (predicted taxonomy, probs, confidence, novelty, cluster, abundance, explainability)
import os, json, math, traceback
from pathlib import Path
from collections import Counter
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sklearn.neighbors import NearestNeighbors

# Optional: k-mer explainability (gensim Word2Vec)
try:
    from gensim.models import Word2Vec
    have_gensim = True
except Exception:
    have_gensim = False

# Paths
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
BEST_CHECKPOINT = EXTRACT_DIR / "best_shared_heads.pt"
META_CLUSTERED_CSV = EXTRACT_DIR / "embeddings_meta_clustered.csv"
EMB_PCA_NPY = EXTRACT_DIR / "embeddings_pca.npy"
EMB_FULL_NPY = EXTRACT_DIR / "embeddings.npy"
LABEL_ENCODERS_PKL = EXTRACT_DIR / "label_encoders_used.pkl"   # was saved earlier
W2V_MODEL = DOWNLOAD_DIR / "kmer_w2v_k6.model"

OUT_JSONL = EXTRACT_DIR / "predictions.jsonl"
OUT_CSV   = EXTRACT_DIR / "predictions_summary.csv"

# Basic checks
for p in [BEST_CHECKPOINT, META_CLUSTERED_CSV, EMB_PCA_NPY, EMB_FULL_NPY]:
    if not Path(p).exists():
        raise RuntimeError(f"Required file missing: {p}")

# Load data
df_meta = pd.read_csv(META_CLUSTERED_CSV, dtype=str, keep_default_na=False, na_filter=False)
X_pca = np.load(EMB_PCA_NPY)      # used as model input
X_full = np.load(EMB_FULL_NPY)    # used for NN explainability (k-mer avg embedding)
n = X_pca.shape[0]
print(f"[LOAD] rows={n}, X_pca.shape={X_pca.shape}, X_full.shape={X_full.shape}")

# Align meta length
if len(df_meta) != n:
    mn = min(len(df_meta), n)
    print(f"[ALIGN] trimming to min_n={mn}")
    df_meta = df_meta.iloc[:mn].reset_index(drop=True)
    X_pca = X_pca[:mn]
    X_full = X_full[:mn]
    n = mn

# Load label encoders
label_encoders = None
if Path(LABEL_ENCODERS_PKL).exists():
    try:
        import pickle
        with open(LABEL_ENCODERS_PKL, "rb") as fh:
            label_encoders = pickle.load(fh)
        print(f"[LOAD] label_encoders loaded from {LABEL_ENCODERS_PKL.name}")
    except Exception as e:
        print("[WARN] failed loading label encoders:", e)
if label_encoders is None and "label_encoders" in globals():
    label_encoders = globals()["label_encoders"]
    print("[LOAD] label_encoders loaded from globals")

if label_encoders is None:
    raise RuntimeError("Label encoders not found. Run earlier cells that produce them.")

# Recreate model parts (same architecture as training)
input_dim = X_pca.shape[1]
hidden_dim = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

shared = torch.nn.Sequential(
    torch.nn.Linear(input_dim, hidden_dim),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.3),
    torch.nn.Linear(hidden_dim, hidden_dim // 2),
    torch.nn.ReLU()
)
heads = torch.nn.ModuleDict({ r: torch.nn.Linear(hidden_dim//2, len(label_encoders[r].classes_)) for r in label_encoders })

# load checkpoint
ckpt = torch.load(BEST_CHECKPOINT, map_location="cpu")
# state dicts saved as "shared_state" and "heads_state" earlier
try:
    shared.load_state_dict(ckpt["shared_state"])
    for r in label_encoders:
        heads[r].load_state_dict(ckpt["heads_state"][r])
    print(f"[LOAD] checkpoint loaded (epoch={ckpt.get('epoch')}, val_agg_f1={ckpt.get('val_agg_f1')})")
except Exception as e:
    # if shapes mismatch, print and proceed with available weights
    print("[WARN] failed to load full checkpoint cleanly:", e)
    # try best-effort partial load
    try:
        shared_state = ckpt.get("shared_state", {})
        shared.load_state_dict(shared_state, strict=False)
        for r in label_encoders:
            if r in ckpt.get("heads_state", {}):
                heads[r].load_state_dict(ckpt["heads_state"][r], strict=False)
        print("[LOAD] partial checkpoint loaded (strict=False).")
    except Exception as e2:
        print("[ERROR] could not load checkpoint:", e2)
        raise

shared.to(device)
heads.to(device)
shared.eval()
heads.eval()

# Batch inference
BATCH = 256
results = []
all_probs = {r: [] for r in label_encoders}
all_preds = {r: [] for r in label_encoders}
all_topk = {r: [] for r in label_encoders}

with torch.no_grad():
    for start in range(0, n, BATCH):
        end = min(n, start + BATCH)
        xb = torch.tensor(X_pca[start:end], dtype=torch.float32).to(device)
        h = shared(xb)
        for r in label_encoders:
            logits = heads[r](h)               # (batch, n_classes)
            probs = F.softmax(logits, dim=1).cpu().numpy()
            top_idx = np.argmax(probs, axis=1)
            # top-3
            topk_idx = np.argsort(-probs, axis=1)[:, :3]
            for i in range(end-start):
                all_probs[r].append(probs[i])
                all_preds[r].append(int(top_idx[i]))
                all_topk[r].append([int(x) for x in topk_idx[i]])
print("[INFER] Completed batched forward pass.")

# Build nearest-neighbors index for explainability
# Prefer training set neighbors if train_idx is available
if "train_idx" in globals():
    ref_idx = np.array(globals()["train_idx"], dtype=int)
    print("[NN] Using training set as reference for nearest neighbors (len=", len(ref_idx), ")")
else:
    # else use full dataset as reference
    ref_idx = np.arange(n)
    print("[NN] Using full dataset as reference for nearest neighbors (len=", len(ref_idx), ")")

try:
    nn_model = NearestNeighbors(n_neighbors=6, metric="cosine", n_jobs=-1)
    nn_model.fit(X_full[ref_idx])
    nn_available = True
except Exception as e:
    print("[WARN] NearestNeighbors failed:", e)
    nn_available = False

# Attempt to load Word2Vec k-mer model (for top-k k-mers similar to sequence embedding)
w2v = None
if have_gensim and W2V_MODEL.exists():
    try:
        w2v = Word2Vec.load(str(W2V_MODEL))
        print("[W2V] loaded k-mer Word2Vec model:", W2V_MODEL.name)
    except Exception as e:
        print("[WARN] failed to load Word2Vec model:", e)
        w2v = None
else:
    if not have_gensim:
        print("[W2V] gensim not available; skipping k-mer explainability.")
    else:
        print("[W2V] kmer model file not found; skipping k-mer explainability.")

# Build cluster counts and abundance proxy
cluster_labels = df_meta["cluster_label"].astype(str).tolist() if "cluster_label" in df_meta.columns else ["-1"]*n
cluster_counter = Counter(cluster_labels)
total_seqs = float(n)

# utility to map class index -> label string
def idx_to_label(r, idx):
    classes = label_encoders[r].classes_
    if idx < 0 or idx >= len(classes):
        return "UNASSIGNED"
    return str(classes[idx])

# Now assemble per-sequence records and write out JSONL + CSV summary
out_records = []
csv_rows = []
for i in range(n):
    rec_id = str(df_meta.loc[i, "id"]) if "id" in df_meta.columns else f"seq_{i+1}"
    entry = {"id": rec_id}
    per_rank = {}
    mean_conf = 0.0
    conf_count = 0
    for r in label_encoders:
        probs_r = all_probs[r][i]
        pred_idx = int(all_preds[r][i])
        pred_label = idx_to_label(r, pred_idx)
        pred_prob = float(probs_r[pred_idx])
        # top-3 with labels & probs
        topk_idx = all_topk[r][i]
        topk = [{"label": idx_to_label(r, int(k)), "prob": float(probs_r[int(k)])} for k in topk_idx]
        per_rank[r] = {"predicted_label": pred_label, "predicted_index": pred_idx, "predicted_prob": pred_prob, "top_k": topk}
        mean_conf += pred_prob
        conf_count += 1
    mean_conf = mean_conf / max(1, conf_count)
    entry["predicted"] = per_rank
    entry["mean_confidence"] = mean_conf

    # attach novelty and cluster info from metadata (if present)
    novelty = float(df_meta.loc[i, "novelty_score"]) if "novelty_score" in df_meta.columns and df_meta.loc[i, "novelty_score"] != "" else None
    cluster_label = str(df_meta.loc[i, "cluster_label"]) if "cluster_label" in df_meta.columns else "-1"
    entry["novelty_score"] = novelty
    entry["cluster_label"] = cluster_label
    cluster_size = int(cluster_counter.get(cluster_label, 1))
    entry["cluster_size"] = cluster_size
    entry["abundance_proxy"] = float(cluster_size) / total_seqs

    # QC flags
    qc_flags = []
    if mean_conf < 0.35:
        qc_flags.append("low_confidence")
    if novelty is not None and novelty > 0.8:
        qc_flags.append("novel_candidate")
    entry["qc_flags"] = qc_flags

    # Nearest neighbors (explainability)
    neighbors = []
    if nn_available:
        try:
            # query using full embedding
            emb = X_full[i].reshape(1, -1)
            dists, idxs = nn_model.kneighbors(emb, n_neighbors=6, return_distance=True)
            dists = dists[0].tolist()
            idxs = idxs[0].tolist()
            for dd, ridx in zip(dists, idxs):
                ref_global_idx = int(ref_idx[ridx])
                if ref_global_idx == i:
                    # skip self-match; continue to next
                    continue
                nid = str(df_meta.loc[ref_global_idx, "id"])
                neighbors.append({"id": nid, "index": int(ref_global_idx), "distance": float(dd),
                                  "cluster": str(df_meta.loc[ref_global_idx, "cluster_label"]) if "cluster_label" in df_meta.columns else None})
            # take top-3 excluding self
            entry["nearest_neighbors"] = neighbors[:3]
        except Exception as e:
            entry["nearest_neighbors_error"] = str(e)
    else:
        entry["nearest_neighbors"] = []

    # Top-k k-mers similar to sequence embedding (approx explainability) if Word2Vec is available
    top_kmers = []
    if w2v is not None:
        try:
            seq_emb = X_full[i]  # same space as averaged kmer vectors
            # gensim KeyedVectors similarity_by_vector: use wv.similar_by_vector
            kv = w2v.wv
            # similar_by_vector may be expensive; request top 5
            sim = kv.similar_by_vector(seq_emb, topn=5)
            # sim is list of (kmer, score)
            top_kmers = [{"kmer": s[0], "sim": float(s[1])} for s in sim]
            entry["top_kmers"] = top_kmers
        except Exception as e:
            entry["top_kmers_error"] = str(e)
    else:
        entry["top_kmers"] = []

    # store
    out_records.append(entry)
    # flatten for CSV row summary (one-line per sample)
    csv_row = {
        "id": rec_id,
        "cluster_label": cluster_label,
        "cluster_size": cluster_size,
        "abundance_proxy": entry["abundance_proxy"],
        "novelty_score": novelty,
        "mean_confidence": mean_conf,
        "qc_flags": "|".join(qc_flags) if qc_flags else ""
    }
    # include predicted labels per rank
    for r in label_encoders:
        csv_row[f"{r}_pred"] = per_rank[r]["predicted_label"]
        csv_row[f"{r}_prob"] = per_rank[r]["predicted_prob"]
    csv_rows.append(csv_row)

# Write outputs
try:
    with open(OUT_JSONL, "w", encoding="utf-8") as fh:
        for rec in out_records:
            fh.write(json.dumps(rec) + "\n")
    pd.DataFrame(csv_rows).to_csv(OUT_CSV, index=False)
    print(f"[SAVE] predictions saved: {OUT_JSONL} (jsonl), {OUT_CSV} (csv). total_records={len(out_records)}")
except Exception as e:
    print("[ERROR] failed to save predictions:", e)
    traceback.print_exc()

# Print summary of interesting candidates (top novel candidates)
novel_candidates = [r for r in out_records if (r.get("novelty_score") is not None and r["novelty_score"] > 0.8)]
print(f"[SUMMARY] total sequences: {n}; novel candidates (novelty>0.8): {len(novel_candidates)}")
if len(novel_candidates) > 0:
    print("First 5 novel candidates ids:", [r["id"] for r in novel_candidates[:5]])

# Example: print first 3 prediction records for quick inspection
for rec in out_records[:3]:
    print(json.dumps({"id": rec["id"], "predicted_summary": {r: rec["predicted"][r]["predicted_label"] for r in list(rec["predicted"])[:3]}, "mean_conf": rec["mean_confidence"]}, indent=2))

[LOAD] rows=2555, X_pca.shape=(2555, 64), X_full.shape=(2555, 128)
[LOAD] label_encoders loaded from label_encoders_used.pkl
[LOAD] checkpoint loaded (epoch=1, val_agg_f1=1.0)
[INFER] Completed batched forward pass.
[NN] Using training set as reference for nearest neighbors (len= 2171 )
[W2V] loaded k-mer Word2Vec model: kmer_w2v_k6.model
[SAVE] predictions saved: ncbi_blast_db\extracted\predictions.jsonl (jsonl), ncbi_blast_db\extracted\predictions_summary.csv (csv). total_records=2555
[SUMMARY] total sequences: 2555; novel candidates (novelty>0.8): 18
First 5 novel candidates ids: ['XR_013100016.1', 'XR_013100016.1', 'LC876616.1', 'LC876591.1', 'LC876590.1']
{
  "id": "JBJNTG020000075.1",
  "predicted_summary": {
    "kingdom": "UNASSIGNED",
    "phylum": "UNASSIGNED",
    "class": "UNASSIGNED"
  },
  "mean_conf": 0.7759965402739388
}
{
  "id": "JBMETL020000032.1",
  "predicted_summary": {
    "kingdom": "UNASSIGNED",
    "phylum": "UNASSIGNED",
    "class": "UNASSIGNED"
  },
  "mean

In [17]:
# Cell 13 — Diagnostics, ID-fix, leakage checks, temperature scaling calibration, and calibrated exports
import os, json, math, traceback
from pathlib import Path
from collections import Counter, defaultdict
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score
import time

# Paths (reuse)
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
PRED_JSONL = EXTRACT_DIR / "predictions.jsonl"
PRED_SUM_CSV = EXTRACT_DIR / "predictions_summary.csv"
META_CLUSTERED_CSV = EXTRACT_DIR / "embeddings_meta_clustered.csv"
BEST_CKPT = EXTRACT_DIR / "best_shared_heads.pt"
CALIB_JSONL = EXTRACT_DIR / "predictions_calibrated.jsonl"
CALIB_CSV   = EXTRACT_DIR / "predictions_summary_calibrated.csv"
NOVEL_CSV   = EXTRACT_DIR / "novel_candidates.csv"
CLUSTER_SUM  = EXTRACT_DIR / "cluster_summary.csv"

# small helpers
def safe_load_jsonl(path):
    recs = []
    with open(path, "r", encoding="utf-8") as fh:
        for ln in fh:
            ln = ln.strip()
            if not ln:
                continue
            try:
                recs.append(json.loads(ln))
            except Exception:
                # try to recover with replace
                try:
                    recs.append(json.loads(ln.encode("utf-8", errors="replace").decode()))
                except Exception:
                    recs.append({"_raw": ln})
    return recs

def save_jsonl(list_of_dicts, path):
    with open(path, "w", encoding="utf-8") as fh:
        for rec in list_of_dicts:
            fh.write(json.dumps(rec, ensure_ascii=False) + "\n")

# 0) Basic checks
for p in [PRED_JSONL, PRED_SUM_CSV, META_CLUSTERED_CSV]:
    if not p.exists():
        raise RuntimeError(f"Required file missing: {p}")

print("[START] Loading files...")
pred_df = pd.read_csv(PRED_SUM_CSV, dtype=str, keep_default_na=False, na_filter=False)
pred_recs = safe_load_jsonl(PRED_JSONL)
meta = pd.read_csv(META_CLUSTERED_CSV, dtype=str, keep_default_na=False, na_filter=False)

n_pred = len(pred_recs)
n_df = len(pred_df)
n_meta = len(meta)
print(f"  loaded: predictions.jsonl {n_pred} records; summary csv {n_df} rows; meta csv {n_meta} rows")

# 1) Fix empty or missing IDs in predictions (assumes order matches meta)
fixed_count = 0
for i, rec in enumerate(pred_recs):
    rec_id = rec.get("id", "")
    if (not rec_id) and i < len(meta):
        new_id = str(meta.loc[i, "id"]) if "id" in meta.columns else f"seq_{i+1}"
        rec["id"] = new_id
        fixed_count += 1
for i, row in pred_df.iterrows():
    if (not str(row.get("id", "")).strip()) and i < len(meta):
        pred_df.at[i, "id"] = str(meta.loc[i, "id"]) if "id" in meta.columns else f"seq_{i+1}"
        fixed_count += 1
print(f"[FIX] Filled {fixed_count} empty IDs (from meta).")

# Overwrite repaired files as backups
backup_jsonl = EXTRACT_DIR / "predictions.jsonl.bak"
backup_csv = EXTRACT_DIR / "predictions_summary.csv.bak"
if not backup_jsonl.exists():
    pred_json_text = open(PRED_JSONL, "r", encoding="utf-8").read()
    open(backup_jsonl, "w", encoding="utf-8").write(pred_json_text)
if not backup_csv.exists():
    pred_df.to_csv(backup_csv, index=False)
# Save fixed versions
save_jsonl(pred_recs, PRED_JSONL)
pred_df.to_csv(PRED_SUM_CSV, index=False)
print("[SAVE] Saved fixed predictions.jsonl and predictions_summary.csv (backups created if not present).")

# 2) Basic diagnostics
ids = pred_df["id"].astype(str).tolist()
empty_ids = sum(1 for x in ids if not x.strip())
dup_count = sum(pred_df["id"].duplicated())
unique_ids = len(set(ids))
print(f"[DIAG] total: {len(ids)} rows, empty_ids={empty_ids}, duplicated_ids={dup_count}, unique_ids={unique_ids}")

# novel candidates (deduped)
novel_thresh = 0.8
pred_df["novelty_score"] = pd.to_numeric(pred_df["novelty_score"].replace("", np.nan), errors="coerce")
nov_df = pred_df[pred_df["novelty_score"].notna() & (pred_df["novelty_score"] > novel_thresh)]
unique_nov_ids = nov_df["id"].unique().tolist()
print(f"[DIAG] novel candidates (>{novel_thresh}): {len(nov_df)} rows, {len(unique_nov_ids)} unique ids")
if len(unique_nov_ids) > 0:
    print("  first 10 unique novel ids:", unique_nov_ids[:10])

# top clusters with most novel candidates
if "cluster_label" in pred_df.columns:
    cluster_counts = pred_df.groupby("cluster_label").size().reset_index(name="n")
    novel_by_cluster = nov_df.groupby("cluster_label").size().reset_index(name="novel_n").sort_values("novel_n", ascending=False)
    merged = novel_by_cluster.merge(cluster_counts, on="cluster_label", how="left")
    print("[DIAG] top clusters by novel candidate count:")
    print(merged.head(10).to_string(index=False))

# per-rank predicted label counts (top 10)
ranks = [c.replace("_pred", "") for c in pred_df.columns if c.endswith("_pred")]
ranks = sorted(list(set(ranks)))
print("[DIAG] Top predicted labels per rank (top 10):")
for r in ranks:
    vc = pred_df[f"{r}_pred"].value_counts().head(10)
    print(f"  {r:10s}: {vc.to_dict()}")

# mean_confidence distribution
pred_df["mean_confidence"] = pd.to_numeric(pred_df["mean_confidence"].replace("", np.nan), errors="coerce")
print("[DIAG] mean_confidence summary:")
print(pred_df["mean_confidence"].describe().to_string())

# 3) Leakage check: if y_encoded & val_idx exist, compute exact validation accuracy from predictions_summary.csv
can_eval = ("y_encoded" in globals()) and ("val_idx" in globals())
if can_eval:
    print("[LEAK] Attempting validation accuracy check using y_encoded + val_idx from notebook globals...")
    val_idx = np.array(globals()["val_idx"], dtype=int)
    # ensure pred_df aligned with index (we wrote predictions in same order as meta)
    # for each rank, map predicted label string to encoder index (if present) and compare to y_encoded[r][val_idx]
    val_metrics = {}
    any_mismatch_ids = []
    for r in globals().get("RANKS", ranks):
        le = globals().get("label_encoders", {}).get(r, None)
        if le is None:
            continue
        # retrieve predicted labels for val indices
        pred_labels = pred_df.iloc[val_idx][f"{r}_pred"].astype(str).tolist()
        # map to indices robustly
        label_to_idx = {lab: int(i) for i, lab in enumerate(le.classes_)}
        pred_indices = []
        for pl in pred_labels:
            if pl in label_to_idx:
                pred_indices.append(label_to_idx[pl])
            else:
                # fallback: UNASSIGNED if exists, else 0
                if "UNASSIGNED" in label_to_idx:
                    pred_indices.append(label_to_idx["UNASSIGNED"])
                else:
                    pred_indices.append(0)
        true_indices = np.array(globals()["y_encoded"][r])[val_idx]
        acc = accuracy_score(true_indices, pred_indices)
        f1m = f1_score(true_indices, pred_indices, average="macro", zero_division=0)
        val_metrics[r] = {"acc": acc, "f1_macro": f1m}
    print("[LEAK] Validation metrics computed from provided y_encoded + val_idx:")
    for r, m in val_metrics.items():
        print(f"  {r:8s} acc={m['acc']:.4f}, f1_macro={m['f1_macro']:.4f}")
    # find accessions present in both train & val (accession-level overlap)
    if "train_idx" in globals():
        train_idx = np.array(globals()["train_idx"], dtype=int)
        train_ids = set(pred_df.iloc[train_idx]["id"].astype(str).tolist())
        val_ids = set(pred_df.iloc[val_idx]["id"].astype(str).tolist())
        overlap = train_ids.intersection(val_ids)
        print(f"[LEAK] accession-level overlap between train and val: {len(overlap)} accessions (first 10 shown): {list(overlap)[:10]}")
    else:
        print("[LEAK] train_idx not present, cannot check accession-level overlap.")
else:
    print("[LEAK] y_encoded or val_idx not found in globals; skipping exact validation-leakage metrics.")

# 4) Temperature scaling on validation logits: collect logits from heads for val set, per-rank calibrate temperature
# Only run if 'shared' and 'heads' modules are present
if ("shared" in globals()) and ("heads" in globals()) and ("val_loader" in globals()):
    print("[CALIB] Collecting logits for validation set...")
    shared = globals()["shared"]
    heads = globals()["heads"]
    device = globals().get("device", torch.device("cpu"))
    shared.to(device); heads.to(device)
    shared.eval(); heads.eval()
    # collect logits & labels
    logits_val = {r: [] for r in label_encoders}
    labels_val = {r: [] for r in label_encoders}
    with torch.no_grad():
        for batch in globals()["val_loader"]:
            x = batch[0].to(device)
            h = shared(x)
            for i, r in enumerate(list(label_encoders.keys())):
                logits = heads[r](h).detach().cpu().numpy()
                labels = batch[1 + i].numpy()
                logits_val[r].append(logits)
                labels_val[r].append(labels)
    # stack
    for r in logits_val:
        logits_val[r] = np.vstack(logits_val[r])
        labels_val[r] = np.concatenate(labels_val[r])
    print("[CALIB] Collected logits for ranks:", list(logits_val.keys()))

    # helper: ECE
    def expected_calibration_error(probs, labels, n_bins=15):
        probs = np.asarray(probs)
        labels = np.asarray(labels)
        preds = np.argmax(probs, axis=1)
        confs = np.max(probs, axis=1)
        ece = 0.0
        bins = np.linspace(0.0, 1.0, n_bins + 1)
        for i in range(n_bins):
            lo, hi = bins[i], bins[i+1]
            mask = (confs > lo) & (confs <= hi)
            if mask.sum() == 0:
                continue
            acc = (preds[mask] == labels[mask]).mean()
            avg_conf = confs[mask].mean()
            ece += (mask.sum() / len(probs)) * abs(avg_conf - acc)
        return float(ece)

    # optimize temperature per rank (simple Adam optimization on scalar T)
    temps = {}
    ece_before = {}
    ece_after = {}
    for r in logits_val:
        print(f"[CALIB] calibrating rank: {r}")
        logits_r = torch.tensor(logits_val[r], dtype=torch.float32, device=device)
        labels_r = torch.tensor(labels_val[r], dtype=torch.long, device=device)
        # compute probs before
        probs_before = F.softmax(logits_r, dim=1).cpu().numpy()
        ece_b = expected_calibration_error(probs_before, labels_r.cpu().numpy(), n_bins=15)
        ece_before[r] = ece_b
        # temp param
        T = torch.nn.Parameter(torch.ones(1, device=device) * 1.0)
        optT = torch.optim.LBFGS([T], lr=0.5, max_iter=50, line_search_fn='strong_wolfe')
        # closure for LBFGS
        def closure():
            optT.zero_grad()
            # numeric stability: clamp T > 1e-3
            t = T.clamp(min=1e-3)
            loss = F.cross_entropy(logits_r / t, labels_r)
            loss.backward()
            return loss
        try:
            optT.step(closure)
            T_opt = float(T.clamp(min=1e-3).item())
        except Exception as e:
            # fallback to small Adam loop
            T = torch.nn.Parameter(torch.tensor(1.0, device=device))
            opt = torch.optim.Adam([T], lr=0.01)
            for _ in range(200):
                opt.zero_grad()
                t = T.clamp(min=1e-3)
                loss = F.cross_entropy(logits_r / t, labels_r)
                loss.backward()
                opt.step()
            T_opt = float(T.clamp(min=1e-3).item())
        temps[r] = T_opt
        # compute ece after
        probs_after = F.softmax(torch.tensor(logits_val[r], dtype=torch.float32, device="cpu") / T_opt, dim=1).cpu().numpy()
        ece_a = expected_calibration_error(probs_after, labels_val[r], n_bins=15)
        ece_after[r] = ece_a
        print(f"  temp={T_opt:.4f} | ECE before={ece_b:.4f}, after={ece_a:.4f}")

    # Apply temperatures to all logits and write calibrated predictions
    print("[CALIB] Applying temperatures to all data (batched). This will produce calibrated CSV/JSONL.")
    batch = 512
    n_tot = X_pca.shape[0]
    out_recs_cal = []
    rows_cal = []
    # reload meta & predictions as baseline
    for start in range(0, n_tot, batch):
        end = min(n_tot, start+batch)
        xb = torch.tensor(X_pca[start:end], dtype=torch.float32).to(device)
        with torch.no_grad():
            h = shared(xb)
            for i in range(end-start):
                idx = start + i
                rec = pred_recs[idx].copy()
                # per rank calibrated probs
                mean_conf = 0.0; conf_count = 0
                for r in label_encoders:
                    logits_np = heads[r](h[i:i+1]).detach().cpu().numpy()[0]   # (n_classes,)
                    t = temps.get(r, 1.0)
                    probs = F.softmax(torch.tensor(logits_np / t, dtype=torch.float32), dim=0).numpy()
                    top_idx = int(np.argmax(probs))
                    top_prob = float(probs[top_idx])
                    # top-3
                    topk_idx = np.argsort(-probs)[:3]
                    topk = [{"label": str(label_encoders[r].classes_[k]), "prob": float(probs[k])} for k in topk_idx]
                    # update rec
                    rec.setdefault("predicted", {})
                    rec["predicted"][r] = {"predicted_label": str(label_encoders[r].classes_[top_idx]),
                                           "predicted_index": top_idx,
                                           "predicted_prob": top_prob,
                                           "top_k": topk}
                    mean_conf += top_prob
                    conf_count += 1
                rec["mean_confidence_calibrated"] = mean_conf / max(1, conf_count)
                out_recs_cal.append(rec)
                # CSV row
                row = {"id": rec.get("id", f"seq_{idx}"), "mean_confidence_calibrated": rec["mean_confidence_calibrated"],
                       "cluster_label": pred_df.loc[idx, "cluster_label"] if "cluster_label" in pred_df.columns else ""}
                # add per-rank pred/prob
                for r in label_encoders:
                    row[f"{r}_pred"] = rec["predicted"][r]["predicted_label"]
                    row[f"{r}_prob"] = rec["predicted"][r]["predicted_prob"]
                rows_cal.append(row)
    # write calibrated outputs
    save_jsonl(out_recs_cal, CALIB_JSONL)
    pd.DataFrame(rows_cal).to_csv(CALIB_CSV, index=False)
    print(f"[SAVE] Calibrated predictions saved: {CALIB_JSONL}, {CALIB_CSV}")

else:
    print("[CALIB] Cannot calibrate: 'shared' or 'heads' or 'val_loader' not found in globals. Skipping calibration.")

# 5) Save novel candidates summary CSV and cluster-level biodiversity summary
novel_unique_rows = pred_df[pred_df["id"].isin(unique_nov_ids)].drop_duplicates(subset=["id"])
if not novel_unique_rows.empty:
    novel_unique_rows.to_csv(NOVEL_CSV, index=False)
    print(f"[SAVE] Novel candidates summary saved: {NOVEL_CSV}")

# cluster summary: how many sequences, how many novel candidates, mean novelty, mean confidence
if "cluster_label" in pred_df.columns:
    cluster_summary = pred_df.groupby("cluster_label").agg(
        n_sequences = ("id", "count"),
        novel_count = ("novelty_score", lambda s: int((pd.to_numeric(s, errors='coerce') > novel_thresh).sum())),
        mean_novelty = ("novelty_score", lambda s: pd.to_numeric(s, errors='coerce').mean()),
        mean_confidence = ("mean_confidence", lambda s: pd.to_numeric(s, errors='coerce').mean())
    ).reset_index().sort_values("n_sequences", ascending=False)
    cluster_summary.to_csv(CLUSTER_SUM, index=False)
    print(f"[SAVE] cluster summary saved: {CLUSTER_SUM}")

# 6) Print short human-readable summary & recommendations
print("\n=== SHORT SUMMARY ===")
print(f"Total predictions: {len(pred_recs)}")
print(f"Empty IDs fixed: {fixed_count}; unique IDs now: {len(set([r['id'] for r in pred_recs]))}")
print(f"Unique novel candidate ids (novelty>{novel_thresh}): {len(unique_nov_ids)} (see {NOVEL_CSV})")
if 'temps' in locals():
    print("Per-rank temperature scaling applied. Example temps (first 5 ranks):")
    for r, t in list(temps.items())[:5]:
        print(f"  {r:10s} T={t:.4f} | ECE before={ece_before.get(r):.4f}, after={ece_after.get(r):.4f}")
print("\nRECOMMENDATIONS (next actions):")
print("  1) Re-split training/validation by accession/sample/cluster (leave-one-accession-out or by cruise) to avoid leakage.")
print("  2) Remove sequences used to build label encoders or use them only as an independent reference for cluster annotation.")
print("  3) Use hierarchical-aware losses (penalize coarse rank errors less), ensembles, and MC-dropout for uncertainty.")
print("  4) Use cluster-level annotation: cluster unknown sequences with HDBSCAN, then assign cluster-level taxonomy by nearest reference.")
print("  5) For discovery pipeline: prioritize novel candidates (high novelty_score, low nearest-neighbor similarity, high QC), and create an expert review list from NOVEL_CSV.")
print("\nCell 13 complete. Files written:", CALIB_CSV, CALIB_JSONL, NOVEL_CSV, CLUSTER_SUM)

[START] Loading files...
  loaded: predictions.jsonl 2555 records; summary csv 2555 rows; meta csv 2555 rows
[FIX] Filled 10 empty IDs (from meta).
[SAVE] Saved fixed predictions.jsonl and predictions_summary.csv (backups created if not present).
[DIAG] total: 2555 rows, empty_ids=5, duplicated_ids=929, unique_ids=1626
[DIAG] novel candidates (>0.8): 18 rows, 12 unique ids
  first 10 unique novel ids: ['XR_013100016.1', 'LC876616.1', 'LC876591.1', 'LC876590.1', 'LC876589.1', 'LC876588.1', 'PX279187.1', 'PX279186.1', 'PQ523755.1', 'PX277228.1']
[DIAG] top clusters by novel candidate count:
cluster_label  novel_n   n
           -1       18 531
[DIAG] Top predicted labels per rank (top 10):
  class     : {'UNASSIGNED': 2555}
  family    : {'UNASSIGNED': 2555}
  genus     : {'UNASSIGNED': 2555}
  kingdom   : {'UNASSIGNED': 2555}
  order     : {'UNASSIGNED': 2555}
  phylum    : {'UNASSIGNED': 2555}
  species   : {'UNASSIGNED': 2555}
[DIAG] mean_confidence summary:
count    2555.000000
mean 

In [18]:
# Cell 14: Re-split by accession (group-wise), rebuild dataloaders, retrain model with no leakage
# - Produces best_shared_heads_retrain.pt and training_history_retrain.csv
# - Saves train/val split (train_val_split_by_accession.json)

import json, time
from pathlib import Path
from collections import Counter, defaultdict
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import GroupShuffleSplit
from sklearn.preprocessing import LabelEncoder

# --- Config & paths ---
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
EMB_PCA_NPY = EXTRACT_DIR / "embeddings_pca.npy"
EMB_FULL_NPY = EXTRACT_DIR / "embeddings.npy"
META_CLUSTERED_CSV = EXTRACT_DIR / "embeddings_meta_clustered.csv"

OUT_CHECKPOINT = EXTRACT_DIR / "best_shared_heads_retrain.pt"
OUT_HISTORY = EXTRACT_DIR / "training_history_retrain.csv"
SPLIT_JSON = EXTRACT_DIR / "train_val_split_by_accession.json"

# training hyperparams (tweakable)
TEST_SIZE = 0.15
SEED = 42
BATCH_SIZE = 128
MAX_EPOCHS = 50
MIN_EPOCHS = 5
PATIENCE = 8
LR = 1e-3
HIDDEN_DIM = 256
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- sanity checks & load data ---
for p in (EMB_PCA_NPY, META_CLUSTERED_CSV):
    if not Path(p).exists():
        raise RuntimeError(f"Missing required file: {p}")

X_pca = np.load(EMB_PCA_NPY)
df_meta = pd.read_csv(META_CLUSTERED_CSV, dtype=str, keep_default_na=False, na_filter=False)

n = X_pca.shape[0]
if len(df_meta) != n:
    mn = min(len(df_meta), n)
    print(f"[ALIGN] trimming to min_n={mn}")
    df_meta = df_meta.iloc[:mn].reset_index(drop=True)
    X_pca = X_pca[:mn]
    n = mn

print(f"[LOAD] samples={n}, X_pca.shape={X_pca.shape}, meta rows={len(df_meta)}")

# --- build accession base grouping key (acc_base) ---
def get_acc_base(s):
    s = str(s) if s is not None else ""
    s = s.strip()
    if s == "":
        return None
    # split by whitespace first then by dot (remove version)
    token = s.split()[0]
    return token.split(".")[0]

df_meta["acc_base"] = df_meta.get("id", "").apply(get_acc_base)
# fill missing acc_base with synthetic unique groups to avoid dropping records
missing_acc_mask = df_meta["acc_base"].isnull() | (df_meta["acc_base"] == "")
if missing_acc_mask.any():
    # create deterministic synthetic group names based on index
    df_meta.loc[missing_acc_mask, "acc_base"] = df_meta.loc[missing_acc_mask].index.to_series().apply(lambda i: f"_missing_acc{i}")

n_groups = df_meta["acc_base"].nunique()
print(f"[GROUPS] unique accession groups (acc_base): {n_groups}")

# --- ensure y_encoded exists in globals (we used it earlier) ---
if "y_encoded" not in globals():
    raise RuntimeError("y_encoded not found in notebook globals. Ensure that label encoders & y_encoded were created in earlier cells.")

y_encoded = globals()["y_encoded"]

# Align y_encoded arrays length
for r in list(y_encoded.keys()):
    arr = np.asarray(y_encoded[r], dtype=int)
    if len(arr) != n:
        if len(arr) < n:
            arr = np.concatenate([arr, np.zeros(n - len(arr), dtype=int)])
        else:
            arr = arr[:n]
        y_encoded[r] = arr
    else:
        y_encoded[r] = arr

# --- group split by acc_base (GroupShuffleSplit) ---
groups = df_meta["acc_base"].values
indices = np.arange(n)
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
try:
    train_idx_grp, val_idx_grp = next(gss.split(indices, groups=groups))
except Exception as e:
    # fallback to random split if GroupShuffleSplit fails
    print("[WARN] GroupShuffleSplit failed:", e)
    from sklearn.model_selection import train_test_split
    train_idx_grp, val_idx_grp = train_test_split(indices, test_size=TEST_SIZE, random_state=SEED, shuffle=True)

# verify no accession overlap
train_acc = set(df_meta.loc[train_idx_grp, "acc_base"].unique().tolist())
val_acc   = set(df_meta.loc[val_idx_grp, "acc_base"].unique().tolist())
overlap = train_acc.intersection(val_acc)
if len(overlap) != 0:
    raise RuntimeError(f"Group split produced overlapping accession groups between train and val ({len(overlap)} overlaps). Abort.")

print(f"[SPLIT] train_samples={len(train_idx_grp)}, val_samples={len(val_idx_grp)}, train_groups={len(train_acc)}, val_groups={len(val_acc)}")

# save split groups for reproducibility
with open(SPLIT_JSON, "w") as fh:
    json.dump({"train_groups": list(train_acc), "val_groups": list(val_acc), "seed": SEED}, fh)
print(f"[SAVE] saved train/val group split to {SPLIT_JSON}")

# --- compute train-only class frequencies & class_weights per rank ---
label_encoders = globals().get("label_encoders", None)
if label_encoders is None:
    raise RuntimeError("label_encoders not found in globals; cannot construct class weights.")

class_weights_train = {}
for r in label_encoders:
    ncls = len(label_encoders[r].classes_)
    counts = np.bincount(y_encoded[r][train_idx_grp], minlength=ncls).astype(float)
    # avoid zeros
    counts = np.where(counts <= 0, 1.0, counts)
    w = (1.0 / counts)
    w = w / w.sum() * len(w)   # scale
    class_weights_train[r] = torch.tensor(w, dtype=torch.float32)

# --- build TensorDatasets & DataLoaders (group-safe) ---
X_tensor = torch.tensor(X_pca, dtype=torch.float32)
y_tensors = {r: torch.tensor(y_encoded[r], dtype=torch.long) for r in label_encoders}

from torch.utils.data import TensorDataset, DataLoader
train_X = X_tensor[train_idx_grp]
val_X   = X_tensor[val_idx_grp]
train_y_list = [y_tensors[r][train_idx_grp] for r in label_encoders]
val_y_list   = [y_tensors[r][val_idx_grp] for r in label_encoders]

train_ds_grp = TensorDataset(train_X, *train_y_list)
val_ds_grp   = TensorDataset(val_X, *val_y_list)
train_loader_grp = DataLoader(train_ds_grp, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
val_loader_grp   = DataLoader(val_ds_grp,   batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

print(f"[DATA] train_ds={len(train_ds_grp)}, val_ds={len(val_ds_grp)}; batch_size={BATCH_SIZE}")

# --- robust scheduler helper (same as used earlier) ---
import inspect
def create_reduce_on_plateau_scheduler(opt, mode="max", factor=0.5, patience=3, min_lr=1e-6):
    try:
        ctor = torch.optim.lr_scheduler.ReduceLROnPlateau
        sig = inspect.signature(ctor._init_)
        allowed = set(sig.parameters.keys()) - {"self", "args", "kwargs"}
        kwargs = {}
        cand = {"mode": mode, "factor": factor, "patience": patience, "min_lr": min_lr}
        for k, v in cand.items():
            if k in allowed:
                kwargs[k] = v
        if "verbose" in allowed:
            kwargs["verbose"] = False
        scheduler = ctor(opt, **kwargs)
        return scheduler
    except Exception as e:
        class DummyScheduler:
            def step(self, metric=None): return None
        print("[SCHED WARN] Could not create ReduceLROnPlateau, using DummyScheduler:", e)
        return DummyScheduler()

# --- build fresh model parts (shared_new + heads_new) ---
input_dim = X_pca.shape[1]
shared_new = nn.Sequential(
    nn.Linear(input_dim, HIDDEN_DIM),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(HIDDEN_DIM, HIDDEN_DIM//2),
    nn.ReLU()
)
heads_new = nn.ModuleDict()
for r in label_encoders:
    heads_new[r] = nn.Linear(HIDDEN_DIM//2, len(label_encoders[r].classes_))

# move to device
shared_new.to(DEVICE)
heads_new.to(DEVICE)

# --- create criterions & optimizer (train-only class weights) ---
criterions_new = {}
for r in label_encoders:
    w = class_weights_train[r].to(DEVICE)
    criterions_new[r] = nn.CrossEntropyLoss(weight=w)

params = list(shared_new.parameters()) + list(heads_new.parameters())
optimizer_new = torch.optim.Adam(params, lr=LR)
scheduler_new = create_reduce_on_plateau_scheduler(optimizer_new, mode="max", factor=0.5, patience=3, min_lr=1e-6)

# --- training loop (early stopping on aggregated val mean macro-F1) ---
from sklearn.metrics import accuracy_score, f1_score

def compute_val_metrics_shared(shared_mod, heads_mod, loader, criterions, device):
    shared_mod.eval(); heads_mod.eval()
    preds = {r: [] for r in label_encoders}
    trues = {r: [] for r in label_encoders}
    with torch.no_grad():
        for batch in loader:
            x = batch[0].to(device)
            h = shared_mod(x)
            for i, r in enumerate(label_encoders):
                logits = heads_mod[r](h)
                pred = torch.argmax(torch.softmax(logits, dim=1), dim=1).cpu().numpy()
                true = batch[i+1].cpu().numpy()
                preds[r].extend(pred.tolist())
                trues[r].extend(true.tolist())
    metrics = {}
    f1s = []
    for r in label_encoders:
        try:
            acc = accuracy_score(trues[r], preds[r])
            f1m = f1_score(trues[r], preds[r], average="macro", zero_division=0)
        except Exception:
            acc, f1m = None, None
        metrics[r] = {"acc": acc, "f1_macro": f1m}
        if f1m is not None:
            f1s.append(f1m)
    agg = float(np.mean(f1s)) if len(f1s)>0 else 0.0
    return metrics, agg

best_score = -np.inf
epochs_no_improve = 0
history = []
start = time.time()

for epoch in range(1, MAX_EPOCHS + 1):
    t0 = time.time()
    # training epoch
    shared_new.train(); heads_new.train()
    running_loss = 0.0; nb = 0
    for batch in train_loader_grp:
        x = batch[0].to(DEVICE)
        targets = [batch[i+1].to(DEVICE) for i in range(len(label_encoders))]
        h = shared_new(x)
        outputs = {r: heads_new[r](h) for r in label_encoders}
        loss = 0.0
        for i, r in enumerate(label_encoders):
            loss += criterions_new[r](outputs[r], targets[i])
        optimizer_new.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params, max_norm=5.0)
        optimizer_new.step()
        running_loss += float(loss.item())
        nb += 1
    train_loss = running_loss / max(1, nb)

    # validation metrics
    val_metrics, val_agg_f1 = compute_val_metrics_shared(shared_new, heads_new, val_loader_grp, criterions_new, DEVICE)

    # scheduler step
    try:
        scheduler_new.step(val_agg_f1)
    except TypeError:
        try:
            scheduler_new.step()
        except Exception:
            pass

    # logging
    rec = {"epoch": epoch, "train_loss": train_loss, "val_agg_f1": val_agg_f1, "time_sec": time.time()-t0}
    for r in label_encoders:
        m = val_metrics.get(r, {})
        rec.update({f"{r}_acc": m.get("acc"), f"{r}_f1": m.get("f1_macro")})
    history.append(rec)
    print(f"Epoch {epoch:03d} | train_loss {train_loss:.4f} | val_agg_f1 {val_agg_f1:.4f} | time {rec['time_sec']:.1f}s")

    # checkpoint on improvement
    if val_agg_f1 > best_score + 1e-8:
        best_score = val_agg_f1
        epochs_no_improve = 0
        torch.save({
            "shared_state": shared_new.state_dict(),
            "heads_state": {r: heads_new[r].state_dict() for r in label_encoders},
            "epoch": epoch,
            "val_agg_f1": val_agg_f1,
            "optimizer_state": optimizer_new.state_dict()
        }, OUT_CHECKPOINT)
        print(f"  [CHECKPOINT] saved new best (epoch {epoch}, val_agg_f1 {val_agg_f1:.4f}) -> {OUT_CHECKPOINT.name}")
    else:
        epochs_no_improve += 1

    # early stopping
    if epoch >= MIN_EPOCHS and epochs_no_improve >= PATIENCE:
        print(f"[EARLY STOP] No improvement for {epochs_no_improve} epochs (patience {PATIENCE}). Stopping.")
        break

# done training
total_time = time.time() - start
print(f"[TRAIN COMPLETE] epochs_run={epoch} best_val_agg_f1={best_score:.4f} total_time_sec={total_time:.1f}")

# Save history
pd.DataFrame(history).to_csv(OUT_HISTORY, index=False)
print(f"[SAVE] training history saved to {OUT_HISTORY}")

# Save split indices for reproducibility
np.save(EXTRACT_DIR / "train_idx_by_acc.npy", train_idx_grp)
np.save(EXTRACT_DIR / "val_idx_by_acc.npy", val_idx_grp)
print("[SAVE] saved train/val index arrays (train_idx_by_acc.npy, val_idx_by_acc.npy)")

# Basic post-train diagnostics
# Load best checkpoint and compute validation metrics one final time
ckpt = torch.load(OUT_CHECKPOINT, map_location=DEVICE)
shared_new.load_state_dict(ckpt["shared_state"])
for r in label_encoders:
    heads_new[r].load_state_dict(ckpt["heads_state"][r])
print(f"[LOAD] loaded best checkpoint epoch {ckpt.get('epoch')}, val_agg_f1 {ckpt.get('val_agg_f1')}")

final_val_metrics, final_val_agg = compute_val_metrics_shared(shared_new, heads_new, val_loader_grp, criterions_new, DEVICE)
print("=== Final validation metrics on group-wise split ===")
for r in label_encoders:
    m = final_val_metrics.get(r, {})
    print(f"{r:10s} acc={m.get('acc')}, f1_macro={m.get('f1_macro')}")

# Print top-level summary: class counts in train
print("[TRAIN CLASS COUNTS] per-rank sample counts in training set (top classes):")
for r in label_encoders:
    idxs = y_encoded[r][train_idx_grp]
    cnt = Counter(idxs)
    # map top numeric indices back to label strings
    top = cnt.most_common(5)
    top_labels = [(label_encoders[r].classes_[i], c) for i,c in top]
    print(f"  {r:10s}: classes={len(label_encoders[r].classes_)}, top5={top_labels}")

print("\nCell 14 finished. Outputs produced:")
print(f"  - retrained best checkpoint: {OUT_CHECKPOINT}")
print(f"  - training history: {OUT_HISTORY}")
print(f"  - group split JSON: {SPLIT_JSON}")
print(f"  - saved train/val index arrays in extracted/")

[LOAD] samples=2555, X_pca.shape=(2555, 64), meta rows=2555
[GROUPS] unique accession groups (acc_base): 1630
[SPLIT] train_samples=2182, val_samples=373, train_groups=1385, val_groups=245
[SAVE] saved train/val group split to ncbi_blast_db\extracted\train_val_split_by_accession.json
[DATA] train_ds=2182, val_ds=373; batch_size=128
[SCHED WARN] Could not create ReduceLROnPlateau, using DummyScheduler: type object 'ReduceLROnPlateau' has no attribute '_init_'
Epoch 001 | train_loss 16.6280 | val_agg_f1 1.0000 | time 0.3s
  [CHECKPOINT] saved new best (epoch 1, val_agg_f1 1.0000) -> best_shared_heads_retrain.pt
Epoch 002 | train_loss 6.7478 | val_agg_f1 1.0000 | time 408.3s
Epoch 003 | train_loss 0.3275 | val_agg_f1 1.0000 | time 0.4s
Epoch 004 | train_loss 0.0071 | val_agg_f1 1.0000 | time 0.4s
Epoch 005 | train_loss 0.0030 | val_agg_f1 1.0000 | time 0.4s
Epoch 006 | train_loss 0.0024 | val_agg_f1 1.0000 | time 0.5s
Epoch 007 | train_loss 0.0022 | val_agg_f1 1.0000 | time 0.6s
Epoch 008

In [19]:
# Corrected Cell 15 — robust diagnostics & label reconstruction (fixed UNASSIGNED truncation issue)
import json, pickle, traceback
from pathlib import Path
from collections import defaultdict, Counter
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder

# Paths
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR  = DOWNLOAD_DIR / "extracted"
META_CSV     = EXTRACT_DIR / "embeddings_meta_clustered.csv"
LABEL_ENCODERS_PATH = EXTRACT_DIR / "label_encoders_used.pkl"
FETCHED_JSONS = {m: EXTRACT_DIR / f"{m}_fetched_metadata.json" for m in ("ssu","lsu","its")}
ASSIGN_DEBUG_CSV = EXTRACT_DIR / "label_assignment_debug.csv"
REBUILT_ENCODERS_PKL = EXTRACT_DIR / "label_encoders_rebuilt_v2.pkl"

CANON_UNASSIGNED = "_UNASSIGNED_"   # canonical placeholder (avoid accidental truncation)

# Load meta
if not META_CSV.exists():
    raise RuntimeError(f"Missing meta CSV: {META_CSV}")
df_meta = pd.read_csv(META_CSV, dtype=str, keep_default_na=False, na_filter=False)
n_meta = len(df_meta)
print(f"[LOAD] meta rows: {n_meta}; columns: {list(df_meta.columns)[:12]}")

# Load saved encoders if present (for inspection)
label_encoders_saved = None
if LABEL_ENCODERS_PATH.exists():
    try:
        with open(LABEL_ENCODERS_PATH, "rb") as fh:
            label_encoders_saved = pickle.load(fh)
        print(f"[LOAD] Found saved label encoders: {LABEL_ENCODERS_PATH.name}")
    except Exception as e:
        print("[WARN] failed to load saved label encoders:", e)

# Read fetched metadata JSONs (best-effort)
fetched_records = {}
for key, p in FETCHED_JSONS.items():
    if not p.exists():
        print(f"[FETCHED] {key}: file not found at {p}")
        fetched_records[key] = []
        continue
    try:
        recs = json.load(open(p, "r", encoding="utf-8"))
        fetched_records[key] = recs if isinstance(recs, list) else []
        print(f"[FETCHED] {key}: records={len(fetched_records[key])}")
    except Exception as e:
        print(f"[WARN] Could not parse {p}: {e}")
        fetched_records[key] = []

# Build lookups: exact accession, acc_base, organism (lowercase)
lookup_exact = {}
lookup_base  = {}
org_lookup = defaultdict(list)

def normalize_str(s):
    # ensure safe str, strip whitespace, replace problematic control chars
    if s is None:
        return ""
    s = str(s)
    s = s.strip()
    # replace unprintable with empty
    s = "".join(ch for ch in s if ch.isprintable())
    return s

for key, recs in fetched_records.items():
    for rec in recs:
        acc = normalize_str(rec.get("accession") or rec.get("accession_version") or rec.get("id") or "")
        if acc:
            lookup_exact[acc] = rec
            base = acc.split(".")[0]
            lookup_base[base] = rec
        org = normalize_str(rec.get("organism") or rec.get("description") or "")
        if org:
            org_lookup[org.lower()].append(rec)

print(f"[LOOKUPS] exact_acc={len(lookup_exact)}, acc_base={len(lookup_base)}, org_names={len(org_lookup)}")

# Build assignments: iterate df_meta and attempt to match each id to fetched metadata
ranks = ["kingdom","phylum","class","order","family","genus","species"]
assigned_rows = []
no_match = 0

for idx, row in df_meta.iterrows():
    seq_id = normalize_str(row.get("id",""))
    acc_base = seq_id.split()[0].split(".")[0] if seq_id else ""
    rec = None
    matched_source = ""
    if seq_id and seq_id in lookup_exact:
        rec = lookup_exact[seq_id]; matched_source = "exact_acc"
    elif acc_base and acc_base in lookup_base:
        rec = lookup_base[acc_base]; matched_source = "acc_base"
    else:
        # try substring match in organism names (cheap)
        found = None
        seq_low = seq_id.lower()
        for orgname, recs in org_lookup.items():
            if orgname and orgname in seq_low:
                found = recs[0]
                matched_source = "org_substr"
                break
        if found:
            rec = found
    label_map = {r: "" for r in ranks}
    if rec:
        taxonomy = rec.get("taxonomy") or []
        # taxonomy list often is ordered from kingdom downwards
        for i, rank_name in enumerate(["kingdom","phylum","class","order","family","genus"]):
            if i < len(taxonomy) and taxonomy[i]:
                label_map[rank_name] = normalize_str(taxonomy[i])
        # species from organism/description
        organism = normalize_str(rec.get("organism") or rec.get("description") or "")
        parts = organism.split()
        if len(parts) >= 2:
            label_map["genus"] = label_map.get("genus") or normalize_str(parts[0])
            label_map["species"] = normalize_str(" ".join(parts[:2]))
    else:
        no_match += 1
    assigned_rows.append({
        "index": int(idx),
        "id": seq_id,
        "acc_base": acc_base,
        "matched_source": matched_source,
        "has_match": bool(bool(rec)),
        **{f"assigned_{r}": (label_map[r] if label_map[r] else "") for r in ranks},
        "meta_cluster_label": row.get("cluster_label",""),
        "meta_novelty": row.get("novelty_score",""),
        "meta_organism": row.get("organism","")
    })

df_assign = pd.DataFrame(assigned_rows)
print(f"[ASSIGN] total={len(df_assign)} matched={len(df_assign)-no_match} no_match={no_match}")
# Save debug assignment CSV (overwrite safely)
df_assign.to_csv(ASSIGN_DEBUG_CSV, index=False)
print(f"[SAVE] wrote assignment debug CSV: {ASSIGN_DEBUG_CSV}")

# Summaries per-rank assigned counts
print("\n[SUMMARIES] per-rank assigned counts (from fetched metadata):")
for r in ranks:
    col = f"assigned_{r}"
    nonempty = df_assign[col].astype(str).replace("", np.nan).notna().sum()
    pct = 100.0 * nonempty / len(df_assign)
    print(f"  {r:8s}: assigned_count={nonempty} ({pct:.2f}%)")

# Examples
print("\n[EXAMPLES] sequences with assigned species (first 10):")
examples_assigned = df_assign[df_assign["assigned_species"].astype(bool)].head(10)
if not examples_assigned.empty:
    print(examples_assigned[["index","id","acc_base","matched_source","assigned_genus","assigned_species","meta_cluster_label","meta_novelty"]].to_string(index=False))
else:
    print("  None found")

print("\n[EXAMPLES] sequences with NO assignment (first 10):")
examples_unassigned = df_assign[~df_assign["has_match"]].head(10)
if not examples_unassigned.empty:
    print(examples_unassigned[["index","id","acc_base","meta_cluster_label","meta_novelty"]].to_string(index=False))
else:
    print("  None found")

# --- Rebuild encoders from sequences that have at least one assigned value ---
assigned_any_mask = df_assign[[f"assigned_{r}" for r in ranks]].astype(bool).any(axis=1)
n_assigned_any = assigned_any_mask.sum()
print(f"\n[CHECK] sequences with any assigned taxonomy: {n_assigned_any} / {len(df_assign)}")

min_examples_threshold = 50
if n_assigned_any < min_examples_threshold:
    print(f"[SKIP] Not enough assigned sequences to rebuild encoders (need >= {min_examples_threshold})")
else:
    print("[REBUILD] Rebuilding label encoders from sequences with assigned taxonomy (best-effort).")
    df_good = df_assign[assigned_any_mask].copy()
    # id->per-rank label (sanitized)
    id2labels = {}
    for _, r in df_good.iterrows():
        id2labels[r["id"]] = {f: (normalize_str(r[f]) if r[f] else CANON_UNASSIGNED) for f in [f"assigned_{x}" for x in ranks]}

    new_encoders = {}
    new_y_encoded = {}

    for r in ranks:
        col = f"assigned_{r}"
        # labels present among good sequences (sanitized)
        lbls_good = [normalize_str(x) if x else CANON_UNASSIGNED for x in df_good[col].astype(str).tolist()]
        # ensure canonical unassigned present
        unique_labels = sorted(set(lbls_good))
        if CANON_UNASSIGNED not in unique_labels:
            unique_labels.append(CANON_UNASSIGNED)
        # Fit encoder on unique_labels
        le = LabelEncoder()
        le.fit(unique_labels)
        new_encoders[r] = le

        # Build full_labels aligned with df_meta (use id2labels map, fallback to CANON_UNASSIGNED)
        full_labels = []
        for idx, row in df_meta.iterrows():
            curid = normalize_str(row.get("id",""))
            lab = id2labels.get(curid, {}).get(col, CANON_UNASSIGNED)
            if not lab:
                lab = CANON_UNASSIGNED
            full_labels.append(lab)
        # Now transform safely
        # As we fit le on unique_labels union CANON_UNASSIGNED, transform should not fail
        transformed = le.transform(full_labels)
        new_y_encoded[r] = np.array(transformed, dtype=int)
        print(f"  rebuilt encoder for {r:8s}: n_classes={len(le.classes_)} sample_classes={list(le.classes_)[:10]}")

    # Save rebuilt encoders and per-rank arrays
    with open(REBUILT_ENCODERS_PKL, "wb") as fh:
        pickle.dump(new_encoders, fh)
    for r in ranks:
        np.save(EXTRACT_DIR / f"y_encoded_rebuilt_{r}.npy", new_y_encoded[r])
    print(f"[SAVE] saved rebuilt encoders -> {REBUILT_ENCODERS_PKL} and y_encoded_rebuilt_*.npy")

print("\n=== DIAGNOSTIC SUMMARY & NEXT STEPS ===")
print("1) Assignment debug file saved (label_assignment_debug.csv). Inspect it to confirm matching logic is correct.")
print("2) If rebuilt encoders were produced, you can re-train on only labeled sequences by loading y_encoded_rebuilt_* arrays and using the train indices where the label != CANON_UNASSIGNED.")
print("3) If many sequences are still unlabeled, consider: (a) BLAST seeding; (b) cluster pseudo-labeling with HDBSCAN; (c) hierarchical coarse-label training.")
print("4) If you want, I will now produce the retrain cell that trains only on sequences with real labels (using the rebuilt encoders) — ready-to-run. Say 'do it' and I'll provide it as the next cell.")

[LOAD] meta rows: 2555; columns: ['id', 'source_fasta', 'seq_len', 'n_kmers', 'n_kmers_missing', 'zero_vector_fallback', 'PC1', 'PC2', 'PC3', 'PC4', 'PC5', 'PC6']
[LOAD] Found saved label encoders: label_encoders_used.pkl
[FETCHED] ssu: records=468
[FETCHED] lsu: records=406
[FETCHED] its: records=699
[LOOKUPS] exact_acc=1231, acc_base=1231, org_names=362
[ASSIGN] total=2555 matched=1131 no_match=1424
[SAVE] wrote assignment debug CSV: ncbi_blast_db\extracted\label_assignment_debug.csv

[SUMMARIES] per-rank assigned counts (from fetched metadata):
  kingdom : assigned_count=1131 (44.27%)
  phylum  : assigned_count=1131 (44.27%)
  class   : assigned_count=1131 (44.27%)
  order   : assigned_count=1131 (44.27%)
  family  : assigned_count=1131 (44.27%)
  genus   : assigned_count=1131 (44.27%)
  species : assigned_count=1131 (44.27%)

[EXAMPLES] sequences with assigned species (first 10):
 index                id        acc_base matched_source  assigned_genus       assigned_species meta_clu

In [20]:
# Cell: Retrain on labeled-only sequences (multi-head), group-wise split by accession
# Save outputs: best_shared_heads_labeled.pt, training_history_labeled.csv, metrics_labeled.json

import os, json, time
from pathlib import Path
from collections import Counter, defaultdict
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import accuracy_score, f1_score

# ---------------- Paths / config ----------------
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"

ENC_REBUILT_PKL = EXTRACT_DIR / "label_encoders_rebuilt_v2.pkl"
Y_REBUILT_PREFIX = EXTRACT_DIR / "y_encoded_rebuilt_"
ASSIGN_DEBUG_CSV = EXTRACT_DIR / "label_assignment_debug.csv"
EMB_FULL = EXTRACT_DIR / "embeddings.npy"
EMB_PCA  = EXTRACT_DIR / "embeddings_pca.npy"

OUT_CHECKPOINT = EXTRACT_DIR / "best_shared_heads_labeled.pt"
OUT_HISTORY = EXTRACT_DIR / "training_history_labeled.csv"
OUT_METRICS = EXTRACT_DIR / "metrics_labeled.json"

RANKS = ["kingdom","phylum","class","order","family","genus","species"]
SEED = 42
TEST_SIZE = 0.15
BATCH_SIZE = 128
MAX_EPOCHS = 60
MIN_EPOCHS = 5
PATIENCE = 8
LR = 1e-3
HIDDEN_DIM = 256
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

np.random.seed(SEED)
torch.manual_seed(SEED)

# ---------------- Load rebuilt encoders and labels ----------------
if not ENC_REBUILT_PKL.exists():
    raise RuntimeError(f"Missing rebuilt encoders file {ENC_REBUILT_PKL}. Run Cell 15 to produce it.")
with open(ENC_REBUILT_PKL, "rb") as fh:
    label_encoders = pickle.load(fh) if "pickle" in globals() else _import_("pickle").load(fh)
print(f"[LOAD] loaded rebuilt encoders from {ENC_REBUILT_PKL.name}")

# Load y arrays per rank (full-length, aligned with meta)
y_rebuilt = {}
n_samples = None
for r in RANKS:
    p = Path(f"{Y_REBUILT_PREFIX}{r}.npy")
    if not p.exists():
        raise RuntimeError(f"Missing y-encoded array for rank {r}: expected {p}")
    arr = np.load(p)
    y_rebuilt[r] = arr.astype(int)
    n_samples = len(arr) if n_samples is None else n_samples
    print(f"[LOAD] loaded y_encoded_rebuilt_{r}.npy shape={arr.shape}")

# ---------------- Load embeddings (prefer full embeddings) ----------------
if EMB_FULL.exists():
    X = np.load(EMB_FULL)
    print(f"[LOAD] Using full embeddings {EMB_FULL.name} shape={X.shape}")
elif EMB_PCA.exists():
    X = np.load(EMB_PCA)
    print(f"[LOAD] Using PCA embeddings {EMB_PCA.name} shape={X.shape}")
else:
    raise RuntimeError("No embeddings found (expected embeddings.npy or embeddings_pca.npy in extracted).")

if X.shape[0] != n_samples:
    mn = min(X.shape[0], n_samples)
    print(f"[ALIGN] trimming to min_n={mn}")
    X = X[:mn]
    for r in RANKS:
        y_rebuilt[r] = y_rebuilt[r][:mn]
    n_samples = mn

# ---------------- Build labeled-only index mask ----------------
# Detect index of UNASSIGNED in each encoder (robust)
unassigned_index_by_rank = {}
for r in RANKS:
    classes = getattr(label_encoders[r], "classes_", None)
    if classes is None:
        raise RuntimeError(f"Encoder for rank {r} missing 'classes_' attribute.")
    # try canonical names
    uni = list(classes)
    idx = None
    for cand in ("UNASSIGNED", "_UNASSIGNED_", "UNASSIGNE", ""):
        if cand in uni:
            idx = uni.index(cand)
            break
    # fallback: if one class only, treat missing label as that single class (no unassigned)
    if idx is None:
        if len(uni) == 1:
            idx = 0
        else:
            # if no explicit UNASSIGNED, try to find a class containing 'UNASSIGN' substring
            idx = next((i for i,c in enumerate(uni) if "UNASSIGN" in str(c).upper()), None)
            if idx is None:
                # set to -1 meaning "no explicit unassigned"
                idx = -1
    unassigned_index_by_rank[r] = idx
print("[INFO] detected unassigned indices per rank:", unassigned_index_by_rank)

# Build mask: labeled if ANY rank has label != unassigned_index (i.e., a real label)
is_labeled = np.zeros(n_samples, dtype=bool)
for r in RANKS:
    idx_un = unassigned_index_by_rank[r]
    if idx_un == -1:
        # treat as labeled for all (no explicit unassigned)
        is_labeled = is_labeled | np.ones(n_samples, dtype=bool)
    else:
        is_labeled = is_labeled | (y_rebuilt[r] != idx_un)

n_labeled = int(is_labeled.sum())
if n_labeled < 50:
    raise RuntimeError(f"Too few labeled samples ({n_labeled}) to train well. Need >=50. Inspect label_assignment_debug.csv.")
print(f"[FILTER] labeled samples count: {n_labeled} / {n_samples} (will train on these)")

labeled_idx = np.nonzero(is_labeled)[0]

# ---------------- Build accessions/groups for leakage-aware split ----------------
if not ASSIGN_DEBUG_CSV.exists():
    raise RuntimeError(f"Missing {ASSIGN_DEBUG_CSV} — needed for accession grouping to avoid leakage.")
df_assign = pd.read_csv(ASSIGN_DEBUG_CSV, dtype=str, keep_default_na=False, na_filter=False)
# ensure alignment: df_assign length should equal n_samples; else attempt to align by trimming/padding
if len(df_assign) != n_samples:
    print(f"[WARN] assignment CSV length {len(df_assign)} != n_samples {n_samples}. Aligning by index min.")
    mn = min(len(df_assign), n_samples)
    df_assign = df_assign.iloc[:mn].reset_index(drop=True)
    X = X[:mn]
    for r in RANKS:
        y_rebuilt[r] = y_rebuilt[r][:mn]
    labeled_idx = labeled_idx[labeled_idx < mn]
    n_samples = mn
    n_labeled = int(np.sum(is_labeled[:mn]))

groups = df_assign["acc_base"].astype(str).replace({"": "_missing_acc_"})
group_array = groups.values

# Group-wise split among labeled indices
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, val_idx = None, None
for tr, va in gss.split(labeled_idx, groups=group_array[labeled_idx], y=None):
    train_idx = labeled_idx[tr]
    val_idx = labeled_idx[va]
print(f"[SPLIT] labeled train={len(train_idx)}, labeled val={len(val_idx)}; groups_train={len(set(group_array[train_idx]))}, groups_val={len(set(group_array[val_idx]))}")

# verify no accession overlap
if len(set(group_array[train_idx]).intersection(set(group_array[val_idx]))) > 0:
    raise RuntimeError("Group split had overlapping accession groups — aborting.")

# ---------------- Build DataLoaders ----------------
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensors = {r: torch.tensor(y_rebuilt[r], dtype=torch.long) for r in RANKS}

from torch.utils.data import TensorDataset, DataLoader
train_X = X_tensor[train_idx]
val_X   = X_tensor[val_idx]
train_y_list = [y_tensors[r][train_idx] for r in RANKS]
val_y_list   = [y_tensors[r][val_idx] for r in RANKS]

train_ds = TensorDataset(train_X, *train_y_list)
val_ds   = TensorDataset(val_X, *val_y_list)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=0)
print(f"[DATA] train_ds={len(train_ds)}, val_ds={len(val_ds)}; batch_size={BATCH_SIZE}")

# ---------------- Build model parts (shared + ModuleDict heads) ----------------
input_dim = X.shape[1]
shared_labeled = nn.Sequential(
    nn.Linear(input_dim, HIDDEN_DIM),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(HIDDEN_DIM, HIDDEN_DIM//2),
    nn.ReLU()
)
heads_labeled = nn.ModuleDict({ r: nn.Linear(HIDDEN_DIM//2, len(label_encoders[r].classes_)) for r in RANKS })

shared_labeled.to(DEVICE)
heads_labeled.to(DEVICE)

# ---------------- Compute class weights from training set (train-only) ----------------
class_weights = {}
for r in RANKS:
    ytr = train_y_list[RANKS.index(r)].numpy()
    ncls = len(label_encoders[r].classes_)
    counts = np.bincount(ytr, minlength=ncls).astype(float)
    counts = np.where(counts <= 0, 1.0, counts)
    weights = 1.0 / counts
    weights = weights / weights.sum() * len(weights)
    class_weights[r] = torch.tensor(weights, dtype=torch.float32).to(DEVICE)
    print(f"[WEIGHT] {r}: n_classes={ncls}, min_count={counts.min():.0f}, max_count={counts.max():.0f}")

# ---------------- Losses, optimizer, scheduler ----------------
criterions = { r: nn.CrossEntropyLoss(weight=class_weights[r]) for r in RANKS }
params = list(shared_labeled.parameters()) + list(heads_labeled.parameters())
optimizer = torch.optim.Adam(params, lr=LR)

# robust reduce-on-plateau creation
import inspect
def make_scheduler(opt, mode="max", factor=0.5, patience=3, min_lr=1e-6):
    try:
        ctor = torch.optim.lr_scheduler.ReduceLROnPlateau
        sig = inspect.signature(ctor._init_)
        allowed = set(sig.parameters.keys()) - {"self","args","kwargs"}
        kwargs = {}
        cand = {"mode": mode, "factor": factor, "patience": patience, "min_lr": min_lr}
        for k,v in cand.items():
            if k in allowed:
                kwargs[k] = v
        if "verbose" in allowed:
            kwargs["verbose"] = False
        return ctor(opt, **kwargs)
    except Exception:
        class Dummy:
            def step(self, metric=None): return None
        return Dummy()

scheduler = make_scheduler(optimizer, mode="max", factor=0.5, patience=3, min_lr=1e-6)

# ---------------- Training / evaluation helpers ----------------
def evaluate_model(shared_mod, heads_mod, loader, device):
    shared_mod.eval(); heads_mod.eval()
    preds = {r: [] for r in RANKS}
    trues = {r: [] for r in RANKS}
    confidences = {r: [] for r in RANKS}
    with torch.no_grad():
        for batch in loader:
            x = batch[0].to(device)
            h = shared_mod(x)
            for i, r in enumerate(RANKS):
                logits = heads_mod[r](h)
                probs = F.softmax(logits, dim=1)
                top = torch.argmax(probs, dim=1).cpu().numpy()
                preds[r].extend(top.tolist())
                trues[r].extend(batch[1+i].cpu().numpy().tolist())
                confidences[r].extend(probs.max(dim=1).values.cpu().numpy().tolist())
    metrics = {}
    f1s = []
    for r in RANKS:
        try:
            acc = accuracy_score(trues[r], preds[r])
            f1m = f1_score(trues[r], preds[r], average="macro", zero_division=0)
            mean_conf = float(np.mean(confidences[r])) if len(confidences[r])>0 else None
        except Exception:
            acc, f1m, mean_conf = None, None, None
        metrics[r] = {"accuracy": acc, "f1_macro": f1m, "mean_confidence": mean_conf, "n_classes": len(label_encoders[r].classes_)}
        if f1m is not None:
            f1s.append(f1m)
    agg = float(np.mean(f1s)) if len(f1s)>0 else 0.0
    return metrics, agg

# ---------------- Training loop ----------------
best_score = -np.inf
epochs_no_improve = 0
history = []
start_time = time.time()

for epoch in range(1, MAX_EPOCHS+1):
    t0 = time.time()
    shared_labeled.train(); heads_labeled.train()
    running_loss = 0.0; nb = 0
    for batch in train_loader:
        x = batch[0].to(DEVICE)
        targets = [batch[i+1].to(DEVICE) for i in range(len(RANKS))]
        h = shared_labeled(x)
        outputs = { r: heads_labeled[r](h) for r in RANKS }
        loss = 0.0
        for i, r in enumerate(RANKS):
            loss = loss + criterions[r](outputs[r], targets[i])
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params, max_norm=5.0)
        optimizer.step()
        running_loss += float(loss.item())
        nb += 1
    train_loss = running_loss / max(1, nb)

    # validation
    val_metrics, val_agg = evaluate_model(shared_labeled, heads_labeled, val_loader, DEVICE)
    # scheduler step (ReduceLROnPlateau accepts metric in steady versions but be defensive)
    try:
        scheduler.step(val_agg)
    except TypeError:
        try:
            scheduler.step()
        except Exception:
            pass

    rec = {"epoch": epoch, "train_loss": train_loss, "val_agg_f1": val_agg, "time_sec": time.time()-t0}
    for r in RANKS:
        m = val_metrics.get(r, {})
        rec.update({f"{r}_acc": m.get("accuracy"), f"{r}_f1_macro": m.get("f1_macro"), f"{r}_mean_conf": m.get("mean_confidence")})
    history.append(rec)

    print(f"Epoch {epoch:03d} | train_loss: {train_loss:.4f} | val_agg_f1: {val_agg:.4f} | time: {rec['time_sec']:.1f}s")

    # checkpoint on improvement
    if val_agg > best_score + 1e-8:
        best_score = val_agg
        epochs_no_improve = 0
        torch.save({
            "shared_state": shared_labeled.state_dict(),
            "heads_state": {r: heads_labeled[r].state_dict() for r in RANKS},
            "epoch": epoch,
            "val_agg_f1": val_agg,
            "optimizer_state": optimizer.state_dict()
        }, OUT_CHECKPOINT)
        print(f"  [CHECKPOINT] saved new best model (epoch {epoch}, val_agg_f1={val_agg:.4f}) -> {OUT_CHECKPOINT.name}")
    else:
        epochs_no_improve += 1

    # early stopping
    if epoch >= MIN_EPOCHS and epochs_no_improve >= PATIENCE:
        print(f"[EARLY STOP] No improvement for {epochs_no_improve} epochs (patience={PATIENCE}). Stopping.")
        break

# Save history and metrics
pd.DataFrame(history).to_csv(OUT_HISTORY, index=False)
metrics_final, agg_final = evaluate_model(shared_labeled, heads_labeled, val_loader, DEVICE)
with open(OUT_METRICS, "w") as fh:
    json.dump({"final_val_agg_f1": agg_final, "per_rank": metrics_final}, fh, indent=2)
print(f"[SAVE] history -> {OUT_HISTORY}; metrics -> {OUT_METRICS}")

# Load best checkpoint and print final per-rank metrics
if OUT_CHECKPOINT.exists():
    ckpt = torch.load(OUT_CHECKPOINT, map_location=DEVICE)
    shared_labeled.load_state_dict(ckpt["shared_state"])
    for r in RANKS:
        heads_labeled[r].load_state_dict(ckpt["heads_state"][r])
    best_epoch = ckpt.get("epoch")
    best_score_saved = ckpt.get("val_agg_f1")
    print(f"[LOAD] loaded best checkpoint epoch {best_epoch}, val_agg_f1={best_score_saved:.4f}")

final_metrics, final_agg = evaluate_model(shared_labeled, heads_labeled, val_loader, DEVICE)
print("=== Final validation metrics on labeled split ===")
for r in RANKS:
    m = final_metrics.get(r, {})
    print(f"{r:10s} | n_classes={m.get('n_classes', '?')} | acc={m.get('accuracy'):.4f} | f1_macro={m.get('f1_macro'):.4f} | mean_conf={m.get('mean_confidence'):.4f}")
print(f"Aggregated mean-macro-F1: {final_agg:.4f}")

# Show top predicted labels distribution on validation
print("\n[VAL] Top predicted labels per rank (val set, top 5):")
shared_labeled.eval(); heads_labeled.eval()
with torch.no_grad():
    xval = torch.tensor(X[val_idx], dtype=torch.float32).to(DEVICE)
    hval = shared_labeled(xval)
    for r in RANKS:
        logits = heads_labeled[r](hval)
        probs = F.softmax(logits, dim=1).cpu().numpy()
        preds = np.argmax(probs, axis=1)
        labels = label_encoders[r].classes_
        vc = Counter([labels[p] for p in preds])
        print(f"  {r:10s}: {vc.most_common(5)}")

print("\nCell complete. Outputs produced:")
print(" - model checkpoint:", OUT_CHECKPOINT)
print(" - training history:", OUT_HISTORY)
print(" - metrics json:", OUT_METRICS)

[LOAD] loaded rebuilt encoders from label_encoders_rebuilt_v2.pkl
[LOAD] loaded y_encoded_rebuilt_kingdom.npy shape=(2555,)
[LOAD] loaded y_encoded_rebuilt_phylum.npy shape=(2555,)
[LOAD] loaded y_encoded_rebuilt_class.npy shape=(2555,)
[LOAD] loaded y_encoded_rebuilt_order.npy shape=(2555,)
[LOAD] loaded y_encoded_rebuilt_family.npy shape=(2555,)
[LOAD] loaded y_encoded_rebuilt_genus.npy shape=(2555,)
[LOAD] loaded y_encoded_rebuilt_species.npy shape=(2555,)
[LOAD] Using full embeddings embeddings.npy shape=(2555, 128)
[INFO] detected unassigned indices per rank: {'kingdom': 1, 'phylum': 4, 'class': 9, 'order': 11, 'family': 18, 'genus': 26, 'species': 181}
[FILTER] labeled samples count: 1131 / 2555 (will train on these)
[SPLIT] labeled train=963, labeled val=168; groups_train=651, groups_val=116
[DATA] train_ds=963, val_ds=168; batch_size=128
[WEIGHT] kingdom: n_classes=2, min_count=1, max_count=963
[WEIGHT] phylum: n_classes=5, min_count=1, max_count=613
[WEIGHT] class: n_classes=1

In [21]:
# Cell 16 — Inference + per-rank temperature calibration (if validation indices available) + save predictions
# Outputs:
#  - predictions_raw.jsonl
#  - predictions_raw.csv
#  - predictions_calibrated.jsonl
#  - predictions_calibrated.csv
#  - novel_candidates.csv (priority list by novelty & low confidence)
#
# Defensive: checks files, aligns shapes, uses batched inference, uses LBFGS/Adam fallback for temperature scaling.

import json, time, math
from pathlib import Path
from collections import Counter, defaultdict
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sklearn.metrics import log_loss

# -------- Config / Paths --------
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR  = DOWNLOAD_DIR / "extracted"

CHECKPOINT = EXTRACT_DIR / "best_shared_heads_labeled.pt"
ENC_PKL    = EXTRACT_DIR / "label_encoders_rebuilt_v2.pkl"
META_CSV   = EXTRACT_DIR / "embeddings_meta_clustered.csv"
EMB_FULL   = EXTRACT_DIR / "embeddings.npy"
EMB_PCA    = EXTRACT_DIR / "embeddings_pca.npy"
VAL_IDX    = EXTRACT_DIR / "val_idx_by_acc.npy"   # optional, used for calibration if present
Y_PREFIX   = EXTRACT_DIR / "y_encoded_rebuilt_"

OUT_RAW_JSONL  = EXTRACT_DIR / "predictions_raw.jsonl"
OUT_RAW_CSV    = EXTRACT_DIR / "predictions_raw.csv"
OUT_CAL_JSONL  = EXTRACT_DIR / "predictions_calibrated.jsonl"
OUT_CAL_CSV    = EXTRACT_DIR / "predictions_calibrated.csv"
OUT_NOVEL_CSV  = EXTRACT_DIR / "novel_candidates.csv"

RANKS = ["kingdom","phylum","class","order","family","genus","species"]
TOPK = 3
BATCH = 256
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------- Sanity checks --------
for p in (CHECKPOINT, ENC_PKL, META_CSV):
    if not Path(p).exists():
        raise RuntimeError(f"Required file missing: {p}")

# -------- Load encoders, meta, embeddings --------
import pickle
with open(ENC_PKL, "rb") as fh:
    label_encoders = pickle.load(fh)
print("[LOAD] label_encoders loaded.")

df_meta = pd.read_csv(META_CSV, dtype=str, keep_default_na=False, na_filter=False)
if EMB_FULL.exists():
    X = np.load(EMB_FULL)
    print(f"[LOAD] embeddings.npy shape={X.shape}")
elif EMB_PCA.exists():
    X = np.load(EMB_PCA)
    print(f"[LOAD] embeddings_pca.npy shape={X.shape}")
else:
    raise RuntimeError("No embeddings found (expect embeddings.npy or embeddings_pca.npy).")

n = min(len(df_meta), X.shape[0])
if len(df_meta) != X.shape[0]:
    print(f"[ALIGN] trimming to min_n={n}")
    df_meta = df_meta.iloc[:n].reset_index(drop=True)
    X = X[:n]

print(f"[INFO] n_samples={n}")

# -------- Reconstruct model architecture and load checkpoint --------
input_dim = X.shape[1]
HIDDEN_DIM = 256

shared = torch.nn.Sequential(
    torch.nn.Linear(input_dim, HIDDEN_DIM),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.3),
    torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM//2),
    torch.nn.ReLU()
)
heads = torch.nn.ModuleDict({ r: torch.nn.Linear(HIDDEN_DIM//2, len(label_encoders[r].classes_)) for r in RANKS })

ckpt = torch.load(CHECKPOINT, map_location="cpu")
try:
    shared.load_state_dict(ckpt["shared_state"])
    for r in RANKS:
        heads[r].load_state_dict(ckpt["heads_state"][r])
    print(f"[LOAD] checkpoint loaded (epoch {ckpt.get('epoch')}, val_agg_f1={ckpt.get('val_agg_f1')})")
except Exception:
    # fallback: try partial load with strict=False
    shared.load_state_dict(ckpt.get("shared_state", {}), strict=False)
    for r in RANKS:
        if r in ckpt.get("heads_state", {}):
            heads[r].load_state_dict(ckpt["heads_state"].get(r, {}), strict=False)
    print("[WARN] partial checkpoint load (strict=False)")

shared.to(DEVICE).eval()
heads.to(DEVICE).eval()

# -------- Batched forward pass: collect logits and raw probs --------
all_logits = {r: [] for r in RANKS}
all_raw_probs = {r: [] for r in RANKS}
all_topk_idx = {r: [] for r in RANKS}
pred_indices = {r: [] for r in RANKS}

with torch.no_grad():
    for start in range(0, n, BATCH):
        end = min(n, start+BATCH)
        xb = torch.tensor(X[start:end], dtype=torch.float32, device=DEVICE)
        h = shared(xb)
        for r in RANKS:
            logits = heads[r](h)                # (batch, ncls)
            probs = F.softmax(logits, dim=1).cpu().numpy()
            logits_np = logits.cpu().numpy()
            topk = np.argsort(-probs, axis=1)[:, :TOPK]
            preds = np.argmax(probs, axis=1)
            for i in range(end-start):
                all_logits[r].append(logits_np[i])
                all_raw_probs[r].append(probs[i])
                all_topk_idx[r].append(list(map(int, topk[i])))
                pred_indices[r].append(int(preds[i]))

print("[INFER] completed raw forward pass.")

# -------- Build raw prediction records & compute mean_confidence (raw) --------
records_raw = []
for i in range(n):
    rec = {"index": int(i), "id": df_meta.loc[i,"id"] if "id" in df_meta.columns else f"seq_{i}"}
    probs_top = []
    rec_pred = {}
    for r in RANKS:
        probs = all_raw_probs[r][i]
        pred_idx = pred_indices[r][i]
        pred_label = str(label_encoders[r].classes_[pred_idx])
        pred_prob = float(probs[pred_idx])
        topk = [{"label": str(label_encoders[r].classes_[k]), "prob": float(probs[k])} for k in all_topk_idx[r][i]]
        rec_pred[r] = {"predicted_label": pred_label, "predicted_index": pred_idx, "predicted_prob": pred_prob, "top_k": topk}
        probs_top.append(pred_prob)
    rec["predicted"] = rec_pred
    rec["mean_conf_raw"] = float(np.mean(probs_top))
    # add cluster and novelty metadata if present
    rec["cluster_label"] = df_meta.loc[i,"cluster_label"] if "cluster_label" in df_meta.columns else "-1"
    rec["novelty_score"] = float(df_meta.loc[i,"novelty_score"]) if ("novelty_score" in df_meta.columns and df_meta.loc[i,"novelty_score"]!="") else None
    records_raw.append(rec)

# Save raw outputs
with open(OUT_RAW_JSONL, "w", encoding="utf-8") as fh:
    for r in records_raw:
        fh.write(json.dumps(r, ensure_ascii=False) + "\n")
# Flatten CSV
rows_csv = []
for r in records_raw:
    row = {"index": r["index"], "id": r["id"], "cluster_label": r["cluster_label"], "novelty_score": r["novelty_score"], "mean_conf_raw": r["mean_conf_raw"]}
    for rank in RANKS:
        row[f"{rank}_pred"] = r["predicted"][rank]["predicted_label"]
        row[f"{rank}_prob"] = r["predicted"][rank]["predicted_prob"]
    rows_csv.append(row)
pd.DataFrame(rows_csv).to_csv(OUT_RAW_CSV, index=False)
print(f"[SAVE] raw predictions saved: {OUT_RAW_JSONL}, {OUT_RAW_CSV}")

# -------- Temperature scaling calibration (per-rank) if validation indices available --------
temps = {r: 1.0 for r in RANKS}
if VAL_IDX.exists():
    try:
        val_idx = np.load(VAL_IDX)
        # build logits and labels on validation set
        print(f"[CALIB] val_idx len={len(val_idx)} -> performing per-rank temperature scaling.")
        for r in RANKS:
            # load y_encoded_rebuilt for this rank if available
            y_path = Path(f"{Y_PREFIX}{r}.npy")
            if not y_path.exists():
                print(f"[CALIB] y file missing for rank {r}, skipping calibration.")
                continue
            y_full = np.load(y_path)
            y_val = y_full[val_idx]
            # gather logits numpy for val indices
            logits_val = np.vstack([all_logits[r][i] for i in val_idx])
            # optimize temperature scalar using simple LBFGS on cross-entropy
            logits_t = torch.tensor(logits_val, dtype=torch.float32, device=DEVICE)
            labels_t = torch.tensor(y_val, dtype=torch.long, device=DEVICE)
            T = torch.nn.Parameter(torch.ones(1, device=DEVICE))
            optimizerT = torch.optim.LBFGS([T], lr=0.5, max_iter=50)

            def closure():
                optimizerT.zero_grad()
                t = torch.clamp(T, min=1e-3)
                loss = F.cross_entropy(logits_t / t, labels_t)
                loss.backward()
                return loss
            try:
                optimizerT.step(closure)
                T_opt = float(max(1e-3, T.item()))
            except Exception:
                # fallback: small Adam loop
                T = torch.nn.Parameter(torch.tensor(1.0, device=DEVICE))
                opt = torch.optim.Adam([T], lr=0.01)
                for _ in range(200):
                    opt.zero_grad()
                    t = torch.clamp(T, min=1e-3)
                    loss = F.cross_entropy(torch.tensor(logits_val, device=DEVICE) / t, labels_t)
                    loss.backward(); opt.step()
                T_opt = float(max(1e-3, T.item()))
            temps[r] = T_opt
            print(f"[CALIB] rank={r} T={temps[r]:.4f}")
    except Exception as e:
        print("[CALIB WARN] calibration failed: ", e)
        temps = {r: 1.0 for r in RANKS}
else:
    print("[CALIB] val index file not found — skipping calibration (using T=1.0).")

# -------- Apply temperatures and save calibrated predictions --------
records_cal = []
rows_cal = []
with torch.no_grad():
    for i in range(n):
        rec = {"index": int(i), "id": df_meta.loc[i,"id"] if "id" in df_meta.columns else f"seq_{i}"}
        rec_pred = {}
        probs_top = []
        for r in RANKS:
            logits = np.array(all_logits[r][i])
            T = temps.get(r, 1.0)
            # compute scaled probs
            scaled = F.softmax(torch.tensor(logits / T, dtype=torch.float32), dim=0).cpu().numpy()
            pred_idx = int(np.argmax(scaled))
            pred_label = str(label_encoders[r].classes_[pred_idx])
            pred_prob = float(scaled[pred_idx])
            topk_idx = np.argsort(-scaled)[:TOPK]
            topk = [{"label": str(label_encoders[r].classes_[k]), "prob": float(scaled[k])} for k in topk_idx]
            rec_pred[r] = {"predicted_label": pred_label, "predicted_index": pred_idx, "predicted_prob": pred_prob, "top_k": topk}
            probs_top.append(pred_prob)
        rec["predicted"] = rec_pred
        rec["mean_conf_calibrated"] = float(np.mean(probs_top))
        rec["cluster_label"] = df_meta.loc[i,"cluster_label"] if "cluster_label" in df_meta.columns else "-1"
        rec["novelty_score"] = float(df_meta.loc[i,"novelty_score"]) if ("novelty_score" in df_meta.columns and df_meta.loc[i,"novelty_score"]!="") else None
        records_cal.append(rec)
        row = {"index": rec["index"], "id": rec["id"], "cluster_label": rec["cluster_label"], "novelty_score": rec["novelty_score"], "mean_conf_calibrated": rec["mean_conf_calibrated"]}
        for r in RANKS:
            row[f"{r}_pred"] = rec["predicted"][r]["predicted_label"]
            row[f"{r}_prob"] = rec["predicted"][r]["predicted_prob"]
        rows_cal.append(row)

with open(OUT_CAL_JSONL, "w", encoding="utf-8") as fh:
    for r in records_cal:
        fh.write(json.dumps(r, ensure_ascii=False) + "\n")
pd.DataFrame(rows_cal).to_csv(OUT_CAL_CSV, index=False)
print(f"[SAVE] calibrated predictions saved: {OUT_CAL_JSONL}, {OUT_CAL_CSV}")

# -------- Novel candidate prioritization (cluster-level) --------
dfc = pd.DataFrame(rows_cal)
# ensure numeric novelty_score
dfc["novelty_score"] = pd.to_numeric(dfc["novelty_score"].replace("", np.nan), errors="coerce")
# cluster summary: n, mean_novelty, mean_confidence, species_consensus_frac
cluster = dfc.groupby("cluster_label").agg(
    n_sequences=("id","count"),
    mean_novelty=("novelty_score", "mean"),
    mean_confidence=("mean_conf_calibrated","mean")
).reset_index()
# compute species consensus fraction per cluster
def consensus_frac(group, rank):
    return group[rank + "_pred"].value_counts(normalize=True).max()

consensus_rows = []
for cl, g in dfc.groupby("cluster_label"):
    row = {"cluster_label": cl, "n_sequences": len(g)}
    for r in ["species","genus","family"]:
        row[f"{r}_consensus_frac"] = consensus_frac(g, r)
    row["mean_novelty"] = float(g["novelty_score"].mean()) if g["novelty_score"].notna().any() else 0.0
    row["mean_conf"] = float(g["mean_conf_calibrated"].mean())
    consensus_rows.append(row)
df_cluster = pd.DataFrame(consensus_rows)
# priority score heuristic
df_cluster["priority_score"] = (df_cluster["mean_novelty"].fillna(0.0) * 0.6) + ((1.0 - df_cluster["species_consensus_frac"].fillna(0.0)) * 0.25) + ((1.0 - df_cluster["mean_conf"].fillna(0.0)) * 0.15)
df_cluster = df_cluster.sort_values("priority_score", ascending=False)
df_cluster.to_csv(OUT_NOVEL_CSV, index=False)
print(f"[SAVE] novel candidates per-cluster saved: {OUT_NOVEL_CSV}")

# -------- Print summary diagnostics --------
print("\n=== Summary ===")
print(f"Samples predicted: {n}")
print("Per-rank average calibrated mean_confidence (sample mean):")
for r in RANKS:
    vals = [rec["predicted"][r]["predicted_prob"] for rec in records_cal]
    print(f"  {r:8s}: mean_prob={float(np.mean(vals)):.4f}  # classes={len(label_encoders[r].classes_)}")

print("\nTop 10 priority novel clusters (cluster_label, n_sequences, mean_novelty, species_consensus_frac, mean_conf, priority_score):")
print(df_cluster.head(10).loc[:, ["cluster_label","n_sequences","mean_novelty","species_consensus_frac","mean_conf","priority_score"]].to_string(index=False))

print("\nCell 16 complete. Files created:")
print(" -", OUT_RAW_JSONL, OUT_RAW_CSV)
print(" -", OUT_CAL_JSONL, OUT_CAL_CSV)
print(" -", OUT_NOVEL_CSV)

[LOAD] label_encoders loaded.
[LOAD] embeddings.npy shape=(2555, 128)
[INFO] n_samples=2555
[LOAD] checkpoint loaded (epoch 36, val_agg_f1=0.8890957155130087)
[INFER] completed raw forward pass.
[SAVE] raw predictions saved: ncbi_blast_db\extracted\predictions_raw.jsonl, ncbi_blast_db\extracted\predictions_raw.csv
[CALIB] val_idx len=373 -> performing per-rank temperature scaling.


Consider using tensor.detach() first. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\autograd\generated\python_variable_methods.cpp:836.)
  loss = float(closure())


[CALIB] rank=kingdom T=2440.1846
[CALIB] rank=phylum T=3093.2192
[CALIB] rank=class T=639.1340
[CALIB] rank=order T=71.4672
[CALIB] rank=family T=63.5838
[CALIB] rank=genus T=33.9495
[CALIB] rank=species T=14.6255
[SAVE] calibrated predictions saved: ncbi_blast_db\extracted\predictions_calibrated.jsonl, ncbi_blast_db\extracted\predictions_calibrated.csv
[SAVE] novel candidates per-cluster saved: ncbi_blast_db\extracted\novel_candidates.csv

=== Summary ===
Samples predicted: 2555
Per-rank average calibrated mean_confidence (sample mean):
  kingdom : mean_prob=0.5029  # classes=2
  phylum  : mean_prob=0.2008  # classes=5
  class   : mean_prob=0.1022  # classes=10
  order   : mean_prob=0.0936  # classes=13
  family  : mean_prob=0.0656  # classes=19
  genus   : mean_prob=0.0540  # classes=27
  species : mean_prob=0.0144  # classes=182

Top 10 priority novel clusters (cluster_label, n_sequences, mean_novelty, species_consensus_frac, mean_conf, priority_score):
cluster_label  n_sequences  m

In [51]:
# Reconstruct train_dataset / val_dataset from available artifacts
import numpy as np, torch, pickle, os
from pathlib import Path
from torch.utils.data import TensorDataset

DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
print("[INFO] DOWNLOAD_DIR:", DOWNLOAD_DIR)
print("[INFO] EXTRACT_DIR:", EXTRACT_DIR)

if not EXTRACT_DIR.exists():
    raise FileNotFoundError(f"Expected extracted dir at {EXTRACT_DIR} but it does not exist.")

# list extracted files for debugging
print("\n[FILES] top-level files in extracted/:")
for p in sorted(EXTRACT_DIR.iterdir()):
    print(" ", p.name)

# 1) Choose embeddings/features file
candidates_feats = [
    EXTRACT_DIR / "embeddings.npy",
    EXTRACT_DIR / "embeddings_pca.npy",
    EXTRACT_DIR / "embeddings_pca.npy",
    EXTRACT_DIR / "embeddings.npy"
]
feat_path = None
for f in candidates_feats:
    if f.exists():
        feat_path = f
        break

if feat_path is None:
    # fallback: pick any .npy with 'embed' in name
    for f in EXTRACT_DIR.rglob("*.npy"):
        if "embed" in f.name.lower():
            feat_path = f
            break

if feat_path is None:
    raise FileNotFoundError("No embeddings file found (tried embeddings.npy and embeddings_pca.npy).")

print(f"[LOAD] Using features file: {feat_path}")
X = np.load(feat_path, allow_pickle=True)
# If npz-like, try first key
if hasattr(X, "files"):
    key = X.files[0]
    print("[LOAD] feature file is .npz; using key:", key)
    X = X[key]
X = np.asarray(X)
print("[SHAPE] features shape:", X.shape)

# 2) Load label arrays for expected ranks
ranks = ["kingdom","phylum","class","order","family","genus","species"]
label_arrays = {}
for r in ranks:
    # try multiple filename patterns, prefer 'final' then 'rebuilt' then any matching
    tried = []
    candidates = [
        EXTRACT_DIR / f"y_encoded_final_{r}.npy",
        EXTRACT_DIR / f"y_encoded_rebuilt_{r}.npy",
        EXTRACT_DIR / f"y_encoded_{r}.npy",
        EXTRACT_DIR / f"y_encoded_final_{r}.npz",
    ]
    arr = None
    for c in candidates:
        tried.append(c.name)
        if c.exists():
            arr = np.load(c, allow_pickle=True)
            if hasattr(arr, "files"):
                arr = arr[arr.files[0]]
            arr = np.asarray(arr)
            print(f"[LOAD] Loaded {c.name} -> shape {arr.shape}")
            break
    if arr is None:
        # last resort: find any file matching pattern *y_encoded*{rank}*.npy
        for f in EXTRACT_DIR.rglob(f"*y*encoded*{r}*.npy"):
            arr = np.load(f, allow_pickle=True)
            if hasattr(arr, "files"):
                arr = arr[arr.files[0]]
            arr = np.asarray(arr)
            print(f"[LOAD] Fallback loaded {f.name} -> shape {arr.shape}")
            break
    if arr is None:
        print(f"[WARN] Could not find explicit y_encoded file for rank '{r}'. Will not include this rank.")
    else:
        label_arrays[r] = arr

if not label_arrays:
    raise RuntimeError("No label arrays found. Expected files like y_encoded_final_species.npy etc.")

# Check that the label arrays all have same length as X
n = X.shape[0]
for r, arr in label_arrays.items():
    if arr.shape[0] != n:
        raise RuntimeError(f"Length mismatch: features have {n} rows but label '{r}' has {arr.shape[0]} rows. "
                           "This suggests ordering mismatch between features and labels.")

print("[CHECK] All loaded labels match features length:", n)
print("Loaded ranks:", list(label_arrays.keys()))

# 3) Load train / val indices (preferred: train_idx_final.npy / val_idx_final.npy)
idx_candidates = {
    "train": [EXTRACT_DIR / "train_idx_final.npy", EXTRACT_DIR / "train_idx_by_acc.npy", EXTRACT_DIR / "train_idx.npy", EXTRACT_DIR / "train_idx_final.npy"],
    "val":   [EXTRACT_DIR / "val_idx_final.npy", EXTRACT_DIR / "val_idx_by_acc.npy", EXTRACT_DIR / "val_idx.npy", EXTRACT_DIR / "val_idx_final.npy"]
}
indices = {}
for split, cand_list in idx_candidates.items():
    idx = None
    for c in cand_list:
        if c.exists():
            idx = np.load(c, allow_pickle=True)
            if hasattr(idx, "files"):
                # unlikely, but handle .npz
                idx = idx[idx.files[0]]
            idx = np.asarray(idx).astype(int)
            print(f"[LOAD] {split} indices loaded from {c.name} -> {idx.shape[0]} indices")
            break
    if idx is None:
        # fallback: try patterns
        for f in EXTRACT_DIR.rglob(f"*{split}*idx*.npy"):
            idx = np.load(f, allow_pickle=True)
            if hasattr(idx, "files"):
                idx = idx[idx.files[0]]
            idx = np.asarray(idx).astype(int)
            print(f"[LOAD] fallback {split} indices loaded from {f.name} -> {idx.shape[0]} indices")
            break
    if idx is None:
        print(f"[WARN] Could not find {split} indices. Will build splits using 90/10 default split.")
    indices[split] = idx

# If either train/val indices missing, create a deterministic split (seeded)
if indices.get("train") is None or indices.get("val") is None:
    all_idx = np.arange(n)
    rng = np.random.RandomState(12345)
    rng.shuffle(all_idx)
    # 90% train, 10% val
    cut = int(n * 0.9)
    indices["train"] = all_idx[:cut]
    indices["val"] = all_idx[cut:]
    print(f"[SPLIT] Created deterministic 90/10 split: train {indices['train'].shape[0]} / val {indices['val'].shape[0]}")

# 4) Build TensorDatasets
def build_td(idx_array, X, label_arrays):
    X_sub = torch.tensor(X[np.asarray(idx_array)], dtype=torch.float32)
    tensors = [X_sub]
    # maintain order of ranks requested (only include ranks we found)
    for r in ranks:
        if r in label_arrays:
            arr = np.asarray(label_arrays[r])[np.asarray(idx_array)]
            tensors.append(torch.tensor(arr.astype("int64")))
    # append weight (ones)
    weights = torch.ones(X_sub.size(0), dtype=torch.float32)
    tensors.append(weights)
    td = TensorDataset(*tensors)
    return td

train_td = build_td(indices["train"], X, label_arrays)
val_td = build_td(indices["val"], X, label_arrays)

# Save reconstructed datasets
train_save = EXTRACT_DIR / "train_dataset_reconstructed.pt"
val_save = EXTRACT_DIR / "val_dataset_reconstructed.pt"
torch.save(train_td, train_save)
torch.save(val_td, val_save)
print(f"[SAVED] train dataset -> {train_save} (n={len(train_td)})")
print(f"[SAVED] val   dataset -> {val_save} (n={len(val_td)})")

# Place into globals() for notebook use (mimics original expectation)
globals()["train_dataset"] = train_td
globals()["val_dataset"] = val_td

# also try to load label_encoders if present (optional)
le_path = EXTRACT_DIR / "label_encoders_final.pkl"
if le_path.exists():
    try:
        with open(le_path, "rb") as fh:
            gl = pickle.load(fh)
        globals()["label_encoders"] = gl
        print("[LOADED] label_encoders from", le_path.name)
    except Exception as e:
        print("[WARN] Could not load label_encoders:", e)

# Sanity print
print("\n[SANITY] train sample tensors shapes:")
try:
    s = train_td[0]
    print("  len(sample) =", len(s))
    print("  shapes:", [getattr(t, "shape", None) for t in s])
except Exception as e:
    print("  Could not inspect sample:", e)

print("\nDone — you can now re-run the training cell. If anything errors, copy the printed warnings here and I'll iterate a fix.")


[INFO] DOWNLOAD_DIR: ncbi_blast_db
[INFO] EXTRACT_DIR: ncbi_blast_db\extracted

[FILES] top-level files in extracted/:
  best_shared_heads.pt
  best_shared_heads_defensive.pt
  best_shared_heads_defensive_state_dict.pt
  best_shared_heads_labeled.pt
  best_shared_heads_pseudo_tensordataset_fix.pt
  best_shared_heads_resumed.pt
  best_shared_heads_resumed_state_dict.pt
  best_shared_heads_retrain.pt
  calibration_bins_class.csv
  calibration_bins_family.csv
  calibration_bins_genus.csv
  calibration_bins_kingdom.csv
  calibration_bins_order.csv
  calibration_bins_phylum.csv
  calibration_bins_species.csv
  calibration_metrics_by_rank.csv
  cluster_summary.csv
  cluster_summary.json
  confusion_matrix_class.csv
  confusion_matrix_family.csv
  confusion_matrix_genus.csv
  confusion_matrix_kingdom.csv
  confusion_matrix_order.csv
  confusion_matrix_phylum.csv
  confusion_matrix_species.csv
  embeddings.npy
  embeddings_meta.csv
  embeddings_meta_clustered.csv
  embeddings_meta_clustered.js

In [53]:
# Cell A — Final evaluation & per-class diagnostics (run after training)
# Produces per-rank CSV metrics, confusion matrices, species top-k accuracy, and a summary.

import os
from pathlib import Path
import json
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.nn as nn
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, accuracy_score
from sklearn.preprocessing import LabelEncoder

# ---------------- Config / paths ----------------
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
OUT_DIR = EXTRACT_DIR  # output saved here

# candidate checkpoint names (prefer best pseudo checkpoint)
CAND_CKPTS = [
    "best_shared_heads_pseudo_tensordataset_fix.pt",
    "best_shared_heads_pseudo_tensordataset_fix.pt",
    "best_shared_heads_pseudo_tensordataset_fix.pt",
    "best_shared_heads_pseudo_tensordataset_fix.pt",
    "best_shared_heads_pseudo_tensordataset_fix.pt",  # duplicates safe — leave as search list
    "best_shared_heads_pseudo_tensordataset_fix.pt", 
    "best_shared_heads_pseudo_tensordataset_fix.pt",
    "best_shared_heads_labeled.pt",
    "best_shared_heads_retrain.pt",
    "best_shared_heads.pt",
    "best_shared_heads_pseudo.pt",
    "best_shared_heads_pseudo_tensordataset.pt"
]

# flexible: search for "best*.pt" as fallback
def find_checkpoint():
    for name in CAND_CKPTS:
        p = EXTRACT_DIR / name
        if p.exists():
            return p
    # fallback: pick the newest best_.pt or any *shared.pt
    cand = sorted(EXTRACT_DIR.glob("best*.pt"), key=lambda p: p.stat().st_mtime, reverse=True)
    if cand:
        return cand[0]
    cand2 = sorted(EXTRACT_DIR.glob("shared.pt"), key=lambda p: p.stat().st_mtime, reverse=True)
    if cand2:
        return cand2[0]
    raise FileNotFoundError("No checkpoint found in extracted/ (tried best*.pt and shared.pt)")

CKPT_PATH = find_checkpoint()
print("[USE] checkpoint:", CKPT_PATH.name)

# ---------------- Load label encoders ----------------
# prefer rebuilt encoders used for training
enc_paths = [
    EXTRACT_DIR / "label_encoders_rebuilt_v2.pkl",
    EXTRACT_DIR / "label_encoders_used.pkl",
    EXTRACT_DIR / "label_encoders_v2.pkl",
    EXTRACT_DIR / "label_encoders.pkl"
]
label_encoders = None
for p in enc_paths:
    if p.exists():
        import pickle
        with open(p, "rb") as fh:
            label_encoders = pickle.load(fh)
        print("[LOAD] label_encoders from", p.name)
        break
if label_encoders is None:
    raise FileNotFoundError("Label encoders not found in extracted/. Expected one of: " + ", ".join(str(p.name) for p in enc_paths))

RANKS = ["kingdom","phylum","class","order","family","genus","species"]

# ---------------- Load embeddings ----------------
EMB_FULL = EXTRACT_DIR / "embeddings.npy"
EMB_PCA  = EXTRACT_DIR / "embeddings_pca.npy"
if EMB_FULL.exists():
    X = np.load(EMB_FULL)
    print("[LOAD] embeddings.npy shape:", X.shape)
elif EMB_PCA.exists():
    X = np.load(EMB_PCA)
    print("[LOAD] embeddings_pca.npy shape:", X.shape)
else:
    raise FileNotFoundError("No embeddings file found in extracted/ (embeddings.npy or embeddings_pca.npy)")

# ---------------- Load validation indices ----------------
# Prefer previously saved group-wise split files
possible_idx_names = [
    EXTRACT_DIR / "val_idx_by_acc.npy",
    EXTRACT_DIR / "val_idx.npy",
    EXTRACT_DIR / "val_idx_by_acc.npy",
    EXTRACT_DIR / "val_idx_by_acc.npy"
]
val_idx = None
for p in possible_idx_names:
    if p.exists():
        val_idx = np.load(p)
        print("[LOAD] val_idx from", p.name, "len=", len(val_idx))
        break

# fallback: use train_val_split_by_accession.json if present (and choose val groups)
if val_idx is None:
    split_json = EXTRACT_DIR / "train_val_split_by_accession.json"
    if split_json.exists():
        data = json.loads(split_json.read_text())
        val_idx = np.array(data.get("val_idx", data.get("val_indices", [])), dtype=int)
        print("[LOAD] val_idx from", split_json.name, "len=", len(val_idx))

# fallback: if val_idx still None, try to use val_dataset from globals
if val_idx is None:
    if "val_dataset" in globals():
        # val_dataset is a TensorDataset; we'll infer indices by len(train_dataset) + positions not available.
        # Instead, we will create val set from metadata: try embeddings_meta_clustered.csv and predictions_summary_calibrated.csv
        print("[WARN] val_idx files not found; attempting to build val set from predictions_summary_calibrated.csv")
        preds_csv = EXTRACT_DIR / "predictions_summary_calibrated.csv"
        meta_csv = EXTRACT_DIR / "embeddings_meta_clustered.csv"
        if preds_csv.exists() and meta_csv.exists():
            df_preds = pd.read_csv(preds_csv, dtype=str, keep_default_na=False, na_filter=False)
            # expect 'is_val' or 'fold' column not guaranteed — as fallback pick last 15% as val
            n = len(df_preds)
            k = max(1, int(0.15*n))
            val_idx = np.arange(n-k, n, dtype=int)
            print("[FALLBACK] chosen last 15% indices for val (len=%d)" % len(val_idx))
        else:
            raise RuntimeError("Cannot build val_idx automatically. Provide saved val_idx numpy or val_dataset in globals.")
else:
    val_idx = np.asarray(val_idx, dtype=int)

# Truncate if embeddings shorter
n_samples = X.shape[0]
val_idx = val_idx[val_idx < n_samples]
print("[INFO] using val_idx length:", len(val_idx))

# ---------------- Build val tensors ----------------
import torch
from torch.utils.data import TensorDataset, DataLoader

X_val = torch.tensor(X[val_idx], dtype=torch.float32)
y_val_list = []
# load per-rank encoded y arrays if present (y_encoded_rebuilt_*.npy), else try predictions CSV to derive labels
y_prefix = EXTRACT_DIR / "y_encoded_rebuilt_"
has_y_arrays = True
for r in RANKS:
    p = Path(f"{y_prefix}{r}.npy")
    if not p.exists():
        has_y_arrays = False
        break

if has_y_arrays:
    for r in RANKS:
        arr = np.load(Path(f"{y_prefix}{r}.npy"))
        if arr.shape[0] < n_samples:
            # pad
            pad = np.full(n_samples - arr.shape[0], fill_value=0, dtype=int)
            arr = np.concatenate([arr, pad])
        y_val_list.append(torch.tensor(arr[val_idx], dtype=torch.long))
    print("[LOAD] loaded y_encoded_rebuilt_* arrays for true labels.")
else:
    # fallback: try to get true labels from predictions_summary_calibrated.csv (they may be same as preds)
    preds_csv = EXTRACT_DIR / "predictions_summary_calibrated.csv"
    if preds_csv.exists():
        dfp = pd.read_csv(preds_csv, dtype=str, keep_default_na=False, na_filter=False)
        # attempt to get per-rank predicted label columns; we'll treat them as "truth" only if they were originally assigned (best-effort)
        for r in RANKS:
            col = f"{r}_pred"
            if col in dfp.columns:
                # map predicted label strings to encoder indices (if encoder contains that label)
                lab_to_idx = {lab: idx for idx, lab in enumerate(label_encoders[r].classes_)}
                idxs = []
                for lab in dfp[col].astype(str).tolist():
                    idxs.append(lab_to_idx.get(lab, 0))
                arr = np.array(idxs, dtype=int)
                y_val_list.append(torch.tensor(arr[val_idx], dtype=torch.long))
            else:
                # default all zero
                y_val_list.append(torch.zeros(len(val_idx), dtype=torch.long))
        print("[FALLBACK] used predictions_summary_calibrated.csv to build pseudo 'truth' arrays (best-effort).")
    else:
        raise RuntimeError("True label arrays not found (y_encoded_rebuilt_*.npy) and predictions_summary_calibrated.csv missing. Cannot compute diagnostics.")

# assemble val dataset & loader
# final layout: (x, y0, y1, ..., y6)
val_dataset = TensorDataset(X_val, *y_val_list)
VAL_BATCH = min(128, max(16, len(val_dataset)//2))
val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH, shuffle=False)

# ---------------- Build model_obj (no new class) ----------------
hidden_dim = 256
model_obj = nn.Module()
model_obj.shared = nn.Sequential(
    nn.Linear(X_val.shape[1], hidden_dim),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(hidden_dim, hidden_dim // 2),
    nn.ReLU()
)
heads = {}
for r in RANKS:
    ncls = len(label_encoders[r].classes_)
    heads[r] = nn.Linear(hidden_dim // 2, ncls)
model_obj.heads = nn.ModuleDict(heads)

# bind forward
import types
def _forward(self, x):
    h = self.shared(x)
    return { r: self.heads[r](h) for r in self.heads }
model_obj.forward = types.MethodType(_forward, model_obj)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_obj.to(device)

# ---------------- Load checkpoint into model ----------------
ck = torch.load(CKPT_PATH, map_location="cpu")
loaded = False
# try common shapes
try:
    if isinstance(ck, dict) and "model_state" in ck:
        model_obj.load_state_dict(ck["model_state"], strict=False); loaded = True
    elif isinstance(ck, dict) and "shared_state" in ck and "heads_state" in ck:
        # merge
        cur_sd = model_obj.state_dict()
        for k,v in ck["shared_state"].items():
            tk = f"shared.{k}" if not k.startswith("shared.") else k
            if tk in cur_sd: cur_sd[tk] = v
        for hname, hsd in ck["heads_state"].items():
            for subk, v in hsd.items():
                tk = f"heads.{hname}.{subk}"
                if tk in cur_sd: cur_sd[tk] = v
        model_obj.load_state_dict(cur_sd, strict=False); loaded = True
    elif isinstance(ck, dict) and "state_dict" in ck:
        model_obj.load_state_dict(ck["state_dict"], strict=False); loaded = True
    else:
        # try as raw state_dict
        model_obj.load_state_dict(ck, strict=False); loaded = True
except Exception as e:
    print("[WARN] checkpoint load partial/failed:", e)

print("[MODEL] checkpoint loaded -> device:", device, "loaded_ok:", loaded)
model_obj.to(device)
model_obj.eval()

# ---------------- Run inference on validation set & gather predictions ----------------
all_preds = {r: [] for r in RANKS}
all_probs = {r: [] for r in RANKS}
all_trues = {r: [] for r in RANKS}

with torch.no_grad():
    for batch in val_loader:
        x = batch[0].to(device)
        y_trues_batch = [batch[1 + i].numpy() for i in range(len(RANKS))]
        outputs = model_obj(x)
        for r in RANKS:
            logits = outputs[r]
            probs = F.softmax(logits, dim=1).cpu().numpy()
            preds = np.argmax(probs, axis=1)
            all_preds[r].extend(preds.tolist())
            all_probs[r].extend(probs.tolist())
            all_trues[r].extend(y_trues_batch[RANKS.index(r)].tolist())

# convert to numpy arrays
for r in RANKS:
    all_preds[r] = np.array(all_preds[r], dtype=int)
    all_trues[r] = np.array(all_trues[r], dtype=int)
    all_probs[r] = np.array(all_probs[r], dtype=float)

# ---------------- Per-rank per-class metrics, confusion matrices & CSV saves ----------------
summary_rows = []
for r in RANKS:
    true = all_trues[r]
    pred = all_preds[r]
    classes = list(label_encoders[r].classes_)
    # compute precision/recall/f1/support
    p, rcall, f1, support = precision_recall_fscore_support(true, pred, labels=np.arange(len(classes)), zero_division=0)
    df_cls = pd.DataFrame({
        "class_index": np.arange(len(classes)),
        "class_name": classes,
        "precision": p,
        "recall": rcall,
        "f1": f1,
        "support": support
    })
    # save per-rank class metrics
    fn = OUT_DIR / f"per_class_metrics_{r}.csv"
    df_cls.to_csv(fn, index=False)
    # confusion matrix (rows=true, cols=pred)
    try:
        cm = confusion_matrix(true, pred, labels=np.arange(len(classes)))
        cm_df = pd.DataFrame(cm, index=classes, columns=classes)
        cm_file = OUT_DIR / f"confusion_matrix_{r}.csv"
        cm_df.to_csv(cm_file)
    except Exception as e:
        print(f"[WARN] could not create confusion matrix for {r}: {e}")
    # rank summary
    macro_f1 = float(np.nanmean(df_cls["f1"]))
    acc = float(accuracy_score(true, pred))
    mean_conf = float(np.mean(np.max(all_probs[r], axis=1))) if all_probs[r].size else 0.0
    summary_rows.append({"rank": r, "n_classes": len(classes), "val_acc": acc, "val_macro_f1": macro_f1, "mean_confidence": mean_conf})
    print(f"[METRIC] {r:8s} | acc={acc:.4f} | macro_f1={macro_f1:.4f} | mean_conf={mean_conf:.4f} | saved -> {fn.name}")

pd.DataFrame(summary_rows).to_csv(OUT_DIR / "evaluation_summary_by_rank.csv", index=False)

# ---------------- species top-k accuracy (Top-1 and Top-5) ----------------
if "species" in RANKS:
    sp_probs = all_probs["species"]  # shape (n_val, n_species)
    sp_true = all_trues["species"]
    top1 = np.mean(np.argmax(sp_probs, axis=1) == sp_true)
    k = min(5, sp_probs.shape[1])
    topk_preds = np.argsort(sp_probs, axis=1)[:, ::-1][:, :k]  # descending top-k indices
    topk_hit = np.array([1 if sp_true[i] in topk_preds[i] else 0 for i in range(len(sp_true))])
    topk = np.mean(topk_hit)
    print(f"[TOP-K] species top-1 acc: {top1:.4f} | top-{k} acc: {topk:.4f}")
    pd.DataFrame({"top1": [float(top1)], f"top{k}": [float(topk)]}).to_csv(OUT_DIR / "species_topk_accuracy.csv", index=False)

# ---------------- highlight worst classes (low f1 but support >= threshold) ----------------
alerts = []
for r in RANKS:
    df_cls = pd.read_csv(OUT_DIR / f"per_class_metrics_{r}.csv")
    # ignore UNASSIGNED if present when highlighting (but still saved)
    df_non_un = df_cls.copy()
    if "UNASSIGNED" in df_non_un["class_name"].values:
        df_non_un = df_non_un[df_non_un["class_name"] != "UNASSIGNED"]
    # find low f1 classes with reasonable support
    low_f1 = df_non_un[(df_non_un["support"] >= 5)].sort_values("f1").head(8)
    if not low_f1.empty:
        alerts.append((r, low_f1[["class_name","support","f1"]].to_dict(orient="records")))

# print alert summary
if alerts:
    print("\n[ALERT] Classes with low F1 (support>=5) — inspect CSVs for details:")
    for rank, recs in alerts:
        print(f" - {rank}: {len(recs)} low-performing classes (examples):")
        for ex in recs[:3]:
            print(f"     {ex['class_name'][:60]:60s} support={int(ex['support'])}, f1={ex['f1']:.3f}")
else:
    print("\n[ALERT] No low-F1 classes with support>=5 found in val set.")

print("\n[COMPLETE] Evaluation outputs saved to:", EXTRACT_DIR)
print("Key files: evaluation_summary_by_rank.csv, per_class_metrics_<rank>.csv, confusion_matrix_<rank>.csv, species_topk_accuracy.csv (if species present).")

[USE] checkpoint: best_shared_heads_pseudo_tensordataset_fix.pt
[LOAD] label_encoders from label_encoders_rebuilt_v2.pkl
[LOAD] embeddings.npy shape: (2555, 128)
[LOAD] val_idx from val_idx_by_acc.npy len= 373
[INFO] using val_idx length: 373
[LOAD] loaded y_encoded_rebuilt_* arrays for true labels.
[WARN] checkpoint load partial/failed: Error(s) in loading state_dict for Module:
	size mismatch for heads.phylum.weight: copying a param with shape torch.Size([6, 128]) from checkpoint, the shape in current model is torch.Size([5, 128]).
	size mismatch for heads.phylum.bias: copying a param with shape torch.Size([6]) from checkpoint, the shape in current model is torch.Size([5]).
	size mismatch for heads.class.weight: copying a param with shape torch.Size([13, 128]) from checkpoint, the shape in current model is torch.Size([10, 128]).
	size mismatch for heads.class.bias: copying a param with shape torch.Size([13]) from checkpoint, the shape in current model is torch.Size([10]).
	size misma

In [55]:
# Cell B — Calibration diagnostics & reliability diagrams
# Saves: calibration_metrics_by_rank.csv, calibration_bins_<rank>.csv, reliability_<rank>.png
import json, math, time
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import brier_score_loss, accuracy_score

# ---------------- Config ----------------
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
OUT_DIR = EXTRACT_DIR
RANKS = ["kingdom","phylum","class","order","family","genus","species"]
BINS = 15
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# helper: find checkpoint (reuse strategy used before)
def find_checkpoint():
    candidates = sorted(EXTRACT_DIR.glob("best*.pt"), key=lambda p: p.stat().st_mtime, reverse=True)
    cand2 = sorted(EXTRACT_DIR.glob("shared.pt"), key=lambda p: p.stat().st_mtime, reverse=True)
    allc = candidates + cand2
    if allc:
        return allc[0]
    raise FileNotFoundError("No checkpoint found in extracted/ (looked for best*.pt and shared.pt)")

CKPT = find_checkpoint()
print(f"[USE] checkpoint: {CKPT.name}")

# ---------------- Load label encoders ----------------
enc_paths = [
    EXTRACT_DIR / "label_encoders_rebuilt_v2.pkl",
    EXTRACT_DIR / "label_encoders_used.pkl",
    EXTRACT_DIR / "label_encoders_v2.pkl",
    EXTRACT_DIR / "label_encoders.pkl"
]
label_encoders = None
for p in enc_paths:
    if p.exists():
        import pickle
        with open(p, "rb") as fh: label_encoders = pickle.load(fh)
        print("[LOAD] label_encoders from", p.name)
        break
if label_encoders is None:
    raise FileNotFoundError("No label_encoders pickle found in extracted/")

# ---------------- Load embeddings and val idx ----------------
EMB_FULL = EXTRACT_DIR / "embeddings.npy"
EMB_PCA  = EXTRACT_DIR / "embeddings_pca.npy"
if EMB_FULL.exists():
    X = np.load(EMB_FULL)
else:
    X = np.load(EMB_PCA)
print("[LOAD] embeddings shape:", X.shape)

val_idx_path_candidates = [
    EXTRACT_DIR / "val_idx_by_acc.npy",
    EXTRACT_DIR / "val_idx.npy",
    EXTRACT_DIR / "val_idx_by_acc.npy"
]
val_idx = None
for p in val_idx_path_candidates:
    if p.exists():
        val_idx = np.load(p)
        break
if val_idx is None:
    # fallback to predictions_summary_calibrated.csv -> last 15%
    preds_csv = EXTRACT_DIR / "predictions_summary_calibrated.csv"
    if preds_csv.exists():
        dfp = pd.read_csv(preds_csv, dtype=str, keep_default_na=False, na_filter=False)
        n = len(dfp); k = max(1, int(0.15*n))
        val_idx = np.arange(n-k, n, dtype=int)
        print("[FALLBACK] val_idx = last 15% len=", len(val_idx))
    else:
        raise RuntimeError("val_idx not found and predictions_summary_calibrated.csv not present to fallback.")
val_idx = val_idx[val_idx < X.shape[0]]
print("[INFO] using val_idx length:", len(val_idx))

# ---------------- Load true labels arrays if present ----------------
y_prefix = EXTRACT_DIR / "y_encoded_rebuilt_"
y_true = {}
have_y = True
for r in RANKS:
    p = Path(f"{y_prefix}{r}.npy")
    if p.exists():
        arr = np.load(p)
        if arr.shape[0] < X.shape[0]:
            pad = np.full(X.shape[0] - arr.shape[0], fill_value=0, dtype=int)
            arr = np.concatenate([arr, pad])
        y_true[r] = arr[val_idx].astype(int)
    else:
        have_y = False
        break
if not have_y:
    # fallback to predictions_summary_calibrated.csv to get per-rank predicted labels as a proxy for truth
    preds_csv = EXTRACT_DIR / "predictions_summary_calibrated.csv"
    if preds_csv.exists():
        dfp = pd.read_csv(preds_csv, dtype=str, keep_default_na=False, na_filter=False)
        for r in RANKS:
            col = f"{r}_pred"
            if col in dfp.columns:
                lab_to_idx = {lab: i for i, lab in enumerate(label_encoders[r].classes_)}
                idxs = [lab_to_idx.get(s, 0) for s in dfp[col].astype(str).tolist()]
                arr = np.array(idxs, dtype=int)
                y_true[r] = arr[val_idx]
            else:
                y_true[r] = np.zeros(len(val_idx), dtype=int)
        print("[FALLBACK] used predictions_summary_calibrated.csv as proxy truth.")
    else:
        raise RuntimeError("True labels not found and fallback predictions csv not present.")

# ---------------- Build model_obj and load checkpoint (same safe approach) ----------------
hidden_dim = 256
model_obj = nn.Module()
model_obj.shared = nn.Sequential(
    nn.Linear(X.shape[1], hidden_dim),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(hidden_dim, hidden_dim // 2),
    nn.ReLU()
)
heads = {}
for r in RANKS:
    ncls = len(label_encoders[r].classes_)
    heads[r] = nn.Linear(hidden_dim // 2, ncls)
model_obj.heads = nn.ModuleDict(heads)
import types
def _forward(self, x):
    h = self.shared(x)
    return { r: self.heads[r](h) for r in self.heads }
model_obj.forward = types.MethodType(_forward, model_obj)
model_obj.to(DEVICE)

ck = torch.load(CKPT, map_location="cpu")
loaded_ok = False
try:
    if isinstance(ck, dict) and "model_state" in ck:
        model_obj.load_state_dict(ck["model_state"], strict=False); loaded_ok=True
    elif isinstance(ck, dict) and "shared_state" in ck and "heads_state" in ck:
        cur_sd = model_obj.state_dict()
        for k,v in ck["shared_state"].items():
            tk = f"shared.{k}" if not k.startswith("shared.") else k
            if tk in cur_sd: cur_sd[tk] = v
        for hname, hsd in ck["heads_state"].items():
            for subk, v in hsd.items():
                tk = f"heads.{hname}.{subk}"
                if tk in cur_sd: cur_sd[tk] = v
        model_obj.load_state_dict(cur_sd, strict=False); loaded_ok=True
    elif isinstance(ck, dict) and "state_dict" in ck:
        model_obj.load_state_dict(ck["state_dict"], strict=False); loaded_ok=True
    else:
        model_obj.load_state_dict(ck, strict=False); loaded_ok=True
except Exception as e:
    print("[WARN] checkpoint load partial/failed:", e)
print("[MODEL] checkpoint loaded_ok:", loaded_ok)
model_obj.eval()

# ---------------- Build val DataLoader ----------------
from torch.utils.data import TensorDataset, DataLoader
X_val = torch.tensor(X[val_idx], dtype=torch.float32)
y_val_tensors = [torch.tensor(y_true[r], dtype=torch.long) for r in RANKS]
val_dataset = TensorDataset(X_val, *y_val_tensors)
val_loader = DataLoader(val_dataset, batch_size=min(256, max(16, len(val_dataset)//4)), shuffle=False)

# ---------------- Inference on val -> probs per rank ----------------
all_probs = {r: [] for r in RANKS}
all_preds = {r: [] for r in RANKS}
all_true  = {r: [] for r in RANKS}

with torch.no_grad():
    for batch in val_loader:
        x = batch[0].to(DEVICE)
        outputs = model_obj(x)
        for i, r in enumerate(RANKS):
            logits = outputs[r]
            probs = F.softmax(logits, dim=1).cpu().numpy()
            preds = np.argmax(probs, axis=1)
            all_probs[r].append(probs)
            all_preds[r].append(preds)
            all_true[r].append(batch[1 + i].numpy())

# concat
for r in RANKS:
    if len(all_probs[r]) == 0:
        all_probs[r] = np.zeros((0, len(label_encoders[r].classes_)))
        all_preds[r] = np.array([], dtype=int)
        all_true[r]  = np.array([], dtype=int)
    else:
        all_probs[r] = np.vstack(all_probs[r])
        all_preds[r] = np.concatenate(all_preds[r])
        all_true[r]  = np.concatenate(all_true[r])

# ---------------- calibration metrics functions ----------------
def compute_ece(probs, true, n_bins=15):
    """ECE computed using predicted-class confidence and correctness (binary)."""
    if probs.shape[0] == 0:
        return {"ece": np.nan, "mce": np.nan, "bins": []}
    confidences = np.max(probs, axis=1)
    preds = np.argmax(probs, axis=1)
    correct = (preds == true).astype(float)
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    bin_stats = []
    ece = 0.0
    mce = 0.0
    n = len(confidences)
    for i in range(n_bins):
        l, u = bins[i], bins[i+1]
        mask = (confidences > l) & (confidences <= u) if i < n_bins-1 else (confidences >= l) & (confidences <= u)
        cnt = mask.sum()
        if cnt == 0:
            avg_conf = 0.0; acc = 0.0
        else:
            avg_conf = confidences[mask].mean()
            acc = correct[mask].mean()
        gap = abs(acc - avg_conf)
        ece += (cnt / n) * gap
        mce = max(mce, gap)
        bin_stats.append({"bin_low": float(l), "bin_high": float(u), "count": int(cnt), "avg_conf": float(avg_conf), "accuracy": float(acc)})
    return {"ece": float(ece), "mce": float(mce), "bins": bin_stats}

def multiclass_brier(probs, true):
    """Multiclass Brier score = mean over samples of sum_k (p_k - y_k)^2"""
    if probs.shape[0] == 0:
        return np.nan
    n_samples, n_classes = probs.shape
    y_onehot = np.zeros_like(probs)
    y_onehot[np.arange(n_samples), true] = 1.0
    bs = ((probs - y_onehot) ** 2).sum(axis=1).mean()
    return float(bs)

# ---------------- compute per-rank calibration metrics, save plots & CSVs ----------------
cal_metrics_rows = []
for r in RANKS:
    probs = all_probs[r]
    true  = all_true[r]
    n_samples = probs.shape[0]
    if n_samples == 0:
        print(f"[SKIP] rank {r} has 0 val samples")
        continue
    # compute ECE, MCE and Brier
    calib = compute_ece(probs, true, n_bins=BINS)
    brier = multiclass_brier(probs, true)
    mean_conf = float(np.max(probs, axis=1).mean())
    acc = float((np.argmax(probs, axis=1) == true).mean())
    cal_metrics_rows.append({
        "rank": r,
        "n_samples": int(n_samples),
        "ece": calib["ece"],
        "mce": calib["mce"],
        "brier": brier,
        "mean_confidence": mean_conf,
        "accuracy": acc
    })

    # save bin-level CSV
    df_bins = pd.DataFrame(calib["bins"])
    df_bins.to_csv(OUT_DIR / f"calibration_bins_{r}.csv", index=False)

    # reliability diagram (bar: bin centers with accuracy; overlay diagonal)
    bin_centers = [(b["bin_low"] + b["bin_high"]) / 2.0 for b in calib["bins"]]
    accs = [b["accuracy"] for b in calib["bins"]]
    confs = [b["avg_conf"] for b in calib["bins"]]
    counts = [b["count"] for b in calib["bins"]]
    plt.figure(figsize=(6,5))
    plt.plot([0,1],[0,1], linestyle="--", color="gray", label="Perfect calibration")
    plt.plot(bin_centers, accs, marker="o", label="Accuracy (per-bin)")
    plt.plot(bin_centers, confs, marker="x", label="Avg Confidence (per-bin)")
    # visual marker sized by count
    maxc = max(counts) if counts else 1
    for x, a, c in zip(bin_centers, accs, counts):
        plt.scatter(x, a, s=20 + (120.0 * (c / maxc)), alpha=0.6, color="C0")
    plt.xlabel("Confidence")
    plt.ylabel("Accuracy")
    plt.title(f"Reliability diagram — {r} (ECE={calib['ece']:.4f})")
    plt.legend(loc="lower right")
    plt.grid(alpha=0.2)
    outpng = OUT_DIR / f"reliability_{r}.png"
    plt.savefig(outpng, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"[PLOT] saved {outpng.name} (n={n_samples})")

# ---------------- save summary CSV ----------------
pd.DataFrame(cal_metrics_rows).to_csv(OUT_DIR / "calibration_metrics_by_rank.csv", index=False)
print("[SAVE] calibration_metrics_by_rank.csv written to", OUT_DIR)

# ---------------- Print a short summary ----------------
print("\n=== Calibration summary ===")
for row in cal_metrics_rows:
    print(f"{row['rank']:10s} | n={row['n_samples']:4d} | acc={row['accuracy']:.4f} | mean_conf={row['mean_confidence']:.4f} | ECE={row['ece']:.4f} | MCE={row['mce']:.4f} | Brier={row['brier']:.4f}")

print("\nOutputs: calibration_bins_<rank>.csv, reliability_<rank>.png, calibration_metrics_by_rank.csv saved in", OUT_DIR)

[USE] checkpoint: best_shared_heads_labeled.pt
[LOAD] label_encoders from label_encoders_rebuilt_v2.pkl
[LOAD] embeddings shape: (2555, 128)
[INFO] using val_idx length: 373
[MODEL] checkpoint loaded_ok: True
[PLOT] saved reliability_kingdom.png (n=373)
[PLOT] saved reliability_phylum.png (n=373)
[PLOT] saved reliability_class.png (n=373)
[PLOT] saved reliability_order.png (n=373)
[PLOT] saved reliability_family.png (n=373)
[PLOT] saved reliability_genus.png (n=373)
[PLOT] saved reliability_species.png (n=373)
[SAVE] calibration_metrics_by_rank.csv written to ncbi_blast_db\extracted

=== Calibration summary ===
kingdom    | n= 373 | acc=0.4638 | mean_conf=1.0000 | ECE=0.5362 | MCE=0.5362 | Brier=1.0724
phylum     | n= 373 | acc=0.4611 | mean_conf=0.9539 | ECE=0.4927 | MCE=0.9017 | Brier=1.0174
class      | n= 373 | acc=0.4638 | mean_conf=0.9364 | ECE=0.4726 | MCE=0.8845 | Brier=0.9912
order      | n= 373 | acc=0.4584 | mean_conf=0.9348 | ECE=0.4764 | MCE=0.8365 | Brier=0.9968
family   

In [57]:
# Fixed cell: build cluster-level novel_candidates_priority.csv from predictions_with_mc_uncertainty
# - robust: loads precomputed predictions CSV/JSONL if present (no re-run of MC)
# - safe aggregation using named tuples; handles missing columns gracefully
# - outputs: novel_candidates_priority.csv and a printed top-10 list
import json, math, sys
from pathlib import Path
import numpy as np
import pandas as pd

# ---------------- Paths ----------------
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
OUT_DIR = EXTRACT_DIR

CSV_IN = OUT_DIR / "predictions_with_mc_uncertainty.csv"
JSONL_IN = OUT_DIR / "predictions_with_mc_uncertainty.jsonl"
OUT_NOVEL = OUT_DIR / "novel_candidates_priority.csv"

# ---------------- Load predictions DataFrame (robust) ----------------
if CSV_IN.exists():
    df = pd.read_csv(CSV_IN, dtype=str, keep_default_na=False, na_filter=False)
    print(f"[LOAD] {CSV_IN.name} rows={len(df)}")
elif JSONL_IN.exists():
    # load JSONL into DataFrame
    rows = []
    with open(JSONL_IN, "r", encoding="utf8") as fh:
        for line in fh:
            try:
                rows.append(json.loads(line))
            except Exception:
                pass
    df = pd.DataFrame(rows)
    print(f"[LOAD] {JSONL_IN.name} rows={len(df)}")
else:
    raise FileNotFoundError("Neither predictions_with_mc_uncertainty.csv nor .jsonl found in extracted/. Run the MC cell first.")

# ensure index column exists
if "__row_index" not in df.columns:
    df["__row_index"] = np.arange(len(df))

# ---------------- normalize and detect key columns ----------------
def find_col(candidates):
    for c in candidates:
        if c in df.columns:
            return c
    return None

# cluster label column candidates
cluster_col = find_col(["cluster_label", "meta_cluster_label", "cluster", "cluster_id"])
if cluster_col is None:
    # create a fallback column (all -1)
    df["cluster_label"] = "-1"
    cluster_col = "cluster_label"
    print("[WARN] cluster column not found - using fallback '-1' for all rows")
else:
    # ensure string dtype
    df[cluster_col] = df[cluster_col].astype(str)

# novelty score column candidates
nov_col = find_col(["novelty_score", "novelty", "meta_novelty", "novel_score"])
if nov_col is None:
    df["novelty_score"] = 0.0
    nov_col = "novelty_score"
    print("[WARN] novelty column not found - defaulting to 0.0")
else:
    # coerce to numeric
    df[nov_col] = pd.to_numeric(df[nov_col], errors="coerce").fillna(0.0)

# priority / mutual-info / mean_conf candidates
priority_col = find_col(["priority_score", "priority"])
mi_col = find_col(["species_mutual_info_mc", "species_mutual_info", "mutual_info", "mi"])
conf_col = find_col(["species_mean_conf_mc", "species_mean_conf", "mean_confidence", "species_mean_conf"])

# if priority doesn't exist but mi & novelty exist, compute a fallback priority
if priority_col is None:
    if mi_col is not None:
        df["priority_score"] = (0.5 * df[nov_col].astype(float) + 0.5 * pd.to_numeric(df[mi_col], errors="coerce").fillna(0.0))
    else:
        df["priority_score"] = pd.to_numeric(df[nov_col], errors="coerce").fillna(0.0)
    priority_col = "priority_score"
    print("[INFO] priority_score not found - computed fallback from novelty and mutual-info where possible")

# species label column detection for consensus fraction
species_label_col = find_col(["species_pred_label_mc", "species_pred_label", "species_pred", "species_pred_label"])
if species_label_col is None:
    # fallback: attempt genus or species columns that exist
    species_label_col = find_col(["genus_pred_label_mc", "genus_pred_label", "genus_pred"])
    if species_label_col is None:
        # create a column of NaNs so aggregation can run
        df["species_pred_label_mc"] = pd.NA
        species_label_col = "species_pred_label_mc"
        print("[WARN] species label column not found - creating empty column for consensus frac")

# ensure numeric columns are numeric
df[nov_col] = pd.to_numeric(df[nov_col], errors="coerce").fillna(0.0)
df[priority_col] = pd.to_numeric(df[priority_col], errors="coerce").fillna(0.0)
# mutual-info & mean_conf (optional)
if mi_col:
    df[mi_col] = pd.to_numeric(df[mi_col], errors="coerce").fillna(0.0)
if conf_col:
    df[conf_col] = pd.to_numeric(df[conf_col], errors="coerce").fillna(0.0)

# ---------------- Groupby aggregation (safe named tuples) ----------------
# We'll use safe column names for aggregation targets; if a column is missing pandas will fill NaN
agg_map = {
    "cluster_n": ("__row_index", "count"),
    "cluster_mean_novelty": (nov_col, "mean"),
    "cluster_mean_priority": (priority_col, "mean"),
}

# add mutual-info and mean_conf if present
if mi_col:
    agg_map["cluster_mean_species_mi"] = (mi_col, "mean")
else:
    df["cluster_mean_species_mi"] = 0.0
    agg_map["cluster_mean_species_mi"] = ("cluster_mean_species_mi", "mean")

if conf_col:
    agg_map["cluster_mean_species_conf"] = (conf_col, "mean")
else:
    df["cluster_mean_species_conf"] = 0.0
    agg_map["cluster_mean_species_conf"] = ("cluster_mean_species_conf", "mean")

# consensus fraction computed from species label column using tuple (col, function)
def consensus_frac_series(s):
    # s may be object dtype strings; consider non-empty / notna as consensus presence
    try:
        return float(s.notna().sum()) / float(len(s)) if len(s) > 0 else 0.0
    except Exception:
        # fallback simple
        return float(np.count_nonzero(~pd.isna(s))) / float(len(s)) if len(s) > 0 else 0.0

# If pandas version supports lambda inside named aggregation tuple, use it; otherwise compute separately
try:
    grp = df.groupby(cluster_col).agg(
        **agg_map,
        cluster_species_consensus_frac = (species_label_col, lambda s: float(s.notna().sum())/float(len(s)) if len(s)>0 else 0.0)
    ).reset_index()
except Exception as e:
    # fallback: compute groupby agg with agg_map, then compute consensus separately
    print("[FALLBACK] Named-lambda aggregation not supported in this pandas version:", e)
    grp = df.groupby(cluster_col).agg(**agg_map).reset_index()
    # compute consensus separately:
    cons = df.groupby(cluster_col)[species_label_col].apply(lambda s: float(s.notna().sum())/float(len(s)) if len(s)>0 else 0.0).reset_index(name="cluster_species_consensus_frac")
    grp = grp.merge(cons, how="left", left_on=cluster_col, right_on=cluster_col)

# ---------------- Post-process & ranking ----------------
# clean cluster label column name in grp if necessary
if cluster_col != "cluster_label":
    # ensure we still have a cluster column in output named 'cluster_label'
    grp = grp.rename(columns={cluster_col: "cluster_label"})

# fill NaNs
for c in ["cluster_mean_novelty","cluster_mean_priority","cluster_mean_species_mi","cluster_mean_species_conf","cluster_species_consensus_frac"]:
    if c in grp.columns:
        grp[c] = pd.to_numeric(grp[c], errors="coerce").fillna(0.0)

# compute priority_score if not present (use cluster_mean_priority if exists)
if "cluster_mean_priority" not in grp.columns or grp["cluster_mean_priority"].isnull().all():
    grp["cluster_mean_priority"] = grp["cluster_mean_novelty"] * 0.6 + grp["cluster_mean_species_mi"] * 0.4

# add additional derived columns for sorting
grp["priority_rank"] = grp["cluster_mean_priority"].rank(method="first", ascending=False).astype(int)
grp = grp.sort_values(["cluster_mean_priority","cluster_mean_novelty"], ascending=[False, False]).reset_index(drop=True)

# ---------------- Save outputs ----------------
grp.to_csv(OUT_NOVEL, index=False)
print(f"[SAVE] novel candidate cluster priority -> {OUT_NOVEL.name}  (clusters={len(grp)})")

# ---------------- Print top-10 cluster-level candidates & top-10 sequence-level for review ----------------
print("\nTop-10 clusters by cluster_mean_priority:")
top_clusters = grp.head(10)
print(top_clusters[["cluster_label","cluster_n","cluster_mean_priority","cluster_mean_novelty","cluster_species_consensus_frac"]].to_string(index=False))

# Also show top-10 sequence-level items (already prioritized by priority_score)
# ensure priority_score column exists in df
if "priority_score" not in df.columns:
    df["priority_score"] = pd.to_numeric(df[priority_col], errors="coerce").fillna(0.0)
top_seq = df.sort_values("priority_score", ascending=False).head(10)
print("\nTop-10 sequence-level priority (idx, id, cluster, priority, novelty, species_mi, species_conf):")
for _, row in top_seq.iterrows():
    seqid = row.get("sequence_id", "") or row.get("id","") or ""
    print(f" idx={int(row['__row_index'])} id={str(seqid)[:40]:40s} cluster={str(row.get(cluster_col,'-1')):6s} pri={float(row.get('priority_score',0.0)):.4f} nov={float(row.get(nov_col,0.0)):.4f} mi={float(row.get(mi_col or '0',0.0)):.4f} conf={float(row.get(conf_col or '0',0.0)):.4f}")

print("\nCell complete. Files written: ", OUT_NOVEL.name)

[LOAD] predictions_with_mc_uncertainty.csv rows=2557
[SAVE] novel candidate cluster priority -> novel_candidates_priority.csv  (clusters=104)

Top-10 clusters by cluster_mean_priority:
cluster_label  cluster_n  cluster_mean_priority  cluster_mean_novelty  cluster_species_consensus_frac
           14          6               0.422928                   0.0                             1.0
           31          6               0.414074                   0.0                             1.0
           30         12               0.393406                   0.0                             1.0
            0          6               0.374172                   0.0                             1.0
           13         10               0.367131                   0.0                             1.0
           18          8               0.365694                   0.0                             1.0
           19          6               0.354962                   0.0                             1.0

In [59]:
# Cell D — Production inference wrapper: infer_fasta(fasta_path, out_prefix, ...)
# Outputs:
#  - <out_prefix>_predictions.jsonl
#  - <out_prefix>_predictions_summary.csv
#  - optionally: <out_prefix>_mc.jsonl / _mc.csv with MC metrics
# Requirements: gensim, torch, sklearn, numpy, pandas
# Usage example (run after cell): infer_fasta("new_seqs.fasta", out_prefix="ncbi_blast_db/extracted/infer_new", n_mc=30)

import os, json, time, math, pickle
from pathlib import Path
from collections import defaultdict
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.decomposition import PCA
from gensim.models import Word2Vec
from torch.utils.data import TensorDataset, DataLoader

# ---------- CONFIG ----------
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
W2V_PATH_CANDIDATES = [EXTRACT_DIR / "kmer_w2v_k6.model", EXTRACT_DIR / "kmer_w2v.model", DOWNLOAD_DIR / "kmer_w2v_k6.model"]
PCA_MODEL_PATH = EXTRACT_DIR / "pca_model.pkl"
EMB_FULL = EXTRACT_DIR / "embeddings.npy"   # used to fit PCA if missing
CHECKPOINT_GLOBS = [EXTRACT_DIR / "best_shared_heads_retrain.pt",
                    EXTRACT_DIR / "best_shared_heads_labeled.pt",
                    EXTRACT_DIR / "best_shared_heads_pt.pt",
                    EXTRACT_DIR / "best_shared_heads_pseudo_tensordataset_fix.pt",
                    EXTRACT_DIR / "best_shared_heads.pt",
                    EXTRACT_DIR / "best_shared_heads_retrain.pt"]
LABEL_ENCODER_CANDIDATES = [EXTRACT_DIR / "label_encoders_rebuilt_v2.pkl",
                            EXTRACT_DIR / "label_encoders_used.pkl",
                            EXTRACT_DIR / "label_encoders_v2.pkl",
                            EXTRACT_DIR / "label_encoders.pkl"]
DEFAULT_KMER = 6
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HIDDEN_DIM = 256

# ---------- helpers ----------
def find_file(cands):
    for p in cands:
        if p is None:
            continue
        if Path(p).exists():
            return Path(p)
    return None

def parse_fasta(path):
    """Simple fasta parser yielding (header, seq)"""
    header = None
    seq_parts = []
    with open(path, "r", encoding="utf8", errors="ignore") as fh:
        for line in fh:
            line = line.rstrip("\n\r")
            if not line:
                continue
            if line[0] == ">":
                if header is not None:
                    yield header, "".join(seq_parts)
                header = line[1:].strip()
                seq_parts = []
            else:
                seq_parts.append(line.strip())
        if header is not None:
            yield header, "".join(seq_parts)

def kmers_from_seq(seq, k=6):
    s = seq.upper()
    n = len(s)
    if n < k:
        return []
    return [s[i:i+k] for i in range(0, n - k + 1)]

def load_word2vec(path_candidates):
    p = find_file(path_candidates)
    if p is None:
        raise FileNotFoundError("Word2Vec k-mer model not found. Looked in: " + ", ".join(str(x) for x in path_candidates))
    print("[LOAD] Word2Vec:", p.name)
    return Word2Vec.load(str(p))

def infer_input_dim_from_ckpt(ck):
    """Robustly inspect checkpoint dict or state_dict to find expected input dim (second dim of the first linear weight)."""
    # ck may be dict or state_dict
    cand = None
    if isinstance(ck, dict):
        # check 'model_state' or 'state_dict'
        for key in ("model_state", "state_dict"):
            if key in ck and isinstance(ck[key], dict):
                sd = ck[key]
                for k, v in sd.items():
                    if k.endswith(".weight"):
                        arr = v.numpy() if isinstance(v, torch.Tensor) else np.array(v)
                        if arr.ndim == 2 and arr.shape[1] < 5000:
                            return int(arr.shape[1])
        # check shared_state pattern
        if "shared_state" in ck and isinstance(ck["shared_state"], dict):
            for k, v in ck["shared_state"].items():
                arr = v.numpy() if isinstance(v, torch.Tensor) else np.array(v)
                if arr.ndim == 2:
                    return int(arr.shape[1])
    # if ck itself is a state_dict-like
    if isinstance(ck, dict):
        for k, v in ck.items():
            if k.endswith(".weight"):
                arr = v.numpy() if isinstance(v, torch.Tensor) else np.array(v)
                if arr.ndim == 2 and arr.shape[1] < 5000:
                    return int(arr.shape[1])
    return None

def build_model(input_dim, label_encoders, hidden_dim=HIDDEN_DIM):
    model = nn.Module()
    model.shared = nn.Sequential(
        nn.Linear(input_dim, hidden_dim),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(hidden_dim, hidden_dim // 2),
        nn.ReLU()
    )
    heads = {}
    for r in label_encoders.keys():
        ncls = len(label_encoders[r].classes_)
        heads[r] = nn.Linear(hidden_dim // 2, ncls)
    model.heads = nn.ModuleDict(heads)
    # attach forward
    def _forward(self, x):
        h = self.shared(x)
        return {r: self.heads[r](h) for r in self.heads}
    import types
    model.forward = types.MethodType(_forward, model)
    return model

def load_label_encoders(candidates):
    p = find_file(candidates)
    if p is None:
        raise FileNotFoundError("label_encoders pickle not found; looked in: " + ", ".join(str(x) for x in candidates))
    with open(p, "rb") as fh:
        enc = pickle.load(fh)
    print("[LOAD] label_encoders:", p.name)
    return enc

def safe_load_ckpt(candidates):
    p = find_file(candidates)
    if p is None:
        raise FileNotFoundError("Checkpoint not found; looked in: " + ", ".join(str(x) for x in candidates))
    ck = torch.load(p, map_location="cpu")
    print("[LOAD] checkpoint:", p.name)
    return ck

# ---------- main inference function ----------
def infer_fasta(fasta_path, out_prefix=None, kmer=DEFAULT_KMER, n_mc=0, batch_size=256, apply_temp_scaling=False, temps=None):
    """
    fasta_path: path to fasta file
    out_prefix: prefix for output files (if None -> <fasta stem> in extracted/)
    n_mc: number of MC-dropout forward passes (0 = deterministic only)
    apply_temp_scaling: if True and temps dict provided mapping rank->T, will divide logits by T
    temps: dict rank->temperature, optional
    """
    tstart = time.time()
    fasta_path = Path(fasta_path)
    if not fasta_path.exists():
        raise FileNotFoundError(f"FASTA not found: {fasta_path}")
    out_prefix = Path(out_prefix) if out_prefix else (EXTRACT_DIR / (fasta_path.stem + "_infer"))
    out_prefix = out_prefix.with_suffix("")  # remove extension
    out_jsonl = out_prefix.name + "_predictions.jsonl"
    out_csv = out_prefix.name + "_predictions_summary.csv"
    out_jsonl = out_prefix.with_suffix("").with_name(out_prefix.name + "_predictions.jsonl")
    out_csv = out_prefix.with_suffix("").with_name(out_prefix.name + "_predictions_summary.csv")
    out_jsonl = Path(out_jsonl)
    out_csv = Path(out_csv)
    print(f"[INFER] fasta={fasta_path.name} -> outputs: {out_jsonl.name}, {out_csv.name}")

    # load resources
    w2v = load_word2vec(W2V_PATH_CANDIDATES)
    enc = load_label_encoders(LABEL_ENCODER_CANDIDATES)
    ck = safe_load_ckpt(CHECKPOINT_GLOBS)

    # determine input_dim expected by model
    input_dim = infer_input_dim_from_ckpt(ck)
    if input_dim is None:
        # fallback to Word2Vec vector_size
        input_dim = getattr(w2v.wv, "vector_size", None) or getattr(w2v, "vector_size", 128)
    print(f"[INFO] chosen input_dim={input_dim}")

    # if input_dim < w2v.vec_size, ensure PCA available or fit
    w2v_dim = w2v.wv.vector_size
    pca_model = None
    if input_dim != w2v_dim:
        # try load saved PCA
        if PCA_MODEL_PATH.exists():
            with open(PCA_MODEL_PATH, "rb") as fh:
                pca_model = pickle.load(fh)
            print("[LOAD] PCA model from", PCA_MODEL_PATH.name)
        else:
            # fit PCA on saved full embeddings.npy (if present)
            if EMB_FULL.exists():
                emb = np.load(EMB_FULL)
                print(f"[FIT] Fitting PCA to reduce {emb.shape[1]} -> {input_dim} using saved embeddings.npy (n={emb.shape[0]})")
                pca = PCA(n_components=input_dim, random_state=42)
                pca.fit(emb)
                pca_model = pca
                with open(PCA_MODEL_PATH, "wb") as fh:
                    pickle.dump(pca_model, fh)
                print("[SAVE] PCA model written to", PCA_MODEL_PATH.name)
            else:
                raise RuntimeError("Model expects input_dim != w2v vector size, but no PCA model and no embeddings.npy to fit PCA.")
    else:
        print("[INFO] No PCA required (model input matches Word2Vec dim).")

    # build model with the input_dim and load weights
    model = build_model(input_dim, enc, hidden_dim=HIDDEN_DIM)
    # load checkpoint state into model (robust)
    try:
        if isinstance(ck, dict) and "model_state" in ck:
            model.load_state_dict(ck["model_state"], strict=False)
        elif isinstance(ck, dict) and "state_dict" in ck:
            model.load_state_dict(ck["state_dict"], strict=False)
        elif isinstance(ck, dict) and "shared_state" in ck and "heads_state" in ck:
            # merge into model.state_dict
            cur = model.state_dict()
            for k,v in ck["shared_state"].items():
                tk = f"shared.{k}" if not k.startswith("shared.") else k
                if tk in cur:
                    cur[tk] = v
            for hname, hsd in ck["heads_state"].items():
                for subk, v in hsd.items():
                    tk = f"heads.{hname}.{subk}"
                    if tk in cur:
                        cur[tk] = v
            model.load_state_dict(cur, strict=False)
        else:
            # attempt to load ck itself as state_dict
            model.load_state_dict(ck, strict=False)
    except Exception as e:
        print("[WARN] checkpoint partial load:", e)
    model.to(DEVICE)
    model.eval()

    # parse FASTA -> sequences list
    seqs = []
    headers = []
    for h, s in parse_fasta(fasta_path):
        headers.append(h)
        seqs.append(s)
    if len(seqs) == 0:
        raise RuntimeError("No sequences parsed from FASTA.")

    # vectorize by averaging k-mer vectors
    n = len(seqs)
    seq_emb = np.zeros((n, w2v_dim), dtype=np.float32)
    for i, s in enumerate(seqs):
        kmer_list = kmers_from_seq(s, k=kmer)
        vecs = []
        for kmer_token in kmer_list:
            if kmer_token in w2v.wv:
                vecs.append(w2v.wv[kmer_token])
        if len(vecs) == 0:
            # fallback: try single characters? or zero-vector
            seq_emb[i] = np.zeros((w2v_dim,), dtype=np.float32)
        else:
            seq_emb[i] = np.mean(vecs, axis=0)

    # apply PCA if needed
    if pca_model is not None:
        seq_emb_in = pca_model.transform(seq_emb)
    else:
        seq_emb_in = seq_emb  # shape (n, input_dim)

    # inference function (single pass)
    def forward_batch(X_batch, temp_dict=None):
        xb = torch.tensor(X_batch, dtype=torch.float32).to(DEVICE)
        with torch.no_grad():
            logits_map = model(xb)
        out_probs = {}
        for r, logits in logits_map.items():
            # apply temperature scaling if provided
            l = logits
            if (apply_temp_scaling and temp_dict and r in temp_dict and float(temp_dict[r]) > 0.0):
                T = float(temp_dict[r])
                l = l / T
            probs = F.softmax(l, dim=1).cpu().numpy()
            out_probs[r] = probs
        return out_probs

    # deterministic predictions
    B = min(batch_size, max(1, n))
    all_probs = {r: [] for r in enc.keys()}
    for i in range(0, n, B):
        batch = seq_emb_in[i:i+B]
        probs_map = forward_batch(batch, temp_dict=temps if apply_temp_scaling else None)
        for r, p in probs_map.items():
            all_probs[r].append(p)
    for r in list(all_probs.keys()):
        all_probs[r] = np.vstack(all_probs[r]) if len(all_probs[r]) else np.zeros((n, len(enc[r].classes_)))

    # MC-dropout if requested
    mc_metrics = {}
    if n_mc and n_mc > 0:
        model.train()  # enable dropout
        sum_probs = {r: np.zeros((n, len(enc[r].classes_)), dtype=np.float64) for r in enc.keys()}
        sum_probs_sq = {r: np.zeros_like(sum_probs[r]) for r in enc.keys()}
        # run passes
        for t in range(n_mc):
            for i in range(0, n, B):
                batch = torch.tensor(seq_emb_in[i:i+B], dtype=torch.float32).to(DEVICE)
                with torch.no_grad():
                    logits_map = model(batch)
                for r, logits in logits_map.items():
                    probs = F.softmax(logits, dim=1).cpu().numpy()
                    sum_probs[r][i:i+len(probs)] += probs
                    sum_probs_sq[r][i:i+len(probs)] += (probs * probs)
        # compute metrics
        for r in enc.keys():
            mean_probs = sum_probs[r] / float(n_mc)
            mean_probs_sq = sum_probs_sq[r] / float(n_mc)
            var_probs = np.clip(mean_probs_sq - (mean_probs ** 2), 0.0, None)
            pred_entropy = -np.sum(mean_probs * np.log(mean_probs + 1e-12), axis=1)
            exp_entropy = ( -np.sum((sum_probs[r]/float(n_mc)) * np.log((sum_probs[r]/float(n_mc)) + 1e-12), axis=1) ) # not quite E[H], but approximate; better below:
            # recompute expected entropy properly: approximate by average of entropies; but we didn't store entropies per pass above; recompute by re-running passes? 
            # Instead compute expected entropy using mean of per-pass entropies:
            # we can approximate expected_entropy = mean over passes of H(p_t) = (sum_probs_sq term used for var)
            # We'll compute expected_entropy by reusing mean_probs and var approx:
            # For reliability we compute mutual_info = pred_entropy - mean_of_entropy. We'll approximate mean_of_entropy by: sum_over_k (mean_p_k^2?) Not perfect but acceptable
            # To avoid re-running passes (time), we approximate expected entropy via: expected_entropy ≈ -sum_k mean_p_k * log(mean_p_k) - sum_k var_k / (2 * mean_p_k + eps) ... this is messy.
            # Practical compromise: compute predictive entropy (H[p_mean]) and use var of predicted-class as uncertainty metric.
            pred_class_idx = np.argmax(mean_probs, axis=1)
            mean_conf = np.max(mean_probs, axis=1)
            # std of max confidence via var of max approximated by var_probs at predicted class
            std_pred_class = np.sqrt(var_probs[np.arange(n), pred_class_idx])
            mc_metrics[r] = {
                "mean_probs": mean_probs,
                "predictive_entropy": pred_entropy,
                "std_pred_class": std_pred_class,
                "mean_conf": mean_conf,
                "pred_class_idx": pred_class_idx
            }
        model.eval()

    # Build output records
    records = []
    summary_rows = []
    for i in range(n):
        rec = {}
        rec["id"] = headers[i]
        rec["seq_len"] = len(seqs[i])
        rec["k"] = kmer
        # deterministic ranks
        mean_conf_sum = 0.0
        for r in enc.keys():
            probs = all_probs[r][i]
            pred_idx = int(np.argmax(probs))
            pred_label = enc[r].classes_[pred_idx] if pred_idx < len(enc[r].classes_) else ""
            conf = float(np.max(probs))
            rec[f"{r}_pred"] = pred_label
            rec[f"{r}_pred_idx"] = pred_idx
            rec[f"{r}_mean_conf"] = conf
            mean_conf_sum += conf
            # MC metrics if present
            if n_mc > 0 and r in mc_metrics:
                rec[f"{r}_mc_std_predprob"] = float(mc_metrics[r]["std_pred_class"][i])
                rec[f"{r}_mc_pred_entropy"] = float(mc_metrics[r]["predictive_entropy"][i])
                rec[f"{r}_mc_mean_conf"] = float(mc_metrics[r]["mean_conf"][i])
        # aggregated summary
        rec["mean_conf_mean_ranks"] = mean_conf_sum / max(1, len(enc.keys()))
        records.append(rec)

        # summary row small
        srow = {"id": headers[i], "seq_len": len(seqs[i])}
        for r in enc.keys():
            srow[f"{r}_pred"] = rec.get(f"{r}_pred", "")
            srow[f"{r}_mean_conf"] = rec.get(f"{r}_mean_conf", "")
            if n_mc > 0:
                srow[f"{r}_mc_mean_conf"] = rec.get(f"{r}_mc_mean_conf", "")
                srow[f"{r}_mc_pred_entropy"] = rec.get(f"{r}_mc_pred_entropy", "")
        summary_rows.append(srow)

    # Save JSONL and CSV
    with open(out_jsonl, "w", encoding="utf8") as fh:
        for r in records:
            fh.write(json.dumps(r) + "\n")
    df_summary = pd.DataFrame(summary_rows)
    df_summary.to_csv(out_csv, index=False)
    elapsed = time.time() - tstart
    print(f"[DONE] wrote {out_jsonl.name} ({len(records)} rows) and {out_csv.name} in {elapsed:.1f}s")
    return out_jsonl, out_csv

# ---------------- Example invocation (uncomment and edit FASTA path to run) ----------------
# infer_fasta("path/to/your/new_sequences.fasta",
#             out_prefix="ncbi_blast_db/extracted/infer_new",
#             kmer=6, n_mc=30, batch_size=128, apply_temp_scaling=False, temps=None)

In [1]:
# Robust locator + safe extractor for real FASTA/tarballs (improved, streaming, progress)
import os, sys, tarfile, gzip, io, shutil, time
from pathlib import Path
import textwrap

DOWNLOAD_DIR = Path("./ncbi_blast_db")
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"

# --- Helpers ---------------------------------------------------------------
def find_fastas(base_dir):
    exts = [".fasta", ".fa", ".fna", ".fasta.gz", ".fa.gz", ".fna.gz"]
    found = []
    if base_dir.exists():
        for e in exts:
            found.extend(list(base_dir.glob(f"*{e}")))     # non-recursive
            found.extend(list(base_dir.rglob(f"*{e}")))    # recursive
    found = sorted({p.resolve() for p in found})
    return found

def find_tarballs(base_dir):
    if not base_dir.exists():
        return []
    patterns = ["*.tar.gz", "*.tgz", "*.tar"]
    found = []
    for p in patterns:
        found.extend(list(base_dir.glob(p)))
        found.extend(list(base_dir.rglob(p)))
    found = sorted({p.resolve() for p in found})
    return found

def safe_name_for_out(tar_path, member_name, prefix_with_tar=True):
    # use basename, but prefix with tar name (without ext) to avoid collisions
    base = Path(member_name).name
    if prefix_with_tar:
        tarstem = Path(tar_path).stem
        # if tar was .tar.gz the stem is e.g. "file.tar" so remove extra .tar
        if tarstem.endswith(".tar"):
            tarstem = tarstem[:-4]
        return f"{tarstem}__{base}"
    return base

def stream_decompress_gz_member(member_fileobj, out_path):
    # member_fileobj is a file-like object (from tar.extractfile)
    # open gzip.GzipFile on top and stream-copy to disk
    try:
        with gzip.GzipFile(fileobj=member_fileobj) as gz:
            with open(out_path, "wb") as w:
                shutil.copyfileobj(gz, w)
        return True
    except (OSError, EOFError):
        # not a gzip stream or decompression failed
        return False

def safe_extract_fastas_from_tar(tar_path, out_dir, verbose=True):
    out_dir.mkdir(parents=True, exist_ok=True)
    count = 0
    start = time.time()
    try:
        # streaming mode avoids building the full member list
        with tarfile.open(tar_path, "r|*") as tf:
            if verbose:
                print(f"Processing tar (stream mode): {tar_path.name}")
            for m in tf:
                # skip if not regular file
                if not m.isfile():
                    continue
                name = Path(m.name).name
                if not name:
                    continue
                lname = name.lower()
                if not (lname.endswith(".fasta") or lname.endswith(".fa") or lname.endswith(".fna")
                        or lname.endswith(".fasta.gz") or lname.endswith(".fa.gz") or lname.endswith(".fna.gz")):
                    # skip non-FASTA-like members
                    continue

                outname = safe_name_for_out(tar_path, name, prefix_with_tar=True)
                out_path = out_dir / outname
                # open member fileobj and stream to disk (decompress if gz)
                try:
                    fobj = tf.extractfile(m)
                    if fobj is None:
                        print(f"  [WARN] could not open member stream: {m.name}")
                        continue

                    # If name endswith .gz, attempt streaming gzip decompression
                    if lname.endswith(".gz"):
                        ok = stream_decompress_gz_member(fobj, out_path.with_suffix(''))  # remove .gz in output
                        if ok:
                            print(f"  extracted (decompressed) -> {out_path.with_suffix('').name}", flush=True)
                        else:
                            # fallback: write raw bytes without loading all at once
                            with open(out_path, "wb") as w:
                                shutil.copyfileobj(fobj, w)
                            print(f"  extracted (saved .gz raw) -> {out_path.name} (decompress failed)", flush=True)
                    else:
                        # stream copy plain member to disk
                        with open(out_path, "wb") as w:
                            shutil.copyfileobj(fobj, w)
                        print(f"  extracted -> {out_path.name}", flush=True)

                    count += 1
                except Exception as ex:
                    print(f"  [WARN] could not extract member {m.name}: {ex}", flush=True)
    except Exception as ex:
        print(f"[WARN] failed to open {tar_path.name}: {ex}")
    elapsed = time.time() - start
    if elapsed > 0:
        print(f"  -> done {count} files from {tar_path.name} in {elapsed:.1f}s")
    return count

# --- environment summary ---------------------------------------------------
print("PWD:", Path.cwd())
print("Checking expected dirs:")
print("  DOWNLOAD_DIR:", DOWNLOAD_DIR.resolve())
print("  EXTRACT_DIR :", EXTRACT_DIR.resolve())
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)

# 1) look for FASTAs in extracted/
fastas_in_extracted = find_fastas(EXTRACT_DIR)
if fastas_in_extracted:
    print(f"\n[OK] Found {len(fastas_in_extracted)} FASTA files directly in {EXTRACT_DIR}:")
    for p in fastas_in_extracted[:50]:
        print("  -", p.name)
else:
    print(f"\n[INFO] No FASTA files found in {EXTRACT_DIR}.")

# 2) find tarballs (only limited, avoid scanning entire home recursively if you don't want it)
tar_candidates = []
tar_candidates += find_tarballs(DOWNLOAD_DIR)
tar_candidates += find_tarballs(Path.cwd())
# (optionally) include home if you explicitly want to scan it - can be slow on large home dirs
# tar_candidates += find_tarballs(Path.home())

tar_candidates = sorted({p.resolve() for p in tar_candidates})
if tar_candidates:
    print(f"\n[INFO] Found {len(tar_candidates)} tarball(s) (candidate sources):")
    for t in tar_candidates[:20]:
        print("  -", t)
else:
    print("\n[INFO] No tarballs (*.tar.gz, *.tgz, *.tar) found in ncbi_blast_db/ or cwd (home scan disabled).")

# 3) If FASTAs not present but tarballs found, attempt extraction of FASTA members into EXTRACT_DIR
if (not fastas_in_extracted) and tar_candidates:
    print("\n[STEP] Attempting to extract FASTA-like members from the tarball(s) into", EXTRACT_DIR)
    total = 0
    for t in tar_candidates:
        extracted = safe_extract_fastas_from_tar(t, EXTRACT_DIR)
        total += extracted
    if total > 0:
        print(f"[OK] Extracted {total} FASTA-like files into {EXTRACT_DIR}. Re-checking files now...")
        fastas_in_extracted = find_fastas(EXTRACT_DIR)
        print(f"[OK] Now found {len(fastas_in_extracted)} FASTA files.")
        for p in fastas_in_extracted[:50]:
            print("  -", p.name)
    else:
        print("[WARN] No FASTA-like members were extracted from the tarballs found. Inspect the tarballs with 'tar -tf <tarfile> | head' to see contents.")

# 4) look for word2vec path (same as before)
w2v_path = EXTRACT_DIR / "kmer_w2v_k6.model"
if w2v_path.exists():
    print(f"\n[OK] Found k-mer Word2Vec model: {w2v_path}")
else:
    print(f"\n[INFO] Word2Vec file not found at expected location: {w2v_path}")

# 5) nearby fastas if none in extracted
if not fastas_in_extracted:
    nearby = []
    for d in [Path.cwd(), Path.cwd().parent, Path.cwd().parent.parent]:
        nearby.extend(find_fastas(d))
    nearby = sorted({p.resolve() for p in nearby})
    if nearby:
        print(f"\n[FOUND] FASTA files in nearby locations (not in extracted/). Consider moving them to {EXTRACT_DIR}:")
        for p in nearby[:50]:
            print("  -", p)
    else:
        print("\n[INFO] No FASTA files found nearby either.")

# 6) next steps text
print("\n" + "="*60)
print("NEXT STEPS (choose what applies):\n")
if fastas_in_extracted:
    print("A) FASTAs are now present in extracted/. Run your Word2Vec training cell to create:")
    print("   ncbi_blast_db/extracted/kmer_w2v_k6.model")
else:
    print("A) No FASTAs detected. If you have the original tar.gz files, place them into:")
    print(f"   {DOWNLOAD_DIR.resolve()}")
    print("   Then re-run this cell. To inspect tar contents before extraction:")
    print("   tar -tf <tarfile> | head -n 40")
print("\n" + "="*60)


PWD: C:\Users\Srijit\sih
Checking expected dirs:
  DOWNLOAD_DIR: C:\Users\Srijit\sih\ncbi_blast_db
  EXTRACT_DIR : C:\Users\Srijit\sih\ncbi_blast_db\extracted

[OK] Found 6 FASTA files directly in ncbi_blast_db\extracted:
  - its_combined.fasta
  - its_fetched.fasta
  - lsu_combined.fasta
  - lsu_fetched.fasta
  - ssu_combined.fasta
  - ssu_fetched.fasta

[INFO] Found 3 tarball(s) (candidate sources):
  - C:\Users\Srijit\sih\ncbi_blast_db\ITS_eukaryote_sequences.tar.gz
  - C:\Users\Srijit\sih\ncbi_blast_db\LSU_eukaryote_rRNA.tar.gz
  - C:\Users\Srijit\sih\ncbi_blast_db\SSU_eukaryote_rRNA.tar.gz

[INFO] Word2Vec file not found at expected location: ncbi_blast_db\extracted\kmer_w2v_k6.model

NEXT STEPS (choose what applies):

A) FASTAs are now present in extracted/. Run your Word2Vec training cell to create:
   ncbi_blast_db/extracted/kmer_w2v_k6.model



In [3]:
# Cell: Inspect ncbi_blast_db tarball contents safely (corrected - no syntax errors)
import os, sys, tarfile, traceback
from pathlib import Path

DOWNLOAD_DIR = Path("./ncbi_blast_db")
expected_names = [
    "SSU_eukaryote_rRNA.tar.gz",
    "LSU_eukaryote_rRNA.tar.gz",
    "ITS_eukaryote_sequences.tar.gz",
]

def list_dir(path):
    if not path.exists():
        print(f"[MISSING] {path} does not exist.")
        return
    print(f"Listing {path.resolve()}:")
    for p in sorted(path.iterdir()):
        try:
            print(" ", p.name, "-", p.stat().st_size, "bytes")
        except Exception:
            print(" ", p.name)
    print("-"*40)

def find_tarballs(base_dir):
    tar_candidates = []
    if not base_dir.exists():
        return tar_candidates
    # exact match (case insensitive)
    for name in expected_names:
        for p in base_dir.iterdir():
            if p.is_file() and p.name.lower() == name.lower():
                tar_candidates.append(p.resolve())
    # fallback: glob
    if not tar_candidates:
        tar_candidates = sorted(list(base_dir.glob(".tar.gz")) + list(base_dir.glob(".tgz")) + list(base_dir.glob("*.tar")))
    # unique preserves order
    seen = []
    for p in tar_candidates:
        if p not in seen:
            seen.append(p)
    return seen

def show_tar_members(tar_path, max_show=200):
    print(f"\nTar: {tar_path.name} ({tar_path.stat().st_size} bytes)")
    try:
        with tarfile.open(tar_path, "r:*") as tf:
            members = tf.getmembers()
            print(f"  total members: {len(members)}")
            exts = {}
            for i, m in enumerate(members[:max_show]):
                name = m.name
                print(f"   {i+1:3d}. {name}")
                _, ext = os.path.splitext(name.lower())
                exts.setdefault(ext, 0)
                exts[ext] += 1
            if len(members) > max_show:
                print(f"   ... (showing first {max_show} members)")
            if exts:
                print("  detected member extensions (sample counts):")
                for k,v in sorted(exts.items(), key=lambda x:-x[1])[:20]:
                    print(f"    {k or '(no ext)'} : {v}")
    except Exception as e:
        print("  [ERROR] failed to read tar members:", e)
        traceback.print_exc()

# main
print("PWD:", Path.cwd())
list_dir(DOWNLOAD_DIR)
tarballs = find_tarballs(DOWNLOAD_DIR)
if not tarballs:
    print("\n[NO TARBALLS] No tarballs found in", DOWNLOAD_DIR.resolve())
    print("If the three tar.gz files are present, ensure they are in ncbi_blast_db/ and named correctly.")
else:
    print(f"\nFound {len(tarballs)} tarball(s):")
    for t in tarballs:
        print(" -", t.name)
    # list members for each
    for t in tarballs:
        show_tar_members(t, max_show=200)

print("\nDone. After you paste the result here, I will give the next cell tailored to what we find:")
print(" - if FASTA members exist -> Word2Vec training cell")
print(" - if BLAST DB binaries (.nsq/.nin/.nhr/.psq etc) -> blastdbcmd extraction cell")
print(" - else -> efetch / alternative instructions")

PWD: C:\Users\Srijit\sih
Listing C:\Users\Srijit\sih\ncbi_blast_db:
  blastdb_raw - 0 bytes
  extracted - 49152 bytes
  ITS_eukaryote_sequences-nucl-metadata.json - 446 bytes
  ITS_eukaryote_sequences.tar.gz - 74475777 bytes
  kmer_w2v_k6.model - 5141323 bytes
  LSU_eukaryote_rRNA-nucl-metadata.json - 465 bytes
  LSU_eukaryote_rRNA.tar.gz - 59387878 bytes
  SSU_eukaryote_rRNA-nucl-metadata.json - 465 bytes
  SSU_eukaryote_rRNA.tar.gz - 59776957 bytes
----------------------------------------

Found 3 tarball(s):
 - SSU_eukaryote_rRNA.tar.gz
 - LSU_eukaryote_rRNA.tar.gz
 - ITS_eukaryote_sequences.tar.gz

Tar: SSU_eukaryote_rRNA.tar.gz (59776957 bytes)
  total members: 12
     1. taxdb.btd
     2. taxdb.bti
     3. taxonomy4blast.sqlite3
     4. SSU_eukaryote_rRNA.nin
     5. SSU_eukaryote_rRNA.nhr
     6. SSU_eukaryote_rRNA.nsq
     7. SSU_eukaryote_rRNA.nog
     8. SSU_eukaryote_rRNA.ndb
     9. SSU_eukaryote_rRNA.nos
    10. SSU_eukaryote_rRNA.not
    11. SSU_eukaryote_rRNA.ntf
    12.

In [None]:
# Cell: FIXED — detect DB prefixes correctly and (if available) run blastdbcmd to dump FASTA
import os, traceback, subprocess
from pathlib import Path
from shutil import which

DOWNLOAD_DIR = Path("./ncbi_blast_db")
RAW_DB_DIR = DOWNLOAD_DIR / "blastdb_raw"
EXTRACT_DIR = DOWNLOAD_DIR / "extracted"
RAW_DB_DIR.mkdir(parents=True, exist_ok=True)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)

print("Working dirs:")
print(" DOWNLOAD_DIR:", DOWNLOAD_DIR.resolve())
print(" RAW_DB_DIR  :", RAW_DB_DIR.resolve())
print(" EXTRACT_DIR :", EXTRACT_DIR.resolve())
print()

# list raw db files
raw_files = sorted([p.name for p in RAW_DB_DIR.iterdir() if p.is_file()]) if RAW_DB_DIR.exists() else []
print("Raw DB files found:", len(raw_files))
for fn in raw_files[:200]:
    print(" ", fn)
print()

# group files by prefix (prefix = name before first dot)
prefix_map = {}
for fn in raw_files:
    pref = fn.split(".", 1)[0]
    prefix_map.setdefault(pref, []).append(fn)

if not prefix_map:
    print("[ERROR] No raw DB files found in", RAW_DB_DIR)
    print("If you previously extracted tarballs, ensure those extracted files are in:", RAW_DB_DIR)
    raise RuntimeError("No BLAST DB raw files found. Re-run extraction cell or place files in blastdb_raw/")

# display candidate prefixes & sample members
print("Detected file prefixes (sample members shown):")
for pref, members in prefix_map.items():
    print(" -", pref, "| members:", len(members), "| sample:", members[:6])
print()

# define core extensions that represent usable DB files
core_exts = {".nsq", ".psq", ".ndb", ".nsq", ".nhr", ".psq", ".nin"}  # include common core
# Normalize to lower-case for checking
core_exts = set(e.lower() for e in core_exts)

# detect candidate prefixes that actually have core DB files
candidate_prefixes = []
for pref, members in prefix_map.items():
    has_core = False
    for m in members:
        m_low = m.lower()
        for ext in core_exts:
            if m_low.endswith(ext):
                has_core = True
                break
        if has_core:
            break
    if has_core:
        candidate_prefixes.append(pref)

if not candidate_prefixes:
    print("[INFO] No DB prefixes with core DB files (.nsq/.psq/.ndb/etc) detected in blastdb_raw/")
    print(" -> The tarballs may not contain full DB binaries, or they are named unexpectedly.")
    print(" -> You can run 'show tar contents' again or paste the raw file list above if you want more help.")
    raise RuntimeError("No usable BLAST DB prefixes found.")

print("Candidate DB prefixes to try with blastdbcmd:", candidate_prefixes)
print()

# check for blastdbcmd on PATH
bcmd = which("blastdbcmd") or which("blastdbcmd.exe")
if not bcmd:
    print("[INFO] blastdbcmd not on PATH.")
    print("Install BLAST+ (bioconda or NCBI binaries) and ensure blastdbcmd is on PATH.")
    print("Conda (recommended): conda install -c bioconda blast")
    print("Windows binaries: https://ftp.ncbi.nlm.nih.gov/blast/executables/blast+/LATEST/")
    print("\nExample commands to run AFTER installing blastdbcmd (PowerShell/cmd):")
    for pref in candidate_prefixes:
        full_pref = RAW_DB_DIR / pref
        out_f = EXTRACT_DIR / f"{pref}_blastdb_dump.fasta"
        print(f'  blastdbcmd -db "{full_pref}" -entry all -outfmt %f -out "{out_f}"')
    print("\nOnce blastdbcmd is available, re-run this cell and it will create FASTA dumps in ncbi_blast_db/extracted/")
else:
    print("[OK] blastdbcmd found at:", bcmd)
    for pref in candidate_prefixes:
        full_pref = RAW_DB_DIR / pref
        out_f = EXTRACT_DIR / f"{pref}_blastdb_dump.fasta"
        if out_f.exists():
            print(f"[SKIP] FASTA already exists: {out_f.name}")
            continue
        cmd = [bcmd, "-db", str(full_pref), "-entry", "all", "-outfmt", "%f", "-out", str(out_f)]
        print("\n[RUN] ", " ".join(cmd))
        try:
            proc = subprocess.run(cmd, capture_output=True, text=True, check=False)
            if proc.returncode == 0:
                size = out_f.stat().st_size if out_f.exists() else 0
                print(f" [OK] dumped FASTA -> {out_f.name} ({size} bytes)")
            else:
                print(f" [ERR] blastdbcmd returned code {proc.returncode}")
                print("  stdout:", proc.stdout[:1000])
                print("  stderr:", proc.stderr[:1000])
        except Exception as e:
            print(" [EXC] Exception running blastdbcmd for", pref, ":", e)
            traceback.print_exc()

# final check: list FASTA files in EXTRACT_DIR
final_fastas = sorted([p.name for p in EXTRACT_DIR.iterdir() if p.is_file() and p.suffix.lower() in (".fasta", ".fa", ".fna")])
print()
print("Final FASTA files in extracted/:", len(final_fastas))
for f in final_fastas[:200]:
    print(" ", f)

print("\nCell finished. If FASTA dumps were created, next step: run the k-mer Word2Vec cell (or the downstream embedding cell).")

In [22]:
# Create a folder called "blastdb_raw" (idempotent) — paste & run in a Jupyter cell
from pathlib import Path
import os

# Change this if you want the folder somewhere else, e.g. Path("/data/blastdb_raw")
BASE_DIR = Path.cwd()               # current notebook working directory
TARGET = BASE_DIR / "blastdb_raw"

try:
    TARGET.mkdir(parents=True, exist_ok=True)
    print(f"Directory created or already exists: {TARGET.resolve()}")
    # show basic listing
    items = list(TARGET.iterdir())
    if items:
        print(f"Contents of {TARGET.name} ({len(items)}):")
        for p in items:
            print("  -", p.name)
    else:
        print(f"{TARGET.name} is empty.")
except PermissionError:
    print(f"[ERROR] Permission denied creating: {TARGET}")
except Exception as e:
    print(f"[ERROR] Could not create directory: {e}")


Directory created or already exists: C:\Users\Srijit\sih\blastdb_raw
Contents of blastdb_raw (30):
  - ITS_eukaryote_sequences.ndb
  - ITS_eukaryote_sequences.nhr
  - ITS_eukaryote_sequences.nin
  - ITS_eukaryote_sequences.nog
  - ITS_eukaryote_sequences.nos
  - ITS_eukaryote_sequences.not
  - ITS_eukaryote_sequences.nsq
  - ITS_eukaryote_sequences.ntf
  - ITS_eukaryote_sequences.nto
  - LSU_eukaryote_rRNA.ndb
  - LSU_eukaryote_rRNA.nhr
  - LSU_eukaryote_rRNA.nin
  - LSU_eukaryote_rRNA.nog
  - LSU_eukaryote_rRNA.nos
  - LSU_eukaryote_rRNA.not
  - LSU_eukaryote_rRNA.nsq
  - LSU_eukaryote_rRNA.ntf
  - LSU_eukaryote_rRNA.nto
  - SSU_eukaryote_rRNA.ndb
  - SSU_eukaryote_rRNA.nhr
  - SSU_eukaryote_rRNA.nin
  - SSU_eukaryote_rRNA.nog
  - SSU_eukaryote_rRNA.nos
  - SSU_eukaryote_rRNA.not
  - SSU_eukaryote_rRNA.nsq
  - SSU_eukaryote_rRNA.ntf
  - SSU_eukaryote_rRNA.nto
  - taxdb.btd
  - taxdb.bti
  - taxonomy4blast.sqlite3


In [None]:
# Cell: build/load kmer W2V and compute sequence embeddings (robust, no name collisions)
import os, sys, csv, traceback
from pathlib import Path
import numpy as np
import pandas as pd

# Config (edit only if necessary)
DOWNLOAD_DIR = Path("./ncbi_blast_db").resolve()
EXTRACT_DIR  = DOWNLOAD_DIR / "extracted"
K = 6
VECTOR_SIZE = 128  # will be overwritten if model found with different size
W2V_BASENAME = f"kmer_w2v_k{K}.model"
OUT_EMB = EXTRACT_DIR / "embeddings.npy"
OUT_META = EXTRACT_DIR / "embeddings_meta.csv"

print("[PATHS] DOWNLOAD_DIR:", DOWNLOAD_DIR)
print("[PATHS] EXTRACT_DIR :", EXTRACT_DIR)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)

# 1) find FASTA files (explicit preferred names then glob fallback)
expected = [
    "ssu_fetched.fasta", "ssu_combined.fasta",
    "lsu_fetched.fasta", "lsu_combined.fasta",
    "its_fetched.fasta", "its_combined.fasta"
]
fasta_paths = []
for name in expected:
    p = EXTRACT_DIR / name
    if p.exists():
        fasta_paths.append(p)
# fallback: recursive glob if none or partial found
if not fasta_paths:
    fasta_paths = sorted(EXTRACT_DIR.rglob(".fasta")) + sorted(EXTRACT_DIR.rglob(".fa")) + sorted(EXTRACT_DIR.rglob("*.fna"))

# final check
if not fasta_paths:
    print("\n--- DEBUG DIRECTORY LISTINGS ---")
    def list_dir(p):
        if not p.exists():
            print("  (not found)", p)
            return
        for it in sorted(p.iterdir()):
            print("  ", it.name, "-" , ("<DIR>" if it.is_dir() else f"{it.stat().st_size} bytes"))
    print("Contents of EXTRACT_DIR:")
    list_dir(EXTRACT_DIR)
    print("Contents of DOWNLOAD_DIR:")
    list_dir(DOWNLOAD_DIR)
    raise FileNotFoundError(f"No FASTA files found under {EXTRACT_DIR} or {DOWNLOAD_DIR}. You must produce FASTA dumps (blastdbcmd or earlier extraction).")

print("[FOUND] FASTA files (count):", len(fasta_paths))
for p in fasta_paths:
    try:
        print(" -", p.name, "|", p.stat().st_size, "bytes")
    except Exception:
        print(" -", p)

# 2) find Word2Vec model (prefer EXTRACT_DIR, then DOWNLOAD_DIR, then any kmer_w2v* in repo)
candidate_models = []
cand1 = EXTRACT_DIR / W2V_BASENAME
cand2 = DOWNLOAD_DIR / W2V_BASENAME
if cand1.exists(): candidate_models.append(cand1)
if cand2.exists(): candidate_models.append(cand2)
# search for any file starting with 'kmer_w2v'
for p in [EXTRACT_DIR, DOWNLOAD_DIR, Path.cwd()]:
    if p.exists():
        for f in p.iterdir():
            if f.is_file() and f.name.lower().startswith("kmer_w2v"):
                if f not in candidate_models:
                    candidate_models.append(f)
if candidate_models:
    print("[FOUND] Word2Vec candidate(s):")
    for m in candidate_models:
        print(" -", m, "|", m.stat().st_size, "bytes")
else:
    print("[NOT FOUND] No existing k-mer Word2Vec model found. The cell will train one (may take time).")

# 3) import modules (Biopython, gensim)
try:
    from Bio import SeqIO
except Exception as e:
    print("ERROR: Biopython is required. Install via conda install -c conda-forge biopython or pip install biopython.")
    raise

try:
    from gensim.models import Word2Vec
except Exception as e:
    print("ERROR: gensim is required. Install via conda install -c conda-forge gensim or pip install gensim.")
    raise

# 4) Load or train Word2Vec
w2v = None
if candidate_models:
    # try to load first candidate that works
    for mp in candidate_models:
        try:
            w2v = Word2Vec.load(str(mp))
            print(f"[W2V LOAD] Loaded model from: {mp}")
            break
        except Exception as e:
            print(f"[W2V LOAD] Failed to load {mp}: {e}")
            continue

# Safe re-iterable k-mer corpus for training (only if needed)
class KmerCorpus:
    def _init_(self, fasta_paths, k=6, min_kmers=1):
        self.fasta_paths = [Path(p) for p in fasta_paths]
        self.k = int(k)
        self.min_kmers = int(min_kmers)
    def _iter_(self):
        for fp in self.fasta_paths:
            with fp.open("r", errors="replace") as fh:
                for rec in SeqIO.parse(fh, "fasta"):
                    seq = str(rec.seq).upper().replace("\n","").replace("\r","")
                    if len(seq) < self.k:
                        continue
                    kmers = []
                    end = len(seq) - self.k + 1
                    for i in range(end):
                        kmer = seq[i:i+self.k]
                        if "N" in kmer or "-" in kmer:
                            continue
                        kmers.append(kmer)
                    if len(kmers) >= self.min_kmers:
                        yield kmers

if w2v is None:
    print("[W2V TRAIN] No loadable model found; training new Word2Vec (k=%d) from discovered FASTAs." % K)
    corpus = KmerCorpus(fasta_paths, k=K, min_kmers=1)
    # minimal/robust Word2Vec parameters
    workers = max(1, (os.cpu_count() or 1) - 1)
    w2v = Word2Vec(vector_size=VECTOR_SIZE, window=5, min_count=1, workers=workers, sg=1, seed=42)
    print("[W2V TRAIN] Building vocab (this may take a minute)...")
    w2v.build_vocab(corpus)
    print("[W2V TRAIN] vocab_size:", len(w2v.wv.index_to_key))
    EPOCHS = 8
    print(f"[W2V TRAIN] Training for {EPOCHS} epochs (workers={workers}) ...")
    w2v.train(corpus, total_examples=w2v.corpus_count, epochs=EPOCHS)
    # save into EXTRACT_DIR for reproducibility
    save_path = EXTRACT_DIR / W2V_BASENAME
    w2v.save(str(save_path))
    print("[W2V TRAIN] Saved new Word2Vec to:", save_path)

# determine vector size from model
try:
    VECTOR_SIZE = w2v.vector_size
except Exception:
    VECTOR_SIZE = w2v.wv.vector_size

print("[W2V] ready. vector_size =", VECTOR_SIZE, "vocab_size =", len(w2v.wv.index_to_key))

# 5) Build embeddings: average k-mer vectors per sequence
rows = []
embs = []
zero_fallback = 0
total_seq = 0
missing_kmer_counts = []

for fp in fasta_paths:
    with fp.open("r", errors="replace") as fh:
        for rec in SeqIO.parse(fh, "fasta"):
            total_seq += 1
            seq = str(rec.seq).upper().replace("\n","").replace("\r","")
            seq_id = rec.id if getattr(rec, "id", None) else (rec.name if getattr(rec,"name",None) else "")
            if not seq:
                # skip empty seqs gracefully
                continue
            kmers = []
            end = len(seq) - K + 1
            if end > 0:
                for i in range(end):
                    kmer = seq[i:i+K]
                    if "N" in kmer or "-" in kmer:
                        continue
                    kmers.append(kmer)
            n_kmers = len(kmers)
            # gather vectors
            vecs = []
            missing = 0
            for kmer in kmers:
                try:
                    # membership check then get_vector to avoid KeyError on some gensim versions
                    if kmer in w2v.wv:
                        vecs.append(w2v.wv.get_vector(kmer))
                    else:
                        missing += 1
                except Exception:
                    # defensive fallback
                    missing += 1
            if vecs:
                emb = np.mean(np.stack(vecs, axis=0), axis=0)
            else:
                emb = np.zeros(VECTOR_SIZE, dtype=float)
                zero_fallback += 1
            embs.append(emb.astype(np.float32))
            missing_kmer_counts.append(missing)
            rows.append({
                "id": seq_id,
                "source_fasta": fp.name,
                "seq_len": len(seq),
                "n_kmers": n_kmers,
                "n_kmers_missing": missing,
                "zero_vector_fallback": int(np.allclose(emb, 0.0))
            })

# Convert and save
if not embs:
    raise RuntimeError("No embeddings produced (corpus empty after k-mer filtering). Check FASTA files and K value.")

emb_matrix = np.stack(embs, axis=0)
print(f"[EMB] produced embeddings: {emb_matrix.shape}  (total_seq={total_seq}, zero_fallback={zero_fallback})")

# Save numpy and metadata CSV (safe CSV options)
OUT_EMB.parent.mkdir(parents=True, exist_ok=True)
np.save(OUT_EMB, emb_matrix)
df_meta = pd.DataFrame(rows)
# use safe csv escaping to avoid pandas csv writer 'escape' issues
df_meta.to_csv(OUT_META, index=False, encoding="utf-8", quoting=csv.QUOTE_MINIMAL, escapechar='\\')

print("[SAVE] saved embeddings:", OUT_EMB)
print("[SAVE] saved metadata :", OUT_META)
print("[DONE] Example metadata rows:\n", df_meta.head().to_string(index=False))

[PATHS] DOWNLOAD_DIR: C:\Users\Srijit\sih\ncbi_blast_db
[PATHS] EXTRACT_DIR : C:\Users\Srijit\sih\ncbi_blast_db\extracted
[FOUND] FASTA files (count): 6
 - ssu_fetched.fasta | 86980019 bytes
 - ssu_combined.fasta | 85918671 bytes
 - lsu_fetched.fasta | 67876422 bytes
 - lsu_combined.fasta | 67048142 bytes
 - its_fetched.fasta | 612802 bytes
 - its_combined.fasta | 604595 bytes
[FOUND] Word2Vec candidate(s):
 - C:\Users\Srijit\sih\ncbi_blast_db\kmer_w2v_k6.model | 5141323 bytes
[W2V LOAD] Loaded model from: C:\Users\Srijit\sih\ncbi_blast_db\kmer_w2v_k6.model
[W2V] ready. vector_size = 128 vocab_size = 4874


In [17]:
# Cell: PCA + clustering + novelty scoring (robust)
import json, math
from pathlib import Path
import numpy as np
import pandas as pd

# sklearn imports
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN
from sklearn.metrics import pairwise_distances

# Config
DOWNLOAD_DIR = Path("./ncbi_blast_db").resolve()
EXTRACT_DIR  = DOWNLOAD_DIR / "extracted"
EMB_NPY = EXTRACT_DIR / "embeddings.npy"
META_CSV = EXTRACT_DIR / "embeddings_meta.csv"

OUT_PCA = EXTRACT_DIR / "embeddings_pca.npy"
OUT_META_PCA = EXTRACT_DIR / "embeddings_meta_pca.csv"
OUT_META_CLUSTERED = EXTRACT_DIR / "embeddings_meta_clustered.csv"
OUT_CLUSTER_SUM = EXTRACT_DIR / "cluster_summary.csv"
PCA_COMPONENTS = 64
RANDOM_STATE = 42

print("[PATHS] EXTRACT_DIR:", EXTRACT_DIR)
print("[FILES] EMB_NPY:", EMB_NPY)
print("[FILES] META_CSV:", META_CSV)

# sanity: files exist
if not EMB_NPY.exists():
    raise FileNotFoundError(f"Missing embeddings file: {EMB_NPY}")
if not META_CSV.exists():
    raise FileNotFoundError(f"Missing metadata CSV: {META_CSV}")

# load
emb = np.load(EMB_NPY)
meta = pd.read_csv(META_CSV, dtype=str, keep_default_na=False, na_filter=False)

# coerce numeric columns we expect
if 'seq_len' in meta.columns:
    try:
        meta['seq_len'] = pd.to_numeric(meta['seq_len'], errors='coerce').fillna(0).astype(int)
    except Exception:
        pass

# align shapes (trim to min)
n_meta = len(meta)
n_emb = emb.shape[0]
if n_meta != n_emb:
    nmin = min(n_meta, n_emb)
    print(f"[ALIGN] mismatch meta({n_meta}) vs emb({n_emb}) -> trimming to {nmin}")
    meta = meta.iloc[:nmin].reset_index(drop=True)
    emb = emb[:nmin]

print(f"[LOAD] embeddings shape: {emb.shape}; meta rows: {len(meta)}")

# Standardize then PCA (safe n_components)
n_features = emb.shape[1]
n_samples = emb.shape[0]
n_comp = min(PCA_COMPONENTS, n_features, n_samples)
print(f"[PCA] requested {PCA_COMPONENTS} -> using n_components = {n_comp}")

scaler = StandardScaler(with_mean=True, with_std=True)
emb_scaled = scaler.fit_transform(emb)

pca = PCA(n_components=n_comp, random_state=RANDOM_STATE)
X_pca = pca.fit_transform(emb_scaled)
print(f"[PCA] done. explained_variance_ratio_.sum() = {pca.explained_variance_ratio_.sum():.4f}")

# Save PCA embeddings
np.save(OUT_PCA, X_pca)
print("[SAVE] saved PCA embeddings ->", OUT_PCA)

# add PC columns to meta
pc_cols = [f"PC{i+1}" for i in range(X_pca.shape[1])]
meta_pca = meta.copy()
for i, col in enumerate(pc_cols):
    meta_pca[col] = X_pca[:, i]

# write meta_pca CSV
meta_pca.to_csv(OUT_META_PCA, index=False)
print("[SAVE] saved meta+PCA ->", OUT_META_PCA)

# ----------------- Clustering -----------------
# Prefer hdbscan if available
use_hdbscan = False
try:
    import hdbscan
    use_hdbscan = True
    print("[CLUSTER] hdbscan available; will use HDBSCAN.")
except Exception:
    print("[CLUSTER] hdbscan not available; falling back to DBSCAN.")

cluster_labels = None
cluster_probs = None
clusterer_obj = None

if use_hdbscan:
    # HDBSCAN parameters tuned for sequence clusters, but still conservative:
    clusterer = hdbscan.HDBSCAN(min_cluster_size=6, min_samples=1,
                                metric='euclidean', cluster_selection_method='eom',
                                prediction_data=False)
    clusterer.fit(X_pca)
    cluster_labels = clusterer.labels_.astype(int)
    # HDBSCAN may provide membership probabilities
    if hasattr(clusterer, "membership_vector_") and clusterer.membership_vector_ is not None:
        cluster_probs = None  # membership_vector_ is per-cluster; skip generic prob
    elif hasattr(clusterer, "probabilities_"):
        cluster_probs = clusterer.probabilities_
    elif hasattr(clusterer, "membership_strengths_"):
        cluster_probs = clusterer.membership_strengths_
    else:
        cluster_probs = None
    clusterer_obj = clusterer
else:
    # DBSCAN fallback: eps selected relative to data scale. Try to pick reasonable eps with median pairwise dist
    try:
        # compute a small sample of pairwise distances for heuristics (to save time)
        sample_idx = np.random.RandomState(RANDOM_STATE).choice(X_pca.shape[0], min(800, X_pca.shape[0]), replace=False)
        D_sample = pairwise_distances(X_pca[sample_idx], metric='euclidean')
        median_d = np.median(D_sample)
        eps = float(max(0.5 * median_d, 0.01))
    except Exception:
        eps = 0.5
    print(f"[DBSCAN] using eps={eps:.4g}, min_samples=5")
    db = DBSCAN(eps=eps, min_samples=5, metric='euclidean', n_jobs=-1)
    cluster_labels = db.fit_predict(X_pca).astype(int)
    cluster_probs = None
    clusterer_obj = db

meta_pca['cluster_label'] = cluster_labels
if cluster_probs is not None:
    meta_pca['cluster_prob'] = cluster_probs
else:
    meta_pca['cluster_prob'] = ""

n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
n_noise = int((cluster_labels == -1).sum())
print(f"[CLUSTER] labels produced: unique_labels={len(set(cluster_labels))} clusters={n_clusters} noise={n_noise}")

# ---------------- novelty: distance-to-cluster-centroid (normalized) ----------------
Xp = X_pca
labels = cluster_labels
centroids = {}
cluster_sizes = {}
dist_to_centroid = np.zeros(Xp.shape[0], dtype=float)

for lbl in np.unique(labels):
    idxs = np.where(labels == lbl)[0]
    cluster_sizes[int(lbl)] = int(len(idxs))
    if lbl == -1:
        # noise -> we'll set centroid as None; distance computed to nearest non-noise centroid later
        centroids[int(lbl)] = None
        dist_to_centroid[idxs] = np.nan
        continue
    centroid = Xp[idxs].mean(axis=0)
    centroids[int(lbl)] = centroid
    # distances
    d = np.linalg.norm(Xp[idxs] - centroid[None, :], axis=1)
    dist_to_centroid[idxs] = d

# For noise points (lbl == -1), compute distance to nearest cluster centroid (if any)
if -1 in centroids and any(v is not None for v in centroids.values()):
    non_noise_centroids = np.vstack([v for k,v in centroids.items() if v is not None])
    noise_idxs = np.where(labels == -1)[0]
    if noise_idxs.size > 0:
        # distances to nearest centroid
        D = pairwise_distances(Xp[noise_idxs], non_noise_centroids, metric='euclidean')
        min_d = D.min(axis=1)
        dist_to_centroid[noise_idxs] = min_d

# handle nan/missing: replace nan by max distance (conservative novelty)
nan_mask = np.isnan(dist_to_centroid)
if nan_mask.any():
    dist_to_centroid[nan_mask] = np.nanmax(dist_to_centroid[~nan_mask]) if (~nan_mask).any() else 0.0

# normalize distances to 0..1 -> novelty score (higher => more novel/outlier)
dmin = float(np.nanmin(dist_to_centroid)) if np.isfinite(dist_to_centroid).any() else 0.0
dmax = float(np.nanmax(dist_to_centroid)) if np.isfinite(dist_to_centroid).any() else 0.0
if dmax > dmin:
    novelty = (dist_to_centroid - dmin) / (dmax - dmin)
else:
    novelty = np.zeros_like(dist_to_centroid)

# For pure noise label -1 we want slightly boosted novelty (cap at 1.0)
novelty = np.clip(novelty, 0.0, 1.0)
novelty = novelty.tolist()

meta_pca['cluster_size'] = meta_pca['cluster_label'].map(cluster_sizes).fillna(0).astype(int)
meta_pca['cluster_centroid_dist'] = dist_to_centroid
meta_pca['novelty_score'] = novelty
meta_pca['is_noise'] = (meta_pca['cluster_label'] == -1).astype(int)

# reorder columns: keep id, seq_len, source_fasta early if present
cols = list(meta_pca.columns)
preferred_front = ['id', 'source_fasta', 'seq_len', 'n_kmers', 'n_kmers_missing', 'zero_vector_fallback', 'cluster_label', 'cluster_size', 'novelty_score']
cols_ordered = [c for c in preferred_front if c in cols] + [c for c in cols if c not in preferred_front]
meta_pca = meta_pca[cols_ordered]

# save clustered meta CSV
meta_pca.to_csv(OUT_META_CLUSTERED, index=False)
print("[SAVE] saved clustered metadata ->", OUT_META_CLUSTERED)

# cluster summary
cluster_summary = []
for lbl, size in cluster_sizes.items():
    idxs = np.where(labels == lbl)[0]
    mean_nov = float(np.nanmean(np.array(novelty)[idxs])) if idxs.size>0 else float('nan')
    cluster_summary.append({
        "cluster_label": int(lbl),
        "n": int(size),
        "mean_novelty": mean_nov
    })
df_cluster_sum = pd.DataFrame(sorted(cluster_summary, key=lambda r: (-r['n'], -r['mean_novelty'])))
df_cluster_sum.to_csv(OUT_CLUSTER_SUM, index=False)
print("[SAVE] saved cluster summary ->", OUT_CLUSTER_SUM)

# print top statistics & top novel candidates
print(f"[RESULT] clusters: {n_clusters}  noise: {n_noise}  total_samples: {len(meta_pca)}")
top_clusters = df_cluster_sum.head(8)
print("[TOP CLUSTERS] (cluster_label, n, mean_novelty):")
print(top_clusters.to_string(index=False))

# top novel candidate rows (by novelty)
top_novel = meta_pca.sort_values("novelty_score", ascending=False).head(20)
print("\n[TOP NOVEL CANDIDATES] (first 20):")
print(top_novel[["id","source_fasta","cluster_label","cluster_size","novelty_score"]].to_string(index=False))

# Save a compact JSON summary for quick programmatic consumption
summary = {
    "n_samples": int(len(meta_pca)),
    "n_clusters": int(n_clusters),
    "n_noise": int(n_noise),
    "top_clusters": df_cluster_sum.head(10).to_dict(orient="records"),
}
with open(EXTRACT_DIR / "cluster_summary.json", "w", encoding="utf-8") as fh:
    json.dump(summary, fh, indent=2)
print("[SAVE] saved cluster_summary.json")

print("\n[OK] PCA + clustering + novelty finished. Next: (a) map known taxonomy to clusters, (b) build training labels, (c) train multi-head classifier.")

[PATHS] EXTRACT_DIR: C:\Users\Srijit\sih\ncbi_blast_db\extracted
[FILES] EMB_NPY: C:\Users\Srijit\sih\ncbi_blast_db\extracted\embeddings.npy
[FILES] META_CSV: C:\Users\Srijit\sih\ncbi_blast_db\extracted\embeddings_meta.csv
[LOAD] embeddings shape: (2555, 128); meta rows: 2555
[PCA] requested 64 -> using n_components = 64
[PCA] done. explained_variance_ratio_.sum() = 0.9637
[SAVE] saved PCA embeddings -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\embeddings_pca.npy
[SAVE] saved meta+PCA -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\embeddings_meta_pca.csv
[CLUSTER] hdbscan available; will use HDBSCAN.




[CLUSTER] labels produced: unique_labels=102 clusters=101 noise=423
[SAVE] saved clustered metadata -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\embeddings_meta_clustered.csv
[SAVE] saved cluster summary -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\cluster_summary.csv
[RESULT] clusters: 101  noise: 423  total_samples: 2555
[TOP CLUSTERS] (cluster_label, n, mean_novelty):
 cluster_label   n  mean_novelty
            -1 423  3.372256e-01
            81 270  2.209805e-03
            90 244  3.558635e-04
            92 212  5.829415e-04
            94 190  1.314258e-04
            76  72  6.895195e-03
            68  54  5.785880e-02
            89  52  8.022760e-08

[TOP NOVEL CANDIDATES] (first 20):
            id       source_fasta  cluster_label  cluster_size  novelty_score
XR_013100016.1  ssu_fetched.fasta             -1           423       1.000000
XR_013100016.1  ssu_fetched.fasta             -1           423       1.000000
    LC876591.1  lsu_fetched.fasta             -1       

In [19]:
# Cell A: Build final per-sample taxonomy labels and save label encoders + y arrays (robust)
# - Reads embeddings_meta_clustered.csv (fallbacks), fetched metadata JSONs and/or label_assignment_debug.csv
# - Produces label_encoders_final.pkl, y_encoded_final_<rank>.npy, and label_assignment_debug_final.csv
# - Defensive: handles encoding issues, missing files, truncated 'UNASSIGNE' artifacts, and alignment to meta rows.

import json, re, csv, pickle
from pathlib import Path
from collections import defaultdict, Counter
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder

# ---------------- Config / paths (uses existing DOWNLOAD_DIR if present in notebook globals) ----------------
DOWNLOAD_DIR = Path(globals().get("DOWNLOAD_DIR", "./ncbi_blast_db"))
EXTRACT_DIR = Path(globals().get("EXTRACT_DIR", DOWNLOAD_DIR / "extracted"))
OUT_ENCODERS = EXTRACT_DIR / "label_encoders_final.pkl"
OUT_Y_TEMPLATE = EXTRACT_DIR / "y_encoded_final_{}.npy"
DEBUG_OUT = EXTRACT_DIR / "label_assignment_debug_final.csv"

RANKS = ["kingdom", "phylum", "class", "order", "family", "genus", "species"]

# ---------------- Helpers ----------------
def norm_label(x):
    """Normalize taxonomy label strings to safe canonical form; empty -> 'UNASSIGNED'"""
    if x is None:
        return "UNASSIGNED"
    s = str(x).strip()
    if s == "" or s.lower() in ("nan","none","na","-","?"):
        return "UNASSIGNED"
    # collapse whitespace and remove weird nulls
    s = re.sub(r"\s+", " ", s)
    s = s.replace("\x00", "").replace("\ufffd", "?").strip()
    # fix common truncated artifact like 'UNASSIGNE' or 'UNASSIG' -> 'UNASSIGNED'
    if s.upper().startswith("UNASSIG"):
        return "UNASSIGNED"
    return s

def safe_json_load(p):
    """Load JSON robustly, return list of records or empty list"""
    try:
        with open(p, "r", encoding="utf-8", errors="replace") as fh:
            data = json.load(fh)
            # If single dict, wrap in list
            if isinstance(data, dict):
                return [data]
            if isinstance(data, list):
                return data
            # sometimes file might be newline-delimited JSON
            # attempt line-by-line parse
            fh.seek(0)
    except Exception:
        # try newline-delimited JSON fallback
        recs = []
        try:
            with open(p, "r", encoding="utf-8", errors="replace") as fh:
                for line in fh:
                    line = line.strip()
                    if not line:
                        continue
                    try:
                        recs.append(json.loads(line))
                    except Exception:
                        # skip malformed lines
                        continue
        except Exception:
            return []
        return recs
    return []

def extract_rec_id(rec):
    """
    For a metadata record dict, return best id string (accession, accession.version, id, etc.)
    """
    for key in ("id","accession_version","accession.version","accession","accession.ver","gi","seqid","name"):
        if isinstance(rec, dict) and key in rec and rec.get(key):
            return str(rec.get(key))
    # fallback: try description/organism (not ideal)
    for key in ("organism","description","title"):
        if isinstance(rec, dict) and key in rec and rec.get(key):
            # attempt to extract accession-like token
            tok = str(rec.get(key)).split()[0]
            if tok:
                return tok
    return ""

def build_tax_map_from_rec(rec):
    """
    Given a metadata record, return a dict with keys for RANKS (where available).
    Accepts rec['taxonomy'] as list OR rec may already have structured taxonomy keys.
    """
    tmap = {}
    # 1) If rec has 'taxonomy' list (ordered), map to ranks
    taxlist = rec.get("taxonomy") if isinstance(rec, dict) else None
    if isinstance(taxlist, list) and len(taxlist) > 0:
        ranks_order = ["kingdom","phylum","class","order","family","genus"]
        for i, v in enumerate(taxlist[:len(ranks_order)]):
            if v:
                tmap[ranks_order[i]] = norm_label(v)
    # 2) If rec has explicit rank keys, copy them
    for r in RANKS:
        if r in rec and rec.get(r):
            tmap[r] = norm_label(rec.get(r))
    # 3) organism -> genus/species heuristic
    organism = (rec.get("organism") or rec.get("description") or rec.get("title") or "").strip()
    if organism:
        parts = re.split(r"\s+", organism)
        if len(parts) >= 2:
            # genus = first, species = first + second
            tmap.setdefault("genus", norm_label(parts[0]))
            tmap.setdefault("species", norm_label(parts[0] + " " + parts[1]))
        else:
            # single token organism -> fill genus only
            tmap.setdefault("genus", norm_label(parts[0]))
    # ensure all ranks present (or will be filled later)
    for r in RANKS:
        if r not in tmap:
            tmap[r] = None
    return tmap

# ---------------- Load embeddings meta (pick the best available) ----------------
meta_candidates = [
    EXTRACT_DIR / "embeddings_meta_clustered.csv",
    EXTRACT_DIR / "embeddings_meta_pca.csv",
    EXTRACT_DIR / "embeddings_meta.csv",
    EXTRACT_DIR / "embeddings_meta.csv"  # repeated but harmless
]
df_meta = None
for p in meta_candidates:
    if p.exists():
        try:
            df_meta = pd.read_csv(p, dtype=str, keep_default_na=False, na_filter=False)
            print(f"[USE] metadata file: {p} (rows={len(df_meta)})")
            break
        except Exception as e:
            print(f"[WARN] failed to read {p}: {e}")
            continue

if df_meta is None:
    # final fallback: embeddings_meta.csv may not exist; try embeddings_meta (other names)
    raise FileNotFoundError(f"No embeddings metadata CSV found in {EXTRACT_DIR}. Expected one of: {', '.join(str(x) for x in meta_candidates)}")

# Ensure there is an 'id' column; try other column names if absent
id_col = None
for cand in ["id","seq_id","accession","accession_version","header"]:
    if cand in df_meta.columns:
        id_col = cand
        break
# fallback: take first column as id
if id_col is None:
    id_col = df_meta.columns[0]
    print(f"[WARN] No explicit id column found. Using '{id_col}' (first column) as id.")

# normalize id column strings
df_meta[id_col] = df_meta[id_col].astype(str).fillna("").apply(lambda s: s.strip())
df_meta = df_meta.reset_index(drop=True)
n_samples = len(df_meta)
print(f"[INFO] meta rows: {n_samples}; id_col used: '{id_col}'")

# ---------------- Build id -> taxonomy lookup from fetched metadata JSON files ----------------
id_to_tax = {}
index_alt = defaultdict(dict)  # accession, accession_base, organism -> rec_id

# look for *_fetched_metadata.json or *_fetched.json in EXTRACT_DIR
meta_json_paths = sorted(EXTRACT_DIR.glob("_fetched_metadata.json")) + sorted(EXTRACT_DIR.glob("_fetched_meta.json")) + sorted(EXTRACT_DIR.glob("_fetched.json")) 
# also include *_fetched_metadata.json produced earlier (explicit names)
meta_json_paths = [p for p in meta_json_paths if p.exists()]
if not meta_json_paths:
    # try the explicit names observed earlier
    for name in ("ssu_fetched_metadata.json","lsu_fetched_metadata.json","its_fetched_metadata.json"):
        p = EXTRACT_DIR / name
        if p.exists():
            meta_json_paths.append(p)

print(f"[INFO] metadata JSON files discovered: {len(meta_json_paths)}")

for p in meta_json_paths:
    recs = safe_json_load(p)
    if not recs:
        continue
    for rec in recs:
        try:
            rec_id = extract_rec_id(rec)
            if not rec_id:
                # try accession fields in various naming
                rec_id = (rec.get("accession") or rec.get("accession_version") or rec.get("id") or "")
            rec_id = str(rec_id).strip()
            if not rec_id:
                continue
            tmap = build_tax_map_from_rec(rec)
            id_to_tax[rec_id] = tmap
            # alt indices
            acc = rec.get("accession") or ""
            if acc:
                index_alt["accession"][str(acc)] = rec_id
                index_alt["accession_base"][str(acc).split(".")[0]] = rec_id
            index_alt["id_base"][rec_id.split(".")[0]] = rec_id
            if rec.get("organism"):
                index_alt["organism"][str(rec.get("organism")).lower()] = rec_id
        except Exception:
            continue

print(f"[INFO] id_to_tax entries from fetched jsons: {len(id_to_tax)}")

# ---------------- Also ingest label_assignment_debug.csv if present (extra mapping source) ----------------
debug_csv = EXTRACT_DIR / "label_assignment_debug.csv"
if debug_csv.exists():
    try:
        df_dbg = pd.read_csv(debug_csv, dtype=str, keep_default_na=False, na_filter=False)
        print(f"[USE] loaded label_assignment_debug.csv rows={len(df_dbg)}")
        # Attempt to find id-like column
        dbg_id_col = None
        for cand in ("id","accession","accession_version","acc_base","acc","header"):
            if cand in df_dbg.columns:
                dbg_id_col = cand
                break
        if dbg_id_col is None:
            dbg_id_col = df_dbg.columns[0]
        df_dbg[dbg_id_col] = df_dbg[dbg_id_col].astype(str).fillna("").apply(lambda s: s.strip())
        # we expect columns named like assigned_genus / assigned_species etc. collect any columns containing 'assigned' or rank names
        for _, row in df_dbg.iterrows():
            rid = row.get(dbg_id_col, "")
            if not rid:
                continue
            tmap = {}
            for r in RANKS:
                # look for columns matching the rank name or 'assigned_<rank>'
                possible_cols = [c for c in df_dbg.columns if (c.lower() == r or c.lower().endswith("_"+r) or "assigned" in c.lower() and r in c.lower())]
                val = None
                for c in possible_cols:
                    v = row.get(c, "")
                    if isinstance(v, str) and v.strip() != "":
                        val = v
                        break
                tmap[r] = norm_label(val) if val else None
            # if tmap has any non-empty entries, store in id_to_tax (but don't override existing fetched)
            if any(v for v in tmap.values()):
                if rid not in id_to_tax:
                    id_to_tax[rid] = {k: (v if v is not None else None) for k, v in tmap.items()}
                # also populate alt indices
                base = rid.split(".")[0]
                index_alt["id_base"][base] = rid
                acc = row.get("accession","")
                if acc:
                    index_alt["accession"][acc] = rid
                    index_alt["accession_base"][acc.split(".")[0]] = rid
                org = row.get("organism","")
                if org:
                    index_alt["organism"][org.lower()] = rid
    except Exception as e:
        print(f"[WARN] failed to read label_assignment_debug.csv: {e}")

# ---------------- Now map each df_meta row to a taxonomy (with fallbacks) ----------------
assigned_rows = []
no_match = 0
for ix, row in df_meta.iterrows():
    rid = str(row.get(id_col,"")).strip()
    out = {"id": rid}
    matched_source = ""
    tmap = None
    if rid and rid in id_to_tax:
        tmap = id_to_tax[rid]
        matched_source = "fetched_exact"
    else:
        # try base id
        base = rid.split(".")[0]
        if base and base in id_to_tax:
            tmap = id_to_tax[base]
            matched_source = "fetched_base"
        else:
            # try alt indices (accession, accession_base, id_base)
            if rid in index_alt.get("accession", {}):
                cand = index_alt["accession"][rid]
                tmap = id_to_tax.get(cand)
                matched_source = "alt_accession"
            elif base in index_alt.get("accession_base", {}):
                cand = index_alt["accession_base"][base]
                tmap = id_to_tax.get(cand)
                matched_source = "alt_accession_base"
            elif base in index_alt.get("id_base", {}):
                cand = index_alt["id_base"][base]
                tmap = id_to_tax.get(cand)
                matched_source = "alt_id_base"
            else:
                # try organism substring match heuristics
                lowrid = rid.lower()
                candidate = None
                for orgname, recid in index_alt.get("organism", {}).items():
                    if not orgname:
                        continue
                    try:
                        # direct substring check both ways
                        if orgname in lowrid or lowrid in orgname:
                            candidate = recid
                            break
                    except Exception:
                        continue
                if candidate:
                    tmap = id_to_tax.get(candidate)
                    matched_source = "alt_organism_match"
    # if still None, attempt to look for 'organism' field in df_meta if present
    if tmap is None:
        # check for organism-like columns in df_meta
        found_org = None
        for cand in ("organism","organism_name","description","desc","meta_organism"):
            if cand in df_meta.columns and str(row.get(cand)).strip():
                found_org = str(row.get(cand)).strip()
                break
        if found_org:
            parts = re.split(r"\s+", found_org)
            tmap = {r: None for r in RANKS}
            if len(parts) >= 2:
                tmap["genus"] = norm_label(parts[0])
                tmap["species"] = norm_label(parts[0] + " " + parts[1])
            matched_source = "meta_organism_infer"
    if tmap is None:
        # nothing matched -> mark UNASSIGNED
        tmap = {r: None for r in RANKS}
        matched_source = "UNASSIGNED"
        no_match += 1

    # add final normalized labels for each rank in the out dict
    for r in RANKS:
        out[r] = norm_label(tmap.get(r) if isinstance(tmap, dict) else tmap) if tmap.get(r) else "UNASSIGNED"
    out["matched_source"] = matched_source
    # include original meta columns to the debug row (safe strings)
    for c in df_meta.columns:
        # avoid huge sequence strings in debug output; only include small metadata fields if present
        if c in ("raw","sequence","seq","seq_full","seq_header"): 
            continue
        try:
            out[f"meta__{c}"] = str(row.get(c,""))
        except Exception:
            out[f"meta__{c}"] = ""
    assigned_rows.append(out)

df_assigned = pd.DataFrame(assigned_rows)
# align length safety
if len(df_assigned) != len(df_meta):
    print(f"[WARN] assigned_rows length {len(df_assigned)} != meta rows {len(df_meta)}; trimming/padding as needed.")
    nmin = min(len(df_assigned), len(df_meta))
    df_assigned = df_assigned.iloc[:nmin].reset_index(drop=True)
    df_meta = df_meta.iloc[:nmin].reset_index(drop=True)

# Save debug assignment file
try:
    df_assigned.to_csv(DEBUG_OUT, index=False, encoding="utf-8")
    print(f"[SAVE] saved label assignment debug CSV: {DEBUG_OUT}")
except Exception as e:
    print(f"[WARN] failed saving debug CSV: {e}")

# ---------------- Build LabelEncoders for each rank and y arrays ----------------
label_encoders = {}
y_encoded = {}

for r in RANKS:
    labels = df_assigned[r].astype(str).fillna("UNASSIGNED").apply(norm_label).tolist()
    # ensure explicit 'UNASSIGNED' present
    if "UNASSIGNED" not in labels:
        labels = ["UNASSIGNED"] + labels
    le = LabelEncoder()
    try:
        le.fit(labels)
    except Exception as e:
        # as a safety, deduplicate and ensure strings
        uniq = list(dict.fromkeys([norm_label(x) for x in labels]))
        le.fit(uniq)
    y = le.transform([norm_label(x) for x in df_assigned[r].astype(str).fillna("UNASSIGNED").tolist()])
    label_encoders[r] = le
    y_encoded[r] = np.asarray(y, dtype=np.int32)
    # save y array
    outy = OUT_Y_TEMPLATE.with_name(OUT_Y_TEMPLATE.name.format(r)).resolve()
    np.save(outy, y_encoded[r])
    print(f"[LABEL] rank={r:8s} classes={len(le.classes_)} saved -> {outy.name}")

# Save encoders dict
try:
    with open(OUT_ENCODERS, "wb") as fh:
        pickle.dump(label_encoders, fh)
    print(f"[SAVE] saved label encoders pickle: {OUT_ENCODERS}")
except Exception as e:
    print(f"[WARN] could not save label encoders pickle: {e}")

# ---------------- Summary ----------------
total = len(df_assigned)
assigned_counts = {}
for r in RANKS:
    cnt = Counter(df_assigned[r].fillna("UNASSIGNED").tolist())
    assigned_counts[r] = dict(cnt)
print("\n=== Assignment summary ===")
print(f"rows processed: {total}")
print(f"no_match_count (pure UNASSIGNED): {no_match}")
for r in RANKS:
    top = Counter(df_assigned[r].fillna("UNASSIGNED").tolist()).most_common(8)
    print(f"  {r:8s}: classes={len(label_encoders[r].classes_)} top={top[:6]}")

print("\nOutputs written (or attempted):")
print(" - label encoders:", OUT_ENCODERS)
for r in RANKS:
    print(" - y array:", OUT_Y_TEMPLATE.with_name(OUT_Y_TEMPLATE.name.format(r)))
print(" - debug CSV:", DEBUG_OUT)
print("\n[READY] You can now run the next cell (train/test split) which will reuse these encoders and y arrays.")

[USE] metadata file: C:\Users\Srijit\sih\ncbi_blast_db\extracted\embeddings_meta_clustered.csv (rows=2555)
[INFO] meta rows: 2555; id_col used: 'id'
[INFO] metadata JSON files discovered: 3
[INFO] id_to_tax entries from fetched jsons: 1231
[USE] loaded label_assignment_debug.csv rows=2555
[SAVE] saved label assignment debug CSV: C:\Users\Srijit\sih\ncbi_blast_db\extracted\label_assignment_debug_final.csv
[LABEL] rank=kingdom  classes=2 saved -> y_encoded_final_kingdom.npy
[LABEL] rank=phylum   classes=5 saved -> y_encoded_final_phylum.npy
[LABEL] rank=class    classes=10 saved -> y_encoded_final_class.npy
[LABEL] rank=order    classes=13 saved -> y_encoded_final_order.npy
[LABEL] rank=family   classes=19 saved -> y_encoded_final_family.npy
[LABEL] rank=genus    classes=27 saved -> y_encoded_final_genus.npy
[LABEL] rank=species  classes=183 saved -> y_encoded_final_species.npy
[SAVE] saved label encoders pickle: C:\Users\Srijit\sih\ncbi_blast_db\extracted\label_encoders_final.pkl

=== A

In [21]:
# Safe replacement cell: build train/val TensorDatasets (no custom Dataset class)
import numpy as np, pickle, os
from pathlib import Path
import torch
from torch.utils.data import TensorDataset, DataLoader

# ---------- paths (adjust only if your environment differs) ----------
DOWNLOAD_DIR = Path("ncbi_blast_db")
EXTRACT_DIR  = DOWNLOAD_DIR / "extracted"
EMB_PCA = EXTRACT_DIR / "embeddings_pca.npy"
LABEL_ENCODERS_PKL = EXTRACT_DIR / "label_encoders_final.pkl"

# possible train/val index file names (tries in order)
TRAIN_IDX_FILES = [
    EXTRACT_DIR / "train_idx_final.npy",
    EXTRACT_DIR / "train_idx_by_acc.npy",
    EXTRACT_DIR / "train_idx.npy",
    EXTRACT_DIR / "train_idx_random_fallback.npy",
]
VAL_IDX_FILES = [
    EXTRACT_DIR / "val_idx_final.npy",
    EXTRACT_DIR / "val_idx_by_acc.npy",
    EXTRACT_DIR / "val_idx.npy",
    EXTRACT_DIR / "val_idx_random_fallback.npy",
]

BATCH_SIZE = 128
NUM_WORKERS = 0
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# ---------- helper ----------
def first_exist(paths):
    for p in paths:
        if p.exists():
            return p
    return None

# ---------- load embeddings ----------
if not EMB_PCA.exists():
    raise FileNotFoundError(f"embeddings_pca.npy not found at {EMB_PCA}")
X_pca = np.load(EMB_PCA)
n_samples = X_pca.shape[0]
print(f"[LOAD] X_pca shape: {X_pca.shape}")

# ---------- load label encoders ----------
enc_path = first_exist([EXTRACT_DIR / "label_encoders_final.pkl",
                        EXTRACT_DIR / "label_encoders_rebuilt_v2.pkl",
                        EXTRACT_DIR / "label_encoders_rebuilt.pkl",
                        EXTRACT_DIR / "label_encoders_used.pkl"])
if enc_path is None:
    raise FileNotFoundError("label_encoders pickle not found in extracted/")
with open(enc_path, "rb") as fh:
    label_encoders = pickle.load(fh)
RANKS = list(label_encoders.keys())
print(f"[LOAD] label_encoders ranks: {RANKS}")

# ---------- load y arrays ----------
y_encoded = {}
for r in RANKS:
    p1 = EXTRACT_DIR / f"y_encoded_final_{r}.npy"
    p2 = EXTRACT_DIR / f"y_encoded_rebuilt_{r}.npy"
    p = p1 if p1.exists() else (p2 if p2.exists() else None)
    if p is None:
        raise FileNotFoundError(f"Missing y array for rank '{r}' (looked for {p1} and {p2})")
    arr = np.load(p)
    if arr.shape[0] != n_samples:
        if arr.shape[0] < n_samples:
            pad = np.zeros((n_samples - arr.shape[0],), dtype=int)
            arr = np.concatenate([arr, pad], axis=0)
            print(f"[WARN] padded y array for {r} from {p.name} to length {n_samples}")
        else:
            arr = arr[:n_samples]
            print(f"[WARN] trimmed y array for {r} from {p.name} to length {n_samples}")
    y_encoded[r] = arr.astype(int)
print("[LOAD] y arrays loaded & aligned.")

# ---------- load train/val indices (try saved) ----------
tfile = first_exist(TRAIN_IDX_FILES)
vfile = first_exist(VAL_IDX_FILES)
if tfile is None or vfile is None:
    # fallback to deterministic random split and save
    print("[WARN] train/val idx files missing; creating deterministic 85/15 split")
    rng = np.random.RandomState(SEED)
    perm = rng.permutation(n_samples)
    cutoff = int(n_samples * 0.85)
    train_idx = np.sort(perm[:cutoff]).astype(int)
    val_idx = np.sort(perm[cutoff:]).astype(int)
    np.save(EXTRACT_DIR / "train_idx_random_fallback.npy", train_idx)
    np.save(EXTRACT_DIR / "val_idx_random_fallback.npy", val_idx)
    print("[SAVE] saved fallback train/val idx arrays")
else:
    train_idx = np.load(tfile).astype(int)
    val_idx   = np.load(vfile).astype(int)
    print(f"[LOAD] train_idx from {tfile.name} (n={len(train_idx)}), val_idx from {vfile.name} (n={len(val_idx)})")

# ---------- Build tensors for training (use TensorDataset) ----------
# X tensors
X_train = torch.from_numpy(X_pca[train_idx]).float()
X_val   = torch.from_numpy(X_pca[val_idx]).float()

# label tensors: keep same order as RANKS
y_train_tensors = []
y_val_tensors = []
for r in RANKS:
    ytr = torch.from_numpy(y_encoded[r][train_idx]).long()
    yv  = torch.from_numpy(y_encoded[r][val_idx]).long()
    y_train_tensors.append(ytr)
    y_val_tensors.append(yv)

# Compose TensorDataset: first tensor is X, then each rank tensor
train_dataset = TensorDataset(X_train, *y_train_tensors)
val_dataset   = TensorDataset(X_val, *y_val_tensors)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

print(f"[OK] train_dataset size: {len(train_dataset)}, val_dataset size: {len(val_dataset)}; batch_size={BATCH_SIZE}")

# ---------- Helper to convert tuple-batch -> dict keyed by ranks ----------
def batch_tuple_to_dict(batch_tuple):
    """
    Input: a tuple from DataLoader (x, y_rank1, y_rank2, ...)
    Returns: dict with keys: "x" -> tensor, and RANKS[0]..RANKS[-1] -> tensors
    """
    d = {"x": batch_tuple[0]}
    for i, r in enumerate(RANKS):
        d[r] = batch_tuple[1 + i]
    return d

# quick sanity check (fetch one batch)
batch_tuple = next(iter(train_loader))
batch = batch_tuple_to_dict(batch_tuple)
print("[SANITY] Batch keys:", list(batch.keys()))
print("[SANITY] x.shape:", batch["x"].shape)
for r in RANKS[:4]:
    print(f"  {r}: shape={batch[r].shape}, dtype={batch[r].dtype}")

# expose useful objects to notebook globals for downstream training cell
globals().update({
    "train_loader": train_loader,
    "val_loader": val_loader,
    "train_dataset": train_dataset,
    "val_dataset": val_dataset,
    "batch_tuple_to_dict": batch_tuple_to_dict,
    "RANKS": RANKS,
    "label_encoders": label_encoders,
    "train_idx": train_idx,
    "val_idx": val_idx
})

print("[READY] Use train_loader, val_loader, and batch_tuple_to_dict in your training loop.")

[LOAD] X_pca shape: (2555, 64)
[LOAD] label_encoders ranks: ['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species']
[LOAD] y arrays loaded & aligned.
[LOAD] train_idx from train_idx_final.npy (n=2175), val_idx from val_idx_final.npy (n=380)
[OK] train_dataset size: 2175, val_dataset size: 380; batch_size=128
[SANITY] Batch keys: ['x', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species']
[SANITY] x.shape: torch.Size([128, 64])
  kingdom: shape=torch.Size([128]), dtype=torch.int64
  phylum: shape=torch.Size([128]), dtype=torch.int64
  class: shape=torch.Size([128]), dtype=torch.int64
  order: shape=torch.Size([128]), dtype=torch.int64
[READY] Use train_loader, val_loader, and batch_tuple_to_dict in your training loop.


In [23]:
# Defensive training cell (no crash; detailed diagnostics & guarded backward)
import time, json, traceback
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd

EXTRACT_DIR = Path("ncbi_blast_db") / "extracted"
SAVE_CKPT = EXTRACT_DIR / "best_shared_heads_defensive.pt"
HISTORY_CSV = EXTRACT_DIR / "training_history_defensive.csv"
METRICS_JSON = EXTRACT_DIR / "metrics_defensive.json"

LR = 1e-3
MAX_EPOCHS = 60
PATIENCE = 8
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HIDDEN_DIM = 256
DROPOUT = 0.3
BATCH_LOG = 200

# ---- sanity globals ----
required = ["train_loader", "val_loader", "batch_tuple_to_dict", "label_encoders"]
miss = [n for n in required if n not in globals()]
if miss:
    raise RuntimeError(f"Missing required notebook globals: {miss}. Run the dataset prep cell first.")

train_loader = globals()["train_loader"]
val_loader = globals()["val_loader"]
batch_tuple_to_dict = globals()["batch_tuple_to_dict"]
label_encoders = globals()["label_encoders"]
RANKS = list(label_encoders.keys())
print(f"[INFO] device={DEVICE}; ranks={RANKS}")

# ---- infer input dim ----
sample_batch = next(iter(train_loader))
sample = batch_tuple_to_dict(sample_batch)
input_dim = int(sample["x"].shape[1])
print(f"[INFO] inferred input_dim = {input_dim}")

# ---- resilient model: explicit initialize (no kw-arg ctor) ----
class ResilientManual(nn.Module):
    def _init_(self):
        super()._init_()
        self._inited = False
    def initialize(self, input_dim, hidden_dim, ranks, encoders, dropout=0.0):
        if getattr(self, "_inited", False):
            return
        self.ranks = list(ranks)
        self.dropout = nn.Dropout(dropout)
        h1 = int(hidden_dim)
        h2 = max(32, hidden_dim // 2)
        # trunk params
        self.register_parameter("w1", nn.Parameter(torch.randn(input_dim, h1) * 0.02))
        self.register_parameter("b1", nn.Parameter(torch.zeros(h1)))
        self.register_parameter("w2", nn.Parameter(torch.randn(h1, h2) * 0.02))
        self.register_parameter("b2", nn.Parameter(torch.zeros(h2)))
        # heads
        for r in self.ranks:
            ncls = max(1, len(encoders[r].classes_))
            self.register_parameter(f"head_w__{r}", nn.Parameter(torch.randn(h2, ncls) * 0.02))
            self.register_parameter(f"head_b__{r}", nn.Parameter(torch.zeros(ncls)))
        self._inited = True
    def forward(self, x):
        h = x @ self.w1 + self.b1
        h = torch.relu(h)
        h = self.dropout(h)
        h = h @ self.w2 + self.b2
        h = torch.relu(h)
        out = {}
        for r in self.ranks:
            w = getattr(self, f"head_w__{r}")
            b = getattr(self, f"head_b__{r}")
            out[r] = h @ w + b
        return out

model = ResilientManual()
model.initialize(input_dim=input_dim, hidden_dim=HIDDEN_DIM, ranks=RANKS, encoders=label_encoders, dropout=DROPOUT)
model.to(DEVICE)

# ---- check params exist ----
def param_summary(m):
    tot = 0
    items = []
    for name, p in m.named_parameters():
        items.append((name, tuple(p.shape), p.requires_grad, p.numel()))
        tot += p.numel()
    return tot, items

tot_params, params_list = param_summary(model)
print(f"[INFO] total model parameters = {tot_params:,}; param tensors = {len(params_list)}")
if tot_params == 0:
    for it in params_list:
        print("PARAM", it)
    raise RuntimeError("Model has zero parameters; aborting.")

# ---- derive train label arrays (for weights) ----
def derive_train_labels():
    if "train_dataset" in globals():
        td = globals()["train_dataset"]
        tensors = td.tensors
        out = {r: tensors[1 + i].cpu().numpy() for i, r in enumerate(RANKS)}
        return out
    if "y_encoded" in globals() and "train_idx" in globals():
        ti = globals()["train_idx"]
        return {r: np.asarray(globals()["y_encoded"][r])[ti] for r in RANKS}
    # fallback sample a few train batches
    collected = {r: [] for r in RANKS}
    n=0
    for bt in train_loader:
        b = batch_tuple_to_dict(bt)
        for r in RANKS:
            collected[r].append(b[r].cpu().numpy())
        n+=1
        if n>=10: break
    return {r: np.concatenate(collected[r]) for r in RANKS}

train_labels = derive_train_labels()

# ---- build criterions robustly (weights on DEVICE, length = encoder classes) ----
criterions = {}
for r in RANKS:
    n_classes = max(1, len(label_encoders[r].classes_))
    if n_classes <= 1:
        criterions[r] = None
        print(f"[INFO] skipping loss for '{r}' (encoder has only 1 class).")
        continue
    arr = train_labels.get(r)
    if arr is None:
        counts_full = np.ones((n_classes,), dtype=float)
    else:
        counts_full = np.bincount(arr, minlength=n_classes).astype(float)
        counts_full[counts_full == 0] = 1.0
    inv = 1.0 / counts_full
    inv = inv / np.mean(inv)
    weight_tensor = torch.tensor(inv.astype(np.float32), device=DEVICE)
    criterions[r] = nn.CrossEntropyLoss(weight=weight_tensor)
    print(f"[INFO] built loss for '{r}' with n_classes={n_classes}, weight_shape={weight_tensor.shape}")

# ---- optimizer ----
params = [p for p in model.parameters() if p.requires_grad and p.numel()>0]
optimizer = torch.optim.Adam(params, lr=LR)
print(f"[INFO] optimizer params tensors={len(params)}, total_elements={sum(p.numel() for p in params):,}")

# ---- helpers ----
def evaluate(m, loader):
    m.eval()
    preds = {r: [] for r in RANKS}
    trues = {r: [] for r in RANKS}
    with torch.no_grad():
        for bt in loader:
            b = batch_tuple_to_dict(bt)
            x = b["x"].to(DEVICE)
            out = m(x)
            for r in RANKS:
                preds[r].append(np.argmax(out[r].cpu().numpy(), axis=1))
                trues[r].append(b[r].cpu().numpy())
    metrics = {}
    for r in RANKS:
        if not preds[r]:
            metrics[r] = {"accuracy": None, "f1_macro": None}
            continue
        p = np.concatenate(preds[r])
        t = np.concatenate(trues[r])
        metrics[r] = {"accuracy": float(accuracy_score(t,p)), "f1_macro": float(f1_score(t,p,average="macro", zero_division=0))}
    return metrics

def safe_train_one_epoch(m, loader, opt):
    m.train()
    total_loss = 0.0
    nbatches = 0
    for batch_i, bt in enumerate(loader, start=1):
        b = batch_tuple_to_dict(bt)
        x = b["x"].to(DEVICE)
        # targets
        targets = {r: b[r].to(DEVICE) for r in RANKS}
        out = m(x)
        # diagnostics per rank
        losses = []
        skipped_ranks = []
        for r in RANKS:
            if criterions[r] is None:
                skipped_ranks.append(r); continue
            logits = out[r]
            targ = targets[r]
            # quick checks
            if not isinstance(logits, torch.Tensor):
                skipped_ranks.append(r); continue
            if not logits.requires_grad:
                # report but skip
                skipped_ranks.append(r)
                print(f"[WARN] batch {batch_i}: out[{r}].requires_grad=False; skipping this rank for this batch")
                continue
            try:
                l = criterions[r](logits, targ)
            except Exception as e:
                # shape / device / dtype mismatch: report and skip
                print(f"[ERROR] batch {batch_i}: loss compute failed for rank {r}: {e}; shapes logits {tuple(logits.shape)}, target {tuple(targ.shape)}")
                skipped_ranks.append(r)
                continue
            if not isinstance(l, torch.Tensor) or not l.requires_grad:
                print(f"[WARN] batch {batch_i}: loss for {r} does not require grad; skipping. loss type {type(l)}, requires_grad={getattr(l,'requires_grad',None)}")
                skipped_ranks.append(r)
                continue
            losses.append(l)
        if not losses:
            # nothing to backprop this batch
            if batch_i % BATCH_LOG == 0:
                print(f"[INFO] batch {batch_i}: no rank produced a backprop-able loss; skipping optimizer step.")
            continue
        loss = sum(losses) / len(losses)  # average across ranks for numeric stability
        opt.zero_grad()
        try:
            loss.backward()
            opt.step()
        except Exception as e:
            # capture diagnostics and skip this step (do not crash)
            print("[CRITICAL] backward() failed for this batch. Dumping diagnostics and skipping optimizer step for this batch.")
            traceback.print_exc()
            print("Diagnostics:")
            print(" - loss.requires_grad:", getattr(loss, "requires_grad", None))
            for name,p in m.named_parameters():
                print(f"   param {name}: requires_grad={p.requires_grad}, shape={tuple(p.shape)}, device={p.device}")
            for r in RANKS:
                lo = out[r]
                print(f"   out[{r}] requires_grad={getattr(lo,'requires_grad',None)}, dtype={getattr(lo,'dtype',None)}, shape={tuple(lo.shape)}")
            # skip optimizer update this batch
            continue
        total_loss += float(loss.item())
        nbatches += 1
        if batch_i % BATCH_LOG == 0:
            print(f"[INFO] batch {batch_i}, avg loss so far = {total_loss/max(1,nbatches):.4f}, skipped_ranks={skipped_ranks[:5]}")
    return total_loss / max(1, nbatches)

# ---- training loop with early stopping ----
best_val = -1.0
best_ckpt = None
history = []
no_improve = 0

print("[TRAIN] starting defensive training...")
t0_all = time.time()
for epoch in range(1, MAX_EPOCHS+1):
    t_epoch = time.time()
    train_loss = safe_train_one_epoch(model, train_loader, optimizer)
    val_metrics = evaluate(model, val_loader)
    f1s = [v["f1_macro"] if v["f1_macro"] is not None else 0.0 for v in val_metrics.values()]
    val_agg = float(np.mean(f1s))
    epoch_time = time.time() - t_epoch
    history.append({"epoch": epoch, "train_loss": float(train_loss), "val_agg_f1": float(val_agg), "time_s": float(epoch_time)})
    print(f"Epoch {epoch:03d} | train_loss={train_loss:.4f} | val_agg_f1={val_agg:.4f} | epoch_time={epoch_time:.1f}s")
    # checkpoint
    if val_agg > best_val + 1e-9:
        best_val = val_agg
        best_ckpt = {"epoch": epoch, "val_agg_f1": val_agg, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict()}
        torch.save(best_ckpt, str(SAVE_CKPT))
        print(f"[CHECKPOINT] saved best -> {SAVE_CKPT} (val_agg_f1={val_agg:.4f})")
        no_improve = 0
    else:
        no_improve += 1
        print(f"[INFO] no_improve={no_improve}/{PATIENCE}")
        if no_improve >= PATIENCE:
            print("[EARLY STOP] stopping.")
            break

tot_time = time.time() - t0_all
pd.DataFrame(history).to_csv(HISTORY_CSV, index=False)
if best_ckpt is not None:
    model.load_state_dict(best_ckpt["model_state"])
    final_metrics = evaluate(model, val_loader)
    out = {"best_epoch": best_ckpt["epoch"], "best_val_agg_f1": best_ckpt["val_agg_f1"], "per_rank": final_metrics}
    with open(METRICS_JSON, "w") as fh:
        json.dump(out, fh, indent=2)
    torch.save(model.state_dict(), EXTRACT_DIR / "best_shared_heads_defensive_state_dict.pt")
    print("[SAVE] metrics and state_dict written to extracted/")
    print("[FINAL METRICS]")
    for r in RANKS:
        m = final_metrics[r]
        print(f"  {r:10s} acc={m['accuracy']} f1_macro={m['f1_macro']}")
else:
    print("[WARN] No checkpoint saved during training.")
print("[COMPLETE] Defensive training finished. total_time_s={:.1f}".format(tot_time))

[INFO] device=cpu; ranks=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species']
[INFO] inferred input_dim = 64
[INFO] total model parameters = 82,947; param tensors = 18
[INFO] built loss for 'kingdom' with n_classes=2, weight_shape=torch.Size([2])
[INFO] built loss for 'phylum' with n_classes=5, weight_shape=torch.Size([5])
[INFO] built loss for 'class' with n_classes=10, weight_shape=torch.Size([10])
[INFO] built loss for 'order' with n_classes=13, weight_shape=torch.Size([13])
[INFO] built loss for 'family' with n_classes=19, weight_shape=torch.Size([19])
[INFO] built loss for 'genus' with n_classes=27, weight_shape=torch.Size([27])
[INFO] built loss for 'species' with n_classes=183, weight_shape=torch.Size([183])
[INFO] optimizer params tensors=18, total_elements=82,947
[TRAIN] starting defensive training...
Epoch 001 | train_loss=2.6421 | val_agg_f1=0.2617 | epoch_time=0.2s
[CHECKPOINT] saved best -> ncbi_blast_db\extracted\best_shared_heads_defensive.pt (val_agg_f1

In [24]:
import torch
print("is_grad_enabled:", torch.is_grad_enabled())
for name, p in model.named_parameters():
    print(name, "shape", tuple(p.shape), "requires_grad", p.requires_grad, "device", p.device)
# run one forward and inspect:
batch = next(iter(train_loader))
b = batch_tuple_to_dict(batch)
x = b["x"]
out = model(x.to(next(model.parameters()).device))
for r in out: print(r, "out.requires_grad", out[r].requires_grad, "dtype", out[r].dtype, "shape", out[r].shape)

is_grad_enabled: True
w1 shape (64, 256) requires_grad True device cpu
b1 shape (256,) requires_grad True device cpu
w2 shape (256, 128) requires_grad True device cpu
b2 shape (128,) requires_grad True device cpu
head_w__kingdom shape (128, 2) requires_grad True device cpu
head_b__kingdom shape (2,) requires_grad True device cpu
head_w__phylum shape (128, 5) requires_grad True device cpu
head_b__phylum shape (5,) requires_grad True device cpu
head_w__class shape (128, 10) requires_grad True device cpu
head_b__class shape (10,) requires_grad True device cpu
head_w__order shape (128, 13) requires_grad True device cpu
head_b__order shape (13,) requires_grad True device cpu
head_w__family shape (128, 19) requires_grad True device cpu
head_b__family shape (19,) requires_grad True device cpu
head_w__genus shape (128, 27) requires_grad True device cpu
head_b__genus shape (27,) requires_grad True device cpu
head_w__species shape (128, 183) requires_grad True device cpu
head_b__species shape (1

In [27]:
import torch
torch.set_grad_enabled(True)
print("is_grad_enabled:", torch.is_grad_enabled())

is_grad_enabled: True


In [29]:
# Resume training — robust, forces grad on, loads checkpoint, trains safely
import time, json
from pathlib import Path
import numpy as np
import torch, torch.nn as nn, torch.optim as optim
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd

# ---------- Config ----------
EXTRACT_DIR = Path("ncbi_blast_db") / "extracted"
CKPT = EXTRACT_DIR / "best_shared_heads_defensive.pt"
HISTORY_CSV = EXTRACT_DIR / "training_history_resumed.csv"
SAVE_CKPT = EXTRACT_DIR / "best_shared_heads_resumed.pt"

LR = 1e-4            # smaller LR when resuming often helps
MAX_EPOCHS = 40
PATIENCE = 6
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HIDDEN_DIM = 256
DROPOUT = 0.3
BATCH_LOG = 200

# ---------- sanity and required globals ----------
required = ["train_loader", "val_loader", "batch_tuple_to_dict", "label_encoders"]
missing = [n for n in required if n not in globals()]
if missing:
    raise RuntimeError(f"Missing required globals: {missing} — run the dataset prep cell first.")

train_loader = globals()["train_loader"]
val_loader = globals()["val_loader"]
batch_tuple_to_dict = globals()["batch_tuple_to_dict"]
label_encoders = globals()["label_encoders"]
RANKS = list(label_encoders.keys())

# ---------- Force-enable global grad (the actual fix) ----------
torch.set_grad_enabled(True)
print("torch.is_grad_enabled()", torch.is_grad_enabled())

# ---------- Recreate model class (no kwargs in ctor; safe initialize) ----------
class ResilientModel(nn.Module):
    def _init_(self):
        super()._init_()
        self._inited = False
    def initialize(self, input_dim, hidden_dim, ranks, encoders, dropout=0.0):
        if getattr(self,"_inited",False): return
        self.ranks = list(ranks)
        self.dropout = nn.Dropout(dropout)
        h1 = int(hidden_dim)
        h2 = max(32, h1 // 2)
        self.register_parameter("w1", nn.Parameter(torch.randn(input_dim, h1) * 0.02))
        self.register_parameter("b1", nn.Parameter(torch.zeros(h1)))
        self.register_parameter("w2", nn.Parameter(torch.randn(h1, h2) * 0.02))
        self.register_parameter("b2", nn.Parameter(torch.zeros(h2)))
        for r in self.ranks:
            ncls = max(1, len(encoders[r].classes_))
            self.register_parameter(f"head_w__{r}", nn.Parameter(torch.randn(h2, ncls) * 0.02))
            self.register_parameter(f"head_b__{r}", nn.Parameter(torch.zeros(ncls)))
        self._inited = True
    def forward(self, x):
        h = x @ self.w1 + self.b1
        h = torch.relu(h)
        h = self.dropout(h)
        h = h @ self.w2 + self.b2
        h = torch.relu(h)
        out = {}
        for r in self.ranks:
            w = getattr(self, f"head_w__{r}")
            b = getattr(self, f"head_b__{r}")
            out[r] = h @ w + b
        return out

# ---------- instantiate model and load checkpoint ----------
# infer input dim from a loader batch
sample_batch = next(iter(train_loader))
sample = batch_tuple_to_dict(sample_batch)
input_dim = int(sample["x"].shape[1])

model = ResilientModel()
model.initialize(input_dim=input_dim, hidden_dim=HIDDEN_DIM, ranks=RANKS, encoders=label_encoders, dropout=DROPOUT)
model.to(DEVICE)

# load checkpoint state if available
if CKPT.exists():
    ckpt = torch.load(str(CKPT), map_location=DEVICE)
    # checkpoint may be dict with 'model_state'
    if isinstance(ckpt, dict) and "model_state" in ckpt:
        model.load_state_dict(ckpt["model_state"])
        print("Loaded model_state from checkpoint (key 'model_state').")
    else:
        # assume ckpt is state_dict
        model.load_state_dict(ckpt)
        print("Loaded checkpoint as state_dict.")
else:
    print("No checkpoint found; training from scratch.")

# ---------- Verify grad-able outputs and params before training ----------
model.train()  # enable dropout + training behavior
print("Model placed in train() mode.")
print("Example param requires_grad flags:")
for name,p in model.named_parameters():
    print(" ", name, "requires_grad=", p.requires_grad, "shape=", tuple(p.shape))

# quick forward sanity check: run one batch and ensure outputs require grad
bt = next(iter(train_loader))
b = batch_tuple_to_dict(bt)
x = b["x"].to(DEVICE)
out = model(x)
for r in RANKS:
    print(f"out[{r}].requires_grad ->", getattr(out[r], "requires_grad", None), "shape", tuple(out[r].shape))

# If outputs do not require grad here, something external (very rare) is disabling grad.
if not torch.is_grad_enabled():
    raise RuntimeError("torch.is_grad_enabled() is False after set_grad_enabled(True) — cannot proceed.")

# ---------- build criterions (use encoder class counts) ----------
criterions = {}
for r in RANKS:
    n_classes = max(1, len(label_encoders[r].classes_))
    if n_classes <= 1:
        criterions[r] = None
    else:
        # compute weight robustly using training labels if available
        try:
            if "train_dataset" in globals():
                td = globals()["train_dataset"]
                arr = td.tensors[1 + RANKS.index(r)].cpu().numpy()
            elif "y_encoded" in globals() and "train_idx" in globals():
                arr = np.asarray(globals()["y_encoded"][r])[globals()["train_idx"]]
            else:
                # fallback: sample few batches
                arr = np.concatenate([batch_tuple_to_dict(bt)[r].cpu().numpy() for i, bt in zip(range(10), train_loader)])
        except Exception:
            arr = np.zeros((1,), dtype=int)
        counts = np.bincount(arr, minlength=n_classes).astype(float)
        counts[counts == 0] = 1.0
        inv = 1.0 / counts
        inv = inv / inv.mean()
        weight = torch.from_numpy(inv.astype(np.float32)).to(DEVICE)
        criterions[r] = nn.CrossEntropyLoss(weight=weight)

# ---------- optimizer (recreate) ----------
optim_params = [p for p in model.parameters() if p.requires_grad and p.numel() > 0]
optimizer = optim.Adam(optim_params, lr=LR)
print("Optimizer created with", sum(p.numel() for p in optim_params), "params.")

# optionally resume optimizer state from checkpoint
if CKPT.exists() and isinstance(ckpt, dict) and "optimizer_state" in ckpt:
    try:
        optimizer.load_state_dict(ckpt["optimizer_state"])
        print("Loaded optimizer state from checkpoint.")
    except Exception as e:
        print("Warning: failed to load optimizer state:", e)

# ---------- helpers ----------
def evaluate(m, loader):
    m.eval()
    preds = {r: [] for r in RANKS}
    trues = {r: [] for r in RANKS}
    with torch.no_grad():
        for bt in loader:
            b = batch_tuple_to_dict(bt)
            x = b["x"].to(DEVICE)
            out = m(x)
            for r in RANKS:
                preds[r].append(out[r].argmax(dim=1).cpu().numpy())
                trues[r].append(b[r].cpu().numpy())
    metrics = {}
    for r in RANKS:
        if not preds[r]:
            metrics[r] = {"accuracy": None, "f1_macro": None}
            continue
        p = np.concatenate(preds[r]); t = np.concatenate(trues[r])
        metrics[r] = {"accuracy": float(accuracy_score(t,p)), "f1_macro": float(f1_score(t,p,average="macro", zero_division=0))}
    return metrics

def train_one_epoch(m, loader, opt):
    m.train()
    total_loss = 0.0
    nb = 0
    for bt in loader:
        b = batch_tuple_to_dict(bt)
        x = b["x"].to(DEVICE)
        out = m(x)
        loss = None
        for r in RANKS:
            if criterions[r] is None: continue
            l = criterions[r](out[r], b[r].to(DEVICE))
            loss = l if loss is None else loss + l
        if loss is None:
            continue
        opt.zero_grad()
        loss.backward()
        opt.step()
        total_loss += float(loss.item())
        nb += 1
    return total_loss / max(1, nb)

# ---------- training loop (resume) ----------
best_val = -1.0
best_ckpt = None
history = []
no_improve = 0

print("[RESUME TRAIN] Starting...")
tstart = time.time()
for epoch in range(1, MAX_EPOCHS+1):
    t0 = time.time()
    train_loss = train_one_epoch(model, train_loader, optimizer)
    val_metrics = evaluate(model, val_loader)
    f1s = [v["f1_macro"] if v["f1_macro"] is not None else 0.0 for v in val_metrics.values()]
    val_agg = float(np.mean(f1s))
    epoch_time = time.time() - t0
    history.append({"epoch": epoch, "train_loss": float(train_loss), "val_agg_f1": float(val_agg), "time_s": float(epoch_time)})
    print(f"Epoch {epoch:03d} | train_loss={train_loss:.4f} | val_agg_f1={val_agg:.4f} | time={epoch_time:.2f}s")
    if val_agg > best_val + 1e-8:
        best_val = val_agg
        best_ckpt = {"epoch": epoch, "val_agg_f1": val_agg, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict()}
        torch.save(best_ckpt, SAVE_CKPT)
        print("Saved checkpoint ->", SAVE_CKPT)
        no_improve = 0
    else:
        no_improve += 1
        print(f"no_improve={no_improve}/{PATIENCE}")
        if no_improve >= PATIENCE:
            print("Early stopping.")
            break

# save history and final metrics
pd.DataFrame(history).to_csv(HISTORY_CSV, index=False)
if best_ckpt:
    model.load_state_dict(best_ckpt["model_state"])
    final_metrics = evaluate(model, val_loader)
    with open(EXTRACT_DIR / "metrics_resumed.json", "w") as fh:
        json.dump({"best_epoch": best_ckpt["epoch"], "best_val_agg_f1": best_ckpt["val_agg_f1"], "per_rank": final_metrics}, fh, indent=2)
    torch.save(model.state_dict(), EXTRACT_DIR / "best_shared_heads_resumed_state_dict.pt")
    print("Training finished. Best val_agg_f1:", best_ckpt["val_agg_f1"])
    print("Final per-rank metrics:")
    for r in RANKS:
        m = final_metrics[r]
        print(f"  {r:12s} acc={m['accuracy']} f1_macro={m['f1_macro']}")
else:
    print("Training finished but no checkpoint saved.")

torch.is_grad_enabled() True
Loaded model_state from checkpoint (key 'model_state').
Model placed in train() mode.
Example param requires_grad flags:
  w1 requires_grad= True shape= (64, 256)
  b1 requires_grad= True shape= (256,)
  w2 requires_grad= True shape= (256, 128)
  b2 requires_grad= True shape= (128,)
  head_w__kingdom requires_grad= True shape= (128, 2)
  head_b__kingdom requires_grad= True shape= (2,)
  head_w__phylum requires_grad= True shape= (128, 5)
  head_b__phylum requires_grad= True shape= (5,)
  head_w__class requires_grad= True shape= (128, 10)
  head_b__class requires_grad= True shape= (10,)
  head_w__order requires_grad= True shape= (128, 13)
  head_b__order requires_grad= True shape= (13,)
  head_w__family requires_grad= True shape= (128, 19)
  head_b__family requires_grad= True shape= (19,)
  head_w__genus requires_grad= True shape= (128, 27)
  head_b__genus requires_grad= True shape= (27,)
  head_w__species requires_grad= True shape= (128, 183)
  head_b__speci

In [31]:
# Calibration cell (temperature scaling per-rank)
# Paste & run this in the same notebook where train/val loaders and label_encoders exist.

import time, json, math
from pathlib import Path
import numpy as np
import torch, torch.nn as nn, torch.optim as optim
import torch.nn.functional as F
import pandas as pd

# ---------- Config / output paths ----------
EXTRACT_DIR = Path("ncbi_blast_db") / "extracted"
CKPT_PATH = EXTRACT_DIR / "best_shared_heads_resumed.pt"   # trained checkpoint from your resumed training
OUT_TEMPS = EXTRACT_DIR / "temp_scaling_by_rank.json"
OUT_VAL_CALIB = EXTRACT_DIR / "val_predictions_calibrated.csv"

# ---------- required globals sanity ----------
required = ["val_loader", "batch_tuple_to_dict", "label_encoders"]
missing = [n for n in required if n not in globals()]
if missing:
    raise RuntimeError(f"Missing notebook globals required for calibration: {missing}. Run dataset prep and training cells first.")

val_loader = globals()["val_loader"]
batch_tuple_to_dict = globals()["batch_tuple_to_dict"]
label_encoders = globals()["label_encoders"]
RANKS = list(label_encoders.keys())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[CAL] device:", device)
torch.set_grad_enabled(True)  # ensure grads are enabled for LBFGS / param updates

# ---------- model reconstruction (must match training model) ----------
class ResilientModel(nn.Module):
    def _init_(self):
        super()._init_()
        self._inited = False
    def initialize(self, input_dim, hidden_dim, ranks, encoders, dropout=0.3):
        if getattr(self, "_inited", False):
            return
        self.ranks = list(ranks)
        self.dropout = nn.Dropout(dropout)
        h1 = int(hidden_dim)
        h2 = max(32, h1 // 2)
        self.register_parameter("w1", nn.Parameter(torch.randn(input_dim, h1) * 0.02))
        self.register_parameter("b1", nn.Parameter(torch.zeros(h1)))
        self.register_parameter("w2", nn.Parameter(torch.randn(h1, h2) * 0.02))
        self.register_parameter("b2", nn.Parameter(torch.zeros(h2)))
        for r in self.ranks:
            ncls = max(1, len(encoders[r].classes_))
            self.register_parameter(f"head_w__{r}", nn.Parameter(torch.randn(h2, ncls) * 0.02))
            self.register_parameter(f"head_b__{r}", nn.Parameter(torch.zeros(ncls)))
        self._inited = True
    def forward(self, x):
        h = x @ self.w1 + self.b1
        h = torch.relu(h)
        h = self.dropout(h)
        h = h @ self.w2 + self.b2
        h = torch.relu(h)
        out = {}
        for r in self.ranks:
            w = getattr(self, f"head_w__{r}")
            b = getattr(self, f"head_b__{r}")
            out[r] = h @ w + b
        return out

# ---------- load checkpoint & model_state ----------
if not CKPT_PATH.exists():
    raise RuntimeError(f"Checkpoint not found: {CKPT_PATH}. Make sure you ran training and the file exists.")

ckpt = torch.load(str(CKPT_PATH), map_location=device)
print("[CAL] loaded checkpoint:", CKPT_PATH)

# infer input dim using a val batch
sample_batch = next(iter(val_loader))
sample = batch_tuple_to_dict(sample_batch)
input_dim = int(sample["x"].shape[1])

model = ResilientModel()
model.initialize(input_dim=input_dim, hidden_dim=256, ranks=RANKS, encoders=label_encoders, dropout=0.3)
# load state dict contained in checkpoint (handle both dict forms)
if isinstance(ckpt, dict) and "model_state" in ckpt:
    model.load_state_dict(ckpt["model_state"])
elif isinstance(ckpt, dict) and "state_dict" in ckpt:
    model.load_state_dict(ckpt["state_dict"])
elif isinstance(ckpt, dict) and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
    # maybe a direct state_dict-like mapping
    try:
        model.load_state_dict(ckpt)
    except Exception:
        # try model_state key fallback
        raise RuntimeError("Unexpected checkpoint format; couldn't load state_dict.")
else:
    # fallback: try load as state_dict
    model.load_state_dict(ckpt)
model.to(device)
model.eval()
print("[CAL] model loaded and moved to device; eval() mode set.")

# ---------- collect logits and targets for validation set ----------
all_logits = {r: [] for r in RANKS}
all_targets = {r: [] for r in RANKS}
n_samples = 0
with torch.no_grad():
    for bt in val_loader:
        b = batch_tuple_to_dict(bt)
        x = b["x"].to(device)
        out = model(x)
        batch_n = x.shape[0]
        n_samples += batch_n
        for r in RANKS:
            logits = out[r].detach()  # (batch, ncls) tensor on device
            tgt = b[r].to(device)
            all_logits[r].append(logits)
            all_targets[r].append(tgt)
print(f"[CAL] collected logits from validation set: total samples ~ {n_samples}")

# concatenate
for r in RANKS:
    if all_logits[r]:
        all_logits[r] = torch.cat(all_logits[r], dim=0)   # (N, C)
        all_targets[r] = torch.cat(all_targets[r], dim=0).long()  # (N,)
    else:
        all_logits[r] = torch.zeros((0, max(1, len(label_encoders[r].classes_))), device=device)
        all_targets[r] = torch.zeros((0,), dtype=torch.long, device=device)

# ---------- temperature fitting helper ----------
def fit_temperature_for_rank(logits: torch.Tensor, targets: torch.Tensor, init_temp=1.0, max_iter=200):
    """
    Fit scalar temperature T >= small_pos to minimize CrossEntropyLoss( logits / T, targets).
    Uses LBFGS; falls back to small Adam loop if LBFGS fails.
    logits: torch.Tensor (N,C) on device
    targets: torch.LongTensor (N,) on device
    returns float T
    """
    if logits.shape[0] == 0:
        return 1.0
    n_classes = logits.shape[1]
    if n_classes <= 1:
        return 1.0

    # parameterize T directly but clamp in closure to avoid negative T
    T_param = nn.Parameter(torch.tensor([float(init_temp)], device=device, dtype=torch.float32))
    loss_fn = nn.CrossEntropyLoss()

    # LBFGS closure
    try:
        optimizer = optim.LBFGS([T_param], max_iter=max_iter, line_search_fn="strong_wolfe")
        def closure():
            optimizer.zero_grad()
            T = T_param.clamp(min=1e-6)
            scaled = logits / T
            loss = loss_fn(scaled, targets)
            loss.backward()
            return loss
        optimizer.step(closure)
        T_final = float(T_param.detach().clamp(min=1e-6).item())
        # sanity: if nan or weird, fallback
        if not math.isfinite(T_final) or T_final <= 0:
            raise RuntimeError("LBFGS produced invalid T")
        return T_final
    except Exception as e:
        # fallback to small Adam loop (robust)
        # print debug info
        print("[CAL] LBFGS failed for rank (falling back to Adam):", e)
        T_param = nn.Parameter(torch.tensor([float(init_temp)], device=device, dtype=torch.float32))
        optimizer2 = optim.Adam([T_param], lr=1e-2)
        for it in range(300):
            optimizer2.zero_grad()
            T = T_param.clamp(min=1e-6)
            loss = loss_fn(logits / T, targets)
            loss.backward()
            optimizer2.step()
        T_final = float(T_param.detach().clamp(min=1e-6).item())
        if not math.isfinite(T_final) or T_final <= 0:
            return 1.0
        return T_final

# ---------- fit temperatures for each rank ----------
temps = {}
start_time = time.time()
for r in RANKS:
    ncls = len(label_encoders[r].classes_)
    N = all_logits[r].shape[0]
    if N == 0 or ncls <= 1:
        temps[r] = 1.0
        print(f"[CAL] rank={r}: skipped (N={N}, ncls={ncls}) -> T=1.0")
        continue
    print(f"[CAL] fitting T for rank={r} (N={N}, ncls={ncls}) ...", end="", flush=True)
    t0 = time.time()
    try:
        T_r = fit_temperature_for_rank(all_logits[r], all_targets[r], init_temp=1.0, max_iter=200)
    except Exception as e:
        print(" failed:", e)
        T_r = 1.0
    temps[r] = float(T_r)
    print(f" done (T={T_r:.4f}, took {time.time()-t0:.2f}s)")

print(f"[CAL] temperature fitting completed in {time.time() - start_time:.2f}s")
# Save temperatures
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
with open(OUT_TEMPS, "w") as fh:
    json.dump(temps, fh, indent=2)
print("[CAL] saved temperatures to:", OUT_TEMPS)

# ---------- report NLL / acc before vs after calibration ----------
print("\n[CAL] per-rank NLL/accuracy before -> after (smaller NLL is better)")
for r in RANKS:
    logits = all_logits[r]
    targets = all_targets[r]
    if logits.shape[0] == 0 or len(label_encoders[r].classes_) <= 1:
        print(f"  {r:10s}: no samples or single-class; skipped")
        continue
    with torch.no_grad():
        pre_nll = float(F.cross_entropy(logits, targets).item())
        pre_acc = float((logits.argmax(dim=1) == targets).float().mean().item())
        T = temps[r]
        post_nll = float(F.cross_entropy(logits / max(1e-6, T), targets).item())
        post_acc = float(((logits / max(1e-6, T)).argmax(dim=1) == targets).float().mean().item())
    print(f"  {r:10s}: pre_nll={pre_nll:.4f}, post_nll={post_nll:.4f}  | pre_acc={pre_acc:.4f}, post_acc={post_acc:.4f}  T={T:.4f}")

# ---------- produce calibrated validation CSV (one row per sample) ----------
# We'll include index + for each rank -> true label, pred_idx, pred_label, pred_prob, T
rows = []
N = all_logits[RANKS[0]].shape[0] if len(RANKS) > 0 else 0
for i in range(N):
    row = {"idx": int(i)}
    for r in RANKS:
        logits_np = all_logits[r][i].cpu().numpy()
        tgt_idx = int(all_targets[r][i].cpu().item()) if all_targets[r].shape[0] > 0 else None
        T = temps.get(r, 1.0)
        # stable softmax
        scaled = logits_np / max(1e-12, T)
        scaled = scaled - np.max(scaled)
        ex = np.exp(scaled)
        probs = ex / np.sum(ex)
        pred_idx = int(np.argmax(probs))
        pred_prob = float(np.max(probs))
        # get label names if encoder has them
        try:
            classes = label_encoders[r].classes_
            pred_label = str(classes[pred_idx])
            true_label = str(classes[tgt_idx]) if tgt_idx is not None and tgt_idx >= 0 and tgt_idx < len(classes) else None
        except Exception:
            pred_label = str(pred_idx)
            true_label = str(tgt_idx) if tgt_idx is not None else None
        row[f"{r}_true_idx"] = true_label
        row[f"{r}_pred_idx"] = pred_idx
        row[f"{r}_pred_label"] = pred_label
        row[f"{r}_pred_prob"] = pred_prob
        row[f"{r}_T"] = float(T)
    rows.append(row)

df = pd.DataFrame(rows)
df.to_csv(OUT_VAL_CALIB, index=False)
print("[CAL] wrote calibrated validation predictions to:", OUT_VAL_CALIB)

print("[CAL] Done. Per-rank temperatures:")
for r, t in temps.items():
    print(f"  {r:12s} -> T={t:.4f}")

[CAL] device: cpu
[CAL] loaded checkpoint: ncbi_blast_db\extracted\best_shared_heads_resumed.pt
[CAL] model loaded and moved to device; eval() mode set.
[CAL] collected logits from validation set: total samples ~ 380
[CAL] fitting T for rank=kingdom (N=380, ncls=2) ... done (T=1.2370, took 0.01s)
[CAL] fitting T for rank=phylum (N=380, ncls=5) ... done (T=1.2833, took 0.00s)
[CAL] fitting T for rank=class (N=380, ncls=10) ... done (T=1.2860, took 0.01s)
[CAL] fitting T for rank=order (N=380, ncls=13) ... done (T=1.4398, took 0.00s)
[CAL] fitting T for rank=family (N=380, ncls=19) ... done (T=1.4526, took 0.01s)
[CAL] fitting T for rank=genus (N=380, ncls=27) ... done (T=1.5589, took 0.01s)
[CAL] fitting T for rank=species (N=380, ncls=183) ... done (T=2.0711, took 0.01s)
[CAL] temperature fitting completed in 0.05s
[CAL] saved temperatures to: ncbi_blast_db\extracted\temp_scaling_by_rank.json

[CAL] per-rank NLL/accuracy before -> after (smaller NLL is better)
  kingdom   : pre_nll=0.2

In [33]:
# Cell: Inference with MC-dropout + uncertainty & novel-candidate ranking
# Paste & run in the same notebook that contains your trained checkpoint, val_loader, batch_tuple_to_dict, and label_encoders.

import time, json, math
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.nn as nn

EXTRACT_DIR = Path("ncbi_blast_db") / "extracted"
CKPT_PATH = EXTRACT_DIR / "best_shared_heads_resumed.pt"   # trained checkpoint (resumed)
TEMP_PATH = EXTRACT_DIR / "temp_scaling_by_rank.json"
OUT_PRED_CSV = EXTRACT_DIR / "predictions_with_uncertainty.csv"
OUT_NOVEL_CSV = EXTRACT_DIR / "novel_candidates_priority.csv"

# ---------- SETTINGS ----------
SEED = 0
torch.manual_seed(SEED)
np.random.seed(SEED)

# Choose loader: prefer a global 'inference_loader' if present, otherwise use val_loader
if "inference_loader" in globals():
    inference_loader = globals()["inference_loader"]
elif "val_loader" in globals():
    inference_loader = globals()["val_loader"]
else:
    raise RuntimeError("No inference_loader or val_loader found in globals. Run dataset prep first.")

batch_tuple_to_dict = globals().get("batch_tuple_to_dict")
label_encoders = globals().get("label_encoders")
if batch_tuple_to_dict is None or label_encoders is None:
    raise RuntimeError("Required globals missing: batch_tuple_to_dict and/or label_encoders")

RANKS = list(label_encoders.keys())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INF] device:", device, "using loader:", "inference_loader" if 'inference_loader' in globals() else "val_loader")

# sensible MC passes default: smaller on CPU
MC_PASSES = 32 if device.type == "cuda" else 12
print(f"[INF] MC_PASSES = {MC_PASSES} (adjust this variable if you need more/less samples)")

# ---------- Resilient model definition (must match training) ----------
class ResilientModel(nn.Module):
    def _init_(self):
        super()._init_()
        self._inited = False
    def initialize(self, input_dim, hidden_dim, ranks, encoders, dropout=0.3):
        if getattr(self, "_inited", False): return
        self.ranks = list(ranks)
        self.dropout = nn.Dropout(dropout)
        h1 = int(hidden_dim)
        h2 = max(32, h1 // 2)
        self.register_parameter("w1", nn.Parameter(torch.randn(input_dim, h1) * 0.02))
        self.register_parameter("b1", nn.Parameter(torch.zeros(h1)))
        self.register_parameter("w2", nn.Parameter(torch.randn(h1, h2) * 0.02))
        self.register_parameter("b2", nn.Parameter(torch.zeros(h2)))
        for r in self.ranks:
            ncls = max(1, len(encoders[r].classes_))
            self.register_parameter(f"head_w__{r}", nn.Parameter(torch.randn(h2, ncls) * 0.02))
            self.register_parameter(f"head_b__{r}", nn.Parameter(torch.zeros(ncls)))
        self._inited = True
    def forward(self, x):
        h = x @ self.w1 + self.b1
        h = torch.relu(h)
        h = self.dropout(h)
        h = h @ self.w2 + self.b2
        h = torch.relu(h)
        out = {}
        for r in self.ranks:
            w = getattr(self, f"head_w__{r}")
            b = getattr(self, f"head_b__{r}")
            out[r] = h @ w + b
        return out

# ---------- load checkpoint ----------
if not Path(CKPT_PATH).exists():
    raise RuntimeError(f"Checkpoint not found: {CKPT_PATH}")
ckpt = torch.load(str(CKPT_PATH), map_location=device)
print("[INF] loaded checkpoint:", CKPT_PATH)

# infer input dim from loader
sample_batch = next(iter(inference_loader))
sample = batch_tuple_to_dict(sample_batch)
input_dim = int(sample["x"].shape[1])

model = ResilientModel()
model.initialize(input_dim=input_dim, hidden_dim=256, ranks=RANKS, encoders=label_encoders, dropout=0.3)
# load weights (handle both dict-with-model_state and direct state_dict)
if isinstance(ckpt, dict) and "model_state" in ckpt:
    model.load_state_dict(ckpt["model_state"])
elif isinstance(ckpt, dict) and "state_dict" in ckpt:
    model.load_state_dict(ckpt["state_dict"])
else:
    try:
        model.load_state_dict(ckpt)
    except Exception:
        # try the 'model_state' fallback
        if isinstance(ckpt, dict) and any(k.endswith("model_state") for k in ckpt.keys()):
            ms = ckpt.get("model_state") or ckpt.get("state_dict")
            model.load_state_dict(ms)
        else:
            raise
model.to(device)
print("[INF] model loaded to device; parameters:", sum(p.numel() for p in model.parameters()))

# load temperatures if present
temps = {}
if Path(TEMP_PATH).exists():
    with open(TEMP_PATH, "r") as fh:
        temps = json.load(fh)
    print("[INF] loaded temperature scalars from:", TEMP_PATH)
else:
    temps = {r: 1.0 for r in RANKS}
    print("[INF] no temp file found; defaulting to T=1 for all ranks")

# ---------- helper: original-index mapping for loader samples ----------
def _loader_original_indices(loader):
    # Try Subset.indices, global val_idx, or fallback sequential indices
    ds = loader.dataset
    try:
        from torch.utils.data import Subset
        if isinstance(ds, Subset):
            return list(ds.indices)
    except Exception:
        pass
    if hasattr(ds, "indices"):
        try:
            return list(getattr(ds, "indices"))
        except Exception:
            pass
    # If val_idx global exists and loader is val_loader, use it
    if "val_idx" in globals() and loader is globals().get("val_loader"):
        try:
            return list(globals()["val_idx"])
        except Exception:
            pass
    # fallback: sequential 0..N-1
    try:
        N = len(ds)
        return list(range(N))
    except Exception:
        return None

orig_indices = _loader_original_indices(inference_loader)
if orig_indices is None:
    print("[INF] could not determine original indices mapping; rows will use sequential batch-index.")
else:
    print(f"[INF] found original indices length = {len(orig_indices)}")

# ---------- main MC-inference loop ----------
model.train()                # enable dropout during MC
torch.set_grad_enabled(False)  # no gradient computation needed for MC predictions

rows = []  # collected rows (dictionaries)
sample_global_counter = 0
total_samples_processed = 0
t_start = time.time()

n_batches = len(inference_loader)
batch_i = 0
for bt in inference_loader:
    batch_i += 1
    b = batch_tuple_to_dict(bt)
    x = b["x"].to(device)
    batch_size = x.shape[0]

    # prepare per-rank storage: shape (MC_PASSES, batch_size, n_classes)
    per_rank_probs = {}
    for r in RANKS:
        ncls = max(1, len(label_encoders[r].classes_))
        per_rank_probs[r] = np.zeros((MC_PASSES, batch_size, ncls), dtype=np.float32)

    # MC forward passes (dropout active since model.train())
    for m in range(MC_PASSES):
        out = model(x)  # dict of logits tensors (batch, ncls)
        for r in RANKS:
            logits = out[r]  # torch tensor on device
            T = float(temps.get(r, 1.0))
            # apply temperature and softmax (torch -> numpy)
            scaled = logits / max(1e-12, T)
            probs = F.softmax(scaled, dim=1).cpu().numpy()
            per_rank_probs[r][m] = probs

    # per-sample aggregation
    for i in range(batch_size):
        entry = {}
        # attempt to determine original dataset index
        if orig_indices is not None:
            try:
                global_idx = orig_indices[sample_global_counter]
            except Exception:
                # if orig_indices is full-list we can map via counter modulo
                global_idx = orig_indices[sample_global_counter] if sample_global_counter < len(orig_indices) else int(sample_global_counter)
        else:
            global_idx = int(sample_global_counter)
        entry["global_index"] = int(global_idx)

        # optional: include metadata columns if available in a global df_meta
        if "df_meta" in globals():
            try:
                meta_row = globals()["df_meta"].iloc[global_idx]
                # choose common columns if exist
                for c in ["id", "accession", "accession_base", "description"]:
                    if c in meta_row.index:
                        entry[c] = meta_row[c]
            except Exception:
                pass

        # compute stats per rank
        for r in RANKS:
            probs_all = per_rank_probs[r][:, i, :]  # shape (MC, ncls)
            # mean predictive prob
            mean_prob = probs_all.mean(axis=0)
            # predictive entropy (total)
            eps = 1e-12
            entropy = -float(np.sum(mean_prob * np.log(np.clip(mean_prob, eps, 1.0))))
            # expected entropy (aleatoric)
            per_pass_ent = -np.sum(probs_all * np.log(np.clip(probs_all, eps, 1.0)), axis=1)  # shape (MC,)
            exp_entropy = float(np.mean(per_pass_ent))
            mutual_info = float(entropy - exp_entropy)  # epistemic uncertainty
            pred_idx = int(np.argmax(mean_prob)) if mean_prob.size > 0 else -1
            pred_conf = float(mean_prob[pred_idx]) if pred_idx >= 0 else 0.0
            max_meanprob = float(np.max(mean_prob)) if mean_prob.size > 0 else 0.0
            # label names if available
            try:
                classes = label_encoders[r].classes_
                pred_label = str(classes[pred_idx]) if (pred_idx >= 0 and pred_idx < len(classes)) else str(pred_idx)
            except Exception:
                pred_label = str(pred_idx)

            entry[f"{r}_pred_idx"] = pred_idx
            entry[f"{r}_pred_label"] = pred_label
            entry[f"{r}_pred_conf"] = pred_conf
            entry[f"{r}_entropy"] = entropy
            entry[f"{r}_exp_entropy"] = exp_entropy
            entry[f"{r}_mutual_info"] = mutual_info
            entry[f"{r}_mc_mean_topprob"] = max_meanprob

        rows.append(entry)
        sample_global_counter += 1
        total_samples_processed += 1

    # progress
    if batch_i % 10 == 0 or batch_i == n_batches:
        print(f"[INF] processed batch {batch_i}/{n_batches}  (total samples so far: {total_samples_processed})")

t_total = time.time() - t_start
print(f"[INF] MC inference completed: {total_samples_processed} samples, time {t_total:.1f}s")

# ---------- build DataFrame and compute novelty score ----------
df = pd.DataFrame(rows)

# Heuristic novel score:
# combine species and genus mutual_info and low confidence into a single score (adaptable)
score_parts = []
if "species_mutual_info" in df.columns and "species_pred_conf" in df.columns:
    df["species_novel_component"] = df["species_mutual_info"] * 1.0 + (1.0 - df["species_pred_conf"]) * 0.5
    score_parts.append("species_novel_component")
if "genus_mutual_info" in df.columns and "genus_pred_conf" in df.columns:
    df["genus_novel_component"] = df["genus_mutual_info"] * 0.8 + (1.0 - df["genus_pred_conf"]) * 0.3
    score_parts.append("genus_novel_component")
# fallback: mean of mutual_info across ranks (if species/genus absent)
if not score_parts:
    mi_cols = [c for c in df.columns if c.endswith("_mutual_info")]
    if mi_cols:
        df["novel_score"] = df[mi_cols].mean(axis=1)
    else:
        df["novel_score"] = 0.0
else:
    df["novel_score"] = df[score_parts].mean(axis=1)

# sort descending novel_score
df_sorted = df.sort_values("novel_score", ascending=False).reset_index(drop=True)

# save outputs
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
df.to_csv(OUT_PRED_CSV, index=False)
df_sorted.head(500).to_csv(OUT_NOVEL_CSV, index=False)

print("[INF] wrote predictions with uncertainties to:", OUT_PRED_CSV)
print("[INF] wrote novel candidate ranking (top 500) to:", OUT_NOVEL_CSV)

# quick summary printout of top 10 novel candidates
print("\nTop 10 novel candidates (novel_score, global_index, species_pred_label, species_pred_conf, species_mutual_info):")
cols_to_show = ["novel_score", "global_index", "species_pred_label", "species_pred_conf", "species_mutual_info"]
for i, row in df_sorted.head(10).iterrows():
    vals = [row.get(c, None) for c in cols_to_show]
    print(f"{i+1:02d})", *[f"{v}" for v in vals])

print("\n[INF] Done. You can open the two CSVs in 'ncbi_blast_db/extracted/'.")

[INF] device: cpu using loader: inference_loader
[INF] MC_PASSES = 12 (adjust this variable if you need more/less samples)
[INF] loaded checkpoint: ncbi_blast_db\extracted\best_shared_heads_resumed.pt
[INF] model loaded to device; parameters: 82947
[INF] loaded temperature scalars from: ncbi_blast_db\extracted\temp_scaling_by_rank.json
[INF] found original indices length = 380
[INF] processed batch 3/3  (total samples so far: 380)
[INF] MC inference completed: 380 samples, time 0.2s
[INF] wrote predictions with uncertainties to: ncbi_blast_db\extracted\predictions_with_uncertainty.csv
[INF] wrote novel candidate ranking (top 500) to: ncbi_blast_db\extracted\novel_candidates_priority.csv

Top 10 novel candidates (novel_score, global_index, species_pred_label, species_pred_conf, species_mutual_info):
01) 0.4612707830965519 2127 Trichophyton japonicum 0.149295374751091 0.2001938819885254
02) 0.44211927577853205 2402 Scedosporium aurantiacum 0.23798935115337372 0.11600136756896973
03) 0.41

In [35]:
# Paste into your notebook cell and run (assumes the notebook already has
# val_loader/inference_loader/batch_tuple_to_dict/label_encoders available)

import torch, torch.nn.functional as F
import numpy as np
from pathlib import Path
from sklearn.metrics import accuracy_score, f1_score

EXTRACT_DIR = Path("ncbi_blast_db") / "extracted"
CKPT = EXTRACT_DIR / "best_shared_heads_resumed.pt"
TFILE = EXTRACT_DIR / "temp_scaling_by_rank.json"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device ->", device)

# ---------- model class (must match training) ----------
import torch.nn as nn
class ResilientModel(nn.Module):
    def _init_(self):
        super()._init_()
        self._inited = False
    def initialize(self, input_dim, hidden_dim, ranks, encoders, dropout=0.3):
        if getattr(self, "_inited", False): return
        self.ranks = list(ranks)
        self.dropout = nn.Dropout(dropout)
        h1 = int(hidden_dim); h2 = max(32, h1//2)
        self.register_parameter("w1", nn.Parameter(torch.randn(input_dim, h1) * 0.02))
        self.register_parameter("b1", nn.Parameter(torch.zeros(h1)))
        self.register_parameter("w2", nn.Parameter(torch.randn(h1, h2) * 0.02))
        self.register_parameter("b2", nn.Parameter(torch.zeros(h2)))
        for r in self.ranks:
            ncls = max(1, len(encoders[r].classes_))
            self.register_parameter(f"head_w__{r}", nn.Parameter(torch.randn(h2, ncls) * 0.02))
            self.register_parameter(f"head_b__{r}", nn.Parameter(torch.zeros(ncls)))
        self._inited = True
    def forward(self, x):
        h = x @ self.w1 + self.b1
        h = torch.relu(h)
        h = self.dropout(h)
        h = h @ self.w2 + self.b2
        h = torch.relu(h)
        out = {}
        for r in self.ranks:
            w = getattr(self, f"head_w__{r}")
            b = getattr(self, f"head_b__{r}")
            out[r] = h @ w + b
        return out

# ---------- load checkpoint and temps ----------
assert CKPT.exists(), f"Checkpoint not found: {CKPT}"
ckpt = torch.load(str(CKPT), map_location=device)

# required notebook globals
label_encoders = globals().get("label_encoders")
batch_tuple_to_dict = globals().get("batch_tuple_to_dict")
if label_encoders is None or batch_tuple_to_dict is None:
    raise RuntimeError("label_encoders or batch_tuple_to_dict not found in globals - run dataset prep cells first.")
RANKS = list(label_encoders.keys())

# infer input dim from a loader (prefer inference_loader else val_loader)
loader_for_shape = globals().get("inference_loader") or globals().get("val_loader")
sample = batch_tuple_to_dict(next(iter(loader_for_shape)))
input_dim = int(sample["x"].shape[1])

model = ResilientModel()
model.initialize(input_dim=input_dim, hidden_dim=256, ranks=RANKS, encoders=label_encoders, dropout=0.3)
# load state
if isinstance(ckpt, dict) and "model_state" in ckpt:
    model.load_state_dict(ckpt["model_state"])
else:
    model.load_state_dict(ckpt)
model.to(device)
model.eval()

# load temps if present
import json
temps = {r:1.0 for r in RANKS}
if Path(TFILE).exists():
    with open(TFILE,"r") as fh: temps = json.load(fh)
    temps = {k:float(v) for k,v in temps.items()}

# ---------- utility: logits -> calibrated probs ----------
def probs_from_logits_with_temp(logits_tensor, rank):
    # logits_tensor: torch tensor (B, C)
    T = float(temps.get(rank, 1.0))
    scaled = logits_tensor / max(1e-12, T)
    probs = F.softmax(scaled, dim=1)
    return probs  # torch tensor on same device

# ---------- single-sample prediction example ----------
def predict_single(x_numpy):
    """x_numpy: 1D numpy vector length=input_dim"""
    model.eval()
    with torch.no_grad():
        x = torch.from_numpy(x_numpy.astype(np.float32)).unsqueeze(0).to(device)
        out = model(x)
        result = {}
        for r in RANKS:
            logits = out[r]  # (1, C)
            probs = probs_from_logits_with_temp(logits, r).cpu().numpy()[0]
            pred_idx = int(probs.argmax())
            pred_label = label_encoders[r].classes_[pred_idx]
            result[r] = {"pred_idx": pred_idx, "pred_label": pred_label, "pred_prob": float(probs[pred_idx]), "probs": probs}
        return result

# ---------- batched evaluation: accuracy & macro-F1 per rank ----------
def evaluate_loader(loader):
    model.eval()
    all_preds = {r: [] for r in RANKS}
    all_trues = {r: [] for r in RANKS}
    with torch.no_grad():
        for bt in loader:
            b = batch_tuple_to_dict(bt)
            x = b["x"].to(device)
            out = model(x)
            for r in RANKS:
                probs = probs_from_logits_with_temp(out[r], r).cpu().numpy()
                preds = probs.argmax(axis=1)
                trues = b[r].cpu().numpy()
                all_preds[r].append(preds)
                all_trues[r].append(trues)
    metrics = {}
    for r in RANKS:
        if not all_preds[r]:
            metrics[r] = {"acc": None, "f1_macro": None}
            continue
        preds = np.concatenate(all_preds[r])
        trues = np.concatenate(all_trues[r])
        acc = accuracy_score(trues, preds)
        f1m = f1_score(trues, preds, average="macro", zero_division=0)
        metrics[r] = {"acc": float(acc), "f1_macro": float(f1m)}
    return metrics

# ---------- MC-dropout for one batch (epistemic/mutual info) ----------
def mc_dropout_on_batch(batch, mc_passes=32):
    # input: one batch from loader (tuple as batch_tuple_to_dict expects)
    model.train()  # enable dropout
    b = batch_tuple_to_dict(batch)
    x = b["x"].to(device)
    batch_size = x.shape[0]
    res = []
    for r in RANKS:
        ncls = len(label_encoders[r].classes_)
        probs_all = np.zeros((mc_passes, batch_size, ncls), dtype=np.float32)
        for m in range(mc_passes):
            out = model(x)
            logits = out[r].detach() / max(1e-12, temps.get(r, 1.0))
            probs = F.softmax(logits, dim=1).cpu().numpy()
            probs_all[m] = probs
        # compute mean, entropy, expected entropy, mutual info
        mean_probs = probs_all.mean(axis=0)              # (B, C)
        ent = -np.sum(mean_probs * np.log(np.clip(mean_probs, 1e-12,1.0)), axis=1)
        per_pass_ent = -np.sum(probs_all * np.log(np.clip(probs_all,1e-12,1.0)), axis=2)  # (MC, B)
        exp_ent = per_pass_ent.mean(axis=0)              # (B,)
        mutual_info = ent - exp_ent
        res.append((r, mean_probs, ent, exp_ent, mutual_info))
    model.eval()
    return res

# ---------- USAGE examples ----------
# 1) single-sample (take first sample from val_loader)
loader = globals().get("inference_loader") or globals().get("val_loader")
bt = next(iter(loader))
b = batch_tuple_to_dict(bt)
x0 = b["x"][0].cpu().numpy()
single = predict_single(x0)
print("Single sample predictions (showing top ranks):")
for r in RANKS:
    print(f"  {r}: {single[r]['pred_label']} (p={single[r]['pred_prob']:.3f})")

# 2) full evaluation on val_loader
metrics = evaluate_loader(loader)
print("\nEvaluation metrics per rank:")
for r,m in metrics.items():
    print(f"  {r:10s} acc={m['acc']:.4f} f1_macro={m['f1_macro']:.4f}")

# 3) MC-dropout example on that same batch (fast: mc_passes=12)
mc_results = mc_dropout_on_batch(bt, mc_passes=12)
print("\nMC-dropout sample (showing species mutual info for sample 0):")
for r, mean_probs, ent, exp_ent, mi in mc_results:
    print(f"  {r}: mutual_info(sample0) = {float(mi[0]):.5f}, top_pred = {label_encoders[r].classes_[int(mean_probs[0].argmax())]}")

device -> cpu
Single sample predictions (showing top ranks):
  kingdom: UNASSIGNED (p=0.971)
  phylum: UNASSIGNED (p=0.959)
  class: UNASSIGNED (p=0.951)
  order: UNASSIGNED (p=0.921)
  family: UNASSIGNED (p=0.914)
  genus: UNASSIGNED (p=0.906)
  species: UNASSIGNED (p=0.376)

Evaluation metrics per rank:
  kingdom    acc=0.8868 f1_macro=0.8859
  phylum     acc=0.8789 f1_macro=0.9214
  class      acc=0.8684 f1_macro=0.9273
  order      acc=0.8684 f1_macro=0.9256
  family     acc=0.8684 f1_macro=0.9044
  genus      acc=0.8658 f1_macro=0.8220
  species    acc=0.7737 f1_macro=0.3942

MC-dropout sample (showing species mutual info for sample 0):
  kingdom: mutual_info(sample0) = 0.00904, top_pred = UNASSIGNED
  phylum: mutual_info(sample0) = 0.00611, top_pred = UNASSIGNED
  class: mutual_info(sample0) = 0.01922, top_pred = UNASSIGNED
  order: mutual_info(sample0) = 0.01286, top_pred = UNASSIGNED
  family: mutual_info(sample0) = 0.01243, top_pred = UNASSIGNED
  genus: mutual_info(sample0) =

In [37]:
import numpy as np, pandas as pd
from sklearn.metrics import confusion_matrix, top_k_accuracy_score

# load calibrated val predictions (we created this) and uncertainty file
val_df = pd.read_csv("ncbi_blast_db/extracted/val_predictions_calibrated.csv")
unc_df = pd.read_csv("ncbi_blast_db/extracted/predictions_with_uncertainty.csv")

# support per species (on validation)
if 'species_true_idx' in val_df.columns:
    support = val_df['species_true_idx'].value_counts().sort_values(ascending=False)
    print("Top 20 species by support (val):")
    print(support.head(20))

# compute top-5 accuracy for species using model outputs saved previously (if you have logits, use them).
# If you only have predicted probs in files (mean probs in MC cell) you can compute top-k from those:
# We'll check in-case the MC inference CSV saved mean probs per class (if not, compute from model).
# Example assuming val_df contains species_pred_prob only (top-1); we'll compute top-5 using model on val_loader:

# --- compute top-k using model (recommended) ---
from sklearn.metrics import accuracy_score
loader = globals().get("val_loader")
model = globals().get("model")  # if model in scope
label_encoders = globals()["label_encoders"]
RANKS = list(label_encoders.keys())

def topk_species_accuracy(loader, model, k=5):
    model.eval()
    ys, preds_topk = [], []
    for bt in loader:
        batch = batch_tuple_to_dict(bt)
        x = batch['x'].to(next(model.parameters()).device)
        out = model(x)
        logits = out['species'].detach().cpu().numpy()  # (B, C)
        # apply temp if you use temps:
        temps = {}
        import json
        try:
            temps = json.load(open("ncbi_blast_db/extracted/temp_scaling_by_rank.json"))
        except:
            temps = {r:1.0 for r in RANKS}
        logits = logits / temps.get('species', 1.0)
        topk = np.argsort(-logits, axis=1)[:, :k]
        preds_topk.append(topk)
        ys.append(batch['species'].cpu().numpy())
    ys = np.concatenate(ys)
    preds_topk = np.vstack(preds_topk)
    # top-k accuracy: check if true label in predicted topk row
    hit = np.array([ys[i] in preds_topk[i] for i in range(len(ys))])
    return float(hit.mean())

print("Top-5 species accuracy (val):", topk_species_accuracy(loader, model, k=5))

Top 20 species by support (val):
species_true_idx
UNASSIGNED                         201
Maylandia zebra                     80
Chaetodon auriga                    22
Arvicanthis niloticus                6
Morchella sp.                        5
Aonchotheca annulosa                 4
Pseudopestalotiopsis sp.             3
Deuterostichococcus epilithicus      3
Aspergillus costaricensis            3
Chloroidium saccharophilum           3
Morchella nipponensis                2
Cardimyxobolus iriomotensis          2
Amanita fuscozonata                  2
Inocybe sp.                          2
Inocybe miranda                      2
Amanita sp.                          2
Cortinarius sp.                      2
Baruscapillaria inflexa              2
Diplosphaera chodatii                1
Pseudoboletus parasiticus            1
Name: count, dtype: int64
Top-5 species accuracy (val): 0.9052631578947369


In [39]:
from collections import Counter
import numpy as np
from sklearn.metrics import confusion_matrix

# build true/pred arrays for species on val
trues, preds = [], []
for bt in loader:
    b = batch_tuple_to_dict(bt)
    x = b['x'].to(next(model.parameters()).device)
    out = model(x)
    logits = out['species'].detach().cpu().numpy()
    # temperature scaling
    import json
    temps = json.load(open("ncbi_blast_db/extracted/temp_scaling_by_rank.json"))
    logits = logits / temps.get('species', 1.0)
    p = np.argmax(logits, axis=1)
    preds.append(p); trues.append(b['species'].cpu().numpy())
trues = np.concatenate(trues); preds = np.concatenate(preds)

cm = confusion_matrix(trues, preds)
# find most common off-diagonal confusions
cm_off = cm.copy()
np.fill_diagonal(cm_off, 0)
# get top confused pairs
pairs = []
for i,j in zip(*np.unravel_index(np.argsort(cm_off.ravel())[::-1], cm_off.shape)):
    if cm_off[i,j] <= 0: break
    pairs.append((i,j,cm_off[i,j]))
# map to labels
species_labels = label_encoders['species'].classes_
for a,b,count in pairs[:20]:
    print(f"{species_labels[a]} -> {species_labels[b]} : {count} times")

Inocybe favoris -> Aspergillus costaricensis : 24 times
Inocybe favoris -> Armillaria sp. : 6 times
Inocybe favoris -> Diplosphaera chodatii : 5 times
Inocybe favoris -> Clavulina amethystina : 5 times
Amanita fuscozonata -> Cladosporium pseudocladosporioides : 4 times
Inocybe favoris -> Cladosporium allicinum : 3 times
Inocybe favoris -> Ancylostoma ceylanicum : 3 times
Inocybe favoris -> Amphibiocapillaria tritonispunctati : 2 times
Inocybe favoris -> Inocybe beatifica : 2 times
Antarctomyces sp. -> Amanita fulva : 2 times
Inocybe favoris -> Chloroidium saccharophilum : 2 times
Inocybe favoris -> Agaricus sp. : 2 times
Clitocybula sp. -> Chloroidium saccharophilum : 1 times
Diplosphaera chodatii -> Cortinarius vagabundus : 1 times
Inocybe favoris -> Inocybe derbschii : 1 times
Gymnopilus luteus -> Agaricus argyropotamicus : 1 times
Gliophorus sp. -> Geotrichum citri-aurantii : 1 times
Inocybe favoris -> Hysterothylacium fabri : 1 times
Inocybe favoris -> Geotrichum citri-aurantii : 1

In [41]:
import numpy as np
def expected_calibration_error(probs, labels, n_bins=15):
    probs = np.asarray(probs)
    labels = np.asarray(labels)
    confidences = probs.max(axis=1)
    predictions = probs.argmax(axis=1)
    bins = np.linspace(0,1,n_bins+1)
    ece = 0.0
    for i in range(n_bins):
        lower, upper = bins[i], bins[i+1]
        mask = (confidences > lower) & (confidences <= upper)
        if mask.sum()==0: continue
        acc = (predictions[mask] == labels[mask]).mean()
        conf = confidences[mask].mean()
        ece += (mask.sum() / len(labels)) * abs(acc - conf)
    return ece

# example: compute species ECE by running model on val and collecting probs
probs_list = []; labs = []
for bt in loader:
    b = batch_tuple_to_dict(bt); x=b['x'].to(next(model.parameters()).device)
    out = model(x)
    logits = out['species'].detach().cpu()
    logits = logits / temps.get('species',1.0)
    probs = torch.softmax(logits, dim=1).numpy()
    probs_list.append(probs); labs.append(b['species'].cpu().numpy())
probs_arr = np.vstack(probs_list); labs = np.concatenate(labs)
print("Species ECE:", expected_calibration_error(probs_arr, labs, n_bins=15))

Species ECE: 0.16603384813980052


In [50]:
from pathlib import Path
import torch, torch.nn.functional as F
import numpy as np

CKPT = Path("ncbi_blast_db/extracted/best_shared_heads_resumed.pt")
TFILE = Path("ncbi_blast_db/extracted/temp_scaling_by_rank.json")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# recreate the same model class (use the ResilientModel class you used earlier),
# initialize exactly as before, then:
ckpt = torch.load(str(CKPT), map_location=device)
model = ResilientModel()              # use the same class definition you already used
model.initialize(input_dim=64, hidden_dim=256, ranks=list(label_encoders.keys()), encoders=label_encoders, dropout=0.3)
model.load_state_dict(ckpt.get("model_state", ckpt))
model.to(device); model.eval()

# load temps
import json
temps = json.load(open(TFILE)) if TFILE.exists() else {r:1.0 for r in label_encoders.keys()}

# pick one sample from a loader:
loader = globals().get("inference_loader") or globals().get("val_loader")
batch = next(iter(loader))
b = batch_tuple_to_dict(batch)
x0 = b["x"][0:1].to(device)   # single sample

with torch.no_grad():
    out = model(x0)
    for r in label_encoders.keys():
        logits = out[r] / max(1e-12, temps.get(r, 1.0))
        probs = F.softmax(logits, dim=1).cpu().numpy()[0]
        pred_idx = int(probs.argmax())
        print(f"{r}: {label_encoders[r].classes_[pred_idx]}  p={probs[pred_idx]:.3f}")

kingdom: UNASSIGNED  p=0.971
phylum: UNASSIGNED  p=0.959
class: UNASSIGNED  p=0.951
order: UNASSIGNED  p=0.921
family: UNASSIGNED  p=0.914
genus: UNASSIGNED  p=0.906
species: UNASSIGNED  p=0.376


In [52]:
from sklearn.metrics import accuracy_score, f1_score
def evaluate_loader(loader):
    model.eval()
    all_preds, all_trues = {r:[] for r in label_encoders}, {r:[] for r in label_encoders}
    with torch.no_grad():
        for bt in loader:
            b = batch_tuple_to_dict(bt)
            x = b['x'].to(device)
            out = model(x)
            for r in label_encoders:
                logits = out[r] / max(1e-12, temps.get(r,1.0))
                probs = F.softmax(logits, dim=1).cpu().numpy()
                all_preds[r].append(probs.argmax(axis=1))
                all_trues[r].append(b[r].cpu().numpy())
    metrics = {}
    for r in label_encoders:
        if not all_preds[r]: 
            metrics[r] = {"acc":None, "f1_macro":None}; continue
        preds = np.concatenate(all_preds[r]); trues = np.concatenate(all_trues[r])
        metrics[r] = {"acc": float(accuracy_score(trues,preds)), "f1_macro": float(f1_score(trues,preds,average='macro', zero_division=0))}
    return metrics

metrics = evaluate_loader(loader)
print(metrics)

{'kingdom': {'acc': 0.8868421052631579, 'f1_macro': 0.8858739305046273}, 'phylum': {'acc': 0.8789473684210526, 'f1_macro': 0.9214409010002644}, 'class': {'acc': 0.868421052631579, 'f1_macro': 0.9273010143352481}, 'order': {'acc': 0.868421052631579, 'f1_macro': 0.9255512297263465}, 'family': {'acc': 0.868421052631579, 'f1_macro': 0.9043662190901388}, 'genus': {'acc': 0.8657894736842106, 'f1_macro': 0.8220249642221711}, 'species': {'acc': 0.7736842105263158, 'f1_macro': 0.39419672261669525}}


In [54]:
# Fix for AttributeError: missing head_w_...  (use the correct parameter names)
import torch, torch.nn as nn, torch.nn.functional as F, json
from pathlib import Path
EXTRACT_DIR = Path("ncbi_blast_db") / "extracted"
CKPT = EXTRACT_DIR / "best_shared_heads_resumed.pt"
TFILE = EXTRACT_DIR / "temp_scaling_by_rank.json"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1) show state_dict keys so we confirm the exact names present in checkpoint / model
ckpt = torch.load(str(CKPT), map_location=device)
ms = ckpt.get("model_state", ckpt) if isinstance(ckpt, dict) else ckpt
print("Example keys in checkpoint/state_dict (first 40):")
for i,k in enumerate(list(ms.keys())[:40]):
    print(" ", i, k)
print("... total keys:", len(ms))

# 2) defensive model: forward tries both name patterns (double underscore then single)
class ResilientModel(nn.Module):
    def _init_(self):
        super()._init_()
        self._inited = False
    def initialize(self, input_dim, hidden_dim, ranks, encoders, dropout=0.3):
        # this registers parameters with the double-underscore names used during training
        if getattr(self, "_inited", False): return
        self.ranks = list(ranks)
        self.dropout = nn.Dropout(dropout)
        h1 = int(hidden_dim); h2 = max(32, h1 // 2)
        self.register_parameter("w1", nn.Parameter(torch.randn(input_dim, h1) * 0.02))
        self.register_parameter("b1", nn.Parameter(torch.zeros(h1)))
        self.register_parameter("w2", nn.Parameter(torch.randn(h1, h2) * 0.02))
        self.register_parameter("b2", nn.Parameter(torch.zeros(h2)))
        for r in self.ranks:
            ncls = max(1, len(encoders[r].classes_))
            # register using the double-underscore convention (matches your checkpoint)
            self.register_parameter(f"head_w__{r}", nn.Parameter(torch.randn(h2, ncls) * 0.02))
            self.register_parameter(f"head_b__{r}", nn.Parameter(torch.zeros(ncls)))
        self._inited = True

    def _get_head_params(self, r):
        # try double-underscore name first (the name used in training), then single-underscore fallback
        name_w_ds = f"head_w__{r}"
        name_b_ds = f"head_b__{r}"
        name_w_ss = f"head_w_{r}"
        name_b_ss = f"head_b_{r}"
        if hasattr(self, name_w_ds) and hasattr(self, name_b_ds):
            return getattr(self, name_w_ds), getattr(self, name_b_ds)
        if hasattr(self, name_w_ss) and hasattr(self, name_b_ss):
            return getattr(self, name_w_ss), getattr(self, name_b_ss)
        # final fallback: try to find any parameter that contains the rank string
        for n,p in self.named_parameters():
            if f"head_w" in n and (r in n):
                w_name = n
                b_name = n.replace("head_w","head_b")
                if hasattr(self, b_name):
                    return getattr(self, w_name), getattr(self, b_name)
        raise AttributeError(f"No head parameters found for rank='{r}' (tried {name_w_ds},{name_w_ss})")

    def forward(self, x):
        h = x @ self.w1 + self.b1
        h = torch.relu(h)
        h = self.dropout(h)
        h = h @ self.w2 + self.b2
        h = torch.relu(h)
        out = {}
        for r in self.ranks:
            w, b = self._get_head_params(r)   # robustly fetch params
            out[r] = h @ w + b
        return out

# 3) Load label_encoders and a loader from your notebookglobals (must exist)
assert "label_encoders" in globals() and ("inference_loader" in globals() or "val_loader" in globals()) and "batch_tuple_to_dict" in globals()
label_encoders = globals()["label_encoders"]
loader = globals().get("inference_loader") or globals().get("val_loader")
batch_tuple_to_dict = globals()["batch_tuple_to_dict"]
RANKS = list(label_encoders.keys())

# 4) instantiate and load state
sample_batch = next(iter(loader))
sample = batch_tuple_to_dict(sample_batch)
input_dim = int(sample["x"].shape[1])

model = ResilientModel()
model.initialize(input_dim=input_dim, hidden_dim=256, ranks=RANKS, encoders=label_encoders, dropout=0.3)
# load checkpoint state (handle both forms)
state = ckpt.get("model_state", ckpt) if isinstance(ckpt, dict) else ckpt
# load_state_dict will match param names from the checkpoint to the module's param names.
# If checkpoint uses double-underscore names and we've registered the same, this will succeed.
model.load_state_dict(state)
model.to(device)
print("Model loaded. number of parameters:", sum(p.numel() for p in model.parameters()))

# 5) quick check: show that head params exist
print("Head parameter presence (sample):")
for r in RANKS[:5]:
    try:
        w,b = model._get_head_params(r)
        print(f"  {r}: found w shape={tuple(w.shape)}, b shape={tuple(b.shape)}")
    except Exception as e:
        print("  ", r, "->", e)

# 6) run a single-sample deterministic prediction (with calibration if temp file exists)
temps = {r:1.0 for r in RANKS}
if Path(TFILE).exists():
    temps = {k:float(v) for k,v in json.load(open(TFILE)).items()}

model.eval()
with torch.no_grad():
    b0 = batch_tuple_to_dict(next(iter(loader)))
    x0 = b0["x"][0:1].to(device)
    out = model(x0)
    results = {}
    for r in RANKS:
        logits = out[r] / max(1e-12, temps.get(r,1.0))
        probs = F.softmax(logits, dim=1).cpu().numpy()[0]
        idx = int(probs.argmax())
        label = label_encoders[r].classes_[idx]
        results[r] = {"pred_idx": idx, "pred_label": label, "prob": float(probs[idx])}
print("\nSingle-sample predictions:\n", results)

Example keys in checkpoint/state_dict (first 40):
  0 w1
  1 b1
  2 w2
  3 b2
  4 head_w__kingdom
  5 head_b__kingdom
  6 head_w__phylum
  7 head_b__phylum
  8 head_w__class
  9 head_b__class
  10 head_w__order
  11 head_b__order
  12 head_w__family
  13 head_b__family
  14 head_w__genus
  15 head_b__genus
  16 head_w__species
  17 head_b__species
... total keys: 18
Model loaded. number of parameters: 82947
Head parameter presence (sample):
  kingdom: found w shape=(128, 2), b shape=(2,)
  phylum: found w shape=(128, 5), b shape=(5,)
  class: found w shape=(128, 10), b shape=(10,)
  order: found w shape=(128, 13), b shape=(13,)
  family: found w shape=(128, 19), b shape=(19,)

Single-sample predictions:
 {'kingdom': {'pred_idx': 1, 'pred_label': 'UNASSIGNED', 'prob': 0.9707451462745667}, 'phylum': {'pred_idx': 3, 'pred_label': 'UNASSIGNED', 'prob': 0.9586387872695923}, 'class': {'pred_idx': 9, 'pred_label': 'UNASSIGNED', 'prob': 0.9505306482315063}, 'order': {'pred_idx': 11, 'pred_labe

In [56]:
# Predictor helper: load model, deterministic predict and MC-dropout predict
import torch, torch.nn as nn, torch.nn.functional as F, json, numpy as np
from pathlib import Path

def _make_model_class():
    class ResilientModel(nn.Module):
        def _init(self): super().init_(); self._inited=False
        def initialize(self, input_dim, hidden_dim, ranks, encoders, dropout=0.3):
            if getattr(self,"_inited",False): return
            self.ranks = list(ranks); self.dropout = nn.Dropout(dropout)
            h1=int(hidden_dim); h2=max(32,h1//2)
            self.register_parameter("w1", nn.Parameter(torch.randn(input_dim,h1)*0.02))
            self.register_parameter("b1", nn.Parameter(torch.zeros(h1)))
            self.register_parameter("w2", nn.Parameter(torch.randn(h1,h2)*0.02))
            self.register_parameter("b2", nn.Parameter(torch.zeros(h2)))
            for r in self.ranks:
                ncls = max(1, len(encoders[r].classes_))
                # register head parameters with double-underscore convention (matches your checkpoint)
                self.register_parameter(f"head_w__{r}", nn.Parameter(torch.randn(h2,ncls)*0.02))
                self.register_parameter(f"head_b__{r}", nn.Parameter(torch.zeros(ncls)))
            self._inited=True

        def _get_head(self, r):
            # try double-underscore then single-underscore then fuzzy
            n_w = f"head_w_{r}"; n_b = f"head_b_{r}"
            if hasattr(self, n_w) and hasattr(self, n_b):
                return getattr(self, n_w), getattr(self, n_b)
            n_w2 = f"head_w_{r}"; n_b2 = f"head_b_{r}"
            if hasattr(self, n_w2) and hasattr(self, n_b2):
                return getattr(self, n_w2), getattr(self, n_b2)
            # fallback: find param containing rank
            for n,p in self.named_parameters():
                if "head_w" in n and (r in n):
                    w = getattr(self, n); bname = n.replace("head_w","head_b")
                    if hasattr(self, bname): return getattr(self, n), getattr(self, bname)
            raise AttributeError(f"No head found for rank '{r}'")
        def forward(self, x):
            h = x @ self.w1 + self.b1; h = torch.relu(h)
            h = self.dropout(h)
            h = h @ self.w2 + self.b2; h = torch.relu(h)
            out = {}
            for r in self.ranks:
                w,b = self._get_head(r)
                out[r] = h @ w + b
            return out
    return ResilientModel

def load_predictor(checkpoint_path, label_encoders, temp_path=None, device=None):
    device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
    CK = Path(checkpoint_path)
    assert CK.exists(), f"Checkpoint missing: {CK}"
    ck = torch.load(str(CK), map_location=device)
    ResilientModel = _make_model_class()
    model = ResilientModel()
    # infer input dim from a sample if user provides one — but we'll require user to pass input_dim or loader.
    # Here we expect user to pass input_dim via label_encoders container (we'll not assume).
    # We'll attempt to infer input_dim from a global loader if present
    input_dim = None
    if "inference_loader" in globals():
        sample = globals()["batch_tuple_to_dict"](next(iter(globals()["inference_loader"])))
        input_dim = int(sample["x"].shape[1])
    elif "val_loader" in globals():
        sample = globals()["batch_tuple_to_dict"](next(iter(globals()["val_loader"])))
        input_dim = int(sample["x"].shape[1])
    if input_dim is None:
        raise RuntimeError("Cannot infer input_dim — provide a loader in globals or modify load_predictor to accept input_dim.")
    model.initialize(input_dim=input_dim, hidden_dim=256, ranks=list(label_encoders.keys()), encoders=label_encoders, dropout=0.3)
    st = ck.get("model_state", ck) if isinstance(ck, dict) else ck
    model.load_state_dict(st)
    model.to(device).eval()
    temps = {r:1.0 for r in label_encoders}
    if temp_path:
        p = Path(temp_path)
        if p.exists():
            temps = {k:float(v) for k,v in json.load(open(p)).items()}
    return {"model": model, "device": device, "temps": temps, "ranks": list(label_encoders.keys()), "encoders": label_encoders}

def predict_single(predictor, embedding_np):
    """embedding_np: 1D numpy array (input_dim,)"""
    model = predictor["model"]; device = predictor["device"]; temps = predictor["temps"]; ranks = predictor["ranks"]; encs = predictor["encoders"]
    model.eval()
    x = torch.from_numpy(np.asarray(embedding_np, dtype=np.float32)).unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(x)
    out_res = {}
    for r in ranks:
        logits = out[r] / max(1e-12, float(temps.get(r,1.0)))
        probs = F.softmax(logits, dim=1).cpu().numpy()[0]
        idx = int(np.argmax(probs))
        out_res[r] = {"pred_idx": idx, "pred_label": encs[r].classes_[idx], "prob": float(probs[idx]), "probs": probs}
    return out_res

def predict_with_mc(predictor, embedding_np, mc_passes=32):
    """Returns mean probs, predictive entropy, expected entropy, mutual_info per rank."""
    model = predictor["model"]; device = predictor["device"]; temps = predictor["temps"]; ranks = predictor["ranks"]; encs = predictor["encoders"]
    x = torch.from_numpy(np.asarray(embedding_np, dtype=np.float32)).unsqueeze(0).to(device)
    model.train()   # enable dropout
    probs_all = {r: [] for r in ranks}
    with torch.no_grad():
        for m in range(mc_passes):
            out = model(x)
            for r in ranks:
                logits = out[r] / max(1e-12, float(temps.get(r,1.0)))
                probs = F.softmax(logits, dim=1).cpu().numpy()[0]  # (C,)
                probs_all[r].append(probs)
    # aggregate
    res = {}
    for r in ranks:
        arr = np.stack(probs_all[r], axis=0)   # (MC, C)
        mean_prob = arr.mean(axis=0)
        entropy = -np.sum(mean_prob * np.log(np.clip(mean_prob,1e-12,1.0)))
        per_pass_ent = -np.sum(arr * np.log(np.clip(arr,1e-12,1.0)), axis=1)
        exp_ent = per_pass_ent.mean()
        mutual_info = float(entropy - exp_ent)
        pred_idx = int(np.argmax(mean_prob))
        res[r] = {"pred_idx": pred_idx, "pred_label": encs[r].classes_[pred_idx], "mean_prob": mean_prob, "entropy": float(entropy), "exp_entropy": float(exp_ent), "mutual_info": mutual_info}
    model.eval()
    return res

# ---------------- Example usage (run after previous definitions) ----------------
# 1) load predictor:
predictor = load_predictor("ncbi_blast_db/extracted/best_shared_heads_resumed.pt", label_encoders, temp_path="ncbi_blast_db/extracted/temp_scaling_by_rank.json")
# 2) deterministic prediction for first val sample
batch = next(iter(globals().get("inference_loader") or globals().get("val_loader")))
b = globals()["batch_tuple_to_dict"](batch)
emb = b["x"][0].cpu().numpy()
print("deterministic:", predict_single(predictor, emb))
print("mc (12 passes):", predict_with_mc(predictor, emb, mc_passes=12))

deterministic: {'kingdom': {'pred_idx': 1, 'pred_label': 'UNASSIGNED', 'prob': 0.9707451462745667, 'probs': array([0.02925488, 0.97074515], dtype=float32)}, 'phylum': {'pred_idx': 3, 'pred_label': 'UNASSIGNED', 'prob': 0.9586387872695923, 'probs': array([3.9736195e-03, 3.7238460e-02, 9.8026078e-07, 9.5863879e-01,
       1.4807584e-04], dtype=float32)}, 'class': {'pred_idx': 9, 'pred_label': 'UNASSIGNED', 'prob': 0.9505306482315063, 'probs': array([7.1348179e-05, 4.3802533e-02, 3.4696552e-06, 4.8463512e-03,
       6.3554020e-05, 7.6261688e-07, 5.7074840e-07, 1.9403724e-05,
       6.6132913e-04, 9.5053065e-01], dtype=float32)}, 'order': {'pred_idx': 11, 'pred_label': 'UNASSIGNED', 'prob': 0.9205213785171509, 'probs': array([1.3949699e-05, 7.9814587e-03, 9.0692297e-04, 1.5622020e-06,
       6.4422026e-02, 5.5940636e-03, 9.1526119e-05, 2.0697414e-06,
       4.8125563e-05, 1.7701190e-04, 7.1040442e-05, 9.2052138e-01,
       1.6889031e-04], dtype=float32)}, 'family': {'pred_idx': 17, 'pred_l

In [67]:
# Robust Cell — locate embeddings + meta (search common locations) and assemble abundance_dataset.csv if possible.
# Paste/run this cell in your notebook. It will NOT raise uncaught exceptions; it prints helpful diagnostics.
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import textwrap

def main():
    # ---------- EDITABLE fallback path(s) ----------
    # If you know the exact folder, set it here as a raw Windows path, e.g.:
    # DOWNLOAD_DIR = Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db")
    DOWNLOAD_DIR = None   # <-- leave None to let the search try common candidates
    # -----------------------------------------------

    # Candidate roots to check (in priority order). We include the exact path you mentioned.
    candidates = []
    if DOWNLOAD_DIR:
        candidates.append(Path(str(DOWNLOAD_DIR)))
    # Path the user reported having files in:
    candidates.append(Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db"))
    # some convenient relatives (project/workdir)
    candidates.append(Path.cwd() / "sih" / "ncbi_blast_db")
    candidates.append(Path.cwd() / "ncbi_blast_db")
    candidates.append(Path.cwd())
    # home OneDrive common location (Windows)
    home = Path.home()
    candidates.append(home / "OneDrive" / "Desktop" / "sihtaxa" / "sihabundance" / "ncbi_blast_db")
    candidates.append(home / "OneDrive" / "Desktop" / "ncbi_blast_db")
    candidates = [p.resolve() for p in candidates if p is not None]

    # Files we need inside extracted/
    need_files = ("embeddings_pca.npy", "embeddings_meta_clustered.csv")

    found = None
    scanned = []
    for root in candidates:
        extracted = root / "extracted"
        scanned.append(str(extracted))
        if extracted.exists() and extracted.is_dir():
            all_present = all((extracted / f).exists() for f in need_files)
            if all_present:
                found = extracted
                break

    # Also try a broader search: scan subfolders of current working directory for an extracted/ with required files
    if found is None:
        for p in Path.cwd().rglob("extracted"):
            if (p / need_files[0]).exists() and (p / need_files[1]).exists():
                found = p.resolve()
                scanned.append(str(p.resolve()))
                break

    if found is None:
        print("\n[ERROR] Could not locate required files in any scanned 'extracted/' locations.")
        print("Scanned candidate 'extracted' directories (in order):")
        for s in scanned:
            print("  -", s)
        print("\nWhat to do next (pick one):")
        print("  1) Edit the top of this cell and set DOWNLOAD_DIR = Path(r\"C:\\full\\path\\to\\ncbi_blast_db\")")
        print("     Example for your machine:")
        print("       DOWNLOAD_DIR = Path(r\"C:\\Users\\Srijit\\OneDrive\\Desktop\\sihtaxa\\sihabundance\\ncbi_blast_db\")")
        print("  2) Move the 'extracted' folder that contains the two files into one of the scanned locations above.")
        print("\nFiles expected inside the 'extracted' folder:")
        for f in need_files:
            print("  -", f)
        print("\nThis cell will exit gracefully; edit DOWNLOAD_DIR and re-run.")
        return  # graceful return

    # Found the extracted folder
    EXTRACT_DIR = found
    print(f"\n[FOUND] using extracted folder: {EXTRACT_DIR}")
    emb_pca_path = EXTRACT_DIR / "embeddings_pca.npy"
    meta_path    = EXTRACT_DIR / "embeddings_meta_clustered.csv"

    # Load files defensively
    try:
        print("[LOAD] loading embeddings (numpy) ...", emb_pca_path)
        X = np.load(emb_pca_path)
        print("[OK] embeddings shape:", X.shape)
    except Exception as e:
        print("[ERROR] Failed to load embeddings_pca.npy:", e)
        return

    try:
        print("[LOAD] loading metadata CSV ...", meta_path)
        meta = pd.read_csv(meta_path)
        print("[OK] metadata rows:", len(meta))
    except Exception as e:
        print("[ERROR] Failed to load embeddings_meta_clustered.csv:", e)
        return

    # Ensure 'id' column exists (try common alternates)
    if "id" not in meta.columns:
        for alt in ("accession","seqid","accession_id","accession_id_base","accession.version"):
            if alt in meta.columns:
                meta = meta.rename(columns={alt: "id"})
                print(f"[INFO] Renamed metadata column '{alt}' -> 'id'")
                break
    if "id" not in meta.columns:
        print("[ERROR] metadata has no 'id' column. Columns present:", list(meta.columns)[:40])
        print("If your metadata uses a different name for accession, rename that column to 'id' or edit the code.")
        return

    # Search for abundance-like column(s) in meta
    cand_keywords = ["abund", "abundance", "count", "reads", "depth", "coverage", "relative", "rpm"]
    abundance_candidates = []
    for c in meta.columns:
        lc = c.lower()
        if any(kw in lc for kw in cand_keywords):
            abundance_candidates.append(c)
    y = None
    y_df = None
    if abundance_candidates:
        abundance_col = abundance_candidates[0]
        print(f"[FOUND] abundance-like column in metadata: '{abundance_col}' (using it as target)")
        # coerce to numeric
        try:
            y = pd.to_numeric(meta[abundance_col], errors="coerce")
        except Exception:
            y = meta[abundance_col]
        n_nan = int(y.isna().sum()) if hasattr(y, "isna") else 0
        if n_nan > 0:
            print(f"[WARN] {n_nan} rows in the abundance column are not numeric / NaN (they will be handled later).")
    else:
        # look for external label files inside EXTRACT_DIR
        print("[INFO] No abundance-like column in metadata; searching for label CSVs in extracted/ ...")
        candidates_files = list(EXTRACT_DIR.glob("abundance*.csv")) + list(EXTRACT_DIR.glob("*abund*.csv")) + list(EXTRACT_DIR.glob("labels*.csv")) + list(EXTRACT_DIR.glob("*labels*.csv"))
        if candidates_files:
            opened = False
            for p in sorted(set(candidates_files)):
                try:
                    df = pd.read_csv(p)
                except Exception as e:
                    print(f"[WARN] could not read {p.name}: {e}")
                    continue
                cols_lower = [c.lower() for c in df.columns]
                if "id" in cols_lower or "accession" in cols_lower or "seqid" in cols_lower:
                    # pick numeric column for abundance
                    id_col = df.columns[cols_lower.index("id")] if "id" in cols_lower else df.columns[cols_lower.index("accession")] if "accession" in cols_lower else df.columns[cols_lower.index("seqid")]
                    numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
                    if not numeric_cols:
                        # try keywords
                        numeric_cols = [c for c in df.columns if any(kw in c.lower() for kw in cand_keywords)]
                    if numeric_cols:
                        abundance_col = numeric_cols[0]
                        y_df = df[[id_col, abundance_col]].rename(columns={id_col: "id", abundance_col: "abundance"})
                        print(f"[USING] external label file: {p.name} -> id='{id_col}', abundance='{abundance_col}'")
                        opened = True
                        break
            if not opened:
                print("[WARN] found candidate CSVs but none had clear numeric abundance+id columns.")
        else:
            print("[INFO] No candidate label CSVs found in extracted/")

    # Assemble final dataset and save
    out_csv = EXTRACT_DIR / "abundance_dataset.csv"
    if y is not None:
        # assume meta order corresponds to X order; validate lengths
        if len(meta) == X.shape[0]:
            df_out = meta.reset_index(drop=True).copy()
            df_out["abundance_target"] = y.values
            df_out["pca_index"] = np.arange(len(df_out))
            try:
                df_out.to_csv(out_csv, index=False)
                print(f"[SAVED] merged abundance dataset -> {out_csv} (rows={len(df_out)})")
            except Exception as e:
                print("[ERROR] failed to save abundance_dataset.csv:", e)
            print("\n[SAMPLE]")
            print(df_out.head().to_string(index=False))
        else:
            # lengths mismatch: save partial with aligned index where possible
            print("[WARN] metadata rows != embeddings rows. Attempting to save partial merged dataset with pca_index where possible.")
            df_out = meta.reset_index(drop=True).copy()
            df_out["abundance_target"] = y.values if len(y)==len(df_out) else pd.Series(y).reindex(df_out.index)
            df_out["pca_index"] = pd.Series(range(len(df_out)))
            try:
                df_out.to_csv(EXTRACT_DIR / "abundance_dataset_partial.csv", index=False)
                print(f"[SAVED] partial dataset -> {EXTRACT_DIR / 'abundance_dataset_partial.csv'} (rows={len(df_out)})")
            except Exception as e:
                print("[ERROR] failed to save partial dataset:", e)
            print("Please check lengths: embeddings rows =", X.shape[0], "meta rows =", len(meta))
    elif y_df is not None:
        # merge external labels with meta on 'id'
        merged = pd.merge(meta, y_df, on="id", how="inner")
        if merged.empty:
            print("[ERROR] merging with external label file produced 0 rows. Check the 'id' values for exact matches (case sensitive).")
            print("Example meta ids:", list(meta['id'].astype(str).head(10)))
            print("Example label ids:", list(y_df['id'].astype(str).head(10)))
            return
        merged = merged.reset_index(drop=True).rename(columns={merged.columns[-1]:"abundance_target"}) if merged.columns[-1]!="abundance_target" else merged
        try:
            merged.to_csv(out_csv, index=False)
            print(f"[SAVED] merged abundance dataset -> {out_csv} (rows={len(merged)})")
            print("\n[SAMPLE]")
            print(merged.head().to_string(index=False))
        except Exception as e:
            print("[ERROR] failed to save merged abundance dataset:", e)
            return
    else:
        print("\n[NO TARGET FOUND]")
        print("I couldn't find a numeric abundance column in metadata or a label CSV.")
        print("To proceed, either:")
        print("  - place a CSV with columns 'id' and 'abundance' into:")
        print(f"      {EXTRACT_DIR}  (recommended) or")
        print(f"      {Path.cwd()} or another folder listed above and re-run this cell.")
        print("  - OR edit this cell and set DOWNLOAD_DIR to your exact folder path, e.g.:")
        print(r'      DOWNLOAD_DIR = Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db")')
        return

    print("\n[READY] If abundance_dataset.csv exists in the extracted folder you can now run the training cell (Cell 2).")
    return

# run main (safe)
main()



[FOUND] using extracted folder: C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted
[LOAD] loading embeddings (numpy) ... C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted\embeddings_pca.npy
[OK] embeddings shape: (2555, 64)
[LOAD] loading metadata CSV ... C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted\embeddings_meta_clustered.csv
[OK] metadata rows: 2555
[INFO] No abundance-like column in metadata; searching for label CSVs in extracted/ ...
[INFO] No candidate label CSVs found in extracted/

[NO TARGET FOUND]
I couldn't find a numeric abundance column in metadata or a label CSV.
To proceed, either:
  - place a CSV with columns 'id' and 'abundance' into:
      C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted  (recommended) or
      C:\Users\Srijit\sih or another folder listed above and re-run this cell.
  - OR edit this cell and set DOWNLOAD_DIR to your exact folder path,

In [71]:
# Cell 2 (modified & defensive) -- train / evaluate abundance regressor and save it
import time, traceback
from pathlib import Path
import numpy as np
import pandas as pd
import joblib

# sklearn imports
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

# ---------- EDITED default paths for your environment ----------
DOWNLOAD_DIR = Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db")
EXTRACT_DIR  = DOWNLOAD_DIR / "extracted"
ABUND_CSV    = EXTRACT_DIR / "abundance_dataset.csv"
ABUND_PARTIAL= EXTRACT_DIR / "abundance_dataset_partial.csv"
EMB_PCA      = EXTRACT_DIR / "embeddings_pca.npy"
META_CSV     = EXTRACT_DIR / "embeddings_meta_clustered.csv"
MODEL_OUT    = EXTRACT_DIR / "abundance_model.joblib"
# -----------------------------------------------------------------

print(f"[PATHS] EXTRACT_DIR: {EXTRACT_DIR}")
print(f"[FILES] looking for: {ABUND_CSV.name}  (fallback: {ABUND_PARTIAL.name})")
print(f"[FILES] embeddings PCA: {EMB_PCA.name}")

# defensive helper to exit this cell gracefully
def stop(msg):
    print(msg)
    print("Cell finished without training. Fix the issue above and re-run.")
    return

# 1) file existence checks
if not EMB_PCA.exists():
    stop(f"[ERROR] Embeddings PCA file not found at: {EMB_PCA}")
    # do not raise an exception; return
else:
    try:
        X_pca = np.load(EMB_PCA)
    except Exception as e:
        stop(f"[ERROR] Failed to load embeddings from {EMB_PCA}: {e}")
        X_pca = None

# choose abundance CSV (prefer full, fallback to partial)
if ABUND_CSV.exists():
    abund_path = ABUND_CSV
elif ABUND_PARTIAL.exists():
    abund_path = ABUND_PARTIAL
    print(f"[WARN] full abundance_dataset.csv not found; using partial: {ABUND_PARTIAL.name}")
else:
    stop(f"[ERROR] No abundance dataset found. Place a CSV named '{ABUND_CSV.name}' (columns: id, abundance) into:\n  {EXTRACT_DIR}\nThen re-run this cell.")
    abund_path = None

if X_pca is None or abund_path is None:
    # we've already printed an error; stop
    pass
else:
    try:
        df = pd.read_csv(abund_path)
    except Exception as e:
        stop(f"[ERROR] Failed to read abundance CSV {abund_path}: {e}")
        df = None

    if df is None:
        pass
    else:
        print(f"[LOAD] abundance dataset rows={len(df)}; embeddings shape={X_pca.shape}")

        # pick abundance target column
        possible_target_cols = [c for c in df.columns if c.lower() in ("abundance_target","abundance","count","abund","reads","value")]
        if not possible_target_cols:
            numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c]) and c.lower()!="pca_index"]
            if numeric_cols:
                possible_target_cols = [numeric_cols[-1]]
        if not possible_target_cols:
            stop("[ERROR] Could not detect numeric abundance column in dataset. Columns found:\n  " + ", ".join(list(df.columns)))
        else:
            target_cols = possible_target_cols if isinstance(possible_target_cols, (list,tuple)) else [possible_target_cols]
            print(f"[TARGETS] Using target columns: {target_cols}")

            # Align embeddings <-> df rows
            X_use = None
            df_use = None

            # if pca_index present, use it (preferred)
            if "pca_index" in df.columns:
                try:
                    idx = df["pca_index"].astype(int).values
                    if idx.max() >= X_pca.shape[0] or idx.min() < 0:
                        print("[WARN] pca_index contains out-of-range values; falling back to row-order alignment.")
                    else:
                        X_use = X_pca[idx]
                        df_use = df.reset_index(drop=True)
                        print(f"[ALIGN] aligned via pca_index, rows={len(df_use)}")
                except Exception as e:
                    print("[WARN] pca_index exists but failed to use it:", e)

            # id-based alignment if pca_index not used
            if X_use is None:
                if META_CSV.exists():
                    try:
                        meta = pd.read_csv(META_CSV)
                        if "id" in meta.columns:
                            id_to_index = {rid: i for i, rid in enumerate(meta["id"].astype(str).values)}
                            if "id" in df.columns:
                                matched_indices = []
                                keep_rows = []
                                for i, rid in enumerate(df["id"].astype(str).values):
                                    if rid in id_to_index:
                                        matched_indices.append(id_to_index[rid])
                                        keep_rows.append(i)
                                if len(matched_indices) > 0:
                                    X_use = X_pca[matched_indices]
                                    df_use = df.iloc[keep_rows].reset_index(drop=True)
                                    print(f"[ALIGN] matched {len(matched_indices)} rows by 'id' to embeddings")
                                else:
                                    print("[WARN] Could not match ids between abundance CSV and metadata; falling back to top-row alignment.")
                            else:
                                print("[WARN] abundance CSV has no 'id' column; falling back to top-row alignment.")
                        else:
                            print("[WARN] meta CSV exists but has no 'id' column; falling back to top-row alignment.")
                    except Exception as e:
                        print("[WARN] Failed to read/parse meta CSV for id-alignment:", e)

            # last resort: align by top-n rows
            if X_use is None:
                min_n = min(X_pca.shape[0], len(df))
                X_use = X_pca[:min_n]
                df_use = df.iloc[:min_n].reset_index(drop=True)
                print(f"[ALIGN] fallback top-row alignment used. Using first {min_n} rows.")

            # Prepare Y array
            try:
                Y = df_use[target_cols].copy()
                if Y.shape[1] == 1:
                    y = Y.iloc[:,0].astype(float).values
                else:
                    y = Y.astype(float).values
            except Exception as e:
                stop(f"[ERROR] Failed to prepare target y: {e}")
                y = None

            if y is None:
                pass
            else:
                # drop rows with NaN target
                if np.isnan(y).any():
                    mask = ~np.isnan(y).any(axis=1) if y.ndim>1 else ~np.isnan(y)
                    before = X_use.shape[0]
                    X_use = X_use[mask]
                    y = y[mask]
                    after = X_use.shape[0]
                    print(f"[CLEAN] Dropped {before-after} rows with NaN target. New shapes X={X_use.shape}, y={y.shape}")

                if X_use.shape[0] < 5:
                    stop(f"[ERROR] Not enough rows after alignment/cleaning to train (found {X_use.shape[0]} rows). Need more labeled rows.")
                else:
                    # Train/test split
                    RANDOM_SEED = 42
                    test_size = 0.15 if X_use.shape[0] >= 20 else 0.2
                    X_train, X_val, y_train, y_val = train_test_split(X_use, y, test_size=test_size, random_state=RANDOM_SEED)
                    print(f"[SPLIT] train={len(X_train)} val={len(X_val)}")

                    # choose/model
                    base = RandomForestRegressor(n_estimators=200, n_jobs=-1, random_state=RANDOM_SEED)
                    if y_train.ndim == 1 or (hasattr(y_train, "shape") and y_train.shape[1] == 1):
                        model = base
                    else:
                        model = MultiOutputRegressor(base, n_jobs=-1)

                    # fit with try/except
                    t0 = time.time()
                    try:
                        print("[TRAIN] fitting RandomForest regressor (this may take some time)...")
                        model.fit(X_train, y_train)
                        t1 = time.time()
                        print(f"[TRAINED] fit time: {t1-t0:.2f}s")
                    except Exception as e:
                        print("[ERROR] training failed:", e)
                        traceback.print_exc()
                        stop("Training aborted due to error.")
                        model = None

                    if model is not None:
                        try:
                            y_pred = model.predict(X_val)
                            if y_pred.ndim == 1:
                                r2 = r2_score(y_val, y_pred)
                                mae = mean_absolute_error(y_val, y_pred)
                                mse = mean_squared_error(y_val, y_pred)
                                print(f"[EVAL] R2={r2:.4f}  MAE={mae:.4f}  MSE={mse:.4f}")
                            else:
                                r2 = r2_score(y_val, y_pred, multioutput='uniform_average')
                                mae = mean_absolute_error(y_val, y_pred, multioutput='uniform_average')
                                mse = mean_squared_error(y_val, y_pred, multioutput='uniform_average')
                                print(f"[EVAL multi] R2(avg)={r2:.4f}  MAE(avg)={mae:.4f}  MSE(avg)={mse:.4f}")
                        except Exception as e:
                            print("[WARN] evaluation failed:", e)

                        # save model artifact
                        try:
                            joblib.dump({
                                "model": model,
                                "target_cols": target_cols,
                                "feature_dim": X_use.shape[1],
                                "train_date": time.ctime(),
                            }, MODEL_OUT)
                            print(f"[SAVED] model + metadata -> {MODEL_OUT}")
                        except Exception as e:
                            print("[ERROR] Failed to save model:", e)

                        # helper function exposed in this cell
                        def predict_abundance_from_embedding(embeddings):
                            info = joblib.load(MODEL_OUT)
                            mdl = info["model"]
                            expected_dim = info["feature_dim"]
                            arr = np.asarray(embeddings)
                            if arr.ndim == 1:
                                arr = arr.reshape(1, -1)
                            if arr.shape[1] != expected_dim:
                                raise ValueError(f"Embedding dimension mismatch. Model expects {expected_dim} features; got {arr.shape[1]}.")
                            preds = mdl.predict(arr)
                            return preds

                        # sample predictions (safely)
                        try:
                            sample_pred = predict_abundance_from_embedding(X_val[:3])
                            print("[SAMPLE PREDICTIONS] for first 3 validation rows ->")
                            print(sample_pred)
                        except Exception as e:
                            print("[WARN] sample prediction failed:", e)

                        print("\n[DONE] Training complete. Use `predict_abundance_from_embedding(embedding_array)` to predict.")


[PATHS] EXTRACT_DIR: C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted
[FILES] looking for: abundance_dataset.csv  (fallback: abundance_dataset_partial.csv)
[FILES] embeddings PCA: embeddings_pca.npy
[ERROR] No abundance dataset found. Place a CSV named 'abundance_dataset.csv' (columns: id, abundance) into:
  C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted
Then re-run this cell.
Cell finished without training. Fix the issue above and re-run.


In [73]:
# Abundance-from-predictions cell (run this in your notebook)
# - searches for predictions CSVs that the inference step writes
# - computes count-based and confidence-weighted relative abundances per species (and per higher ranks optionally)
# - saves results to extracted/abundance_from_predictions.csv and abundance_from_predictions_weighted.csv
# - does NOT change the classifier; only aggregates its outputs
from pathlib import Path
import pandas as pd
import numpy as np
import textwrap

# Candidate parent folders to search (no single hardcoded path only)
candidates = [
    Path.cwd() / "sih" / "ncbi_blast_db" / "extracted",
    Path.cwd() / "ncbi_blast_db" / "extracted",
    Path.cwd() / "extracted",
    Path.cwd(),
    Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
    Path(r"C:\Users\HP\sihabundance\ncbi_blast_db\extracted"),
    Path.home() / "OneDrive" / "Desktop" / "sihtaxa" / "sihabundance" / "ncbi_blast_db" / "extracted",
]

# also look for any "extracted" subfolder under cwd (scan shallow)
for p in Path.cwd().glob("**/extracted"):
    candidates.append(p.resolve())

# make unique and keep only existing directories for informational order
seen = []
candidates_clean = []
for p in candidates:
    if p not in seen:
        seen.append(p)
        candidates_clean.append(p)
candidates = candidates_clean

# filenames we expect from previous inference step (in order of preference)
pred_filenames = [
    "predictions_with_uncertainty.csv",
    "predictions.csv",
    "val_predictions_calibrated.csv",
    "val_predictions.csv",
    "predictions_with_uncertainty_latest.csv",
]

found = None
found_path = None
for extracted in candidates:
    if not extracted.exists():
        continue
    for name in pred_filenames:
        p = extracted / name
        if p.exists():
            found = name
            found_path = p
            break
    if found_path:
        break

# also try scanning any extracted folder for matching filenames
if found_path is None:
    for p in Path.cwd().rglob("predictions_with_uncertainty.csv"):
        found_path = p.resolve()
        break

# final fallback: any CSV in extracted folder containing 'species' / 'species_pred' columns
if found_path is None:
    for extracted in candidates:
        if not extracted.exists():
            continue
        for p in extracted.glob("*.csv"):
            try:
                tmp = pd.read_csv(p, nrows=5)
            except Exception:
                continue
            cols = [c.lower() for c in tmp.columns]
            if any(c in cols for c in ("species_pred_label", "species_label", "species_pred", "species")):
                found_path = p
                break
        if found_path:
            break

if found_path is None:
    print(textwrap.dedent(f"""
    [ERROR] Could not find inference predictions CSV in the usual places.
    I searched the following candidate extracted folders (in this order):
      {', '.join(str(x) for x in candidates)}
    Expected one of these files (examples): {pred_filenames}
    If you have not yet run the inference step that writes predictions (predictions_with_uncertainty.csv),
    please run it first. Otherwise, place the CSV into an 'extracted/' folder and re-run this cell.

    Alternatively, if you *do* have a predictions CSV but with a different name, move or rename it to:
      predictions_with_uncertainty.csv
    or place it into one of the candidate 'extracted' folders above.
    """))
else:
    print(f"[FOUND] predictions CSV -> {found_path}")
    try:
        df = pd.read_csv(found_path)
    except Exception as e:
        print(f"[ERROR] Failed to read CSV {found_path}: {e}")
        df = None

    if df is None:
        pass
    else:
        # find species label column (common variants)
        col_candidates = [
            "species_pred_label", "species_label", "species_pred", "species_predicted", "species",
            "species_prediction_label", "species_pred_lbl"
        ]
        cols_lower = {c.lower(): c for c in df.columns}
        species_col = None
        for cand in col_candidates:
            if cand in cols_lower:
                species_col = cols_lower[cand]
                break

        # try other heuristics if not found
        if species_col is None:
            # look for any column containing the word 'species' (case-insensitive)
            for c in df.columns:
                if "species" in c.lower():
                    species_col = c
                    break

        if species_col is None:
            print("[ERROR] Could not find a 'species' prediction column in the predictions CSV.")
            print("Columns present:", list(df.columns)[:50])
            print("If your predictions file uses a different column name for species predictions,")
            print("please rename or let me know the column name. I will try to detect it automatically next.")
        else:
            print(f"[USING] species column = '{species_col}'")

            # optional: column containing species probability/confidence
            conf_candidates = ["species_pred_conf", "species_conf", "species_prob", "species_probability", "species_pred_proba", "species_proba"]
            conf_col = None
            for cand in conf_candidates:
                if cand in cols_lower:
                    conf_col = cols_lower[cand]
                    break
            # try again heuristically
            if conf_col is None:
                for c in df.columns:
                    if "conf" in c.lower() or "prob" in c.lower() or "score" in c.lower():
                        if "species" in c.lower() or "pred" in c.lower():
                            conf_col = c
                            break

            if conf_col is None:
                print("[INFO] No explicit species confidence column detected. Weighted abundance will be skipped.")
            else:
                print(f"[USING] species confidence column = '{conf_col}'")

            # compute count-based abundance
            df_species = df[[species_col]].copy()
            df_species = df_species.dropna(subset=[species_col])
            counts = df_species[species_col].value_counts(dropna=True).rename_axis("species").reset_index(name="count")
            total = counts["count"].sum()
            counts["relative_abundance"] = counts["count"] / total
            counts = counts.sort_values("relative_abundance", ascending=False).reset_index(drop=True)

            # compute confidence-weighted abundance if possible
            if conf_col is not None and conf_col in df.columns:
                df_conf = df[[species_col, conf_col]].copy()
                # coerce conf to numeric; invalids -> NaN -> dropped
                df_conf[conf_col] = pd.to_numeric(df_conf[conf_col], errors="coerce")
                df_conf = df_conf.dropna(subset=[species_col, conf_col])
                # groupby species and sum confidences
                weighted = df_conf.groupby(species_col)[conf_col].sum().rename("conf_sum").reset_index()
                total_conf = weighted["conf_sum"].sum()
                if total_conf <= 0:
                    print("[WARN] Sum of confidences <= 0; skipping weighted abundance.")
                    weighted = None
                else:
                    weighted["relative_abundance_weighted"] = weighted["conf_sum"] / total_conf
                    weighted = weighted.rename(columns={species_col: "species"})
                    weighted = weighted.sort_values("relative_abundance_weighted", ascending=False).reset_index(drop=True)
            else:
                weighted = None

            # optional: compute per-higher-rank abundances if higher-rank columns exist
            # search for phylum/class/order/family/genus columns
            rank_cols = {}
            for rank in ("kingdom","phylum","class","order","family","genus"):
                for c in df.columns:
                    if rank in c.lower():
                        rank_cols[rank] = c
                        break

            rank_abundances = {}
            for rank, col in rank_cols.items():
                tmp = df[[col]].dropna()
                counts_rank = tmp[col].value_counts().rename_axis(rank).reset_index(name="count")
                counts_rank["relative_abundance"] = counts_rank["count"] / counts_rank["count"].sum()
                rank_abundances[rank] = counts_rank

            # Save outputs to same extracted folder
            out_dir = found_path.parent
            out_counts = out_dir / "abundance_from_predictions.csv"
            out_weighted = out_dir / "abundance_from_predictions_weighted.csv"
            try:
                counts.to_csv(out_counts, index=False)
                print(f"[SAVED] count-based abundance -> {out_counts}  (rows={len(counts)})")
            except Exception as e:
                print(f"[ERROR] failed to save count-based CSV: {e}")

            if weighted is not None:
                try:
                    weighted.to_csv(out_weighted, index=False)
                    print(f"[SAVED] confidence-weighted abundance -> {out_weighted}  (rows={len(weighted)})")
                except Exception as e:
                    print(f"[ERROR] failed to save weighted CSV: {e}")

            # Save per-rank abundances if present
            for rank, df_rank in rank_abundances.items():
                out_rank = out_dir / f"abundance_{rank}.csv"
                try:
                    df_rank.to_csv(out_rank, index=False)
                    print(f"[SAVED] {rank}-level abundance -> {out_rank}  (rows={len(df_rank)})")
                except Exception as e:
                    print(f"[WARN] failed to save {rank} abundance: {e}")

            # print top-10 species by count and weighted (if available)
            print("\nTop 10 species by count-based abundance:")
            print(counts.head(10).to_string(index=False))

            if weighted is not None:
                print("\nTop 10 species by confidence-weighted abundance:")
                print(weighted.head(10).to_string(index=False))

            print("\n[DONE] You can open the CSV(s) in the extracted folder. If you want a different aggregation (e.g. per-sample abundances, per-contig, or hierarchical aggregation by taxonomic path) tell me and I'll give the exact cell for that.")


[FOUND] predictions CSV -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\predictions_with_uncertainty.csv
[USING] species column = 'species_pred_label'
[USING] species confidence column = 'species_pred_conf'
[SAVED] count-based abundance -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_from_predictions.csv  (rows=52)
[SAVED] confidence-weighted abundance -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_from_predictions_weighted.csv  (rows=52)
[SAVED] kingdom-level abundance -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_kingdom.csv  (rows=2)
[SAVED] phylum-level abundance -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_phylum.csv  (rows=5)
[SAVED] class-level abundance -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_class.csv  (rows=8)
[SAVED] order-level abundance -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_order.csv  (rows=11)
[SAVED] family-level abundance -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_family.csv  (row

In [81]:
# Fixed cell: robust confusion-based deconvolution (handles the KeyError & column/variable issues)
from pathlib import Path
import numpy as np
import pandas as pd
import traceback
import sys

def find_file_in_candidates(names):
    candidates = [
        Path.cwd() / "sih" / "ncbi_blast_db" / "extracted",
        Path.cwd() / "ncbi_blast_db" / "extracted",
        Path.cwd() / "extracted",
        Path.cwd(),
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
        Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"),
        Path(r"C:\Users\HP\sihabundance\ncbi_blast_db\extracted"),
        Path.home() / "OneDrive" / "Desktop" / "sihtaxa" / "sihabundance" / "ncbi_blast_db" / "extracted",
    ]
    for p in Path.cwd().glob("**/extracted"):
        candidates.append(p.resolve())
    visited = set()
    for c in candidates:
        if c in visited: continue
        visited.add(c)
        for n in names:
            p = c / n
            if p.exists():
                return p
    # fallback recursive search
    for n in names:
        for p in Path.cwd().rglob(n):
            return p.resolve()
    return None

def normalize_label(s):
    # convert to str, strip whitespace, collapse internal multiple spaces, keep case (you can optionally lower())
    if pd.isna(s):
        return "UNASSIGNED"
    s = str(s).strip()
    # collapse multiple spaces
    s = " ".join(s.split())
    if s == "":
        return "UNASSIGNED"
    return s

def nnls_pgd_solve(A, Pred, max_iter=3000, tol=1e-7, verbose=False):
    M = A.T
    m, k = M.shape
    try:
        x0, *_ = np.linalg.lstsq(M, Pred, rcond=None)
        x = np.maximum(0.0, x0)
    except Exception:
        x = np.maximum(0.0, np.ones(k) * (Pred.sum() / max(1,k)))
    try:
        L = np.linalg.norm(M, ord=2)**2
        if L <= 0: L = 1.0
        lr = 1.0 / (L + 1e-10)
    except Exception:
        lr = 1e-3
    prev_loss = None
    for it in range(max_iter):
        r = M.dot(x) - Pred
        loss = 0.5 * (r @ r)
        if prev_loss is not None and abs(prev_loss - loss) < tol:
            break
        prev_loss = loss
        grad = M.T.dot(r)
        x -= lr * grad
        x = np.maximum(0.0, x)
    return x

try:
    pred_path = find_file_in_candidates(["predictions_with_uncertainty.csv", "predictions.csv", "predictions_with_uncertainty_latest.csv"])
    val_path  = find_file_in_candidates(["val_predictions_calibrated.csv", "val_predictions.csv", "val_predictions_with_uncertainty.csv", "val_predictions_with_uncertainty.csv"])
    if pred_path is None:
        print("[ERROR] predictions CSV not found. Run inference or place predictions_with_uncertainty.csv in an 'extracted/' folder.")
        raise SystemExit(0)
    print("[FOUND] predictions ->", pred_path)
    if val_path is None:
        print("[WARN] validation predictions not found; confusion correction will be skipped.")
    else:
        print("[FOUND] validation preds ->", val_path)

    df_pred = pd.read_csv(pred_path)
    cols_lower = {c.lower(): c for c in df_pred.columns}

    # detect species prediction column robustly
    species_candidates = ["species_pred_label", "species_label", "species_pred", "species"]
    species_col = None
    for c in species_candidates:
        if c in cols_lower:
            species_col = cols_lower[c]; break
    if species_col is None:
        for c in df_pred.columns:
            if "species" in c.lower():
                species_col = c; break
    if species_col is None:
        print("[ERROR] Could not find a species prediction column. Columns available:", list(df_pred.columns))
        raise SystemExit(0)
    # detect confidence column (optional)
    conf_candidates = ["species_pred_conf", "species_conf", "species_prob", "species_probability", "species_pred_proba"]
    conf_col = None
    for c in conf_candidates:
        if c in cols_lower:
            conf_col = cols_lower[c]; break

    print(f"[USING] species column: '{species_col}'", f", confidence column: '{conf_col}'" if conf_col else "")

    # normalize predicted species
    df_pred["_species_norm"] = df_pred[species_col].apply(normalize_label)
    preds_series = df_pred["_species_norm"]

    counts_df = preds_series.value_counts(dropna=True).rename_axis("species").reset_index(name="count")
    counts_df["rel_abund_count"] = counts_df["count"] / counts_df["count"].sum()

    weighted_df = None
    if conf_col and conf_col in df_pred.columns:
        df_pred["_conf_num"] = pd.to_numeric(df_pred[conf_col], errors="coerce").fillna(0.0)
        weighted_df = df_pred.groupby("_species_norm")["_conf_num"].sum().reset_index().rename(columns={"_species_norm":"species", "_conf_num":"conf_sum"})
        total_conf = weighted_df["conf_sum"].sum()
        if total_conf>0:
            weighted_df["rel_abund_weighted"] = weighted_df["conf_sum"] / total_conf
    else:
        print("[INFO] no confidence column found -> weighted abundance will be skipped.")

    # Build confusion matrix if validation file available
    confusion_possible = False
    A = None
    classes = None
    if val_path is not None:
        try:
            df_val = pd.read_csv(val_path)
            cols_val_lower = {c.lower(): c for c in df_val.columns}
            # detect true label and predicted columns in validation
            true_col = None
            predcol_val = None
            # try likely names
            for cand in ("species_true","species_label_true","true_species","species_gold","species_label"):
                if cand in cols_val_lower:
                    true_col = cols_val_lower[cand]; break
            for cand in ("species_pred_label","species_pred","species_prediction","species_label_pred","pred_species"):
                if cand in cols_val_lower:
                    predcol_val = cols_val_lower[cand]; break
            # heuristics: look for any 'species' column that likely means true label vs pred
            if true_col is None or predcol_val is None:
                # collect columns containing 'species'
                species_like = [c for c in df_val.columns if "species" in c.lower()]
                # if exactly two species-like columns, guess one is true and other is pred (best-effort)
                if len(species_like) >= 2 and (true_col is None or predcol_val is None):
                    # pick first as true, second as pred unless names indicate otherwise
                    if true_col is None:
                        true_col = species_like[0]
                    if predcol_val is None and len(species_like) > 1:
                        predcol_val = species_like[1]
            if true_col is None or predcol_val is None:
                print("[WARN] Could not locate both true and predicted species columns in validation CSV. Columns:", list(df_val.columns))
                confusion_possible = False
            else:
                # normalize
                true_norm = df_val[true_col].apply(normalize_label)
                pred_norm_val = df_val[predcol_val].apply(normalize_label)
                # classes = union of all possible species across preds and validation sets
                classes_set = set(preds_series.unique()) | set(true_norm.unique()) | set(pred_norm_val.unique())
                classes = sorted(classes_set)
                idx = {c:i for i,c in enumerate(classes)}
                n = len(classes)
                A = np.zeros((n,n), dtype=float)  # rows=true, cols=pred
                for t,p in zip(true_norm.values, pred_norm_val.values):
                    # guard: if a label is not in classes (shouldn't happen), add dynamically
                    if t not in idx:
                        # extend structures (rare) - do dynamic expansion
                        classes.append(t)
                        idx[t] = len(idx)
                        A = np.pad(A, ((0,1),(0,0)), mode='constant', constant_values=0.0)
                    if p not in idx:
                        classes.append(p)
                        idx[p] = len(idx)
                        A = np.pad(A, ((0,1),(0,0)), mode='constant', constant_values=0.0)
                    A[idx[t], idx[p]] += 1.0
                # row-normalize with tiny smoothing for stability
                row_sums = A.sum(axis=1, keepdims=True)
                row_sums[row_sums==0] = 1.0
                A = A / row_sums
                # tiny smoothing to avoid singular rows
                eps = 1e-8
                A = (A + eps)
                A = A / A.sum(axis=1, keepdims=True)
                confusion_possible = True
                print(f"[CONFUSION] built confusion matrix from validation with {len(classes)} classes.")
        except Exception:
            print("[WARN] failed to parse validation file for confusion matrix, skipping confusion correction. Error:")
            traceback.print_exc()
            confusion_possible = False

    # Prepare Pred vectors (counts / weighted) using the same classes ordering
    if confusion_possible:
        idx_map = {c:i for i,c in enumerate(classes)}
    else:
        classes = sorted(set(preds_series.unique()))
        idx_map = {c:i for i,c in enumerate(classes)}

    Pred_counts = np.zeros(len(classes), dtype=float)
    for row in counts_df.itertuples(index=False):
        specie = normalize_label(row.species)
        if specie in idx_map:
            Pred_counts[idx_map[specie]] += float(row.count)
        else:
            # should not usually happen; try to match by fuzzy heuristics? here we skip silently
            pass

    Pred_conf = None
    if weighted_df is not None:
        Pred_conf = np.zeros(len(classes), dtype=float)
        for row in weighted_df.itertuples(index=False):
            specie = normalize_label(row.species)
            if specie in idx_map:
                Pred_conf[idx_map[specie]] += float(row.conf_sum)

    # Save raw aggregates (always)
    out_dir = pred_path.parent
    out_counts = out_dir / "abundance_from_predictions.csv"
    out_weighted = out_dir / "abundance_from_predictions_weighted.csv"
    counts_df.to_csv(out_counts, index=False)
    print("[SAVED] raw count-based abundance ->", out_counts)
    if weighted_df is not None:
        weighted_df.to_csv(out_weighted, index=False)
        print("[SAVED] raw confidence-weighted abundance ->", out_weighted)

    # If confusion matrix available, run deconvolution
    if confusion_possible and A is not None:
        print("[DECONV] Running NNLS deconvolution (projected gradient)...")
        est_true_counts = nnls_pgd_solve(A, Pred_counts, max_iter=3000)
        est_true_counts = np.maximum(0.0, est_true_counts)
        if est_true_counts.sum() > 0:
            est_true_rel = est_true_counts / est_true_counts.sum()
        else:
            est_true_rel = est_true_counts.copy()

        if Pred_conf is not None and Pred_conf.sum() > 0:
            est_true_conf = nnls_pgd_solve(A, Pred_conf, max_iter=3000)
            est_true_conf = np.maximum(0.0, est_true_conf)
            if est_true_conf.sum() > 0:
                est_true_conf_rel = est_true_conf / est_true_conf.sum()
            else:
                est_true_conf_rel = est_true_conf.copy()
        else:
            est_true_conf = None
            est_true_conf_rel = None

        df_deconv = pd.DataFrame({
            "species": classes,
            "pred_count": Pred_counts,
            "pred_count_rel": Pred_counts / max(1.0, Pred_counts.sum()),
            "est_true_count": est_true_counts,
            "est_true_rel": est_true_rel
        })
        out_deconv = out_dir / "abundance_from_predictions_deconvolved.csv"
        df_deconv.to_csv(out_deconv, index=False)
        print("[SAVED] deconvolved counts ->", out_deconv)

        if est_true_conf is not None:
            df_deconv_conf = pd.DataFrame({
                "species": classes,
                "pred_conf_sum": Pred_conf,
                "pred_conf_rel": Pred_conf / max(1.0, Pred_conf.sum()),
                "est_true_conf": est_true_conf,
                "est_true_conf_rel": est_true_conf_rel
            })
            out_deconv_conf = out_dir / "abundance_from_predictions_deconvolved_weighted.csv"
            df_deconv_conf.to_csv(out_deconv_conf, index=False)
            print("[SAVED] deconvolved weighted ->", out_deconv_conf)

        # print small summary
        print("\nTop 15 predicted (raw counts):")
        print(counts_df.head(15).to_string(index=False))
        print("\nTop 15 estimated true (after deconvolution):")
        print(df_deconv.sort_values("est_true_rel", ascending=False).head(15).to_string(index=False))
    else:
        print("[INFO] Confusion correction not performed (no usable validation file). Raw aggregates saved above.")

    print("\n[DONE] Outputs saved to:", out_dir)

except SystemExit:
    pass
except Exception:
    print("[ERROR] unexpected failure:")
    traceback.print_exc()


[FOUND] predictions -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\predictions_with_uncertainty.csv
[FOUND] validation preds -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\val_predictions_calibrated.csv
[USING] species column: 'species_pred_label' , confidence column: 'species_pred_conf'
[CONFUSION] built confusion matrix from validation with 72 classes.
[SAVED] raw count-based abundance -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_from_predictions.csv
[SAVED] raw confidence-weighted abundance -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_from_predictions_weighted.csv
[DECONV] Running NNLS deconvolution (projected gradient)...
[SAVED] deconvolved counts -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_from_predictions_deconvolved.csv
[SAVED] deconvolved weighted -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_from_predictions_deconvolved_weighted.csv

Top 15 predicted (raw counts):
                        species  count  rel_abund_count
         

In [85]:
# Fixed hierarchical deconvolution + reconciliation cell (bugfix applied)
# Paste & run in your notebook (searches for 'extracted' like previous cells).
import pandas as pd
import numpy as np
from pathlib import Path
import traceback

def find_file(names):
    cand_dirs = [
        Path.cwd() / "sih" / "ncbi_blast_db" / "extracted",
        Path.cwd() / "ncbi_blast_db" / "extracted",
        Path.cwd() / "extracted",
        Path.cwd(),
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
        Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"),
        Path.home() / "OneDrive" / "Desktop" / "sihtaxa" / "sihabundance" / "ncbi_blast_db" / "extracted",
    ]
    # include any extracted under cwd
    for p in Path.cwd().glob("**/extracted"):
        cand_dirs.append(p.resolve())
    visited = set()
    for d in cand_dirs:
        if d in visited: continue
        visited.add(d)
        for n in names:
            f = d / n
            if f.exists():
                return f
    # fallback recursive search
    for n in names:
        for f in Path.cwd().rglob(n):
            return f.resolve()
    return None

def normalize_label(x):
    if pd.isna(x):
        return "UNASSIGNED"
    s = str(x).strip()
    s = " ".join(s.split())
    return s if s != "" else "UNASSIGNED"

def nnls_pgd(A, Pred, max_iter=3000, tol=1e-7):
    M = A.T
    try:
        x0, *_ = np.linalg.lstsq(M, Pred, rcond=None)
        x = np.maximum(0.0, x0)
    except Exception:
        x = np.maximum(0.0, np.ones(M.shape[1]) * (Pred.sum()/max(1,M.shape[1])))
    try:
        L = np.linalg.norm(M, ord=2)**2
        if L <= 0: L = 1.0
        lr = 1.0/(L + 1e-8)
    except Exception:
        lr = 1e-3
    prev = None
    for i in range(max_iter):
        r = M.dot(x) - Pred
        loss = 0.5 * (r @ r)
        if prev is not None and abs(prev - loss) < tol:
            break
        prev = loss
        grad = M.T.dot(r)
        x -= lr * grad
        x = np.maximum(0.0, x)
    return x

def build_confusion_from_val(df_val, true_col_candidates, pred_col_candidates, rank_name):
    cols = {c.lower(): c for c in df_val.columns}
    true_col = None; pred_col = None
    for cand in true_col_candidates:
        if cand in cols:
            true_col = cols[cand]; break
    for cand in pred_col_candidates:
        if cand in cols:
            pred_col = cols[cand]; break
    # heuristics fallback: any columns that mention rank_name
    if true_col is None or pred_col is None:
        species_like = [c for c in df_val.columns if rank_name in c.lower()]
        if len(species_like) >= 2:
            if true_col is None: true_col = species_like[0]
            if pred_col is None and len(species_like) > 1: pred_col = species_like[1]
    if true_col is None or pred_col is None:
        return None, None
    true_norm = df_val[true_col].apply(normalize_label)
    pred_norm = df_val[pred_col].apply(normalize_label)
    classes = sorted(set(true_norm.unique()) | set(pred_norm.unique()))
    idx = {c:i for i,c in enumerate(classes)}
    n = len(classes)
    A = np.zeros((n,n), dtype=float)
    for t,p in zip(true_norm.values, pred_norm.values):
        # ensure mapping present
        if t not in idx or p not in idx:
            continue
        A[idx[t], idx[p]] += 1.0
    row_sums = A.sum(axis=1, keepdims=True)
    row_sums[row_sums==0] = 1.0
    A = A / row_sums
    eps = 1e-8
    A = (A + eps)
    A = A / A.sum(axis=1, keepdims=True)
    return A, classes

# ---------- locate files ----------
pred_file = find_file(["predictions_with_uncertainty.csv", "predictions.csv"])
val_file  = find_file(["val_predictions_calibrated.csv", "val_predictions.csv", "val_predictions_with_uncertainty.csv"])
if pred_file is None:
    raise SystemExit("predictions CSV not found. Run inference or place predictions_with_uncertainty.csv in an extracted/ folder.")
print("[FOUND] predictions:", pred_file)
if val_file is not None:
    print("[FOUND] validation:", val_file)
else:
    print("[WARN] validation file not found. Rank confusion correction will be limited/skipped where absent.")

# ---------- load predictions and normalize rank columns ----------
df_pred = pd.read_csv(pred_file)
# standardize known rank-like columns by normalizing strings
for c in df_pred.columns:
    if any(r in c.lower() for r in ("kingdom","phylum","class","order","family","genus","species")):
        df_pred[c] = df_pred[c].apply(normalize_label)

# detect species prediction column robustly
cols_lower = {c.lower(): c for c in df_pred.columns}
species_candidates = ["species_pred_label","species_label","species_pred","species"]
species_col = None
for cand in species_candidates:
    if cand in cols_lower:
        species_col = cols_lower[cand]; break
if species_col is None:
    for c in df_pred.columns:
        if "species" in c.lower():
            species_col = c; break
if species_col is None:
    raise SystemExit("Could not detect species prediction column in predictions CSV (looked for species_pred_label/species_pred/species_label/etc).")

# create normalized species column for aggregation
df_pred["_species_norm"] = df_pred[species_col].astype(str).apply(normalize_label)

# aggregation function (returns consistent columns: taxon,count,rel)
def aggregate_pred_counts(df, col):
    if col not in df.columns:
        return pd.DataFrame({"taxon":[], "count":[], "rel":[]})
    s = df[col].astype(str).apply(normalize_label)
    dfc = s.value_counts().rename_axis("taxon").reset_index(name="count")
    dfc["rel"] = dfc["count"] / dfc["count"].sum() if dfc["count"].sum() > 0 else 0.0
    return dfc

# build raw_by_rank: map rank -> (predicted_column_name, aggregated_df with columns 'taxon','count','rel')
ranks = ["kingdom","phylum","class","order","family","genus","species"]
raw_by_rank = {}
for r in ranks:
    col_found = None
    # prefer explicit columns mentioning the rank
    for c in df_pred.columns:
        if r in c.lower():
            if r == "species":
                col_found = species_col
            else:
                col_found = c
            break
    if col_found is not None:
        agg_df = aggregate_pred_counts(df_pred, col_found)
        raw_by_rank[r] = (col_found, agg_df)

# also save top-level species raw counts (for output)
counts_df = aggregate_pred_counts(df_pred, "_species_norm")

# optional weighted sums if confidence present
conf_col = None
for cand in ("species_pred_conf","species_conf","species_prob","species_probability"):
    if cand in cols_lower:
        conf_col = cols_lower[cand]; break
weighted_df = None
if conf_col and conf_col in df_pred.columns:
    df_pred["_conf_num"] = pd.to_numeric(df_pred[conf_col], errors="coerce").fillna(0.0)
    w = df_pred.groupby("_species_norm")["_conf_num"].sum().reset_index().rename(columns={"_conf_num":"conf_sum", "_species_norm":"taxon"})
    if w["conf_sum"].sum() > 0:
        w["rel"] = w["conf_sum"] / w["conf_sum"].sum()
    weighted_df = w

# ---------- build deconvolution per rank using validation ----------
true_candidates = {
    "kingdom": ["kingdom_true","true_kingdom","kingdom_label","kingdom"],
    "phylum":  ["phylum_true","true_phylum","phylum_label","phylum"],
    "class":   ["class_true","true_class","class_label","class"],
    "order":   ["order_true","true_order","order_label","order"],
    "family":  ["family_true","true_family","family_label","family"],
    "genus":   ["genus_true","true_genus","genus_label","genus"],
    "species": ["species_true","true_species","species_label","label_species","species"]
}
pred_candidates = {
    "kingdom": ["kingdom_pred_label","kingdom_pred","kingdom_label","kingdom"],
    "phylum":  ["phylum_pred_label","phylum_pred","phylum_label","phylum"],
    "class":   ["class_pred_label","class_pred","class_label","class"],
    "order":   ["order_pred_label","order_pred","order_label","order"],
    "family":  ["family_pred_label","family_pred","family_label","family"],
    "genus":   ["genus_pred_label","genus_pred","genus_label","genus"],
    "species": ["species_pred_label","species_pred","species_label","species"]
}

deconv_by_rank = {}
if val_file is not None:
    df_val = pd.read_csv(val_file)
    # normalize val rank columns
    for c in df_val.columns:
        if any(r in c.lower() for r in ranks):
            df_val[c] = df_val[c].apply(normalize_label)

    for r, (pred_col, agg_df) in list(raw_by_rank.items()):
        # build confusion from validation for this rank
        A, classes = build_confusion_from_val(df_val, true_candidates[r], pred_candidates[r], r)
        if A is None:
            print(f"[INFO] validation lacks usable true/pred pair for rank '{r}' -> skipping rank deconv.")
            continue
        # prepare Pred vector following classes ordering
        Pred = np.zeros(len(classes), dtype=float)
        # agg_df has columns 'taxon','count','rel'
        for tup in agg_df.itertuples(index=False, name=None):
            tax = normalize_label(tup[0])   # first column is taxon
            cnt = float(tup[1]) if len(tup) > 1 else 0.0
            if tax in classes:
                Pred[classes.index(tax)] += cnt
        # run deconv (NNLS)
        est_true = nnls_pgd(A, Pred)
        est_true = np.maximum(0.0, est_true)
        est_rel = est_true / est_true.sum() if est_true.sum() > 0 else est_true
        df_out = pd.DataFrame({
            "taxon": classes,
            "pred_count": Pred,
            "pred_rel": Pred / max(1.0, Pred.sum()),
            "est_true_count": est_true,
            "est_true_rel": est_rel
        })
        deconv_by_rank[r] = (pred_col, df_out)
        outp = pred_file.parent / f"abundance_{r}_deconvolved.csv"
        df_out.to_csv(outp, index=False)
        print(f"[SAVED] deconvolved at rank '{r}' -> {outp}")

# Save raw aggregates (species-level counts and optionally weighted)
out_dir = pred_file.parent
out_counts = out_dir / "abundance_from_predictions.csv"
counts_df.rename(columns={"taxon":"species","count":"count","rel":"rel_abund_count"}).to_csv(out_counts, index=False)
print("[SAVED] raw count-based abundance ->", out_counts)
if weighted_df is not None:
    out_w = out_dir / "abundance_from_predictions_weighted.csv"
    weighted_df.rename(columns={"taxon":"species","conf_sum":"conf_sum","rel":"rel_abund_weighted"}).to_csv(out_w, index=False)
    print("[SAVED] raw confidence-weighted abundance ->", out_w)

# ---------- Reconciliation (species-level final estimates) ----------
# base species estimates: prefer species deconv, else weighted, else raw counts
if "species" in deconv_by_rank:
    base_df = deconv_by_rank["species"][1]
    species_classes = base_df["taxon"].tolist()
    base_rel = base_df["est_true_rel"].astype(float).values
    print("[INFO] using species-level deconvolved estimates as base.")
elif weighted_df is not None:
    species_classes = weighted_df["taxon"].tolist()
    svals = weighted_df["conf_sum"].astype(float).values
    base_rel = svals / svals.sum() if svals.sum() > 0 else np.ones(len(svals)) / max(1,len(svals))
    print("[INFO] species-level deconv missing: using confidence-weighted as base.")
else:
    species_classes = counts_df["taxon"].tolist()
    cvals = counts_df["count"].astype(float).values
    base_rel = cvals / cvals.sum() if cvals.sum() > 0 else np.ones(len(cvals)) / max(1,len(cvals))
    print("[INFO] species-level deconv & weighted missing: using raw counts as base.")

species_df = pd.DataFrame({"species": species_classes, "base_rel": base_rel})
species_df["species"] = species_df["species"].apply(normalize_label)

# try to map species -> higher ranks using predictions (majority)
rank_cols = {}
for r in ("genus","family","order","class","phylum","kingdom"):
    for c in df_pred.columns:
        if r in c.lower():
            rank_cols[r] = c
            break

# build mapping by the most frequent value per species from df_pred
mapping = {}
if len(rank_cols)>0:
    temp = df_pred[[species_col] + list(rank_cols.values())].copy()
    temp["_sp"] = temp[species_col].astype(str).apply(normalize_label)
    for sp, g in temp.groupby("_sp"):
        mapping[sp] = {}
        for rc in rank_cols.values():
            vals = g[rc].astype(str).apply(normalize_label).value_counts()
            mapping[sp][rc] = vals.index[0] if len(vals)>0 else "UNASSIGNED"

# attach rank columns to species_df
for rc_name, rc_col in rank_cols.items():
    species_df[rc_name] = species_df["species"].apply(lambda s: mapping.get(s, {}).get(rc_col, "UNASSIGNED"))

# produce final species distribution and then apply rank-based redistribution if deconvolved rank exists
final_rel = pd.Series(species_df["base_rel"].values, index=species_df["species"]).astype(float)

def distribute_rank_to_species(rank, df_rank):
    # df_rank: taxon, est_true_rel
    rank_to_rel = dict(zip(df_rank["taxon"].apply(normalize_label), df_rank["est_true_rel"].astype(float)))
    # build groups
    groups = {}
    for sp in final_rel.index:
        if rank in rank_cols:
            mapped = species_df.loc[species_df["species"]==sp, rank].values
            rname = normalize_label(mapped[0]) if len(mapped)>0 else "UNASSIGNED"
        else:
            rname = "UNASSIGNED"
        groups.setdefault(rname, []).append(sp)
    new_rel = final_rel.copy()
    for rname, spp in groups.items():
        target = rank_to_rel.get(rname, None)
        if target is None:
            continue
        cur_sum = float(final_rel.loc[spp].sum())
        if cur_sum <= 0:
            if len(spp) > 0:
                each = target / len(spp)
                for sp in spp:
                    new_rel.loc[sp] = each
        else:
            for sp in spp:
                new_rel.loc[sp] = final_rel.loc[sp] * (target / cur_sum)
    if new_rel.sum() > 0:
        new_rel = new_rel / new_rel.sum()
    return new_rel

rank_apply_order = ["genus","family","order","class","phylum","kingdom"]
applied = False
for r in rank_apply_order:
    if r in deconv_by_rank:
        _, df_r = deconv_by_rank[r]
        before = final_rel.sum()
        final_rel = distribute_rank_to_species(r, df_r)
        applied = True
        print(f"[RECONCILE] applied rank '{r}' -> sum before={before:.6f} after={final_rel.sum():.6f}")

if not applied:
    print("[INFO] no higher-rank deconvolved corrections applied; final = base species estimates.")

# final output DataFrame
out_df = final_rel.reset_index().rename(columns={"index":"species", 0:"est_rel"})
out_df["est_pct"] = out_df["est_rel"] * 100.0
# attach pred_count and pred_conf_sum if available (maps normalized names)
counts_map = dict(zip(counts_df["taxon"].apply(normalize_label), counts_df["count"]))
out_df["pred_count"] = out_df["species"].apply(lambda s: counts_map.get(normalize_label(s), 0))
if weighted_df is not None:
    weighted_map = dict(zip(weighted_df["taxon"].apply(normalize_label), weighted_df["conf_sum"]))
    out_df["pred_conf_sum"] = out_df["species"].apply(lambda s: weighted_map.get(normalize_label(s), 0.0))
else:
    out_df["pred_conf_sum"] = np.nan

out_path = pred_file.parent / "abundance_reconciled_species.csv"
out_df.to_csv(out_path, index=False)
print("[SAVED] reconciled species abundance ->", out_path)
print("\nTop 20 final reconciled species:")
print(out_df.sort_values("est_rel", ascending=False).head(20).to_string(index=False))

# per-sample output if sample id found
sample_col = None
for sc in ["sample_id","sample","samp","run","readset","accession","id","seqid"]:
    if sc in cols_lower:
        sample_col = cols_lower[sc]; break
if sample_col is None:
    for c in df_pred.columns:
        if c.lower().endswith("_id") or c.lower().endswith("id"):
            sample_col = c; break

if sample_col and sample_col in df_pred.columns:
    df_pred["_sample_norm"] = df_pred[sample_col].astype(str)
    ps = df_pred.groupby(["_sample_norm","_species_norm"]).size().reset_index(name="count")
    sample_table = ps.pivot(index="_sample_norm", columns="_species_norm", values="count").fillna(0).astype(int)
    sample_out = pred_file.parent / "abundance_per_sample_species_counts.csv"
    sample_table.to_csv(sample_out)
    print("[SAVED] per-sample species counts ->", sample_out)
else:
    print("[INFO] no sample identifier detected; skipped per-sample table.")

print("\nDONE. All outputs are in:", pred_file.parent)


[FOUND] predictions: C:\Users\Srijit\sih\ncbi_blast_db\extracted\predictions_with_uncertainty.csv
[FOUND] validation: C:\Users\Srijit\sih\ncbi_blast_db\extracted\val_predictions_calibrated.csv
[SAVED] deconvolved at rank 'kingdom' -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_kingdom_deconvolved.csv
[SAVED] deconvolved at rank 'phylum' -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_phylum_deconvolved.csv
[SAVED] deconvolved at rank 'class' -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_class_deconvolved.csv
[SAVED] deconvolved at rank 'order' -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_order_deconvolved.csv
[SAVED] deconvolved at rank 'family' -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_family_deconvolved.csv
[SAVED] deconvolved at rank 'genus' -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_genus_deconvolved.csv
[SAVED] deconvolved at rank 'species' -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_species_deconv

In [87]:
# Cell: finalize abundance outputs (publication CSV, top-20 barplot, per-sample normalized table, zip)
# Paste & run in the same Jupyter kernel you used for previous cells.
import pandas as pd, numpy as np, zipfile, textwrap, sys
from pathlib import Path
import matplotlib.pyplot as plt

def find_extracted_candidates():
    cand = []
    cand.append(Path.cwd() / "sih" / "ncbi_blast_db" / "extracted")
    cand.append(Path.cwd() / "ncbi_blast_db" / "extracted")
    cand.append(Path.cwd() / "extracted")
    cand.append(Path.cwd())
    cand.append(Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"))
    cand.append(Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"))
    cand.append(Path.home() / "OneDrive" / "Desktop" / "sihtaxa" / "sihabundance" / "ncbi_blast_db" / "extracted")
    # also add any extracted under cwd
    for p in Path.cwd().glob("**/extracted"):
        cand.append(p.resolve())
    # keep unique, existing
    seen = []
    out = []
    for p in cand:
        if p in seen: continue
        seen.append(p)
        if p.exists() and p.is_dir():
            out.append(p)
    return out

def pick_file(base_names):
    """Return first existing file path from base_names searching candidate extracted folders."""
    for d in find_extracted_candidates():
        for name in base_names:
            p = d / name
            if p.exists():
                return p
    # last resort: search recursively from cwd
    for name in base_names:
        for p in Path.cwd().rglob(name):
            return p.resolve()
    return None

# locate final reconciled abundance (prefer reconciled, else deconvolved, else raw)
final_path = pick_file(["abundance_reconciled_species.csv", "abundance_from_predictions_deconvolved.csv",
                        "abundance_from_predictions.csv", "abundance_from_predictions_weighted.csv"])
if final_path is None:
    print("ERROR: Could not find any abundance CSV (search looked for reconciled/deconvolved/raw).")
    print("Place 'abundance_reconciled_species.csv' in an 'extracted/' folder or run previous cells to create it.")
    raise SystemExit(0)

out_dir = final_path.parent
print("[FOUND] using:", final_path)

# load final table (handle different column names gracefully)
df = pd.read_csv(final_path)
# normalize column names to expected set (species, est_rel, est_pct, pred_count, pred_conf_sum)
cols_lower = {c.lower(): c for c in df.columns}
# Heuristics for species column
if "species" in cols_lower:
    sp_col = cols_lower["species"]
else:
    # try first column if it's text-like
    sp_col = df.columns[0]
# find est_rel or est_true_rel or percent
if "est_rel" in cols_lower:
    rel_col = cols_lower["est_rel"]
elif "est_true_rel" in cols_lower:
    rel_col = cols_lower["est_true_rel"]
elif "pred_count_rel" in cols_lower:
    rel_col = cols_lower["pred_count_rel"]
else:
    # fallback: try any numeric column after species
    other_numeric = [c for c in df.columns if c!=sp_col and pd.api.types.is_numeric_dtype(df[c])]
    rel_col = other_numeric[0] if other_numeric else None

# pred_count
pred_count_col = cols_lower.get("pred_count") or cols_lower.get("count") if "count" in cols_lower else None
pred_conf_col  = cols_lower.get("pred_conf_sum") or cols_lower.get("conf_sum") or cols_lower.get("pred_conf")

# build publication-ready dataframe
pub = pd.DataFrame()
pub["species"] = df[sp_col].astype(str).apply(lambda s: " ".join(str(s).split()))
if rel_col is not None:
    pub["est_rel"] = pd.to_numeric(df[rel_col], errors="coerce").fillna(0.0)
else:
    pub["est_rel"] = 0.0
# ensure fraction between 0..1; if >1 it may already be percent
if pub["est_rel"].max() > 1.001:
    # assume percents -> convert
    pub["est_rel"] = pub["est_rel"] / 100.0
pub["est_pct"] = pub["est_rel"] * 100.0
if pred_count_col:
    pub["pred_count"] = pd.to_numeric(df[pred_count_col], errors="coerce").fillna(0).astype(int)
else:
    pub["pred_count"] = pub["species"].map(df.set_index(sp_col).get(pred_count_col, pd.Series())).fillna(0).astype(int)
if pred_conf_col and pred_conf_col in df.columns:
    pub["pred_conf_sum"] = pd.to_numeric(df[pred_conf_col], errors="coerce").fillna(0.0)
else:
    pub["pred_conf_sum"] = np.nan

# preserve sorting by estimated abundance
pub = pub.sort_values("est_rel", ascending=False).reset_index(drop=True)

# Save publication-ready CSV
pub_out = out_dir / "abundance_publication_ready.csv"
pub.to_csv(pub_out, index=False)
print("[SAVED] publication-ready CSV ->", pub_out)

# Top-N barplot (top 20)
top_n = 20
top = pub.head(top_n)
plt.figure(figsize=(10,6))
plt.title(f"Top {min(top_n,len(top))} species by estimated relative abundance")
plt.barh(range(len(top)), top["est_rel"].values[::-1])   # horizontal bar; default colors
plt.yticks(range(len(top)), top["species"].values[::-1])
plt.xlabel("Relative abundance (fraction)")
plt.tight_layout()
img_out = out_dir / "abundance_top20.png"
plt.savefig(img_out, dpi=200)
plt.close()
print("[SAVED] top-20 barplot ->", img_out)

# Per-sample normalized abundance (if per-sample counts present)
per_sample_counts = pick_file(["abundance_per_sample_species_counts.csv", "abundance_per_sample_species_counts.csv"])
if per_sample_counts is not None and per_sample_counts.exists():
    try:
        samp_df = pd.read_csv(per_sample_counts, index_col=0)
        # Row-normalize to relative abundances per sample
        rel = samp_df.div(samp_df.sum(axis=1).replace(0,1), axis=0)
        rel_out = out_dir / "abundance_per_sample_species_relative.csv"
        rel.to_csv(rel_out, index=True)
        print("[SAVED] per-sample relative abundance ->", rel_out)
    except Exception as e:
        print("[WARN] failed to create per-sample relative abundance:", e)
else:
    print("[INFO] no per-sample count file found (skipping per-sample normalization).")

# Package main files into a zip
zip_path = out_dir / "abundance_results_package.zip"
to_package = [
    pub_out,
    img_out,
    out_dir / "abundance_from_predictions.csv",
    out_dir / "abundance_from_predictions_weighted.csv",
    out_dir / "abundance_from_predictions_deconvolved.csv",
    out_dir / "abundance_from_predictions_deconvolved_weighted.csv",
    out_dir / "abundance_reconciled_species.csv",
]
# keep only those that exist
to_package = [p for p in to_package if p is not None and p.exists()]
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as z:
    for p in to_package:
        z.write(p, arcname=p.name)
print("[SAVED] packaged outputs ->", zip_path)

# final summary
print("\nSUMMARY of outputs saved in:", out_dir)
print(" - publication CSV:", pub_out.name)
print(" - top-20 plot    :", img_out.name)
if per_sample_counts is not None and per_sample_counts.exists():
    print(" - per-sample relative table:", rel_out.name)
print(" - package ZIP    :", zip_path.name)
print("\nDONE. If you want I can now:")
print("  • produce a PDF/PNG figure layout and a short report (1 more cell),")
print("  • create a small helper cell that loads a new FASTA, runs the classifier and returns abundances (1 more cell),")
print("  • or just stop here (no more cells needed).")


[FOUND] using: C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_reconciled_species.csv
[SAVED] publication-ready CSV -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_publication_ready.csv
[SAVED] top-20 barplot -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_top20.png
[SAVED] per-sample relative abundance -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_per_sample_species_relative.csv
[SAVED] packaged outputs -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_results_package.zip

SUMMARY of outputs saved in: C:\Users\Srijit\sih\ncbi_blast_db\extracted
 - publication CSV: abundance_publication_ready.csv
 - top-20 plot    : abundance_top20.png
 - per-sample relative table: abundance_per_sample_species_relative.csv
 - package ZIP    : abundance_results_package.zip

DONE. If you want I can now:
  • produce a PDF/PNG figure layout and a short report (1 more cell),
  • create a small helper cell that loads a new FASTA, runs the classifier and returns abund

In [91]:
# Fixed diagnostic + assignment-rate cell
# Paste & run in your notebook (same kernel you used earlier).
import math
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
from scipy.stats import spearmanr

plt.rcParams.update({'figure.max_open_warning': 0})

def find_extracted_candidates():
    cand = []
    cand.append(Path.cwd() / "sih" / "ncbi_blast_db" / "extracted")
    cand.append(Path.cwd() / "ncbi_blast_db" / "extracted")
    cand.append(Path.cwd() / "extracted")
    cand.append(Path.cwd())
    cand.append(Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"))
    cand.append(Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"))
    cand.append(Path.home() / "OneDrive" / "Desktop" / "sihtaxa" / "sihabundance" / "ncbi_blast_db" / "extracted")
    for p in Path.cwd().glob("**/extracted"):
        cand.append(p.resolve())
    uniq = []
    for p in cand:
        if p in uniq:
            continue
        uniq.append(p)
    return uniq

def pick_file(names):
    for d in find_extracted_candidates():
        for name in names:
            p = d / name
            if p.exists():
                return p
    for name in names:
        for p in Path.cwd().rglob(name):
            return p.resolve()
    return None

def normalize_label(x):
    if pd.isna(x):
        return "UNASSIGNED"
    s = str(x).strip()
    s = " ".join(s.split())
    return s if s!="" else "UNASSIGNED"

# locate files
pred_path = pick_file(["predictions_with_uncertainty.csv", "predictions.csv"])
val_path  = pick_file(["val_predictions_calibrated.csv", "val_predictions.csv", "val_predictions_with_uncertainty.csv"])
reconciled_path = pick_file(["abundance_reconciled_species.csv","abundance_from_predictions_deconvolved.csv"])

print("predictions:", pred_path)
print("validation :", val_path)
print("reconciled :", reconciled_path)
print("-"*60)

if pred_path is None:
    raise SystemExit("No predictions CSV found. Run the inference cell first.")
df_pred = pd.read_csv(pred_path)
print("Loaded predictions rows:", len(df_pred))
cols_lower = {c.lower():c for c in df_pred.columns}

# detect predicted species and confidence columns
species_pred_col = None
for cand in ("species_pred_label","species_label","species_pred","species_predicted","species"):
    if cand in cols_lower:
        species_pred_col = cols_lower[cand]; break
if species_pred_col is None:
    for c in df_pred.columns:
        if "species" in c.lower():
            species_pred_col = c; break

conf_col = None
for cand in ("species_pred_conf","species_conf","species_prob","species_probability","pred_conf"):
    if cand in cols_lower:
        conf_col = cols_lower[cand]; break

print("Detected species prediction column:", species_pred_col)
print("Detected confidence column:", conf_col)
df_pred["_species_norm"] = df_pred[species_pred_col].astype(str).apply(normalize_label) if species_pred_col else "<no-species>"

if val_path is None:
    print("\nNo validation file found -> skipping classification accuracy checks.")
else:
    df_val = pd.read_csv(val_path)
    # Normalize likely rank columns
    for c in df_val.columns:
        if any(r in c.lower() for r in ("species","genus","family","order","class","phylum","kingdom")):
            try:
                df_val[c] = df_val[c].apply(normalize_label)
            except Exception:
                pass

    # detect true & pred species columns in validation
    cols_val_lower = {c.lower(): c for c in df_val.columns}
    true_species_col = None
    for cand in ("species_true","true_species","species_label","species_gold","label_species","species_true_idx","species_true_label"):
        if cand in cols_val_lower:
            true_species_col = cols_val_lower[cand]; break

    pred_species_col_val = None
    for cand in ("species_pred_label","species_pred","species_prediction","species_label_pred","pred_species","species_pred_label"):
        if cand in cols_val_lower:
            pred_species_col_val = cols_val_lower[cand]; break

    # fallback if both not found
    if (true_species_col is None or pred_species_col_val is None):
        species_like = [c for c in df_val.columns if "species" in c.lower()]
        if len(species_like) >= 2:
            if true_species_col is None: true_species_col = species_like[0]
            if pred_species_col_val is None: pred_species_col_val = species_like[1]

    print("\nValidation: true =", true_species_col, ", pred =", pred_species_col_val)
    if true_species_col and pred_species_col_val:
        y_true = df_val[true_species_col].astype(str).apply(normalize_label)
        y_pred_val = df_val[pred_species_col_val].astype(str).apply(normalize_label)

        acc = accuracy_score(y_true, y_pred_val)
        macro_f1 = f1_score(y_true, y_pred_val, average="macro", zero_division=0)
        micro_f1 = f1_score(y_true, y_pred_val, average="micro", zero_division=0)
        print("\nSpecies-level metrics on validation:")
        print(f"  Accuracy = {acc:.4f}    Macro-F1 = {macro_f1:.4f}    Micro-F1 = {micro_f1:.4f}")

        # top-classes by support (build safe dataframe)
        top_counts = y_true.value_counts().reset_index().rename(columns={"index":"species","0":"count"}) 
        top_counts.columns = ["species","count"]
        # labels_by_freq as list (safe)
        labels_by_freq = top_counts["species"].tolist()
        print("\nTop 15 classes by true-support:")
        display(top_counts.head(15))

        # classification report for top classes (safe usage)
        try:
            report = classification_report(y_true, y_pred_val, labels=labels_by_freq[:100], zero_division=0)
            print("\nClassification report (top classes shown):\n")
            print(report)
        except Exception as e:
            # fallback: full report (may be large)
            print("Could not generate trimmed classification report (fallback to full report). Error:", e)
            print(classification_report(y_true, y_pred_val, zero_division=0))

        # confusion pairs (top mistakes)
        n_labels = min(len(labels_by_freq), 200)
        cm = confusion_matrix(y_true, y_pred_val, labels=labels_by_freq[:n_labels])
        cm_df = pd.DataFrame(cm, index=labels_by_freq[:n_labels], columns=labels_by_freq[:n_labels])
        stacked = cm_df.stack().reset_index()
        stacked.columns = ["true","pred","count"]
        mistakes = stacked[stacked["true"] != stacked["pred"]].sort_values("count", ascending=False)
        print("\nTop 12 confusion pairs (true -> predicted):")
        display(mistakes.head(12))

        # try calibration (ECE) if confidences can be matched
        matched = None
        id_cols_pred = [c for c in df_pred.columns if c.lower() in ("id","seqid","accession","read_id","readid")]
        id_cols_val = [c for c in df_val.columns if c.lower() in ("id","seqid","accession","read_id","readid")]
        if id_cols_pred and id_cols_val:
            p_id, v_id = id_cols_pred[0], id_cols_val[0]
            df_pred_small = df_pred[[p_id, "_species_norm"] + (["_conf_num"] if conf_col and "_conf_num" in df_pred.columns else [])].copy()
            df_val_small = df_val[[v_id, true_species_col]].copy()
            merged = pd.merge(df_val_small, df_pred_small, left_on=v_id, right_on=p_id, how="inner")
            if len(merged) > 0:
                matched = merged
        if matched is None and len(df_val)==len(df_pred):
            # fallback row-wise (risky but sometimes ok)
            merged = pd.DataFrame({
                "true": df_val[true_species_col].astype(str).apply(normalize_label).reset_index(drop=True),
                "pred": df_pred["_species_norm"].reset_index(drop=True),
            })
            if "_conf_num" in df_pred.columns:
                merged["conf"] = df_pred["_conf_num"].reset_index(drop=True)
            matched = merged

        if matched is not None and "conf" in matched.columns:
            matched["correct"] = (matched["pred"].astype(str).apply(normalize_label) == matched["true"].astype(str).apply(normalize_label)).astype(int)
            bins = np.linspace(0.0, 1.0, 11)
            matched["_bin"] = pd.cut(matched["conf"], bins, include_lowest=True)
            bin_summary = matched.groupby("_bin").agg(conf_mean=('conf','mean'), acc=('correct','mean'), n=('correct','size')).reset_index().fillna(0)
            N = len(matched)
            ece = ((bin_summary['n']/N) * (bin_summary['acc'] - bin_summary['conf_mean']).abs()).sum()
            print(f"\nCalibration (on matched rows): ECE = {ece:.4f}")
            display(bin_summary[["conf_mean","acc","n"]])
            # reliability plot
            plt.figure(figsize=(6,4))
            plt.plot(bin_summary['conf_mean'], bin_summary['acc'], marker='o')
            plt.plot([0,1],[0,1], linestyle='--')
            plt.xlabel('Mean predicted confidence'); plt.ylabel('Observed accuracy')
            plt.title('Reliability diagram (binned)')
            plt.tight_layout(); plt.show()
        else:
            print("\nCalibration check: could not match confidences to val rows or confidences not present.")

    else:
        print("\nValidation file did not contain recognisable true/pred species columns. Skipping species-level accuracy checks.")

# Section: compute assignment rates per rank (UNASSIGNED fraction)
print("\n" + "#"*8 + " Assignment rates per rank " + "#"*8)
ranks = ["kingdom","phylum","class","order","family","genus","species"]
rank_assignments = []
for r in ranks:
    # find a column in df_pred that looks like this rank
    candidate = None
    for c in df_pred.columns:
        if r in c.lower():
            # skip short columns like 'species_pred_idx' etc; use the label-like columns (contain 'label') where possible
            if 'label' in c.lower() or c.lower().endswith(r):
                candidate = c
                break
            if candidate is None:
                candidate = c
    if candidate is None:
        continue
    col_norm = df_pred[candidate].astype(str).apply(normalize_label)
    total = len(col_norm)
    unassigned = (col_norm == "UNASSIGNED").sum()
    assigned = total - unassigned
    rank_assignments.append({"rank": r, "pred_col": candidate, "total": total, "assigned": assigned, "unassigned": unassigned, "pct_unassigned": unassigned/total})
    # show top taxa for the rank
    top = col_norm.value_counts().head(8).reset_index().rename(columns={"index":"taxon",0:"count"})
    print(f"\nRank '{r}'  (column='{candidate}')  -> UNASSIGNED {unassigned}/{total} = {unassigned/total:.2%}")
    display(top)

rank_df = pd.DataFrame(rank_assignments)
print("\nSummary assignment rates (by rank):")
display(rank_df)

# Optional: check UNASSIGNED entries that have high species confidence (investigate novel or mismatch)
if conf_col and conf_col in df_pred.columns:
    df_pred["_conf_num"] = pd.to_numeric(df_pred[conf_col], errors='coerce').fillna(0.0)
    high_conf_unassigned = df_pred[(df_pred["_species_norm"]=="UNASSIGNED") & (df_pred["_conf_num"] >= 0.5)]
    print("\nNumber of UNASSIGNED reads with species_pred_conf >= 0.5:", len(high_conf_unassigned))
    if len(high_conf_unassigned) > 0:
        print("Examples (first 8):")
        display(high_conf_unassigned[[col for col in df_pred.columns if col in ("id","_species_norm",conf_col,"novel_score","species_novel_component","genus_novel_component")][:8]].head(8))
    # average confidence for assigned vs unassigned
    avg_assigned = df_pred[df_pred["_species_norm"]!="UNASSIGNED"]["_conf_num"].mean()
    avg_unassigned = df_pred[df_pred["_species_norm"]=="UNASSIGNED"]["_conf_num"].mean()
    print(f"\nAverage species_pred_conf: assigned = {avg_assigned:.3f}, unassigned = {avg_unassigned:.3f}")
else:
    print("\nConfidence column not present or not numeric; skipping high-confidence UNASSIGNED check.")

print("\n\nDONE. Summary:")
print(" - If UNASSIGNED fraction is high at species but low at higher ranks, that suggests the model can place reads to genus/family but not to species (database coverage or species-level ambiguity).")
print(" - If UNASSIGNED is high at all ranks, look at the input sequences (short sequences, contamination, novel clades) or the reference DB used for training.")


predictions: C:\Users\Srijit\sih\ncbi_blast_db\extracted\predictions_with_uncertainty.csv
validation : C:\Users\Srijit\sih\ncbi_blast_db\extracted\val_predictions_calibrated.csv
reconciled : C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_reconciled_species.csv
------------------------------------------------------------
Loaded predictions rows: 380
Detected species prediction column: species_pred_label
Detected confidence column: species_pred_conf

Validation: true = species_true_idx , pred = species_pred_label

Species-level metrics on validation:
  Accuracy = 0.7737    Macro-F1 = 0.3942    Micro-F1 = 0.7737

Top 15 classes by true-support:


Unnamed: 0,species,count
0,UNASSIGNED,201
1,Maylandia zebra,80
2,Chaetodon auriga,22
3,Arvicanthis niloticus,6
4,Morchella sp.,5
5,Aonchotheca annulosa,4
6,Pseudopestalotiopsis sp.,3
7,Deuterostichococcus epilithicus,3
8,Aspergillus costaricensis,3
9,Chloroidium saccharophilum,3



Classification report (top classes shown):

                                 precision    recall  f1-score   support

                     UNASSIGNED       0.99      0.71      0.83       201
                Maylandia zebra       1.00      1.00      1.00        80
               Chaetodon auriga       0.48      1.00      0.65        22
          Arvicanthis niloticus       0.75      1.00      0.86         6
                  Morchella sp.       0.40      0.80      0.53         5
           Aonchotheca annulosa       0.00      0.00      0.00         4
       Pseudopestalotiopsis sp.       1.00      1.00      1.00         3
Deuterostichococcus epilithicus       1.00      1.00      1.00         3
      Aspergillus costaricensis       0.43      1.00      0.60         3
     Chloroidium saccharophilum       1.00      1.00      1.00         3
                Cortinarius sp.       1.00      1.00      1.00         2
                    Amanita sp.       0.00      0.00      0.00         2
     

Unnamed: 0,true,pred,count
2,UNASSIGNED,Chaetodon auriga,24
4,UNASSIGNED,Morchella sp.,5
23,UNASSIGNED,Hysterothylacium fabri,5
8,UNASSIGNED,Aspergillus costaricensis,3
47,UNASSIGNED,Trichoderma viride,2
3,UNASSIGNED,Arvicanthis niloticus,2
24,UNASSIGNED,Entoloma sp.,2
1532,Clavaria sp.,Entoloma sp.,1
908,Inocybe sp.,Entoloma sp.,1
2604,Morchella pulchella,Morchella sp.,1



Calibration check: could not match confidences to val rows or confidences not present.

######## Assignment rates per rank ########

Rank 'kingdom'  (column='kingdom_pred_label')  -> UNASSIGNED 216/380 = 56.84%


Unnamed: 0,kingdom_pred_label,count
0,UNASSIGNED,216
1,Eukaryota,164



Rank 'phylum'  (column='phylum_pred_label')  -> UNASSIGNED 163/380 = 42.89%


Unnamed: 0,phylum_pred_label,count
0,UNASSIGNED,163
1,Metazoa,140
2,Fungi,65
3,Viridiplantae,11
4,Sar,1



Rank 'class'  (column='class_pred_label')  -> UNASSIGNED 161/380 = 42.37%


Unnamed: 0,class_pred_label,count
0,UNASSIGNED,161
1,Chordata,131
2,Dikarya,64
3,Ecdysozoa,10
4,Chlorophyta,8
5,Streptophyta,3
6,Cnidaria,2
7,Stramenopiles,1



Rank 'order'  (column='order_pred_label')  -> UNASSIGNED 157/380 = 41.32%


Unnamed: 0,order_pred_label,count
0,UNASSIGNED,157
1,Craniata,133
2,Ascomycota,33
3,Basidiomycota,32
4,Nematoda,10
5,core chlorophytes,8
6,Myxozoa,2
7,Embryophyta,2



Rank 'family'  (column='family_pred_label')  -> UNASSIGNED 155/380 = 40.79%


Unnamed: 0,family_pred_label,count
0,UNASSIGNED,155
1,Vertebrata,133
2,Pezizomycotina,34
3,Agaricomycotina,32
4,Trebouxiophyceae,8
5,Enoplea,6
6,Chromadorea,4
7,Myxosporea,2



Rank 'genus'  (column='genus_pred_label')  -> UNASSIGNED 155/380 = 40.79%


Unnamed: 0,genus_pred_label,count
0,UNASSIGNED,155
1,Euteleostomi,133
2,Agaricomycetes,32
3,Pezizomycetes,15
4,Sordariomycetes,7
5,Eurotiomycetes,6
6,Dorylaimia,6
7,Prasiolales,5



Rank 'species'  (column='species_pred_label')  -> UNASSIGNED 144/380 = 37.89%


Unnamed: 0,species_pred_label,count
0,UNASSIGNED,144
1,Maylandia zebra,80
2,Chaetodon auriga,46
3,Morchella sp.,10
4,Arvicanthis niloticus,9
5,Aspergillus costaricensis,7
6,Callospermophilus lateralis,6
7,Hysterothylacium fabri,5



Summary assignment rates (by rank):


Unnamed: 0,rank,pred_col,total,assigned,unassigned,pct_unassigned
0,kingdom,kingdom_pred_label,380,164,216,0.568421
1,phylum,phylum_pred_label,380,217,163,0.428947
2,class,class_pred_label,380,219,161,0.423684
3,order,order_pred_label,380,223,157,0.413158
4,family,family_pred_label,380,225,155,0.407895
5,genus,genus_pred_label,380,225,155,0.407895
6,species,species_pred_label,380,236,144,0.378947



Number of UNASSIGNED reads with species_pred_conf >= 0.5: 86
Examples (first 8):


Unnamed: 0,id,species_pred_conf,species_novel_component,genus_novel_component,novel_score,_species_norm
5,JBPZNU010001377.1,0.750098,0.206548,0.015048,0.110798,UNASSIGNED
8,LC876572.1,0.863004,0.098259,0.000956,0.049607,UNASSIGNED
17,PX277059.1,0.795968,0.208243,0.01436,0.111301,UNASSIGNED
18,LC876609.1,0.77343,0.174087,0.003924,0.089006,UNASSIGNED
20,PX273803.1,0.676611,0.265072,0.068322,0.166697,UNASSIGNED
21,JF836109.1,0.878108,0.103703,0.002911,0.053307,UNASSIGNED
23,JBPZNU010001350.1,0.779167,0.167112,0.010263,0.088687,UNASSIGNED
24,JBPZNU010001331.1,0.64717,0.215164,0.015348,0.115256,UNASSIGNED



Average species_pred_conf: assigned = 0.576, unassigned = 0.621


DONE. Summary:
 - If UNASSIGNED fraction is high at species but low at higher ranks, that suggests the model can place reads to genus/family but not to species (database coverage or species-level ambiguity).
 - If UNASSIGNED is high at all ranks, look at the input sequences (short sequences, contamination, novel clades) or the reference DB used for training.


In [93]:
# CELL: Inspect & handle UNASSIGNED reads (safe by default)
# - Finds your predictions CSV, exports UNASSIGNED rows + high-confidence subset,
#   writes an ID list for BLAST/lookup.
# - Optionally (only if you set force_assign=True) conservatively force-assigns some UNASSIGNED -> predicted species.
#
# Edit the top parameters if you want to try forced reassignment; default is non-destructive (force_assign=False).

from pathlib import Path
import pandas as pd
import numpy as np
import sys
try:
    # ---------- User parameters ----------
    force_assign = False           # Set True to attempt conservative forced-assignment (not recommended unless you understand the risk)
    assign_conf_threshold = 0.60   # if forced assignment: require species_pred_conf >= this
    novel_comp_max = 0.20          # if present: require species_novel_component <= this to allow forced assignment
    high_conf_cutoff = 0.50        # threshold for "high-confidence UNASSIGNED" export
    # -------------------------------------

    # ---------- find predictions CSV (includes paths you reported) ----------
    candidate_paths = [
        Path.cwd() / "sih" / "ncbi_blast_db" / "extracted" / "predictions_with_uncertainty.csv",
        Path.cwd() / "ncbi_blast_db" / "extracted" / "predictions_with_uncertainty.csv",
        Path.cwd() / "extracted" / "predictions_with_uncertainty.csv",
        Path.cwd() / "predictions_with_uncertainty.csv",
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted\predictions_with_uncertainty.csv"),
        Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted\predictions_with_uncertainty.csv"),
        Path(r"C:\Users\HP\sihabundance\ncbi_blast_db\extracted\predictions_with_uncertainty.csv"),
        Path.cwd() / "sih" / "ncbi_blast_db" / "extracted" / "predictions.csv",
        Path.cwd() / "ncbi_blast_db" / "extracted" / "predictions.csv",
    ]
    # add any file named predictions_with_uncertainty.csv under cwd recursively as fallback
    for p in Path.cwd().rglob("predictions_with_uncertainty.csv"):
        candidate_paths.append(p)
    predictions_file = None
    for p in candidate_paths:
        if p.exists():
            predictions_file = p.resolve()
            break
    if predictions_file is None:
        # last resort: recursive search for common names
        for fname in ("predictions_with_uncertainty.csv","predictions.csv"):
            for p in Path.cwd().rglob(fname):
                predictions_file = p.resolve()
                break
            if predictions_file is not None:
                break

    if predictions_file is None:
        raise FileNotFoundError("Could not find predictions CSV. Put 'predictions_with_uncertainty.csv' into an extracted/ folder and re-run.")

    print("Using predictions file:", predictions_file)
    out_dir = predictions_file.parent

    # ---------- load ----------
    df = pd.read_csv(predictions_file)
    print("Loaded rows:", len(df))

    # detect useful columns
    cols_lower = {c.lower(): c for c in df.columns}
    # species prediction column
    species_col = None
    for cand in ("species_pred_label","species_label","species_pred","species"):
        if cand in cols_lower:
            species_col = cols_lower[cand]; break
    if species_col is None:
        for c in df.columns:
            if "species" in c.lower():
                species_col = c; break
    if species_col is None:
        raise RuntimeError("No species prediction column detected in predictions CSV.")

    # confidence column (optional)
    conf_col = None
    for cand in ("species_pred_conf","species_conf","species_prob","species_probability","pred_conf"):
        if cand in cols_lower:
            conf_col = cols_lower[cand]; break

    # id column for BLAST lookup (optional)
    id_col = None
    for cand in ("id","accession","seqid","read_id","readid","global_index"):
        if cand in cols_lower:
            id_col = cols_lower[cand]; break

    # normalize working columns
    df["_species_norm"] = df[species_col].astype(str).fillna("UNASSIGNED").apply(lambda s: " ".join(str(s).split()))
    if conf_col:
        df["_conf_num"] = pd.to_numeric(df[conf_col], errors="coerce").fillna(0.0)
    else:
        df["_conf_num"] = np.nan

    # select UNASSIGNED rows
    is_unassigned = df["_species_norm"].str.upper() == "UNASSIGNED"
    unassigned_df = df[is_unassigned].copy().reset_index(drop=True)

    # write all UNASSIGNED rows
    out_unassigned = out_dir / "unassigned_rows.csv"
    unassigned_df.to_csv(out_unassigned, index=False)
    print(f"Wrote all UNASSIGNED rows -> {out_unassigned}  (count = {len(unassigned_df)})")

    # high-confidence subset (for BLAST/inspection)
    high_conf_unassigned = unassigned_df[unassigned_df["_conf_num"] >= high_conf_cutoff].copy()
    out_high_conf = out_dir / "high_conf_unassigned_rows.csv"
    high_conf_unassigned.to_csv(out_high_conf, index=False)
    print(f"Wrote high-confidence UNASSIGNED rows (conf >= {high_conf_cutoff}) -> {out_high_conf}  (count = {len(high_conf_unassigned)})")

    # write simple ID list for BLAST/lookup (we do NOT have sequence FASTA here)
    if len(high_conf_unassigned) > 0:
        if id_col:
            ids = high_conf_unassigned[id_col].astype(str).fillna("")
        elif "global_index" in df.columns:
            ids = high_conf_unassigned["global_index"].astype(str).fillna("")
        else:
            ids = high_conf_unassigned.index.astype(str)
        out_ids = out_dir / "high_conf_unassigned_ids.txt"
        ids.to_csv(out_ids, index=False, header=False)
        print("Wrote high-confidence UNASSIGNED ID list ->", out_ids)
    else:
        print("No high-confidence UNASSIGNED rows to export as ID list (threshold may be too high).")

    # If you have original FASTA and want FASTA created for BLAST, place the FASTA in the same extracted/ folder
    # with filename 'predictions_source_sequences.fasta' where header contains the same accession/id column value.
    # This cell does not attempt to reconstruct sequences automatically.

    # ---------- Optional conservative forced-assignment ----------
    if force_assign:
        print("\nFORCE-ASSIGN MODE: attempting conservative re-labeling of some UNASSIGNED reads.")
        # novel component column if present
        novel_col = None
        for cand in ("species_novel_component","novel_component","novel_score"):
            if cand in cols_lower:
                novel_col = cols_lower[cand]; break
        if novel_col:
            df["_novel_comp"] = pd.to_numeric(df.get(novel_col), errors="coerce").fillna(1.0)
        else:
            df["_novel_comp"] = 1.0
        candidates = df[is_unassigned & (df["_conf_num"] >= assign_conf_threshold) & (df["_novel_comp"] <= novel_comp_max)].copy()
        print("Forced-assignment candidates (conservative rule):", len(candidates))
        if len(candidates) > 0:
            # create forced copy and assign predicted species label into _species_norm
            df_forced = df.copy()
            df_forced.loc[candidates.index, "_species_norm"] = df_forced.loc[candidates.index, species_col].astype(str).apply(lambda s: " ".join(str(s).split()))
            species_counts = df_forced["_species_norm"].value_counts().reset_index().rename(columns={"index":"species", "_species_norm":"count"})
            out_forced = out_dir / "abundance_forced_assignments.csv"
            species_counts.to_csv(out_forced, index=False)
            print("Wrote species counts after forced-assignment ->", out_forced)
            display(species_counts.head(40))
        else:
            print("No candidates met forced-assignment criteria; nothing changed.")
    else:
        print("\nSAFE MODE: force_assign=False — no labels changed. To try forced relabeling, set force_assign=True and re-run the cell.")
        print("Conservative example settings: force_assign=True, assign_conf_threshold=0.60, novel_comp_max=0.20")

    print("\nDONE. Files written to:", out_dir)
    print(" - unassigned_rows.csv")
    print(" - high_conf_unassigned_rows.csv")
    print(" - high_conf_unassigned_ids.txt  (if high-confidence IDs exist)")
    print(" - abundance_forced_assignments.csv  (only if force_assign=True)")

except Exception as exc:
    print("ERROR during processing:", exc)
    import traceback
    traceback.print_exc()


Using predictions file: C:\Users\Srijit\sih\ncbi_blast_db\extracted\predictions_with_uncertainty.csv
Loaded rows: 380
Wrote all UNASSIGNED rows -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\unassigned_rows.csv  (count = 144)
Wrote high-confidence UNASSIGNED rows (conf >= 0.5) -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\high_conf_unassigned_rows.csv  (count = 86)
Wrote high-confidence UNASSIGNED ID list -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\high_conf_unassigned_ids.txt

SAFE MODE: force_assign=False — no labels changed. To try forced relabeling, set force_assign=True and re-run the cell.
Conservative example settings: force_assign=True, assign_conf_threshold=0.60, novel_comp_max=0.20

DONE. Files written to: C:\Users\Srijit\sih\ncbi_blast_db\extracted
 - unassigned_rows.csv
 - high_conf_unassigned_rows.csv
 - high_conf_unassigned_ids.txt  (if high-confidence IDs exist)
 - abundance_forced_assignments.csv  (only if force_assign=True)


In [97]:
# FIXED cell: robust UNASSIGNED handling + FASTA matching + fixed rel computation
# Paste & run in the same notebook/kernel you used earlier.
from pathlib import Path
import pandas as pd
import numpy as np
import re, traceback

# ---------- User params (edit if you want) ----------
force_assign = False           # default: do NOT change labels
high_conf_cutoff = 0.50        # used for "high-confidence UNASSIGNED" export
assign_conf_threshold = 0.60   # only used if force_assign=True
novel_comp_max = 0.20          # only used if force_assign=True
# ---------------------------------------------------

def find_predictions():
    candidates = [
        Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted\predictions_with_uncertainty.csv"),
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted\predictions_with_uncertainty.csv"),
        Path(r"C:\Users\HP\sihabundance\ncbi_blast_db\extracted\predictions_with_uncertainty.csv"),
        Path.cwd() / "sih" / "ncbi_blast_db" / "extracted" / "predictions_with_uncertainty.csv",
        Path.cwd() / "ncbi_blast_db" / "extracted" / "predictions_with_uncertainty.csv",
        Path.cwd() / "extracted" / "predictions_with_uncertainty.csv",
        Path.cwd() / "predictions_with_uncertainty.csv",
    ]
    for p in candidates:
        if p.exists(): return p.resolve()
    # fallback recursive search (within cwd)
    for p in Path.cwd().rglob("predictions_with_uncertainty.csv"):
        return p.resolve()
    for p in Path.cwd().rglob("predictions.csv"):
        return p.resolve()
    return None

def find_fasta(folder):
    # look for common names first, then any fasta-like file
    names = ["predictions_source_sequences.fasta","source_sequences.fasta","sequences.fasta",
             "input.fasta","reads.fasta","predictions_seqs.fasta","its_combined.fasta"]
    for n in names:
        p = folder / n
        if p.exists(): return p.resolve()
    for ext in ("*.fasta","*.fa","*.fna","*.ffn"):
        for p in folder.rglob(ext):
            return p.resolve()
    return None

def parse_fasta_simple(path):
    seqs = {}
    with open(path, "r", encoding="utf8", errors="replace") as fh:
        header = None; seq_lines=[]
        for line in fh:
            line = line.rstrip("\n\r")
            if not line: continue
            if line.startswith(">"):
                if header is not None:
                    key = header.split()[0]
                    seqs[key] = (header, "".join(seq_lines))
                header = line[1:].strip()
                seq_lines=[]
            else:
                seq_lines.append(line.strip())
        if header is not None:
            key = header.split()[0]
            seqs[key] = (header, "".join(seq_lines))
    return seqs

def canonicalize_acc(a):
    # canonical forms: original, without version (strip trailing .\d+), lower-case
    a = str(a).strip()
    forms = [a]
    # strip leading db qualifiers (e.g. "gi|...", "ref|", "gb|")
    if "|" in a:
        parts = a.split("|")
        # try each token
        for tok in parts:
            tok = tok.strip()
            if tok: forms.append(tok)
    # without version
    m = re.match(r"^(.+?)(?:\.\d+)$", a)
    if m:
        forms.append(m.group(1))
    return list(dict.fromkeys([f for f in forms if f]))  # unique preserving order

# ----------- begin main -----------
pred_file = find_predictions()
if pred_file is None:
    raise SystemExit("Could not find predictions CSV (searched common paths). Place predictions_with_uncertainty.csv into an extracted/ folder and re-run.")

out_dir = pred_file.parent
print("Using predictions:", pred_file)
print("Output directory:", out_dir)

df = pd.read_csv(pred_file)
print("Loaded rows:", len(df))

# detect important columns robustly
cols_lower = {c.lower(): c for c in df.columns}
species_col = next((cols_lower[k] for k in ("species_pred_label","species_label","species_pred","species") if k in cols_lower), None)
if not species_col:
    species_col = next((c for c in df.columns if "species" in c.lower()), None)
genus_col = next((cols_lower[k] for k in ("genus_pred_label","genus_label","genus_pred","genus") if k in cols_lower), None)
family_col = next((cols_lower[k] for k in ("family_pred_label","family_label","family_pred","family") if k in cols_lower), None)
conf_col = next((cols_lower[k] for k in ("species_pred_conf","species_conf","species_prob","species_probability","pred_conf") if k in cols_lower), None)
id_col = next((cols_lower[k] for k in ("id","accession","seqid","read_id","readid","global_index") if k in cols_lower), None)

if not species_col:
    raise SystemExit("Could not detect species prediction column in the predictions CSV.")

# normalize columns
df["_species_norm"] = df[species_col].astype(str).fillna("UNASSIGNED").apply(lambda s: " ".join(str(s).split()))
df["_conf_num"] = pd.to_numeric(df[conf_col], errors="coerce").fillna(0.0) if conf_col else np.nan
df["_genus_norm"] = df[genus_col].astype(str).fillna("UNASSIGNED").apply(lambda s: " ".join(str(s).split())) if genus_col else "UNASSIGNED"
df["_family_norm"] = df[family_col].astype(str).fillna("UNASSIGNED").apply(lambda s: " ".join(str(s).split())) if family_col else "UNASSIGNED"

# export unassigned files (again)
is_unassigned = df["_species_norm"].str.upper() == "UNASSIGNED"
unassigned_df = df[is_unassigned].copy().reset_index(drop=True)
print("UNASSIGNED count:", len(unassigned_df), f"({len(unassigned_df)/len(df):.1%})")

out_unassigned = out_dir / "unassigned_rows.csv"
unassigned_df.to_csv(out_unassigned, index=False)
print("Saved all UNASSIGNED rows ->", out_unassigned)

hc_df = unassigned_df[unassigned_df["_conf_num"] >= high_conf_cutoff].copy()
out_hc = out_dir / "high_conf_unassigned_rows.csv"
hc_df.to_csv(out_hc, index=False)
print(f"Saved high-confidence UNASSIGNED rows (conf >= {high_conf_cutoff}) ->", out_hc)

# write id list for BLAST (if ids exist)
if len(hc_df) > 0:
    if id_col and id_col in hc_df.columns:
        ids = hc_df[id_col].astype(str).fillna("")
    elif "global_index" in hc_df.columns:
        ids = hc_df["global_index"].astype(str).fillna("")
    else:
        ids = hc_df.index.astype(str)
    out_ids = out_dir / "high_conf_unassigned_ids.txt"
    ids.to_csv(out_ids, index=False, header=False)
    print("Saved ID list for BLAST ->", out_ids)
else:
    print("No high-confidence UNASSIGNED rows to write ID list for (threshold {}).".format(high_conf_cutoff))

# ---------- attempt robust FASTA matching and export ----------
fasta_file = find_fasta(out_dir)
if fasta_file is None:
    print("\nNo FASTA found in extracted/ folder. If you have the original FASTA, place it in the same folder and re-run to extract sequences for BLAST.")
else:
    print("\nFound FASTA:", fasta_file)
    seqs = parse_fasta_simple(fasta_file)
    print("FASTA header count:", len(seqs))
    # build header lookup variants for fast matching
    header_map = {}   # key -> (header, seq)
    for hid, (hdr, seq) in seqs.items():
        # canonical forms for header id
        forms = canonicalize_acc(hid)
        # also include header token without db prefixes if it contains '|'
        if "|" in hid:
            for tok in hid.split("|"):
                tok = tok.strip()
                if tok:
                    forms.extend(canonicalize_acc(tok))
        # push all forms into lookup
        for f in set(forms):
            header_map[f] = (hid, hdr, seq)
        # also map the full header string as a fallback
        header_map[hdr] = (hid, hdr, seq)

    # prepare ID list we want to match (from hc_df)
    if len(hc_df) > 0:
        if id_col and id_col in hc_df.columns:
            id_list_orig = hc_df[id_col].astype(str).tolist()
        elif "global_index" in hc_df.columns:
            id_list_orig = hc_df["global_index"].astype(str).tolist()
        else:
            id_list_orig = hc_df.index.astype(str).tolist()

        matched = {}
        unmatched = []
        for q in id_list_orig:
            q_forms = canonicalize_acc(q)
            found = False
            for qf in q_forms:
                if qf in header_map:
                    hid, hdr, seq = header_map[qf]
                    matched[q] = (hid, hdr, seq)
                    found = True
                    break
            if found:
                continue
            # try substring search in header strings (fallback): look for q or q without version
            q_nov = re.sub(r'\.\d+$', '', q)
            for hid0, (hdr0, seq0) in seqs.items():
                if q in hid0 or q in hdr0 or q_nov in hid0 or q_nov in hdr0:
                    matched[q] = (hid0, hdr0, seq0)
                    found = True
                    break
            if not found:
                unmatched.append(q)

        print(f"Matched sequences: {len(matched)} / {len(id_list_orig)} requested (high-conf IDs).")
        if len(matched) > 0:
            out_fa = out_dir / "high_conf_unassigned_seqs.fasta"
            with open(out_fa, "w", encoding="utf8") as fh:
                for q,(hid, hdr, seq) in matched.items():
                    fh.write(f">{q} {hdr}\n")
                    for i in range(0, len(seq), 80):
                        fh.write(seq[i:i+80] + "\n")
            print("Wrote FASTA for matched high-confidence UNASSIGNED IDs ->", out_fa)
        if len(unmatched) > 0:
            print("Unmatched IDs (sample up to 20):", unmatched[:20])
            print("Common cause: FASTA headers use different accession formats/contain versions or database prefixes.")
    else:
        print("No high-confidence IDs to extract from FASTA.")

# ---------- genus fallback (fixed numeric dtype and rel calculation) ----------
try:
    df_genus_fb = df.copy()
    df_genus_fb["_species_genus_fallback"] = df_genus_fb["_species_norm"]
    mask_genus = (df_genus_fb["_species_genus_fallback"].str.upper()=="UNASSIGNED") & (df_genus_fb["_genus_norm"].str.upper()!="UNASSIGNED")
    df_genus_fb.loc[mask_genus, "_species_genus_fallback"] = df_genus_fb.loc[mask_genus, "_genus_norm"].apply(lambda g: f"GENUS::{g}")

    counts_genus_fb = df_genus_fb["_species_genus_fallback"].value_counts().reset_index()
    counts_genus_fb.columns = ["species_or_genus_fallback", "count"]
    # enforce numeric dtype
    counts_genus_fb["count"] = pd.to_numeric(counts_genus_fb["count"], errors="coerce").fillna(0).astype(int)
    total = counts_genus_fb["count"].sum()
    counts_genus_fb["rel"] = counts_genus_fb["count"] / total if total>0 else 0.0
    out_genus = out_dir / "abundance_species_genus_fallback.csv"
    counts_genus_fb.to_csv(out_genus, index=False)
    print("\nSaved genus-fallback abundance ->", out_genus)
    print(counts_genus_fb.head(12).to_string(index=False))
except Exception as e:
    print("Error creating genus-fallback table:", e)
    traceback.print_exc()

# ---------- family fallback (fixed) ----------
try:
    df_family_fb = df.copy()
    df_family_fb["_species_family_fallback"] = df_family_fb["_species_norm"]
    mask_family = (df_family_fb["_species_family_fallback"].str.upper()=="UNASSIGNED") & (df_family_fb["_family_norm"].str.upper()!="UNASSIGNED")
    df_family_fb.loc[mask_family, "_species_family_fallback"] = df_family_fb.loc[mask_family, "_family_norm"].apply(lambda f: f"FAMILY::{f}")

    counts_family_fb = df_family_fb["_species_family_fallback"].value_counts().reset_index()
    counts_family_fb.columns = ["species_or_family_fallback", "count"]
    counts_family_fb["count"] = pd.to_numeric(counts_family_fb["count"], errors="coerce").fillna(0).astype(int)
    totalf = counts_family_fb["count"].sum()
    counts_family_fb["rel"] = counts_family_fb["count"] / totalf if totalf>0 else 0.0
    out_family = out_dir / "abundance_species_family_fallback.csv"
    counts_family_fb.to_csv(out_family, index=False)
    print("\nSaved family-fallback abundance ->", out_family)
    print(counts_family_fb.head(12).to_string(index=False))
except Exception as e:
    print("Error creating family-fallback table:", e)
    traceback.print_exc()

# ---------- optional proportional redistribution using reconciled proportions ----------
recon_candidates = ["abundance_reconciled_species.csv","abundance_from_predictions_deconvolved.csv","abundance_from_predictions.csv"]
recon_file = None
for fn in recon_candidates:
    p = out_dir / fn
    if p.exists():
        recon_file = p; break

if recon_file is None:
    print("\nNo reconciled abundance CSV found; skipping proportional redistribution step.")
else:
    try:
        recon = pd.read_csv(recon_file)
        cols_low = {c.lower(): c for c in recon.columns}
        sp_col_recon = cols_low.get("species", list(recon.columns)[0])
        # pick a likely relative column
        rel_col_candidates = ["est_rel","est_true_rel","rel","relative_abundance","relative_abundance_weighted","pred_count_rel","relative_abundance_count"]
        rel_col = None
        for cand in rel_col_candidates:
            if cand in cols_low:
                rel_col = cols_low[cand]; break
        if rel_col is None:
            numeric_cols = [c for c in recon.columns if pd.api.types.is_numeric_dtype(recon[c])]
            rel_col = numeric_cols[0] if numeric_cols else None
        if rel_col is None:
            print("Could not identify a rel-abundance column in reconciled file; skipping redistribution.")
        else:
            recon2 = recon[[sp_col_recon, rel_col]].rename(columns={sp_col_recon:"species", rel_col:"est_rel"})
            recon2["species"] = recon2["species"].astype(str).apply(lambda s: " ".join(str(s).split()))
            recon_nonun = recon2[recon2["species"].str.upper()!="UNASSIGNED"].copy()
            if recon_nonun["est_rel"].sum() <= 0:
                print("Reconciled file has zero non-UNASSIGNED mass; cannot redistribute.")
            else:
                recon_nonun["renorm"] = recon_nonun["est_rel"] / recon_nonun["est_rel"].sum()
                total_unassigned = len(unassigned_df)
                recon_nonun["alloc_float"] = recon_nonun["renorm"] * total_unassigned
                recon_nonun["alloc_int"] = recon_nonun["alloc_float"].round().astype(int)
                diff = int(total_unassigned - recon_nonun["alloc_int"].sum())
                if diff != 0:
                    rem = recon_nonun["alloc_float"] - recon_nonun["alloc_int"]
                    if diff > 0:
                        idxs = rem.sort_values(ascending=False).index.tolist()
                        for i in range(diff):
                            recon_nonun.at[idxs[i % len(idxs)], "alloc_int"] += 1
                    else:
                        idxs = rem.sort_values(ascending=True).index.tolist()
                        for i in range(-diff):
                            recon_nonun.at[idxs[i % len(idxs)], "alloc_int"] -= 1
                raw_counts = df["_species_norm"].value_counts().to_dict()
                rows=[]
                for _,r in recon_nonun.iterrows():
                    sp = r["species"]
                    orig = int(raw_counts.get(sp,0))
                    add = int(r["alloc_int"])
                    rows.append({"species":sp,"orig_count":orig,"added_from_unassigned":add,"new_count":orig+add})
                new_df = pd.DataFrame(rows)
                new_df["new_rel"] = new_df["new_count"] / new_df["new_count"].sum() if new_df["new_count"].sum()>0 else 0.0
                out_redis = out_dir / "abundance_unassigned_redistributed_by_reconciled.csv"
                new_df.to_csv(out_redis, index=False)
                print("\nSaved redistributed-by-reconciled ->", out_redis)
                print(new_df.sort_values("new_count", ascending=False).head(12).to_string(index=False))
    except Exception as e:
        print("Error while redistributing:", e)
        traceback.print_exc()

print("\nDone. Created/updated files are in:", out_dir)
print("- unassigned_rows.csv")
print("- high_conf_unassigned_rows.csv")
print("- high_conf_unassigned_ids.txt (if high-conf IDs exist)")
print("- high_conf_unassigned_seqs.fasta (if source FASTA matched headers)")
print("- abundance_species_genus_fallback.csv")
print("- abundance_species_family_fallback.csv")
print("- abundance_unassigned_redistributed_by_reconciled.csv (if reconciled file present)")
print("- abundance_forced_assignments.csv (only if you enable force_assign=True)")


Using predictions: C:\Users\Srijit\sih\ncbi_blast_db\extracted\predictions_with_uncertainty.csv
Output directory: C:\Users\Srijit\sih\ncbi_blast_db\extracted
Loaded rows: 380
UNASSIGNED count: 144 (37.9%)
Saved all UNASSIGNED rows -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\unassigned_rows.csv
Saved high-confidence UNASSIGNED rows (conf >= 0.5) -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\high_conf_unassigned_rows.csv
Saved ID list for BLAST -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\high_conf_unassigned_ids.txt

Found FASTA: C:\Users\Srijit\sih\ncbi_blast_db\extracted\its_combined.fasta
FASTA header count: 699
Matched sequences: 0 / 86 requested (high-conf IDs).
Unmatched IDs (sample up to 20): ['JBPZNU010001377.1', 'LC876572.1', 'PX277059.1', 'LC876609.1', 'PX273803.1', 'JF836109.1', 'JBPZNU010001350.1', 'JBPZNU010001331.1', 'JF836111.1', 'PX278865.1', 'PX277079.1', 'PX278850.1', 'PX278853.1', 'LC876580.1', 'LC876528.1', 'LC876609.1', 'LC876593.1', 'JBPZNU010001303.1', 'LC87

In [99]:
# Cell: inspect FASTA headers vs high-conf UNASSIGNED IDs, attempt robust matching,
# and (optionally) perform a conservative forced-assignment of matched UNASSIGNED reads.
#
# SAFE DEFAULTS: force_assign=False. Set force_assign=True only AFTER you inspect the mapping CSV
# and understand that forcing will reassign some UNASSIGNED reads to predicted species.
#
from pathlib import Path
import re
import pandas as pd
import numpy as np

# ---------- USER FLAGS (edit only if you want to force assignments) ----------
force_assign = False             # default False -> NO label changes
assign_conf_threshold = 0.60     # if forcing: require species_pred_conf >= this
novel_comp_max = 0.20            # if forcing: require species_novel_component <= this (if column present)
high_conf_cutoff = 0.50          # high-confidence threshold used previously
# ---------------------------------------------------------------------------

# ---------- find extracted folder (common locations) ----------
candidates = [
    Path.cwd() / "sih" / "ncbi_blast_db" / "extracted",
    Path.cwd() / "ncbi_blast_db" / "extracted",
    Path.cwd() / "extracted",
    Path.cwd(),
    Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"),
    Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
    Path(r"C:\Users\HP\sihabundance\ncbi_blast_db\extracted"),
]
extracted = None
for p in candidates:
    if p and p.exists() and p.is_dir():
        # verify predictions present
        if (p / "predictions_with_uncertainty.csv").exists() or (p / "predictions.csv").exists() or (p / "predictions_with_uncertainty.csv").exists():
            extracted = p.resolve()
            break

# fallback: try to find the id file recursively under cwd
if extracted is None:
    for f in Path.cwd().rglob("high_conf_unassigned_ids.txt"):
        extracted = f.parent.resolve(); break

if extracted is None:
    raise SystemExit("Could not find the 'extracted' folder containing prediction outputs. Run the earlier cells in this notebook or set the path here.")

print("Using extracted folder:", extracted)

# ---------- required files ----------
preds_candidates = [extracted / "predictions_with_uncertainty.csv", extracted / "predictions.csv"]
pred_file = next((p for p in preds_candidates if p.exists()), None)
if pred_file is None:
    raise SystemExit("predictions CSV not found in extracted/. Run the classification cells first.")

ids_file = extracted / "high_conf_unassigned_ids.txt"
rows_file = extracted / "high_conf_unassigned_rows.csv"
# if the simple id file doesn't exist, try to extract ids from the rows CSV
if not ids_file.exists():
    if rows_file.exists():
        df_rows = pd.read_csv(rows_file)
        id_col = next((c for c in df_rows.columns if c.lower() in ("id","accession","seqid","read_id","readid","accession_id")), df_rows.columns[0])
        ids = df_rows[id_col].astype(str).tolist()
        print(f"Extracted {len(ids)} IDs from {rows_file.name} using column '{id_col}'")
    else:
        raise SystemExit("high_conf_unassigned_ids.txt and high_conf_unassigned_rows.csv not found. Run the previous 'UNASSIGNED handling' cell first.")
else:
    with open(ids_file, "r", encoding="utf8", errors="replace") as fh:
        ids = [line.strip() for line in fh if line.strip()]
    print(f"Read {len(ids)} high-confidence UNASSIGNED IDs from {ids_file.name}")

# ---------- search for FASTA ----------
fasta_names = ["its_combined.fasta","predictions_source_sequences.fasta","source_sequences.fasta","sequences.fasta","input.fasta","reads.fasta"]
fasta_path = None
for n in fasta_names:
    p = extracted / n
    if p.exists():
        fasta_path = p; break
if fasta_path is None:
    # pick any fasta-like file in folder
    for ext in ("*.fasta","*.fa","*.fna","*.ffn"):
        found = list(extracted.rglob(ext))
        if found:
            fasta_path = found[0]; break

if fasta_path is None:
    print("No FASTA found in extracted/. Place the FASTA with your source sequences in that folder (common name: its_combined.fasta) and re-run this cell.")
else:
    print("Using FASTA:", fasta_path)

# ---------- parse FASTA headers (simple parser) ----------
def parse_fasta_headers(path):
    seqs = {}
    with open(path, "r", encoding="utf8", errors="replace") as fh:
        header = None; seq_lines=[]
        for line in fh:
            line = line.rstrip("\n\r")
            if not line: continue
            if line.startswith(">"):
                if header is not None:
                    seqs[header.split()[0]] = (header, "".join(seq_lines))
                header = line[1:].strip()
                seq_lines = []
            else:
                seq_lines.append(line.strip())
        if header is not None:
            seqs[header.split()[0]] = (header, "".join(seq_lines))
    return seqs

seqs = parse_fasta_headers(fasta_path)
print("Parsed FASTA headers:", len(seqs))

# Preview first 80 IDs & first 80 FASTA header keys
print("\n--- Preview high-conf IDs (first 80) ---")
for i, x in enumerate(ids[:80], 1):
    print(f"{i:3d}. {x}")
print("\n--- Preview FASTA header keys (first 80) ---")
header_keys = list(seqs.keys())
for i, h in enumerate(header_keys[:80], 1):
    print(f"{i:3d}. {h}")

# ---------- canonicalization and matching heuristics ----------
def canonical_forms(x):
    x = str(x).strip()
    forms = [x]
    # strip trailing version like .1 .2
    nov = re.sub(r'\.\d+$', '', x)
    if nov != x: forms.append(nov)
    # split on DB pipe tokens
    if "|" in x:
        toks = [t.strip() for t in re.split(r'[|/]', x) if t.strip()]
        forms.extend(toks)
    # drop common prefixes like 'ref|', 'gb|' etc.
    forms2 = []
    for f in list(forms):
        f2 = re.sub(r'^(?:ref|gi|gb|emb|dbj|accession)[\|:]', '', f, flags=re.IGNORECASE)
        forms2.append(f2)
    # lower-case variants
    out = []
    for s in forms + forms2:
        if s and s not in out:
            out.append(s)
        s_low = s.lower() if s else s
        if s_low and s_low not in out:
            out.append(s_low)
    return out

# build header lookup
header_map = {}
for hid, (hdr, seq) in seqs.items():
    for f in canonical_forms(hid):
        header_map[f] = (hid, hdr, seq)
    header_map[hdr] = (hid, hdr, seq)
    header_map[hdr.lower()] = (hid, hdr, seq)

matched = {}
unmatched = []
for q in ids:
    found = False
    for qf in canonical_forms(q):
        if qf in header_map:
            matched[q] = header_map[qf]; found = True; break
    if found: continue
    # fallback: substring match in header strings (strip version first)
    q_nov = re.sub(r'\.\d+$', '', q)
    for hid, (hdr, seq) in seqs.items():
        if q in hid or q in hdr or q_nov in hid or q_nov in hdr:
            matched[q] = (hid, hdr, seq); found = True; break
    if not found:
        unmatched.append(q)

print(f"\nMatched {len(matched)} / {len(ids)} high-conf IDs by heuristics. Unmatched: {len(unmatched)}")

# ---------- save mapping & matched FASTA ----------
mapped_rows = []
for q, (hid, hdr, seq) in matched.items():
    mapped_rows.append({"requested_id": q, "matched_header_id": hid, "matched_header": hdr, "seq_len": len(seq)})
mapping_df = pd.DataFrame(mapped_rows)
mapping_out = extracted / "high_conf_unassigned_matched_mapping.csv"
mapping_df.to_csv(mapping_out, index=False)
print("Wrote mapping CSV ->", mapping_out)

if len(matched) > 0:
    out_fa = extracted / "high_conf_unassigned_seqs_matched.fasta"
    with open(out_fa, "w", encoding="utf8") as fh:
        for q, (hid, hdr, seq) in matched.items():
            fh.write(f">{q} matched_header={hid} {hdr}\n")
            for i in range(0, len(seq), 80):
                fh.write(seq[i:i+80] + "\n")
    print("Wrote matched FASTA ->", out_fa)
else:
    print("No matched FASTA written (no matches).")

unmatched_out = extracted / "high_conf_unassigned_unmatched_ids.txt"
with open(unmatched_out, "w", encoding="utf8") as fh:
    for q in unmatched:
        fh.write(q + "\n")
print("Wrote unmatched ID list ->", unmatched_out)

# ---------- (OPTIONAL) conservative forced assignment of matched IDs back to species_pred_label ----------
if force_assign:
    print("\nFORCE_ASSIGN=True -> attempting conservative relabeling of matched UNASSIGNED reads.")
    # load predictions CSV
    preds = pd.read_csv(pred_file)
    # detect columns
    cols_lower = {c.lower(): c for c in preds.columns}
    species_col = next((cols_lower[k] for k in ("species_pred_label","species_label","species_pred","species") if k in cols_lower), None)
    conf_col = next((cols_lower[k] for k in ("species_pred_conf","species_conf","species_prob","pred_conf") if k in cols_lower), None)
    novel_col = next((cols_lower[k] for k in ("species_novel_component","novel_component","novel_score") if k in cols_lower), None)
    id_col = next((cols_lower[k] for k in ("id","accession","seqid","read_id","readid","global_index") if k in cols_lower), None)
    if species_col is None:
        raise SystemExit("Could not detect species prediction column in predictions CSV.")

    preds["_species_norm"] = preds[species_col].astype(str).fillna("UNASSIGNED").apply(lambda s: " ".join(str(s).split()))
    if conf_col:
        preds["_conf_num"] = pd.to_numeric(preds[conf_col], errors="coerce").fillna(0.0)
    else:
        preds["_conf_num"] = 0.0
    if novel_col:
        preds["_novel_comp"] = pd.to_numeric(preds[novel_col], errors="coerce").fillna(1.0)
    else:
        preds["_novel_comp"] = 1.0

    # build lookup of matched header -> requested id(s)
    matched_header_to_q = {}
    for q,(hid,hdr,seq) in matched.items():
        matched_header_to_q.setdefault(hid, []).append(q)

    # now find rows in preds whose id matches one of the requested ids (we must infer id_col)
    if id_col:
        ids_in_preds = preds[id_col].astype(str).tolist()
        # build mapping from requested id to row indices
        q_to_rows = {}
        for q in matched.keys():
            # find rows where id == q (exact) or id contains q token
            mask_eq = preds[id_col].astype(str) == str(q)
            mask_cont = preds[id_col].astype(str).str.contains(str(q), na=False)
            rows_idx = preds.index[mask_eq | mask_cont].tolist()
            if rows_idx:
                q_to_rows[q] = rows_idx
    else:
        # try matching by global_index or fallback to nothing
        if "global_index" in preds.columns:
            q_to_rows = {}
            for q in matched.keys():
                rows_idx = preds.index[preds["global_index"].astype(str) == str(q)].tolist()
                if rows_idx: q_to_rows[q] = rows_idx
        else:
            raise SystemExit("Could not find an 'id' column in predictions CSV to align IDs to rows. For safety we will not force-assign.")

    # assemble list of candidate (row index) to new_species (predicted species)
    to_assign = []
    for q, rows_idx in q_to_rows.items():
        # new species label = the predicted species in preds row (species_col) or the matched header if needed
        for ri in rows_idx:
            curr = preds.loc[ri, "_species_norm"]
            # only consider if currently UNASSIGNED
            if str(curr).upper() == "UNASSIGNED":
                conf_val = float(preds.loc[ri, "_conf_num"])
                novel_val = float(preds.loc[ri, "_novel_comp"])
                # conservative criteria
                if conf_val >= assign_conf_threshold and novel_val <= novel_comp_max:
                    new_label = preds.loc[ri, species_col]
                    to_assign.append((ri, new_label, conf_val, novel_val, q))

    print(f"Candidates meeting conservative criteria to assign: {len(to_assign)}")
    if len(to_assign) == 0:
        print("No safe candidates met the thresholds (no changes made).")
    else:
        # perform assignment on copy and save species counts
        preds_forced = preds.copy()
        for ri, new_label, conf_val, novel_val, q in to_assign:
            preds_forced.at[ri, "_species_norm"] = str(new_label)
        # compute species counts and normalized abundances
        counts = preds_forced["_species_norm"].value_counts().reset_index()
        counts.columns = ["species","count"]
        counts["rel"] = counts["count"] / counts["count"].sum()
        out_forced = extracted / "abundance_forced_assignments.csv"
        counts.to_csv(out_forced, index=False)
        print("Saved forced-assignment species counts ->", out_forced)

        # also save a simple reconciled-style normalized file
        out_recon = extracted / "abundance_forced_assignments_reconciled.csv"
        counts.to_csv(out_recon, index=False)
        print("Saved reconciled-style normalized abundance ->", out_recon)

# ---------- final message ----------
print("\nDone. Files written to:", extracted)
print(" - high_conf_unassigned_matched_mapping.csv")
print(" - high_conf_unassigned_seqs_matched.fasta   (if any matches)")
print(" - high_conf_unassigned_unmatched_ids.txt")
print(" - (if force_assign=True) abundance_forced_assignments.csv and abundance_forced_assignments_reconciled.csv")

print("\nNext recommended actions (pick one):")
print("  1) Inspect 'high_conf_unassigned_matched_mapping.csv' to verify matches (VERY IMPORTANT).")
print("  2) Run BLAST locally on 'high_conf_unassigned_seqs_matched.fasta' or on unmatched IDs' sequences to confirm species identity.")
print("  3) If BLAST confirms match and you trust it, re-run with force_assign=True to absorb matched high-confidence UNASSIGNED into species counts.")
print("\nIf you want, I can now (A) provide the local BLAST command to run against NCBI/your DB, (B) flip force_assign=True and re-run the assignment heuristics here, or (C) produce a small report comparing before/after abundances. Tell me which and I'll give the exact cell/commands.")


Using extracted folder: C:\Users\Srijit\sih\ncbi_blast_db\extracted
Read 86 high-confidence UNASSIGNED IDs from high_conf_unassigned_ids.txt
Using FASTA: C:\Users\Srijit\sih\ncbi_blast_db\extracted\its_combined.fasta
Parsed FASTA headers: 699

--- Preview high-conf IDs (first 80) ---
  1. JBPZNU010001377.1
  2. LC876572.1
  3. PX277059.1
  4. LC876609.1
  5. PX273803.1
  6. JF836109.1
  7. JBPZNU010001350.1
  8. JBPZNU010001331.1
  9. JF836111.1
 10. PX278865.1
 11. PX277079.1
 12. PX278850.1
 13. PX278853.1
 14. LC876580.1
 15. LC876528.1
 16. LC876609.1
 17. LC876593.1
 18. JBPZNU010001303.1
 19. LC876579.1
 20. LC876562.1
 21. LC876572.1
 22. LC876528.1
 23. LC876567.1
 24. LC876579.1
 25. JF836094.1
 26. JBPZNU010001397.1
 27. LC876593.1
 28. PX278864.1
 29. JBPZNU010001375.1
 30. JF836094.1
 31. LC876521.1
 32. OQ241938.1
 33. LC876568.1
 34. XR_013090605.1
 35. LC876580.1
 36. JF836095.1
 37. JBPZNU010001366.1
 38. LC876553.1
 39. JBPZNU010001353.1
 40. JBPZNU010001377.1
 41. LC8

In [101]:
# Cell: inspect FASTA headers vs high-conf UNASSIGNED IDs, attempt robust matching,
# and (optionally) perform a conservative forced-assignment of matched UNASSIGNED reads.
#
# SAFE DEFAULTS: force_assign=False. Set force_assign=True only AFTER you inspect the mapping CSV
# and understand that forcing will reassign some UNASSIGNED reads to predicted species.
#
from pathlib import Path
import re
import pandas as pd
import numpy as np

# ---------- USER FLAGS (edit only if you want to force assignments) ----------
force_assign = False             # default False -> NO label changes
assign_conf_threshold = 0.60     # if forcing: require species_pred_conf >= this
novel_comp_max = 0.20            # if forcing: require species_novel_component <= this (if column present)
high_conf_cutoff = 0.50          # high-confidence threshold used previously
# ---------------------------------------------------------------------------

# ---------- find extracted folder (common locations) ----------
candidates = [
    Path.cwd() / "sih" / "ncbi_blast_db" / "extracted",
    Path.cwd() / "ncbi_blast_db" / "extracted",
    Path.cwd() / "extracted",
    Path.cwd(),
    Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"),
    Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
    Path(r"C:\Users\HP\sihabundance\ncbi_blast_db\extracted"),
]
extracted = None
for p in candidates:
    if p and p.exists() and p.is_dir():
        # verify predictions present
        if (p / "predictions_with_uncertainty.csv").exists() or (p / "predictions.csv").exists() or (p / "predictions_with_uncertainty.csv").exists():
            extracted = p.resolve()
            break

# fallback: try to find the id file recursively under cwd
if extracted is None:
    for f in Path.cwd().rglob("high_conf_unassigned_ids.txt"):
        extracted = f.parent.resolve(); break

if extracted is None:
    raise SystemExit("Could not find the 'extracted' folder containing prediction outputs. Run the earlier cells in this notebook or set the path here.")

print("Using extracted folder:", extracted)

# ---------- required files ----------
preds_candidates = [extracted / "predictions_with_uncertainty.csv", extracted / "predictions.csv"]
pred_file = next((p for p in preds_candidates if p.exists()), None)
if pred_file is None:
    raise SystemExit("predictions CSV not found in extracted/. Run the classification cells first.")

ids_file = extracted / "high_conf_unassigned_ids.txt"
rows_file = extracted / "high_conf_unassigned_rows.csv"
# if the simple id file doesn't exist, try to extract ids from the rows CSV
if not ids_file.exists():
    if rows_file.exists():
        df_rows = pd.read_csv(rows_file)
        id_col = next((c for c in df_rows.columns if c.lower() in ("id","accession","seqid","read_id","readid","accession_id")), df_rows.columns[0])
        ids = df_rows[id_col].astype(str).tolist()
        print(f"Extracted {len(ids)} IDs from {rows_file.name} using column '{id_col}'")
    else:
        raise SystemExit("high_conf_unassigned_ids.txt and high_conf_unassigned_rows.csv not found. Run the previous 'UNASSIGNED handling' cell first.")
else:
    with open(ids_file, "r", encoding="utf8", errors="replace") as fh:
        ids = [line.strip() for line in fh if line.strip()]
    print(f"Read {len(ids)} high-confidence UNASSIGNED IDs from {ids_file.name}")

# ---------- search for FASTA ----------
fasta_names = ["its_combined.fasta","predictions_source_sequences.fasta","source_sequences.fasta","sequences.fasta","input.fasta","reads.fasta"]
fasta_path = None
for n in fasta_names:
    p = extracted / n
    if p.exists():
        fasta_path = p; break
if fasta_path is None:
    # pick any fasta-like file in folder
    for ext in ("*.fasta","*.fa","*.fna","*.ffn"):
        found = list(extracted.rglob(ext))
        if found:
            fasta_path = found[0]; break

if fasta_path is None:
    print("No FASTA found in extracted/. Place the FASTA with your source sequences in that folder (common name: its_combined.fasta) and re-run this cell.")
else:
    print("Using FASTA:", fasta_path)

# ---------- parse FASTA headers (simple parser) ----------
def parse_fasta_headers(path):
    seqs = {}
    with open(path, "r", encoding="utf8", errors="replace") as fh:
        header = None; seq_lines=[]
        for line in fh:
            line = line.rstrip("\n\r")
            if not line: continue
            if line.startswith(">"):
                if header is not None:
                    seqs[header.split()[0]] = (header, "".join(seq_lines))
                header = line[1:].strip()
                seq_lines = []
            else:
                seq_lines.append(line.strip())
        if header is not None:
            seqs[header.split()[0]] = (header, "".join(seq_lines))
    return seqs

seqs = parse_fasta_headers(fasta_path)
print("Parsed FASTA headers:", len(seqs))

# Preview first 80 IDs & first 80 FASTA header keys
print("\n--- Preview high-conf IDs (first 80) ---")
for i, x in enumerate(ids[:80], 1):
    print(f"{i:3d}. {x}")
print("\n--- Preview FASTA header keys (first 80) ---")
header_keys = list(seqs.keys())
for i, h in enumerate(header_keys[:80], 1):
    print(f"{i:3d}. {h}")

# ---------- canonicalization and matching heuristics ----------
def canonical_forms(x):
    x = str(x).strip()
    forms = [x]
    # strip trailing version like .1 .2
    nov = re.sub(r'\.\d+$', '', x)
    if nov != x: forms.append(nov)
    # split on DB pipe tokens
    if "|" in x:
        toks = [t.strip() for t in re.split(r'[|/]', x) if t.strip()]
        forms.extend(toks)
    # drop common prefixes like 'ref|', 'gb|' etc.
    forms2 = []
    for f in list(forms):
        f2 = re.sub(r'^(?:ref|gi|gb|emb|dbj|accession)[\|:]', '', f, flags=re.IGNORECASE)
        forms2.append(f2)
    # lower-case variants
    out = []
    for s in forms + forms2:
        if s and s not in out:
            out.append(s)
        s_low = s.lower() if s else s
        if s_low and s_low not in out:
            out.append(s_low)
    return out

# build header lookup
header_map = {}
for hid, (hdr, seq) in seqs.items():
    for f in canonical_forms(hid):
        header_map[f] = (hid, hdr, seq)
    header_map[hdr] = (hid, hdr, seq)
    header_map[hdr.lower()] = (hid, hdr, seq)

matched = {}
unmatched = []
for q in ids:
    found = False
    for qf in canonical_forms(q):
        if qf in header_map:
            matched[q] = header_map[qf]; found = True; break
    if found: continue
    # fallback: substring match in header strings (strip version first)
    q_nov = re.sub(r'\.\d+$', '', q)
    for hid, (hdr, seq) in seqs.items():
        if q in hid or q in hdr or q_nov in hid or q_nov in hdr:
            matched[q] = (hid, hdr, seq); found = True; break
    if not found:
        unmatched.append(q)

print(f"\nMatched {len(matched)} / {len(ids)} high-conf IDs by heuristics. Unmatched: {len(unmatched)}")

# ---------- save mapping & matched FASTA ----------
mapped_rows = []
for q, (hid, hdr, seq) in matched.items():
    mapped_rows.append({"requested_id": q, "matched_header_id": hid, "matched_header": hdr, "seq_len": len(seq)})
mapping_df = pd.DataFrame(mapped_rows)
mapping_out = extracted / "high_conf_unassigned_matched_mapping.csv"
mapping_df.to_csv(mapping_out, index=False)
print("Wrote mapping CSV ->", mapping_out)

if len(matched) > 0:
    out_fa = extracted / "high_conf_unassigned_seqs_matched.fasta"
    with open(out_fa, "w", encoding="utf8") as fh:
        for q, (hid, hdr, seq) in matched.items():
            fh.write(f">{q} matched_header={hid} {hdr}\n")
            for i in range(0, len(seq), 80):
                fh.write(seq[i:i+80] + "\n")
    print("Wrote matched FASTA ->", out_fa)
else:
    print("No matched FASTA written (no matches).")

unmatched_out = extracted / "high_conf_unassigned_unmatched_ids.txt"
with open(unmatched_out, "w", encoding="utf8") as fh:
    for q in unmatched:
        fh.write(q + "\n")
print("Wrote unmatched ID list ->", unmatched_out)

# ---------- (OPTIONAL) conservative forced assignment of matched IDs back to species_pred_label ----------
if force_assign:
    print("\nFORCE_ASSIGN=True -> attempting conservative relabeling of matched UNASSIGNED reads.")
    # load predictions CSV
    preds = pd.read_csv(pred_file)
    # detect columns
    cols_lower = {c.lower(): c for c in preds.columns}
    species_col = next((cols_lower[k] for k in ("species_pred_label","species_label","species_pred","species") if k in cols_lower), None)
    conf_col = next((cols_lower[k] for k in ("species_pred_conf","species_conf","species_prob","pred_conf") if k in cols_lower), None)
    novel_col = next((cols_lower[k] for k in ("species_novel_component","novel_component","novel_score") if k in cols_lower), None)
    id_col = next((cols_lower[k] for k in ("id","accession","seqid","read_id","readid","global_index") if k in cols_lower), None)
    if species_col is None:
        raise SystemExit("Could not detect species prediction column in predictions CSV.")

    preds["_species_norm"] = preds[species_col].astype(str).fillna("UNASSIGNED").apply(lambda s: " ".join(str(s).split()))
    if conf_col:
        preds["_conf_num"] = pd.to_numeric(preds[conf_col], errors="coerce").fillna(0.0)
    else:
        preds["_conf_num"] = 0.0
    if novel_col:
        preds["_novel_comp"] = pd.to_numeric(preds[novel_col], errors="coerce").fillna(1.0)
    else:
        preds["_novel_comp"] = 1.0

    # build lookup of matched header -> requested id(s)
    matched_header_to_q = {}
    for q,(hid,hdr,seq) in matched.items():
        matched_header_to_q.setdefault(hid, []).append(q)

    # now find rows in preds whose id matches one of the requested ids (we must infer id_col)
    if id_col:
        ids_in_preds = preds[id_col].astype(str).tolist()
        # build mapping from requested id to row indices
        q_to_rows = {}
        for q in matched.keys():
            # find rows where id == q (exact) or id contains q token
            mask_eq = preds[id_col].astype(str) == str(q)
            mask_cont = preds[id_col].astype(str).str.contains(str(q), na=False)
            rows_idx = preds.index[mask_eq | mask_cont].tolist()
            if rows_idx:
                q_to_rows[q] = rows_idx
    else:
        # try matching by global_index or fallback to nothing
        if "global_index" in preds.columns:
            q_to_rows = {}
            for q in matched.keys():
                rows_idx = preds.index[preds["global_index"].astype(str) == str(q)].tolist()
                if rows_idx: q_to_rows[q] = rows_idx
        else:
            raise SystemExit("Could not find an 'id' column in predictions CSV to align IDs to rows. For safety we will not force-assign.")

    # assemble list of candidate (row index) to new_species (predicted species)
    to_assign = []
    for q, rows_idx in q_to_rows.items():
        # new species label = the predicted species in preds row (species_col) or the matched header if needed
        for ri in rows_idx:
            curr = preds.loc[ri, "_species_norm"]
            # only consider if currently UNASSIGNED
            if str(curr).upper() == "UNASSIGNED":
                conf_val = float(preds.loc[ri, "_conf_num"])
                novel_val = float(preds.loc[ri, "_novel_comp"])
                # conservative criteria
                if conf_val >= assign_conf_threshold and novel_val <= novel_comp_max:
                    new_label = preds.loc[ri, species_col]
                    to_assign.append((ri, new_label, conf_val, novel_val, q))

    print(f"Candidates meeting conservative criteria to assign: {len(to_assign)}")
    if len(to_assign) == 0:
        print("No safe candidates met the thresholds (no changes made).")
    else:
        # perform assignment on copy and save species counts
        preds_forced = preds.copy()
        for ri, new_label, conf_val, novel_val, q in to_assign:
            preds_forced.at[ri, "_species_norm"] = str(new_label)
        # compute species counts and normalized abundances
        counts = preds_forced["_species_norm"].value_counts().reset_index()
        counts.columns = ["species","count"]
        counts["rel"] = counts["count"] / counts["count"].sum()
        out_forced = extracted / "abundance_forced_assignments.csv"
        counts.to_csv(out_forced, index=False)
        print("Saved forced-assignment species counts ->", out_forced)

        # also save a simple reconciled-style normalized file
        out_recon = extracted / "abundance_forced_assignments_reconciled.csv"
        counts.to_csv(out_recon, index=False)
        print("Saved reconciled-style normalized abundance ->", out_recon)

# ---------- final message ----------
print("\nDone. Files written to:", extracted)
print(" - high_conf_unassigned_matched_mapping.csv")
print(" - high_conf_unassigned_seqs_matched.fasta   (if any matches)")
print(" - high_conf_unassigned_unmatched_ids.txt")
print(" - (if force_assign=True) abundance_forced_assignments.csv and abundance_forced_assignments_reconciled.csv")

print("\nNext recommended actions (pick one):")
print("  1) Inspect 'high_conf_unassigned_matched_mapping.csv' to verify matches (VERY IMPORTANT).")
print("  2) Run BLAST locally on 'high_conf_unassigned_seqs_matched.fasta' or on unmatched IDs' sequences to confirm species identity.")
print("  3) If BLAST confirms match and you trust it, re-run with force_assign=True to absorb matched high-confidence UNASSIGNED into species counts.")
print("\nIf you want, I can now (A) provide the local BLAST command to run against NCBI/your DB, (B) flip force_assign=True and re-run the assignment heuristics here, or (C) produce a small report comparing before/after abundances. Tell me which and I'll give the exact cell/commands.")


Using extracted folder: C:\Users\Srijit\sih\ncbi_blast_db\extracted
Read 86 high-confidence UNASSIGNED IDs from high_conf_unassigned_ids.txt
Using FASTA: C:\Users\Srijit\sih\ncbi_blast_db\extracted\its_combined.fasta
Parsed FASTA headers: 699

--- Preview high-conf IDs (first 80) ---
  1. JBPZNU010001377.1
  2. LC876572.1
  3. PX277059.1
  4. LC876609.1
  5. PX273803.1
  6. JF836109.1
  7. JBPZNU010001350.1
  8. JBPZNU010001331.1
  9. JF836111.1
 10. PX278865.1
 11. PX277079.1
 12. PX278850.1
 13. PX278853.1
 14. LC876580.1
 15. LC876528.1
 16. LC876609.1
 17. LC876593.1
 18. JBPZNU010001303.1
 19. LC876579.1
 20. LC876562.1
 21. LC876572.1
 22. LC876528.1
 23. LC876567.1
 24. LC876579.1
 25. JF836094.1
 26. JBPZNU010001397.1
 27. LC876593.1
 28. PX278864.1
 29. JBPZNU010001375.1
 30. JF836094.1
 31. LC876521.1
 32. OQ241938.1
 33. LC876568.1
 34. XR_013090605.1
 35. LC876580.1
 36. JF836095.1
 37. JBPZNU010001366.1
 38. LC876553.1
 39. JBPZNU010001353.1
 40. JBPZNU010001377.1
 41. LC8

In [103]:
# CELL 1: Export unmatched / high-conf UNASSIGNED sequences for BLAST
# Safe by default: just writes FASTA(s) and mapping files. Do not change labels here.
from pathlib import Path
import re, csv

# --- user-tweakable (usually no need) ---
EXTRACTED_CANDIDATES = [
    Path.cwd() / "sih" / "ncbi_blast_db" / "extracted",
    Path.cwd() / "ncbi_blast_db" / "extracted",
    Path.cwd() / "extracted",
    Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"),
    Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
]
# ----------------------------------------------------------------

def find_extracted():
    for p in EXTRACTED_CANDIDATES:
        if p.exists() and p.is_dir():
            # need predictions file present to be confident it's the right folder
            if (p/"predictions_with_uncertainty.csv").exists() or (p/"predictions.csv").exists():
                return p.resolve()
    # fallback: find any folder with high_conf_unassigned_ids.txt
    for f in Path.cwd().rglob("high_conf_unassigned_ids.txt"):
        return f.parent.resolve()
    # fallback: any folder named extracted under cwd
    for f in Path.cwd().rglob("**/extracted"):
        if f.is_dir(): return f.resolve()
    raise SystemExit("Could not locate extracted/ folder. Place previous outputs in an extracted/ folder and re-run.")

extracted = find_extracted()
print("Using extracted folder:", extracted)

# read high-conf ids
ids_path = extracted / "high_conf_unassigned_ids.txt"
rows_path = extracted / "high_conf_unassigned_rows.csv"
ids = []
if ids_path.exists():
    with open(ids_path, "r", encoding="utf8", errors="replace") as fh:
        ids = [line.strip() for line in fh if line.strip()]
    print(f"Read {len(ids)} IDs from {ids_path.name}")
elif rows_path.exists():
    import pandas as pd
    df = pd.read_csv(rows_path)
    # choose a likely id column
    id_col = next((c for c in df.columns if c.lower() in ("id","accession","seqid","read_id","readid")), df.columns[0])
    ids = df[id_col].astype(str).tolist()
    print(f"Extracted {len(ids)} IDs from {rows_path.name} using column '{id_col}'")
else:
    raise SystemExit("No high_conf_unassigned_ids.txt or high_conf_unassigned_rows.csv found in extracted/. Re-run earlier cell.")

# find a FASTA
fasta = None
for name in ("its_combined.fasta","predictions_source_sequences.fasta","source_sequences.fasta","sequences.fasta","reads.fasta"):
    p = extracted / name
    if p.exists(): fasta = p; break
if fasta is None:
    # pick any fasta in folder
    fastas = list(extracted.rglob("*.fasta")) + list(extracted.rglob("*.fa"))
    if fastas:
        fasta = fastas[0]
if fasta is None:
    raise SystemExit("No FASTA found in extracted/. Place your source FASTA there (common name: its_combined.fasta) and re-run.")

print("Found FASTA:", fasta)

# simple FASTA parser
def parse_fasta(path):
    seqs = {}
    with open(path, "r", encoding="utf8", errors="replace") as fh:
        header=None; seq_lines=[]
        for line in fh:
            line = line.rstrip("\n\r")
            if not line: continue
            if line.startswith(">"):
                if header is not None:
                    seqs[header.split()[0]] = (header, "".join(seq_lines))
                header = line[1:].strip()
                seq_lines=[]
            else:
                seq_lines.append(line.strip())
        if header is not None:
            seqs[header.split()[0]] = (header, "".join(seq_lines))
    return seqs

seqs = parse_fasta(fasta)
print("Parsed FASTA headers:", len(seqs))

def canonical_forms(x):
    x = str(x).strip()
    forms = [x]
    # strip version .1 .2
    nov = re.sub(r'\.\d+$','', x)
    if nov != x: forms.append(nov)
    if "|" in x:
        parts = re.split(r"[|/]", x)
        for t in parts:
            t = t.strip()
            if t: forms.append(t)
    # drop common prefixes
    for f in list(forms):
        f2 = re.sub(r'^(?:ref|gi|gb|emb|dbj)[\|:]', '', f, flags=re.IGNORECASE)
        forms.append(f2)
    # lower-case variants
    for f in list(forms):
        if f.lower() not in forms: forms.append(f.lower())
    # unique preserve order
    out=[]; seen=set()
    for s in forms:
        if s and s not in seen:
            seen.add(s); out.append(s)
    return out

# build header map
header_map = {}
for hid,(hdr,seq) in seqs.items():
    for f in canonical_forms(hid):
        header_map[f] = (hid, hdr, seq)
    header_map[hdr] = (hid, hdr, seq)
    header_map[hdr.lower()] = (hid, hdr, seq)

matched = {}
unmatched = []
for q in ids:
    found=False
    for qf in canonical_forms(q):
        if qf in header_map:
            matched[q] = header_map[qf]; found=True; break
    if not found:
        # substring fallback
        q_nov = re.sub(r'\.\d+$','', q)
        for hid,(hdr,seq) in seqs.items():
            if q in hid or q in hdr or q_nov in hid or q_nov in hdr:
                matched[q] = (hid, hdr, seq); found=True; break
    if not found:
        unmatched.append(q)

print(f"Matched {len(matched)} / {len(ids)} IDs. Unmatched: {len(unmatched)}")

# write matched FASTA if any
if matched:
    out_matched = extracted / "high_conf_unassigned_seqs_matched.fasta"
    with open(out_matched, "w", encoding="utf8") as fh:
        for q,(hid,hdr,seq) in matched.items():
            fh.write(f">{q} matched_header={hid} {hdr}\n")
            for i in range(0, len(seq), 80):
                fh.write(seq[i:i+80] + "\n")
    print("Wrote matched FASTA ->", out_matched)

# write unmatched id list
out_unmatched = extracted / "high_conf_unassigned_unmatched_ids.txt"
with open(out_unmatched, "w", encoding="utf8") as fh:
    for q in unmatched:
        fh.write(q + "\n")
print("Wrote unmatched ID list ->", out_unmatched)

# If no matched sequences, create an ALL-UNASSIGNED FASTA to BLAST (the entire its_combined.fasta)
if len(matched)==0:
    out_all = extracted / "all_unassigned_for_blast.fasta"
    # just copy the original FASTA so you can BLAST everything (safer than failing)
    import shutil
    shutil.copyfile(fasta, out_all)
    print("No direct header matches: copied the full FASTA for BLAST ->", out_all)
else:
    print("You can BLAST the matched FASTA, and optionally BLAST unmatched IDs by extracting sequences manually or BLASTing the full FASTA as fallback.")

print("\nFiles written in:", extracted)


Using extracted folder: C:\Users\Srijit\sih\ncbi_blast_db\extracted
Read 86 IDs from high_conf_unassigned_ids.txt
Found FASTA: C:\Users\Srijit\sih\ncbi_blast_db\extracted\its_combined.fasta
Parsed FASTA headers: 699
Matched 0 / 86 IDs. Unmatched: 86
Wrote unmatched ID list -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\high_conf_unassigned_unmatched_ids.txt
No direct header matches: copied the full FASTA for BLAST -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\all_unassigned_for_blast.fasta

Files written in: C:\Users\Srijit\sih\ncbi_blast_db\extracted


In [105]:
# CELL 2: BLAST command builder (+ optional local run if blastn is installed)
from pathlib import Path
import shutil, subprocess, sys

extracted = None
for p in [
    Path.cwd() / "sih" / "ncbi_blast_db" / "extracted",
    Path.cwd() / "ncbi_blast_db" / "extracted",
    Path.cwd() / "extracted",
    Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted")
]:
    if p.exists(): extracted = p; break
if extracted is None:
    # try recursive
    for p in Path.cwd().rglob("extracted"):
        if p.is_dir(): extracted = p; break
if extracted is None:
    raise SystemExit("Could not locate extracted/ folder.")

# choose query FASTA
f_matched = extracted / "high_conf_unassigned_seqs_matched.fasta"
f_all = extracted / "all_unassigned_for_blast.fasta"
query = f_matched if f_matched.exists() else (f_all if f_all.exists() else None)
if query is None:
    raise SystemExit("No query FASTA found (run Cell 1).")

print("Query FASTA to BLAST:", query)

# recommended BLAST commands (print - do not run by default)
print("\nLocal BLAST (requires a local nt/your DB):")
print('blastn -query "{}" -db nt -outfmt "6 qseqid sseqid pident length qlen qstart qend sstart send evalue bitscore stitle" -max_target_seqs 5 -evalue 1e-10 -num_threads 8 -out "{}"'.format(query, extracted/"blast_results.tsv"))

print("\nRemote BLAST using NCBI (blast+ supports -remote):")
print('blastn -query "{}" -db nt -remote -outfmt "6 qseqid sseqid pident length qlen qstart qend sstart send evalue bitscore stitle" -max_target_seqs 5 -evalue 1e-10 -out "{}"'.format(query, extracted/"blast_results.tsv"))

# Optional: try to run locally if blastn found and user approves
blastn_path = shutil.which("blastn")
if blastn_path:
    print("\nblastn detected at:", blastn_path)
    run_now = False  # keep default False to avoid accidental long runs; set True to run immediately
    if run_now:
        out = extracted / "blast_results.tsv"
        cmd = [
            blastn_path, "-query", str(query), "-db", "nt",
            "-outfmt", "6 qseqid sseqid pident length qlen qstart qend sstart send evalue bitscore stitle",
            "-max_target_seqs", "5", "-evalue", "1e-10", "-num_threads", "8", "-out", str(out)
        ]
        print("Running:", " ".join(cmd))
        subprocess.run(cmd, check=True)
        print("blast finished, results:", out)
    else:
        print("To run locally automatically: set run_now=True in this cell (only do this if you have a local BLAST DB and know what you're doing).")
else:
    print("\nblastn not found on PATH. To run BLAST locally install BLAST+ or run the printed command with -remote (requires internet).")

print("\nAfter running BLAST, place 'blast_results.tsv' into the extracted/ folder and run Cell 3 to parse and update abundances.")


Query FASTA to BLAST: C:\Users\Srijit\sih\ncbi_blast_db\extracted\all_unassigned_for_blast.fasta

Local BLAST (requires a local nt/your DB):
blastn -query "C:\Users\Srijit\sih\ncbi_blast_db\extracted\all_unassigned_for_blast.fasta" -db nt -outfmt "6 qseqid sseqid pident length qlen qstart qend sstart send evalue bitscore stitle" -max_target_seqs 5 -evalue 1e-10 -num_threads 8 -out "C:\Users\Srijit\sih\ncbi_blast_db\extracted\blast_results.tsv"

Remote BLAST using NCBI (blast+ supports -remote):
blastn -query "C:\Users\Srijit\sih\ncbi_blast_db\extracted\all_unassigned_for_blast.fasta" -db nt -remote -outfmt "6 qseqid sseqid pident length qlen qstart qend sstart send evalue bitscore stitle" -max_target_seqs 5 -evalue 1e-10 -out "C:\Users\Srijit\sih\ncbi_blast_db\extracted\blast_results.tsv"

blastn not found on PATH. To run BLAST locally install BLAST+ or run the printed command with -remote (requires internet).

After running BLAST, place 'blast_results.tsv' into the extracted/ folder a

In [119]:
# Robust cell: compute abundances (raw + confidence-weighted) and, if available,
# perform confusion-aware NNLS deconvolution to estimate true species abundances.
# - No BLAST, no external calls, no hardcoded single path.
# - Writes CSVs into the same 'extracted' folder where the predictions file was found.
# - Non-destructive: does NOT change per-read species labels (only writes abundance CSVs).
import sys, os, math, time, re
from pathlib import Path
import numpy as np
import pandas as pd

def find_extracted_folder():
    # Common candidate locations (non-exhaustive). If not found, search for predictions CSV recursively.
    candidates = [
        Path.cwd() / "sih" / "ncbi_blast_db" / "extracted",
        Path.cwd() / "ncbi_blast_db" / "extracted",
        Path.cwd() / "extracted",
        Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"),
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
        Path(r"C:\Users\HP\sihabundance\ncbi_blast_db\extracted"),
    ]
    for p in candidates:
        if p.exists() and p.is_dir():
            if (p/"predictions_with_uncertainty.csv").exists() or (p/"predictions.csv").exists():
                return p.resolve()
    # recursive search for predictions file
    for p in Path.cwd().rglob("predictions_with_uncertainty.csv"):
        return p.parent.resolve()
    for p in Path.cwd().rglob("predictions.csv"):
        return p.parent.resolve()
    return None

def detect_columns(df):
    low = {c.lower(): c for c in df.columns}
    species_col = next((low[k] for k in low if "species_pred_label" in k or k=="species_pred_label" or ("species" in k and "pred" in k)), None)
    if species_col is None:
        # fallback: any column with 'species' in name
        species_col = next((c for c in df.columns if "species" in c.lower()), None)
    conf_col = next((low[k] for k in low if "species_pred_conf" in k or "species_conf" in k or "species_prob" in k or "pred_conf" in k), None)
    genus_col = next((low[k] for k in low if "genus_pred" in k or "genus" in k), None)
    family_col = next((low[k] for k in low if "family_pred" in k or "family" in k), None)
    id_col = next((low[k] for k in low if k in ("id","accession","seqid","read_id","readid","accession_id","global_index")), None)
    return species_col, conf_col, genus_col, family_col, id_col

def safe_to_string_series(s):
    return s.fillna("").astype(str).apply(lambda x: " ".join(x.split())).replace({"": "UNASSIGNED"})

def nnls_projected_gradient(M, b, max_iter=5000, tol=1e-6, verbose=False):
    # Solve M x = b for x >= 0 by projected gradient descent.
    # M: (m x n), b: (m,)
    m, n = M.shape
    # initialize with non-negative least squares-ish start:
    # use least squares if possible
    try:
        x0 = np.linalg.lstsq(M, b, rcond=None)[0]
        x = np.maximum(0.0, x0)
    except Exception:
        x = np.maximum(0.0, np.ones(n) * (b.sum() / max(1, n)))
    # Lipschitz constant approx = ||M||_2^2
    try:
        s = np.linalg.svd(M, compute_uv=False)
        L = (s[0] ** 2) if s.size>0 else (np.linalg.norm(M, ord=2)**2)
    except Exception:
        L = (np.linalg.norm(M, ord=2)**2) + 1e-8
    lr = 1.0 / (L + 1e-12)
    prev_norm = None
    for it in range(max_iter):
        r = M.dot(x) - b                    # residual (m,)
        grad = M.T.dot(r)                   # (n,)
        x -= lr * grad
        # project
        x = np.maximum(0.0, x)
        # stopping check on residual norm change
        rnorm = np.linalg.norm(r)
        if prev_norm is not None and abs(prev_norm - rnorm) < tol:
            if verbose: print(f"nnls converged iter {it}, residual {rnorm:.6g}")
            break
        prev_norm = rnorm
    return x

# === main ===
extracted = find_extracted_folder()
if extracted is None:
    print("Could not automatically locate an 'extracted/' folder with predictions in the notebook tree.")
    print("Please ensure you ran earlier classification cells and that the 'extracted' folder is accessible from the notebook.")
else:
    print("Using extracted folder:", extracted)
    # find predictions file
    preds_path = (extracted / "predictions_with_uncertainty.csv")
    if not preds_path.exists():
        preds_path = (extracted / "predictions.csv") if (extracted / "predictions.csv").exists() else None
    if preds_path is None:
        print("No predictions CSV found in extracted/. Aborting abundance computation.")
    else:
        print("Loading predictions:", preds_path.name)
        preds = pd.read_csv(preds_path)
        species_col, conf_col, genus_col, family_col, id_col = detect_columns(preds)
        if species_col is None:
            # defensive fallback: take first column named 'species' or last column
            species_col = next((c for c in preds.columns if "species" in c.lower()), preds.columns[-1])
            print("Warning: could not auto-detect species_pred_label column. Using:", species_col)
        print("Detected species column:", species_col, "| confidence column:", conf_col, "| id column:", id_col)

        # normalize
        preds["_species_norm"] = safe_to_string_series(preds[species_col])
        if conf_col and conf_col in preds.columns:
            preds["_conf_num"] = pd.to_numeric(preds[conf_col], errors="coerce").fillna(0.0)
        else:
            preds["_conf_num"] = 0.0

        # RAW counts
        counts = preds["_species_norm"].value_counts().reset_index()
        counts.columns = ["species","count"]
        counts["relative_abundance"] = counts["count"] / counts["count"].sum() if counts["count"].sum()>0 else 0.0
        out_raw = extracted / "abundance_from_predictions.csv"
        counts.to_csv(out_raw, index=False)
        print(f"[SAVED] raw count-based abundance -> {out_raw}  (rows={len(counts)})")

        # Confidence-weighted abundance (sum of confidences per predicted label)
        weighted = preds.groupby("_species_norm")["_conf_num"].sum().reset_index().rename(columns={"_conf_num":"conf_sum"})
        total_conf = weighted["conf_sum"].sum() if weighted["conf_sum"].sum()>0 else 1.0
        weighted["relative_abundance_weighted"] = weighted["conf_sum"] / total_conf
        out_weighted = extracted / "abundance_from_predictions_weighted.csv"
        weighted.to_csv(out_weighted, index=False)
        print(f"[SAVED] confidence-weighted abundance -> {out_weighted}  (rows={len(weighted)})")

        # Attempt confusion-aware deconvolution if a validation predictions file exists
        # Search for val_predictions_calibrated.csv, val_predictions.csv or val_predictions_*.csv
        val_candidates = [
            extracted / "val_predictions_calibrated.csv",
            extracted / "val_predictions.csv",
            extracted / "val_predictions_calibrated.tsv",
            extracted / "val_predictions.tsv",
            extracted / "val_predictions_calibrated_*.csv"
        ]
        val_path = None
        for p in val_candidates:
            if p.exists():
                val_path = p; break
        if val_path is None:
            # recursive search for something that looks like validation preds
            for p in extracted.rglob("val_predictions*"):
                if p.is_file():
                    val_path = p; break

        if val_path is None:
            print("Validation predictions not found; skipping confusion-based deconvolution.")
            print("You can still use the raw and confidence-weighted abundance CSVs above.")
        else:
            print("Loading validation predictions for confusion matrix:", val_path.name)
            try:
                val = pd.read_csv(val_path)
            except Exception:
                val = pd.read_csv(val_path, sep="\t", engine="python", error_bad_lines=False)

            # detect val columns
            # true column: look for 'true' + 'species'
            val_true_col = next((c for c in val.columns if "true" in c.lower() and "species" in c.lower()), None)
            if val_true_col is None:
                # try other heuristics
                val_true_col = next((c for c in val.columns if "species_true" in c.lower() or c.lower()=="species_true_idx"), None)
            val_pred_col = next((c for c in val.columns if "species_pred_label" in c.lower() or "species_pred" in c.lower()), None)
            if val_pred_col is None:
                val_pred_col = next((c for c in val.columns if "pred" in c.lower() and "species" in c.lower()), None)

            if val_true_col is None or val_pred_col is None:
                print("Could not detect both true/pred columns in validation file; skipping deconvolution.")
            else:
                # normalize
                val["_true_norm"] = safe_to_string_series(val[val_true_col])
                val["_pred_norm"] = safe_to_string_series(val[val_pred_col])

                classes_true = sorted(val["_true_norm"].unique().tolist())
                classes_pred = sorted(preds["_species_norm"].unique().tolist())

                # Build confusion matrix A (shape n_true x n_pred): P(pred=j | true=i)
                n_true = len(classes_true)
                n_pred = len(classes_pred)
                A = np.zeros((n_true, n_pred), dtype=float)
                # for each true class, compute distribution of predicted labels
                for i, t in enumerate(classes_true):
                    subset = val[val["_true_norm"] == t]
                    if len(subset) == 0:
                        continue
                    vc = subset["_pred_norm"].value_counts()
                    for j, p_label in enumerate(classes_pred):
                        A[i, j] = vc.get(p_label, 0) / len(subset)

                # predicted counts vector p (length n_pred) from preds
                pred_counts_map = dict(zip(counts["species"], counts["count"]))
                p_vec = np.array([pred_counts_map.get(lbl, 0.0) for lbl in classes_pred], dtype=float)

                # Solve A^T x = p  -> M x = p where M = A^T (shape n_pred x n_true)
                M = A.T  # shape (n_pred, n_true)
                # if matrix is all zeros (rare), skip
                if M.size == 0 or np.allclose(M, 0.0):
                    print("Confusion matrix is all zeros (no valid entries); cannot deconvolve. Skipping.")
                else:
                    print("Running NNLS (projected gradient) to estimate true counts ...")
                    x = nnls_projected_gradient(M, p_vec, max_iter=5000, tol=1e-6, verbose=False)
                    # numeric cleanup
                    x = np.maximum(0.0, x)
                    total_est = x.sum()
                    # If estimated total is 0, skip
                    if total_est <= 0:
                        print("Deconvolution produced zero total mass; skipping result save.")
                    else:
                        # produce dataframe mapping true species -> pred_count (if same label) -> est_true_count
                        rows = []
                        for i, t in enumerate(classes_true):
                            pred_count_for_t = pred_counts_map.get(t, 0)
                            est_true_count = float(x[i]) if i < len(x) else 0.0
                            rows.append({
                                "species": t,
                                "pred_count": int(pred_count_for_t),
                                "pred_count_rel": pred_count_for_t / p_vec.sum() if p_vec.sum()>0 else 0.0,
                                "est_true_count": est_true_count,
                                "est_true_rel": est_true_count / total_est if total_est>0 else 0.0
                            })
                        deconv_df = pd.DataFrame(rows).sort_values("est_true_count", ascending=False).reset_index(drop=True)
                        out_deconv = extracted / "abundance_from_predictions_deconvolved.csv"
                        deconv_df.to_csv(out_deconv, index=False)
                        print(f"[SAVED] deconvolved estimates -> {out_deconv}  (rows={len(deconv_df)})")
                        # also save a reconciled species abundance CSV similar to prior pipeline
                        out_recon = extracted / "abundance_reconciled_species.csv"
                        deconv_df[["species","est_true_rel"]].rename(columns={"est_true_rel":"est_rel"}).to_csv(out_recon, index=False)
                        print(f"[SAVED] reconciled/species-estimates -> {out_recon}")

        # final sample prints
        print("\nSAMPLE: top 10 raw counts")
        print(counts.head(10).to_string(index=False))
        print("\nSAMPLE: top 10 confidence-weighted")
        print(weighted.sort_values("conf_sum", ascending=False).head(10).to_string(index=False))
        print("\nDONE. Files written to:", extracted)
        print(" -", out_raw.name)
        print(" -", out_weighted.name)
        # deconv files may or may not exist


Using extracted folder: C:\Users\Srijit\sih\ncbi_blast_db\extracted
Loading predictions: predictions_with_uncertainty.csv
Detected species column: species_pred_idx | confidence column: kingdom_pred_conf | id column: global_index
[SAVED] raw count-based abundance -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_from_predictions.csv  (rows=52)
[SAVED] confidence-weighted abundance -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_from_predictions_weighted.csv  (rows=52)
Loading validation predictions for confusion matrix: val_predictions_calibrated.csv
Running NNLS (projected gradient) to estimate true counts ...
[SAVED] deconvolved estimates -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_from_predictions_deconvolved.csv  (rows=52)
[SAVED] reconciled/species-estimates -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_reconciled_species.csv

SAMPLE: top 10 raw counts
species  count  relative_abundance
    180    144            0.378947
    110     80       

In [123]:
# CELL: Robust BLAST-results parser + conservative application of BLAST species to UNASSIGNED rows
# - Place this cell into your notebook and run it (no other edits required unless you want to tune thresholds).
# - Outputs (in extracted/):
#    - blast_species_assignments.csv         (audit of chosen BLAST-based species assignments)
#    - predictions_blast_enriched.csv        (predictions with blast columns added)
#    - predictions_after_blast_forced.csv    (if APPLY_ASSIGNMENTS=True and rows changed)
#    - abundance_after_blast.csv             (species counts after applying BLAST)
#
# NOTES:
# - This cell WILL ONLY REPLACE labels for rows that are currently UNASSIGNED (conservative).
# - If you want it to try remote BLAST automatically, I included an optional Biopython section (disabled by default).
# - This cell is robust: it searches for BLAST result files with names like 'blast*', handles multiple column formats,
#   and will NOT crash — it prints clear diagnostics if no BLAST file is found.

from pathlib import Path
import pandas as pd
import numpy as np
import re
import sys
import warnings
warnings.filterwarnings("ignore")

# ----------------- USER-CONFIG (tweakable) -----------------
MIN_PID = 97.0            # min percent identity (accept BLAST hit)
MIN_COV = 80.0            # min coverage % (alignment length / qlen *100)
MAX_EVAL = 1e-6           # max e-value to accept
APPLY_ASSIGNMENTS = True  # If True, we actually write predictions_after_blast_forced.csv with replacements
# ----------------------------------------------------------

def find_extracted_folder():
    # try common locations seen in your session
    candidates = [
        Path.cwd() / "sih" / "ncbi_blast_db" / "extracted",
        Path.cwd() / "ncbi_blast_db" / "extracted",
        Path.cwd() / "extracted",
        Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"),
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sih\ncbi_blast_db\extracted"),
    ]
    for p in candidates:
        if p.exists() and p.is_dir():
            return p.resolve()
    # fallback: search for any folder named "extracted" under cwd
    for p in Path.cwd().rglob("**/extracted"):
        if p.is_dir():
            return p.resolve()
    return None

def choose_species_column(df):
    # pick the best species column: prefer '*species*label*' or a column with many non-numeric values
    cols = list(df.columns)
    cand = [c for c in cols if 'species' in c.lower()]
    if not cand:
        return None
    # prefer label-like names
    for k in cand:
        if 'label' in k.lower() or 'name' in k.lower():
            return k
    # prefer column that is mostly non-numeric (human-readable)
    def non_numeric_ratio(series):
        s = series.dropna().astype(str).head(1000)
        if len(s) == 0:
            return 0.0
        nonnum = [(not is_float(x)) for x in s]
        return float(sum(nonnum)) / len(s)
    for k in cand:
        try:
            r = non_numeric_ratio(df[k])
            if r > 0.5:
                return k
        except Exception:
            continue
    # fallback: return first candidate
    return cand[0]

def choose_confidence_column(df, species_col):
    # try to find column that contains both 'species' and 'conf' or 'prob' or 'mean_topprob'
    for c in df.columns:
        lc = c.lower()
        if 'species' in lc and any(x in lc for x in ('conf','prob','mean','topprob','score','confidence')):
            # ensure it's numeric-ish
            try:
                pd.to_numeric(df[c].dropna().astype(str).head(200))
                return c
            except Exception:
                pass
    # otherwise pick any numeric column with 'conf' or 'prob' in name
    for c in df.columns:
        lc = c.lower()
        if any(x in lc for x in ('conf','prob','confidence','topprob')):
            try:
                pd.to_numeric(df[c].dropna().astype(str).head(200))
                return c
            except Exception:
                pass
    # last resort: None
    return None

def choose_id_column(df):
    names = ['id','accession','seqid','read_id','readid','accession_id','global_index','query_id']
    cols = {c.lower(): c for c in df.columns}
    for n in names:
        if n in cols:
            return cols[n]
    # fallback: if there is a column with  unique values near number of rows, pick the most-id-like
    for c in df.columns:
        if df[c].nunique() >= max(1, int(len(df)*0.5)):
            if 'id' in c.lower() or 'access' in c.lower() or 'acc' in c.lower():
                return c
    # fallback to first column
    return df.columns[0]

def is_float(s):
    try:
        float(str(s))
        return True
    except Exception:
        return False

def find_blast_file(extracted):
    # search for common BLAST result filenames or anything starting with 'blast' in extracted
    patterns = ["blast_results.tsv","blast_results.out","blast_results.txt","blast.tsv","blast.out","blast.txt",
                "*blast*.tsv","*blast*.txt","*blast*.out","blast_results*.tsv","blast*results*.tsv"]
    # try exact names first
    for name in patterns[:6]:
        p = extracted / name
        if p.exists():
            return p
    # glob search
    for pat in patterns[6:]:
        hits = list(extracted.glob(pat))
        if hits:
            # choose the largest (likely the actual results) or the first
            hits = sorted(hits, key=lambda p: p.stat().st_size if p.exists() else 0, reverse=True)
            return hits[0]
    # also search recursively for any file whose name contains 'blast' or 'blast_results'
    for p in extracted.rglob("*"):
        if p.is_file() and 'blast' in p.name.lower():
            return p
    return None

def parse_blast_tsv(blast_path):
    # Try reading tab-separated BLAST outfmt6 (no header) with 12 columns:
    cols12 = ["qseqid","sseqid","pident","length","qlen","qstart","qend","sstart","send","evalue","bitscore","stitle"]
    try:
        df = pd.read_csv(blast_path, sep="\t", header=None, quoting=3, dtype=str, engine='python')
        if df.shape[1] >= 12:
            df = df.iloc[:, :12]
            df.columns = cols12
            return df
        # maybe file has header with these names
        dfh = pd.read_csv(blast_path, sep="\t", header=0, engine='python', dtype=str)
        # normalize columns if present
        found = [c for c in dfh.columns if c.lower() in cols12]
        if found:
            # try to map
            lower_map = {c.lower(): c for c in dfh.columns}
            cols_map = {name: lower_map.get(name) for name in cols12 if lower_map.get(name) is not None}
            dfh = dfh.rename(columns={v:k for k,v in cols_map.items()})
            return dfh
    except Exception as e:
        # try a more permissive read
        try:
            text = blast_path.read_text(errors='ignore')
            # simple heuristic: lines with >=11 tabs -> outfmt6
            lines = [L for L in text.splitlines() if L.count('\t')>=11]
            if lines:
                from io import StringIO
                df = pd.read_csv(StringIO("\n".join(lines)), sep="\t", header=None, dtype=str)
                if df.shape[1] >= 12:
                    df = df.iloc[:, :12]
                    df.columns = cols12
                    return df
        except Exception:
            pass
    return None

def extract_species_from_title(title):
    if not isinstance(title, str) or not title.strip():
        return None
    title = title.strip()
    # 1) binomial: Genus species (Genus capitalized, species lowercase)
    m = re.search(r'\b([A-Z][a-zA-Z\-]+)\s+([a-z][a-zA-Z\-\(\)]+)\b', title)
    if m:
        return f"{m.group(1)} {m.group(2)}"
    # 2) Genus sp. or Genus sp
    m2 = re.search(r'\b([A-Z][a-zA-Z\-]+)\s+sp\b', title)
    if m2:
        return f"{m2.group(1)} sp."
    # 3) at least a capitalized Genus word (be conservative: return None in this case)
    # Avoid guessing from e.g. 'uncultured bacterium' etc.
    return None

# --- main ---
print("Searching for 'extracted/' folder...")
extracted = find_extracted_folder()
if extracted is None:
    print("ERROR: could not find an 'extracted' directory automatically. Please run the earlier extraction cell or set EXTRACT_DIR to the correct path.")
else:
    print("Using extracted folder:", extracted)

if extracted is None:
    # stop gracefully
    print("No extracted/ found. Aborting this cell (nothing changed).")
else:
    # find predictions CSV
    pred_candidates = ["predictions_with_uncertainty.csv","predictions.csv","predictions_blast_enriched.csv"]
    pred_file = None
    for name in pred_candidates:
        p = extracted / name
        if p.exists():
            pred_file = p
            break
    if pred_file is None:
        # try any CSV with 'predictions' in name
        preds = list(extracted.glob("*predic*.csv"))
        if preds:
            pred_file = preds[0]
    if pred_file is None:
        print("ERROR: cannot find predictions CSV in extracted/. Expected file e.g. predictions_with_uncertainty.csv")
        print("Files in extracted/:")
        for f in sorted(extracted.iterdir()):
            print(" ", f.name)
        print("\nAborting — nothing changed.")
    else:
        print("Loading predictions:", pred_file.name)
        try:
            pred = pd.read_csv(pred_file, dtype=str)
        except Exception as e:
            print("Failed to read predictions CSV:", e)
            pred = pd.read_csv(pred_file, engine='python', dtype=str, error_bad_lines=False)

        # determine core columns robustly
        species_col = choose_species_column(pred)
        conf_col = choose_confidence_column(pred, species_col)
        id_col = choose_id_column(pred)

        print("Detected columns:")
        print(" - species column  ->", species_col)
        print(" - confidence col  ->", conf_col)
        print(" - id column       ->", id_col)

        # find blast results file
        blast_path = find_blast_file(extracted)
        if blast_path is None:
            print("\nNo BLAST results file found in extracted/. Searching for files named like 'blast*.tsv' etc returned nothing.")
            # helpful suggestions
            print("If you already ran BLAST, place the BLAST tabular (outfmt 6) file named 'blast_results.tsv' into:")
            print("  ", extracted)
            print("\nIf you haven't run BLAST, you can run (local blastn or remote NCBI - remote may be slow and has usage limits).")
            print("Example local command (blast+):\n  blastn -query all_unassigned_for_blast.fasta -db nt -outfmt '6 qseqid sseqid pident length qlen qstart qend sstart send evalue bitscore stitle' -max_target_seqs 5 -evalue 1e-10 -out blast_results.tsv -num_threads 8")
            print("\nAfter you run BLAST and put blast_results.tsv into the extracted/ folder, re-run this cell.")
        else:
            print("Found BLAST file:", blast_path.name, "(parsing...)")
            dfb = parse_blast_tsv(blast_path)
            if dfb is None:
                print("Could not parse BLAST file automatically. Inspect file and ensure it is BLAST tabular (-outfmt 6) with columns:")
                print("  qseqid sseqid pident length qlen qstart qend sstart send evalue bitscore stitle")
                print("Aborting (no changes).")
            else:
                # coerce numeric columns
                for col in ("pident","length","qlen","evalue","bitscore"):
                    if col in dfb.columns:
                        dfb[col] = pd.to_numeric(dfb[col], errors='coerce').fillna(0.0)
                # compute coverage safely
                if "qlen" in dfb.columns and "length" in dfb.columns:
                    dfb["coverage_pct"] = dfb.apply(lambda r: (float(r["length"])/float(r["qlen"])*100.0) if (pd.notna(r["qlen"]) and float(r["qlen"])>0) else 0.0, axis=1)
                else:
                    dfb["coverage_pct"] = 0.0

                # conservative filtering
                pre_len = len(dfb)
                dfb_high = dfb[(dfb["pident"] >= MIN_PID) & (dfb["coverage_pct"] >= MIN_COV) & (dfb["evalue"] <= MAX_EVAL)].copy()
                print(f"BLAST rows total={pre_len}, passing thresholds (pident>={MIN_PID}, cov>={MIN_COV}, e<={MAX_EVAL}) = {len(dfb_high)}")

                if dfb_high.empty:
                    print("No BLAST hits passed thresholds. Nothing will be applied.")
                    # still save a tiny audit for the user
                    try:
                        dfb.to_csv(extracted/"blast_parsed_all.csv", index=False)
                        print("Saved parsed BLAST (unfiltered) -> blast_parsed_all.csv")
                    except Exception:
                        pass
                else:
                    # pick top hit per qseqid by bitscore (most reliable)
                    dfb_top = dfb_high.sort_values(["qseqid","bitscore"], ascending=[True, False]).groupby("qseqid", as_index=False).first()
                    print("Unique queries with accepted top-hit:", len(dfb_top))

                    # extract species names conservatively
                    assignments = {}
                    for _, row in dfb_top.iterrows():
                        q = str(row["qseqid"])
                        stitle = row.get("stitle", "")
                        species = extract_species_from_title(stitle)
                        if species is None:
                            # do not guess from non-binomial titles
                            continue
                        assignments[q] = {
                            "species": species,
                            "pident": float(row["pident"]),
                            "coverage": float(row["coverage_pct"]),
                            "evalue": float(row["evalue"]),
                            "sseqid": row.get("sseqid"),
                            "stitle": stitle
                        }
                    print("Conservative, parsed BLAST-derived species assignments:", len(assignments))
                    # save assignments for audit
                    assign_df = pd.DataFrame([{"query_id":q, **v} for q,v in assignments.items()])
                    assign_out = extracted / "blast_species_assignments.csv"
                    assign_df.to_csv(assign_out, index=False)
                    print("Wrote BLAST assignments ->", assign_out.name)

                    # prepare predictions (do not overwrite original file)
                    pred2 = pred.copy()
                    # ensure id and species columns exist
                    if species_col is None:
                        # create one if missing
                        species_col = "species_pred_label"
                        pred2[species_col] = ""
                    if id_col not in pred2.columns:
                        print(f"ERROR: Could not detect id column in predictions to map BLAST qseqid to rows. Columns available: {list(pred2.columns)}")
                        print("Aborting apply. You can still inspect blast_species_assignments.csv manually.")
                    else:
                        # normalized current species label string
                        pred2["_species_norm"] = pred2[species_col].astype(str).fillna("").apply(lambda s: " ".join(str(s).split()))
                        mask_unassigned = pred2["_species_norm"].str.upper().isin(["UNASSIGNED", "", "NONE", "NAN"])
                        n_unassigned_before = int(mask_unassigned.sum())
                        print("UNASSIGNED rows before:", n_unassigned_before)

                        # try multiple matching heuristics; only change rows currently UNASSIGNED
                        rows_to_change = []  # (row_index, new_species, q, info)
                        for q, info in assignments.items():
                            # Exact match
                            exact_mask = pred2[id_col].astype(str) == str(q)
                            if exact_mask.any():
                                for ri in pred2.index[exact_mask]:
                                    if pred2.at[ri, "_species_norm"].upper() in ("UNASSIGNED","", "NONE", "NAN"):
                                        rows_to_change.append((ri, info["species"], q, info))
                                continue
                            # Versionless match: strip .1 etc
                            q_base = q.split('.')[0]
                            try:
                                id_base = pred2[id_col].astype(str).str.split('.').str[0]
                                mask_base = id_base == q_base
                                if mask_base.any():
                                    for ri in pred2.index[mask_base]:
                                        if pred2.at[ri, "_species_norm"].upper() in ("UNASSIGNED","", "NONE", "NAN"):
                                            rows_to_change.append((ri, info["species"], q, info))
                                    continue
                            except Exception:
                                pass
                            # substring match (conservative)
                            try:
                                mask_sub = pred2[id_col].astype(str).str.contains(str(q), na=False)
                                if mask_sub.any():
                                    for ri in pred2.index[mask_sub]:
                                        if pred2.at[ri, "_species_norm"].upper() in ("UNASSIGNED","", "NONE", "NAN"):
                                            rows_to_change.append((ri, info["species"], q, info))
                                    continue
                            except Exception:
                                pass
                            # no match -> skip

                        # de-duplicate picking best pident/coverage if multiple assignments for same row
                        row_best = {}
                        for ri, new_sp, q, info in rows_to_change:
                            key = int(ri)
                            if key not in row_best:
                                row_best[key] = (new_sp, q, info)
                            else:
                                existing = row_best[key][2]
                                # prefer higher pident, then higher coverage
                                if info.get("pident",0) > existing.get("pident",0) or (info.get("pident",0)==existing.get("pident",0) and info.get("coverage",0) > existing.get("coverage",0)):
                                    row_best[key] = (new_sp, q, info)
                        final_rows = sorted([(ri, *row_best[ri]) for ri in row_best])
                        n_candidates = len(final_rows)
                        print("Candidate UNASSIGNED rows matching BLAST assignments:", n_candidates)

                        if n_candidates == 0:
                            print("No UNASSIGNED rows matched conservative BLAST assignments (nothing applied).")
                        else:
                            # apply if user wants
                            if APPLY_ASSIGNMENTS:
                                for ri, new_sp, q, info in final_rows:
                                    pred2.at[ri, species_col] = new_sp
                                    pred2.at[ri, "blast_assigned_species"] = new_sp
                                    pred2.at[ri, "blast_assigned_query"] = q
                                    pred2.at[ri, "blast_pident"] = info["pident"]
                                    pred2.at[ri, "blast_coverage"] = info["coverage"]
                                    pred2.at[ri, "blast_evalue"] = info["evalue"]
                                    pred2.at[ri, "blast_sseqid"] = info.get("sseqid")
                                    pred2.at[ri, "blast_stitle"] = info.get("stitle")
                                out_forced = extracted / "predictions_after_blast_forced.csv"
                                pred2.to_csv(out_forced, index=False)
                                print(f"Applied BLAST assignments to {len(final_rows)} UNASSIGNED rows -> saved {out_forced.name}")
                            else:
                                print(f"APPLY_ASSIGNMENTS=False; {len(final_rows)} candidate rows identified but not applied.")

                        # always write enriched copy (with BLAST columns added)
                        out_enriched = extracted / "predictions_blast_enriched.csv"
                        pred2.to_csv(out_enriched, index=False)
                        print("Wrote enriched predictions (audit) ->", out_enriched.name)

                        # recompute abundance counts from updated predictions
                        pred2["_species_norm_final"] = pred2[species_col].astype(str).fillna("").apply(lambda s: " ".join(str(s).split()))
                        counts = pred2["_species_norm_final"].value_counts(dropna=False).reset_index()
                        counts.columns = ["species","count"]
                        counts["count"] = pd.to_numeric(counts["count"], errors='coerce').fillna(0).astype(int)
                        total = counts["count"].sum()
                        counts["rel"] = counts["count"] / total if total>0 else 0.0
                        out_abund = extracted / "abundance_after_blast.csv"
                        counts.to_csv(out_abund, index=False)
                        print("Wrote abundance_after_blast.csv ->", out_abund.name)
                        print("\nTop species after applying BLAST assignments (top 20):")
                        print(counts.head(20).to_string(index=False))

                        n_unassigned_after = int((pred2["_species_norm_final"].str.upper().isin(["UNASSIGNED","", "NONE", "NAN"])).sum())
                        print(f"\nSUMMARY: UNASSIGNED before = {n_unassigned_before}; UNASSIGNED after = {n_unassigned_after}; filled = {n_unassigned_before - n_unassigned_after}")

print("\nCell finished. Check the extracted/ folder for generated audit files.")


Searching for 'extracted/' folder...
Using extracted folder: C:\Users\Srijit\sih\ncbi_blast_db\extracted
Loading predictions: predictions_with_uncertainty.csv
Detected columns:
 - species column  -> species_pred_label
 - confidence col  -> species_pred_conf
 - id column       -> id
Found BLAST file: all_unassigned_for_blast.fasta (parsing...)
Could not parse BLAST file automatically. Inspect file and ensure it is BLAST tabular (-outfmt 6) with columns:
  qseqid sseqid pident length qlen qstart qend sstart send evalue bitscore stitle
Aborting (no changes).

Cell finished. Check the extracted/ folder for generated audit files.


In [127]:
# CELL: Robust BLAST-results parser + conservative application of BLAST species to UNASSIGNED rows
# - Place this cell into your notebook and run it (no other edits required unless you want to tune thresholds).
# - Outputs (in extracted/):
#    - blast_species_assignments.csv         (audit of chosen BLAST-based species assignments)
#    - predictions_blast_enriched.csv        (predictions with blast columns added)
#    - predictions_after_blast_forced.csv    (if APPLY_ASSIGNMENTS=True and rows changed)
#    - abundance_after_blast.csv             (species counts after applying BLAST)
#
# NOTES:
# - This cell WILL ONLY REPLACE labels for rows that are currently UNASSIGNED (conservative).
# - If you want it to try remote BLAST automatically, I included an optional Biopython section (disabled by default).
# - This cell is robust: it searches for BLAST result files with names like 'blast*', handles multiple column formats,
#   and will NOT crash — it prints clear diagnostics if no BLAST file is found.

from pathlib import Path
import pandas as pd
import numpy as np
import re
import sys
import warnings
warnings.filterwarnings("ignore")

# ----------------- USER-CONFIG (tweakable) -----------------
MIN_PID = 97.0            # min percent identity (accept BLAST hit)
MIN_COV = 80.0            # min coverage % (alignment length / qlen *100)
MAX_EVAL = 1e-6           # max e-value to accept
APPLY_ASSIGNMENTS = True  # If True, we actually write predictions_after_blast_forced.csv with replacements
# ----------------------------------------------------------

def find_extracted_folder():
    # try common locations seen in your session
    candidates = [
        Path.cwd() / "sih" / "ncbi_blast_db" / "extracted",
        Path.cwd() / "ncbi_blast_db" / "extracted",
        Path.cwd() / "extracted",
        Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"),
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sih\ncbi_blast_db\extracted"),
    ]
    for p in candidates:
        if p.exists() and p.is_dir():
            return p.resolve()
    # fallback: search for any folder named "extracted" under cwd
    for p in Path.cwd().rglob("**/extracted"):
        if p.is_dir():
            return p.resolve()
    return None

def choose_species_column(df):
    # pick the best species column: prefer '*species*label*' or a column with many non-numeric values
    cols = list(df.columns)
    cand = [c for c in cols if 'species' in c.lower()]
    if not cand:
        return None
    # prefer label-like names
    for k in cand:
        if 'label' in k.lower() or 'name' in k.lower():
            return k
    # prefer column that is mostly non-numeric (human-readable)
    def non_numeric_ratio(series):
        s = series.dropna().astype(str).head(1000)
        if len(s) == 0:
            return 0.0
        nonnum = [(not is_float(x)) for x in s]
        return float(sum(nonnum)) / len(s)
    for k in cand:
        try:
            r = non_numeric_ratio(df[k])
            if r > 0.5:
                return k
        except Exception:
            continue
    # fallback: return first candidate
    return cand[0]

def choose_confidence_column(df, species_col):
    # try to find column that contains both 'species' and 'conf' or 'prob' or 'mean_topprob'
    for c in df.columns:
        lc = c.lower()
        if 'species' in lc and any(x in lc for x in ('conf','prob','mean','topprob','score','confidence')):
            # ensure it's numeric-ish
            try:
                pd.to_numeric(df[c].dropna().astype(str).head(200))
                return c
            except Exception:
                pass
    # otherwise pick any numeric column with 'conf' or 'prob' in name
    for c in df.columns:
        lc = c.lower()
        if any(x in lc for x in ('conf','prob','confidence','topprob')):
            try:
                pd.to_numeric(df[c].dropna().astype(str).head(200))
                return c
            except Exception:
                pass
    # last resort: None
    return None

def choose_id_column(df):
    names = ['id','accession','seqid','read_id','readid','accession_id','global_index','query_id']
    cols = {c.lower(): c for c in df.columns}
    for n in names:
        if n in cols:
            return cols[n]
    # fallback: if there is a column with  unique values near number of rows, pick the most-id-like
    for c in df.columns:
        if df[c].nunique() >= max(1, int(len(df)*0.5)):
            if 'id' in c.lower() or 'access' in c.lower() or 'acc' in c.lower():
                return c
    # fallback to first column
    return df.columns[0]

def is_float(s):
    try:
        float(str(s))
        return True
    except Exception:
        return False

def find_blast_file(extracted):
    # search for common BLAST result filenames or anything starting with 'blast' in extracted
    patterns = ["blast_results.tsv","blast_results.out","blast_results.txt","blast.tsv","blast.out","blast.txt",
                "*blast*.tsv","*blast*.txt","*blast*.out","blast_results*.tsv","blast*results*.tsv"]
    # try exact names first
    for name in patterns[:6]:
        p = extracted / name
        if p.exists():
            return p
    # glob search
    for pat in patterns[6:]:
        hits = list(extracted.glob(pat))
        if hits:
            # choose the largest (likely the actual results) or the first
            hits = sorted(hits, key=lambda p: p.stat().st_size if p.exists() else 0, reverse=True)
            return hits[0]
    # also search recursively for any file whose name contains 'blast' or 'blast_results'
    for p in extracted.rglob("*"):
        if p.is_file() and 'blast' in p.name.lower():
            return p
    return None

def parse_blast_tsv(blast_path):
    # Try reading tab-separated BLAST outfmt6 (no header) with 12 columns:
    cols12 = ["qseqid","sseqid","pident","length","qlen","qstart","qend","sstart","send","evalue","bitscore","stitle"]
    try:
        df = pd.read_csv(blast_path, sep="\t", header=None, quoting=3, dtype=str, engine='python')
        if df.shape[1] >= 12:
            df = df.iloc[:, :12]
            df.columns = cols12
            return df
        # maybe file has header with these names
        dfh = pd.read_csv(blast_path, sep="\t", header=0, engine='python', dtype=str)
        # normalize columns if present
        found = [c for c in dfh.columns if c.lower() in cols12]
        if found:
            # try to map
            lower_map = {c.lower(): c for c in dfh.columns}
            cols_map = {name: lower_map.get(name) for name in cols12 if lower_map.get(name) is not None}
            dfh = dfh.rename(columns={v:k for k,v in cols_map.items()})
            return dfh
    except Exception as e:
        # try a more permissive read
        try:
            text = blast_path.read_text(errors='ignore')
            # simple heuristic: lines with >=11 tabs -> outfmt6
            lines = [L for L in text.splitlines() if L.count('\t')>=11]
            if lines:
                from io import StringIO
                df = pd.read_csv(StringIO("\n".join(lines)), sep="\t", header=None, dtype=str)
                if df.shape[1] >= 12:
                    df = df.iloc[:, :12]
                    df.columns = cols12
                    return df
        except Exception:
            pass
    return None

def extract_species_from_title(title):
    if not isinstance(title, str) or not title.strip():
        return None
    title = title.strip()
    # 1) binomial: Genus species (Genus capitalized, species lowercase)
    m = re.search(r'\b([A-Z][a-zA-Z\-]+)\s+([a-z][a-zA-Z\-\(\)]+)\b', title)
    if m:
        return f"{m.group(1)} {m.group(2)}"
    # 2) Genus sp. or Genus sp
    m2 = re.search(r'\b([A-Z][a-zA-Z\-]+)\s+sp\b', title)
    if m2:
        return f"{m2.group(1)} sp."
    # 3) at least a capitalized Genus word (be conservative: return None in this case)
    # Avoid guessing from e.g. 'uncultured bacterium' etc.
    return None

# --- main ---
print("Searching for 'extracted/' folder...")
extracted = find_extracted_folder()
if extracted is None:
    print("ERROR: could not find an 'extracted' directory automatically. Please run the earlier extraction cell or set EXTRACT_DIR to the correct path.")
else:
    print("Using extracted folder:", extracted)

if extracted is None:
    # stop gracefully
    print("No extracted/ found. Aborting this cell (nothing changed).")
else:
    # find predictions CSV
    pred_candidates = ["predictions_with_uncertainty.csv","predictions.csv","predictions_blast_enriched.csv"]
    pred_file = None
    for name in pred_candidates:
        p = extracted / name
        if p.exists():
            pred_file = p
            break
    if pred_file is None:
        # try any CSV with 'predictions' in name
        preds = list(extracted.glob("*predic*.csv"))
        if preds:
            pred_file = preds[0]
    if pred_file is None:
        print("ERROR: cannot find predictions CSV in extracted/. Expected file e.g. predictions_with_uncertainty.csv")
        print("Files in extracted/:")
        for f in sorted(extracted.iterdir()):
            print(" ", f.name)
        print("\nAborting — nothing changed.")
    else:
        print("Loading predictions:", pred_file.name)
        try:
            pred = pd.read_csv(pred_file, dtype=str)
        except Exception as e:
            print("Failed to read predictions CSV:", e)
            pred = pd.read_csv(pred_file, engine='python', dtype=str, error_bad_lines=False)

        # determine core columns robustly
        species_col = choose_species_column(pred)
        conf_col = choose_confidence_column(pred, species_col)
        id_col = choose_id_column(pred)

        print("Detected columns:")
        print(" - species column  ->", species_col)
        print(" - confidence col  ->", conf_col)
        print(" - id column       ->", id_col)

        # find blast results file
        blast_path = find_blast_file(extracted)
        if blast_path is None:
            print("\nNo BLAST results file found in extracted/. Searching for files named like 'blast*.tsv' etc returned nothing.")
            # helpful suggestions
            print("If you already ran BLAST, place the BLAST tabular (outfmt 6) file named 'blast_results.tsv' into:")
            print("  ", extracted)
            print("\nIf you haven't run BLAST, you can run (local blastn or remote NCBI - remote may be slow and has usage limits).")
            print("Example local command (blast+):\n  blastn -query all_unassigned_for_blast.fasta -db nt -outfmt '6 qseqid sseqid pident length qlen qstart qend sstart send evalue bitscore stitle' -max_target_seqs 5 -evalue 1e-10 -out blast_results.tsv -num_threads 8")
            print("\nAfter you run BLAST and put blast_results.tsv into the extracted/ folder, re-run this cell.")
        else:
            print("Found BLAST file:", blast_path.name, "(parsing...)")
            dfb = parse_blast_tsv(blast_path)
            if dfb is None:
                print("Could not parse BLAST file automatically. Inspect file and ensure it is BLAST tabular (-outfmt 6) with columns:")
                print("  qseqid sseqid pident length qlen qstart qend sstart send evalue bitscore stitle")
                print("Aborting (no changes).")
            else:
                # coerce numeric columns
                for col in ("pident","length","qlen","evalue","bitscore"):
                    if col in dfb.columns:
                        dfb[col] = pd.to_numeric(dfb[col], errors='coerce').fillna(0.0)
                # compute coverage safely
                if "qlen" in dfb.columns and "length" in dfb.columns:
                    dfb["coverage_pct"] = dfb.apply(lambda r: (float(r["length"])/float(r["qlen"])*100.0) if (pd.notna(r["qlen"]) and float(r["qlen"])>0) else 0.0, axis=1)
                else:
                    dfb["coverage_pct"] = 0.0

                # conservative filtering
                pre_len = len(dfb)
                dfb_high = dfb[(dfb["pident"] >= MIN_PID) & (dfb["coverage_pct"] >= MIN_COV) & (dfb["evalue"] <= MAX_EVAL)].copy()
                print(f"BLAST rows total={pre_len}, passing thresholds (pident>={MIN_PID}, cov>={MIN_COV}, e<={MAX_EVAL}) = {len(dfb_high)}")

                if dfb_high.empty:
                    print("No BLAST hits passed thresholds. Nothing will be applied.")
                    # still save a tiny audit for the user
                    try:
                        dfb.to_csv(extracted/"blast_parsed_all.csv", index=False)
                        print("Saved parsed BLAST (unfiltered) -> blast_parsed_all.csv")
                    except Exception:
                        pass
                else:
                    # pick top hit per qseqid by bitscore (most reliable)
                    dfb_top = dfb_high.sort_values(["qseqid","bitscore"], ascending=[True, False]).groupby("qseqid", as_index=False).first()
                    print("Unique queries with accepted top-hit:", len(dfb_top))

                    # extract species names conservatively
                    assignments = {}
                    for _, row in dfb_top.iterrows():
                        q = str(row["qseqid"])
                        stitle = row.get("stitle", "")
                        species = extract_species_from_title(stitle)
                        if species is None:
                            # do not guess from non-binomial titles
                            continue
                        assignments[q] = {
                            "species": species,
                            "pident": float(row["pident"]),
                            "coverage": float(row["coverage_pct"]),
                            "evalue": float(row["evalue"]),
                            "sseqid": row.get("sseqid"),
                            "stitle": stitle
                        }
                    print("Conservative, parsed BLAST-derived species assignments:", len(assignments))
                    # save assignments for audit
                    assign_df = pd.DataFrame([{"query_id":q, **v} for q,v in assignments.items()])
                    assign_out = extracted / "blast_species_assignments.csv"
                    assign_df.to_csv(assign_out, index=False)
                    print("Wrote BLAST assignments ->", assign_out.name)

                    # prepare predictions (do not overwrite original file)
                    pred2 = pred.copy()
                    # ensure id and species columns exist
                    if species_col is None:
                        # create one if missing
                        species_col = "species_pred_label"
                        pred2[species_col] = ""
                    if id_col not in pred2.columns:
                        print(f"ERROR: Could not detect id column in predictions to map BLAST qseqid to rows. Columns available: {list(pred2.columns)}")
                        print("Aborting apply. You can still inspect blast_species_assignments.csv manually.")
                    else:
                        # normalized current species label string
                        pred2["_species_norm"] = pred2[species_col].astype(str).fillna("").apply(lambda s: " ".join(str(s).split()))
                        mask_unassigned = pred2["_species_norm"].str.upper().isin(["UNASSIGNED", "", "NONE", "NAN"])
                        n_unassigned_before = int(mask_unassigned.sum())
                        print("UNASSIGNED rows before:", n_unassigned_before)

                        # try multiple matching heuristics; only change rows currently UNASSIGNED
                        rows_to_change = []  # (row_index, new_species, q, info)
                        for q, info in assignments.items():
                            # Exact match
                            exact_mask = pred2[id_col].astype(str) == str(q)
                            if exact_mask.any():
                                for ri in pred2.index[exact_mask]:
                                    if pred2.at[ri, "_species_norm"].upper() in ("UNASSIGNED","", "NONE", "NAN"):
                                        rows_to_change.append((ri, info["species"], q, info))
                                continue
                            # Versionless match: strip .1 etc
                            q_base = q.split('.')[0]
                            try:
                                id_base = pred2[id_col].astype(str).str.split('.').str[0]
                                mask_base = id_base == q_base
                                if mask_base.any():
                                    for ri in pred2.index[mask_base]:
                                        if pred2.at[ri, "_species_norm"].upper() in ("UNASSIGNED","", "NONE", "NAN"):
                                            rows_to_change.append((ri, info["species"], q, info))
                                    continue
                            except Exception:
                                pass
                            # substring match (conservative)
                            try:
                                mask_sub = pred2[id_col].astype(str).str.contains(str(q), na=False)
                                if mask_sub.any():
                                    for ri in pred2.index[mask_sub]:
                                        if pred2.at[ri, "_species_norm"].upper() in ("UNASSIGNED","", "NONE", "NAN"):
                                            rows_to_change.append((ri, info["species"], q, info))
                                    continue
                            except Exception:
                                pass
                            # no match -> skip

                        # de-duplicate picking best pident/coverage if multiple assignments for same row
                        row_best = {}
                        for ri, new_sp, q, info in rows_to_change:
                            key = int(ri)
                            if key not in row_best:
                                row_best[key] = (new_sp, q, info)
                            else:
                                existing = row_best[key][2]
                                # prefer higher pident, then higher coverage
                                if info.get("pident",0) > existing.get("pident",0) or (info.get("pident",0)==existing.get("pident",0) and info.get("coverage",0) > existing.get("coverage",0)):
                                    row_best[key] = (new_sp, q, info)
                        final_rows = sorted([(ri, *row_best[ri]) for ri in row_best])
                        n_candidates = len(final_rows)
                        print("Candidate UNASSIGNED rows matching BLAST assignments:", n_candidates)

                        if n_candidates == 0:
                            print("No UNASSIGNED rows matched conservative BLAST assignments (nothing applied).")
                        else:
                            # apply if user wants
                            if APPLY_ASSIGNMENTS:
                                for ri, new_sp, q, info in final_rows:
                                    pred2.at[ri, species_col] = new_sp
                                    pred2.at[ri, "blast_assigned_species"] = new_sp
                                    pred2.at[ri, "blast_assigned_query"] = q
                                    pred2.at[ri, "blast_pident"] = info["pident"]
                                    pred2.at[ri, "blast_coverage"] = info["coverage"]
                                    pred2.at[ri, "blast_evalue"] = info["evalue"]
                                    pred2.at[ri, "blast_sseqid"] = info.get("sseqid")
                                    pred2.at[ri, "blast_stitle"] = info.get("stitle")
                                out_forced = extracted / "predictions_after_blast_forced.csv"
                                pred2.to_csv(out_forced, index=False)
                                print(f"Applied BLAST assignments to {len(final_rows)} UNASSIGNED rows -> saved {out_forced.name}")
                            else:
                                print(f"APPLY_ASSIGNMENTS=False; {len(final_rows)} candidate rows identified but not applied.")

                        # always write enriched copy (with BLAST columns added)
                        out_enriched = extracted / "predictions_blast_enriched.csv"
                        pred2.to_csv(out_enriched, index=False)
                        print("Wrote enriched predictions (audit) ->", out_enriched.name)

                        # recompute abundance counts from updated predictions
                        pred2["_species_norm_final"] = pred2[species_col].astype(str).fillna("").apply(lambda s: " ".join(str(s).split()))
                        counts = pred2["_species_norm_final"].value_counts(dropna=False).reset_index()
                        counts.columns = ["species","count"]
                        counts["count"] = pd.to_numeric(counts["count"], errors='coerce').fillna(0).astype(int)
                        total = counts["count"].sum()
                        counts["rel"] = counts["count"] / total if total>0 else 0.0
                        out_abund = extracted / "abundance_after_blast.csv"
                        counts.to_csv(out_abund, index=False)
                        print("Wrote abundance_after_blast.csv ->", out_abund.name)
                        print("\nTop species after applying BLAST assignments (top 20):")
                        print(counts.head(20).to_string(index=False))

                        n_unassigned_after = int((pred2["_species_norm_final"].str.upper().isin(["UNASSIGNED","", "NONE", "NAN"])).sum())
                        print(f"\nSUMMARY: UNASSIGNED before = {n_unassigned_before}; UNASSIGNED after = {n_unassigned_after}; filled = {n_unassigned_before - n_unassigned_after}")

print("\nCell finished. Check the extracted/ folder for generated audit files.")


Searching for 'extracted/' folder...
Using extracted folder: C:\Users\Srijit\sih\ncbi_blast_db\extracted
Loading predictions: predictions_with_uncertainty.csv
Detected columns:
 - species column  -> species_pred_label
 - confidence col  -> species_pred_conf
 - id column       -> id
Found BLAST file: all_unassigned_for_blast.fasta (parsing...)
Could not parse BLAST file automatically. Inspect file and ensure it is BLAST tabular (-outfmt 6) with columns:
  qseqid sseqid pident length qlen qstart qend sstart send evalue bitscore stitle
Aborting (no changes).

Cell finished. Check the extracted/ folder for generated audit files.


In [129]:
# === Robust single cell to FIX UNASSIGNED reads (BLAST if present; otherwise safe auto-fill) ===
# - If blast_results.tsv is present in extracted/, it uses BLAST assignments (strict).
# - Otherwise it computes an adaptive confidence threshold (from existing assigned reads)
#   and conservatively fills a subset of UNASSIGNED rows with the model's own predicted species.
# - Writes audit files and recomputes abundance CSVs (raw, weighted, deconvolved/reconciled).
# - Non-destructive: original predictions CSV is not overwritten.
# Paste & run in the same notebook where your earlier cells ran.
from pathlib import Path
import pandas as pd
import numpy as np
import re, sys, math, time

def find_extracted():
    cand = [
        Path.cwd() / "sih" / "ncbi_blast_db" / "extracted",
        Path.cwd() / "ncbi_blast_db" / "extracted",
        Path.cwd() / "extracted",
        Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"),
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
    ]
    for p in cand:
        if p.exists() and p.is_dir():
            if (p/"predictions_with_uncertainty.csv").exists() or (p/"predictions.csv").exists():
                return p.resolve()
    # recursive fallback
    for p in Path.cwd().rglob("predictions_with_uncertainty.csv"):
        return p.parent.resolve()
    for p in Path.cwd().rglob("predictions.csv"):
        return p.parent.resolve()
    return None

def detect_species_conf_cols(df):
    low = {c.lower(): c for c in df.columns}
    # species label detection: prefer *_pred_label or explicit names
    for pat in ("species_pred_label", "_pred_label", "species_label", "species_pred"):
        for c in df.columns:
            if pat in c.lower():
                # ensure it's not an index column (like species_pred_idx)
                if "idx" in c.lower() or c.lower().endswith("_idx"):
                    continue
                return c
    # fallback: any non-numeric column that contains 'species'
    for c in df.columns:
        if "species" in c.lower():
            if not pd.api.types.is_numeric_dtype(df[c]):
                return c
    # last resort: pick the column with 'species' even if numeric
    for c in df.columns:
        if "species" in c.lower():
            return c
    return None

def detect_confidence_col(df):
    for c in df.columns:
        lc = c.lower()
        if "species_pred_conf" in lc or "species_conf" in lc or "pred_conf" in lc or lc.endswith("_conf"):
            return c
    # fallback: something like 'species_mc_mean_topprob' or 'mc_mean_topprob'
    for c in df.columns:
        if "mc_mean_topprob" in c.lower() or "topprob" in c.lower() or "prob" in c.lower():
            return c
    return None

def detect_novel_col(df):
    for c in df.columns:
        if "novel" in c.lower() or "novel_component" in c.lower() or "novelty" in c.lower():
            return c
    return None

def nnls_projected_gradient(M, b, max_iter=5000, tol=1e-6, verbose=False):
    # M shape: (m x n)? we will follow earlier convention M = A^T
    try:
        x0 = np.linalg.lstsq(M, b, rcond=None)[0]
        x = np.maximum(0.0, x0)
    except Exception:
        x = np.maximum(0.0, np.ones(M.shape[1]) * (b.sum() / max(1, M.shape[1])))
    # estimate Lipschitz L
    try:
        s = np.linalg.svd(M, compute_uv=False)
        L = (s[0] ** 2) if s.size>0 else (np.linalg.norm(M, ord=2)**2)
    except Exception:
        L = (np.linalg.norm(M, ord=2)**2) + 1e-8
    lr = 1.0 / (L + 1e-12)
    prev = None
    for it in range(max_iter):
        r = M.dot(x) - b
        grad = M.T.dot(r)
        x = x - lr * grad
        x = np.maximum(0.0, x)
        rnorm = np.linalg.norm(r)
        if prev is not None and abs(prev - rnorm) < tol:
            if verbose: print("nnls conv iter", it, "rnorm", rnorm)
            break
        prev = rnorm
    return x

# --- main driver ---
extracted = find_extracted()
if extracted is None:
    print("Could not locate an 'extracted/' folder with predictions. Make sure you ran earlier cells and that outputs exist.")
else:
    print("Using extracted folder:", extracted)
    # find predictions
    preds_path = extracted / "predictions_with_uncertainty.csv"
    if not preds_path.exists():
        preds_path = extracted / "predictions.csv" if (extracted / "predictions.csv").exists() else None
    if preds_path is None:
        print("No predictions CSV found in extracted/. Nothing to do.")
    else:
        print("Loading predictions:", preds_path.name)
        preds = pd.read_csv(preds_path)
        # detect columns
        species_col = detect_species_conf_cols(preds)
        conf_col    = detect_confidence_col(preds)
        novel_col   = detect_novel_col(preds)
        id_col = next((c for c in preds.columns if c.lower() in ("id","accession","seqid","read_id","readid","accession_id","global_index")), None)
        print("Detected columns -> species:", species_col, "| conf:", conf_col, "| novel:", novel_col, "| id:", id_col)

        # Normalize species string for checking UNASSIGNED (keep original column untouched)
        preds["_species_norm"] = preds[species_col].astype(str).fillna("UNASSIGNED").apply(lambda s: " ".join(str(s).split()))
        # numeric confidence
        if conf_col is not None:
            preds["_conf_num"] = pd.to_numeric(preds[conf_col], errors="coerce").fillna(0.0)
        else:
            preds["_conf_num"] = 0.0

        # If blast_results.tsv exists, use BLAST-based strict assignment first
        blast_file = extracted / "blast_results.tsv"
        applied_from_blast = 0
        applied_from_autofill = 0
        applied_rows = []

        if blast_file.exists():
            print("blast_results.tsv found -> using BLAST-based strict assignments first.")
            # reuse earlier robust BLAST parsing logic (conservative thresholds)
            cols = ["qseqid","sseqid","pident","length","qlen","qstart","qend","sstart","send","evalue","bitscore","stitle"]
            try:
                dfb = pd.read_csv(blast_file, sep="\t", names=cols, header=None, quoting=3, dtype={"stitle":str})
            except Exception as e:
                print("Failed to read blast_results.tsv:", e)
                dfb = None
            if dfb is not None and not dfb.empty:
                dfb["pident"] = pd.to_numeric(dfb["pident"], errors="coerce").fillna(0.0)
                dfb["length"] = pd.to_numeric(dfb["length"], errors="coerce").fillna(0.0)
                dfb["qlen"] = pd.to_numeric(dfb["qlen"], errors="coerce").replace(0, np.nan)
                dfb["coverage_pct"] = dfb.apply(lambda r: (r["length"]/r["qlen"]*100.0) if (pd.notna(r["qlen"]) and r["qlen"]>0) else 0.0, axis=1)
                dfb["evalue"] = pd.to_numeric(dfb["evalue"], errors="coerce").fillna(1e6)
                # conservative filter:
                dfb_high = dfb[(dfb["pident"]>=97.0) & (dfb["coverage_pct"]>=80.0) & (dfb["evalue"]<=1e-6)].copy()
                if not dfb_high.empty:
                    dfb_top = dfb_high.sort_values(["qseqid","bitscore"], ascending=[True,False]).groupby("qseqid", as_index=False).first()
                    def extract_species_from_title(title):
                        if not isinstance(title, str) or not title.strip(): return None
                        m = re.search(r'\b([A-Z][a-z]+(?:\s+[a-z][a-z\-]+))\b', title)
                        if m: return m.group(1).strip()
                        m2 = re.search(r'\b([A-Z][a-z]+)\s+sp\b', title)
                        if m2: return m2.group(0).strip()
                        return None
                    assignments = {}
                    for _,r in dfb_top.iterrows():
                        q = str(r["qseqid"]); s = extract_species_from_title(r["stitle"])
                        if s: assignments[q] = {"species": s, "pident": float(r["pident"]), "coverage": float(r["coverage_pct"]), "sseqid": r["sseqid"], "stitle": r["stitle"]}
                    # map assignments to prediction rows by id column (exact or substring) and apply only to UNASSIGNED
                    if assignments:
                        if id_col is None:
                            print("Cannot map BLAST query IDs to prediction rows: no id-like column in predictions.")
                        else:
                            for q,info in assignments.items():
                                # exact
                                mask_exact = preds[id_col].astype(str) == str(q)
                                if mask_exact.any():
                                    for ri in preds.index[mask_exact]:
                                        if preds.at[ri, "_species_norm"].upper()=="UNASSIGNED":
                                            preds.at[ri, species_col] = info["species"]
                                            preds.at[ri, "blast_assigned_species"] = info["species"]
                                            preds.at[ri, "blast_assigned_query"] = q
                                            preds.at[ri, "blast_pident"] = info["pident"]
                                            preds.at[ri, "blast_coverage"] = info["coverage"]
                                            preds.at[ri, "blast_sseqid"] = info["sseqid"]
                                            preds.at[ri, "blast_stitle"] = info["stitle"]
                                            applied_from_blast += 1
                                            applied_rows.append((ri, q, info["species"], "blast"))
                                    continue
                                # substring
                                mask_sub = preds[id_col].astype(str).str.contains(str(q), na=False)
                                if mask_sub.any():
                                    for ri in preds.index[mask_sub]:
                                        if preds.at[ri, "_species_norm"].upper()=="UNASSIGNED":
                                            preds.at[ri, species_col] = info["species"]
                                            preds.at[ri, "blast_assigned_species"] = info["species"]
                                            preds.at[ri, "blast_assigned_query"] = q
                                            preds.at[ri, "blast_pident"] = info["pident"]
                                            preds.at[ri, "blast_coverage"] = info["coverage"]
                                            preds.at[ri, "blast_sseqid"] = info["sseqid"]
                                            preds.at[ri, "blast_stitle"] = info["stitle"]
                                            applied_from_blast += 1
                                            applied_rows.append((ri, q, info["species"], "blast"))
                else:
                    print("No high-confidence BLAST hits passed the conservative thresholds. Will try auto-fill below.")
            else:
                print("No BLAST rows loaded. Will try auto-fill below.")

        # --- If BLAST did not fill everything, perform data-driven safe autofill of remaining UNASSIGNED ---
        # Compute how many UNASSIGNED remain
        preds["_species_norm"] = preds[species_col].astype(str).fillna("UNASSIGNED").apply(lambda s: " ".join(str(s).split()))
        n_unassigned_before = int((preds["_species_norm"].str.upper()=="UNASSIGNED").sum())
        print("UNASSIGNED before autofill (after BLAST step):", n_unassigned_before)
        # Recompute conf numbers if not present
        if conf_col and conf_col in preds.columns:
            preds["_conf_num"] = pd.to_numeric(preds[conf_col], errors="coerce").fillna(0.0)
        else:
            # check for any _prob-like column
            alt = next((c for c in preds.columns if "prob" in c.lower() or "topprob" in c.lower()), None)
            if alt:
                preds["_conf_num"] = pd.to_numeric(preds[alt], errors="coerce").fillna(0.0)
                print("Using alternative probability column for confidence:", alt)
            else:
                preds["_conf_num"] = 0.0

        # if there are still UNASSIGNED rows, do adaptive thresholding
        remaining_unassigned_mask = preds["_species_norm"].str.upper()=="UNASSIGNED"
        if remaining_unassigned_mask.sum() > 0:
            # Determine adaptive confidence threshold from assigned rows
            assigned_mask = preds["_species_norm"].str.upper() != "UNASSIGNED"
            assigned_conf = preds.loc[assigned_mask, "_conf_num"].dropna().astype(float)
            if len(assigned_conf) >= 5:
                conf_threshold = float(np.percentile(assigned_conf, 90))  # top 10% of assigned confidences
                conf_threshold = max(conf_threshold, 0.5)  # ensure not ridiculously low
            else:
                # fallback when few assigned rows: use 0.65 (data-driven fallback)
                conf_threshold = 0.65
            # Novel threshold if present (use median of novel among all rows or a conservative small value)
            novel_col = novel_col  # already detected
            if novel_col and novel_col in preds.columns:
                preds["_novel_num"] = pd.to_numeric(preds[novel_col], errors="coerce").fillna(1.0)
                # choose median of _novel_num (lower means less novel)
                novel_threshold = float(np.percentile(preds["_novel_num"].dropna(), 50))
            else:
                preds["_novel_num"] = 1.0
                novel_threshold = 0.5  # conservative default if no signal
            print(f"Adaptive thresholds -> conf >= {conf_threshold:.3f}, novel <= {novel_threshold:.3f}")

            # Candidate criteria: UNASSIGNED AND predicted species is not 'UNASSIGNED' AND conf >= threshold AND novel <= threshold
            # Detect predicted species column (use same species_col)
            candidates_mask = remaining_unassigned_mask & (preds[species_col].astype(str).str.upper() != "UNASSIGNED") & \
                              (preds["_conf_num"] >= conf_threshold) & (preds["_novel_num"] <= novel_threshold)
            cand_count = int(candidates_mask.sum())
            print("Auto-fill candidate rows meeting data-driven criteria:", cand_count)
            if cand_count > 0:
                # Apply the autofill
                for ri in preds.index[candidates_mask]:
                    new_sp = str(preds.at[ri, species_col])
                    if preds.at[ri, "_species_norm"].upper()=="UNASSIGNED" and new_sp.strip() and new_sp.strip().upper()!="UNASSIGNED":
                        preds.at[ri, species_col] = new_sp
                        preds.at[ri, "autofill_assigned_species"] = new_sp
                        preds.at[ri, "autofill_conf"] = float(preds.at[ri, "_conf_num"])
                        preds.at[ri, "autofill_novel"] = float(preds.at[ri, "_novel_num"])
                        applied_from_autofill += 1
                        applied_rows.append((ri, preds.at[ri, id_col] if id_col else ri, new_sp, "autofill"))
            else:
                print("No auto-fill candidates met the conservative, adaptive criteria. No per-read labels changed by autofill.")
        else:
            print("No UNASSIGNED rows remain after BLAST step.")

        # Write audit of changed rows (if any)
        audit_rows = []
        for (ri, q, new_sp, how) in applied_rows:
            old = preds.at[ri, "_species_norm"]
            audit_rows.append({
                "row_index": int(ri),
                "id": preds.at[ri, id_col] if id_col else "",
                "old_species": old,
                "new_species": new_sp,
                "method": how,
                "conf": float(preds.at[ri, "_conf_num"]) if "_conf_num" in preds.columns else None,
                "novel": float(preds.at[ri, "_novel_num"]) if "_novel_num" in preds.columns else None
            })
        if audit_rows:
            audit_df = pd.DataFrame(audit_rows)
            out_audit = extracted / "autofill_assignment_audit.csv"
            audit_df.to_csv(out_audit, index=False)
            print("Wrote audit of applied assignments ->", out_audit, " (rows changed:", len(audit_df), ")")
        else:
            print("No assignments were applied (no audit file created).")

        # Save enriched predictions (audit copy)
        enriched_out = extracted / "predictions_after_autofill.csv"
        preds.to_csv(enriched_out, index=False)
        print("Wrote enriched predictions (audit copy) ->", enriched_out)

        # Recompute abundances from updated predictions (raw + weighted)
        preds["_species_norm_final"] = preds[species_col].astype(str).fillna("").apply(lambda s: " ".join(str(s).split()))
        counts = preds["_species_norm_final"].value_counts().reset_index()
        counts.columns = ["species","count"]
        counts["relative_abundance"] = counts["count"] / counts["count"].sum() if counts["count"].sum()>0 else 0.0
        out_raw = extracted / "abundance_from_predictions_after_autofill.csv"
        counts.to_csv(out_raw, index=False)
        # weighted
        preds["_conf_num"] = preds["_conf_num"].fillna(0.0)
        weighted = preds.groupby("_species_norm_final")["_conf_num"].sum().reset_index().rename(columns={"_conf_num":"conf_sum"})
        total_conf = weighted["conf_sum"].sum() if weighted["conf_sum"].sum()>0 else 1.0
        weighted["relative_abundance_weighted"] = weighted["conf_sum"] / total_conf
        out_weight = extracted / "abundance_from_predictions_weighted_after_autofill.csv"
        weighted.to_csv(out_weight, index=False)
        print("Wrote post-fill abundances ->", out_raw.name, "and", out_weight.name)

        # Attempt confusion-aware deconvolution with validation preds if available (same as earlier)
        val_path = None
        for cand in (extracted / "val_predictions_calibrated.csv", extracted / "val_predictions.csv"):
            if cand.exists(): val_path = cand; break
        if val_path is None:
            for p in extracted.rglob("val_predictions*"):
                if p.is_file(): val_path = p; break

        if val_path is None:
            print("No validation predictions found; skipping confusion-based deconvolution.")
        else:
            try:
                val = pd.read_csv(val_path)
                # detect true/pred columns heuristically
                val_true_col = next((c for c in val.columns if "true" in c.lower() and "species" in c.lower()), None)
                if val_true_col is None:
                    val_true_col = next((c for c in val.columns if "species_true" in c.lower() or c.lower()=="species_true_idx"), None)
                val_pred_col = next((c for c in val.columns if "species_pred_label" in c.lower() or "species_pred" in c.lower()), None)
                if val_true_col is None or val_pred_col is None:
                    print("Could not detect true/pred columns in validation file; skipping deconv.")
                else:
                    val["_true_norm"] = val[val_true_col].astype(str).fillna("").apply(lambda s: " ".join(str(s).split()))
                    val["_pred_norm"] = val[val_pred_col].astype(str).fillna("").apply(lambda s: " ".join(str(s).split()))
                    classes_true = sorted(val["_true_norm"].unique().tolist())
                    classes_pred = sorted(preds["_species_norm_final"].unique().tolist())
                    A = np.zeros((len(classes_true), len(classes_pred)), dtype=float)
                    for i,t in enumerate(classes_true):
                        sub = val[val["_true_norm"]==t]
                        if len(sub)==0: continue
                        vc = sub["_pred_norm"].value_counts()
                        for j,p_label in enumerate(classes_pred):
                            A[i,j] = vc.get(p_label, 0) / len(sub)
                    pred_counts_map = dict(zip(counts["species"], counts["count"]))
                    p_vec = np.array([pred_counts_map.get(lbl, 0.0) for lbl in classes_pred], dtype=float)
                    M = A.T
                    if M.size==0 or np.allclose(M, 0.0):
                        print("Confusion matrix empty -> skipping deconvolution.")
                    else:
                        print("Running NNLS deconvolution (projected gradient) ...")
                        x = nnls_projected_gradient(M, p_vec, max_iter=5000, tol=1e-6)
                        x = np.maximum(0.0, x)
                        tot = x.sum()
                        rows_out = []
                        for i,t in enumerate(classes_true):
                            pred_count_for_t = pred_counts_map.get(t, 0)
                            est_true_count = float(x[i]) if i < len(x) else 0.0
                            rows_out.append({"species": t, "pred_count": int(pred_count_for_t), "pred_count_rel": pred_count_for_t/(p_vec.sum() if p_vec.sum()>0 else 1.0), "est_true_count": est_true_count, "est_true_rel": (est_true_count/tot if tot>0 else 0.0)})
                        deconv_df = pd.DataFrame(rows_out).sort_values("est_true_count", ascending=False).reset_index(drop=True)
                        out_deconv = extracted / "abundance_from_predictions_deconvolved_after_autofill.csv"
                        deconv_df.to_csv(out_deconv, index=False)
                        out_recon = extracted / "abundance_reconciled_species_after_autofill.csv"
                        deconv_df[["species","est_true_rel"]].rename(columns={"est_true_rel":"est_rel"}).to_csv(out_recon, index=False)
                        print("Wrote deconvolved & reconciled ->", out_deconv.name, out_recon.name)
            except Exception as e:
                print("Failed to run deconvolution on validation file:", e)

        # Final summary
        n_unassigned_after = int((preds["_species_norm_final"].str.upper()=="UNASSIGNED").sum())
        print("\nSUMMARY:")
        print(" UNASSIGNED before (after BLAST step):", n_unassigned_before)
        print(" Assigned from BLAST:", applied_from_blast)
        print(" Assigned from auto-fill:", applied_from_autofill)
        print(" UNASSIGNED after processing:", n_unassigned_after)
        print("\nFiles written to:", extracted)
        print(" -", enriched_out.name)
        if 'out_audit' in locals() and out_audit.exists(): print(" -", out_audit.name)
        print(" -", out_raw.name)
        print(" -", out_weight.name)
        if 'out_deconv' in locals() and out_deconv.exists(): print(" -", out_deconv.name)
        if 'out_recon' in locals() and out_recon.exists(): print(" -", out_recon.name)
        print("\nIf you want stricter behavior (only BLAST, no autofill) re-run this cell after placing blast_results.tsv into extracted/.")


Using extracted folder: C:\Users\Srijit\sih\ncbi_blast_db\extracted
Loading predictions: predictions_with_uncertainty.csv
Detected columns -> species: species_pred_label | conf: kingdom_pred_conf | novel: species_novel_component | id: global_index
UNASSIGNED before autofill (after BLAST step): 144
Adaptive thresholds -> conf >= 0.980, novel <= 0.258
Auto-fill candidate rows meeting data-driven criteria: 0
No auto-fill candidates met the conservative, adaptive criteria. No per-read labels changed by autofill.
No assignments were applied (no audit file created).
Wrote enriched predictions (audit copy) -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\predictions_after_autofill.csv
Wrote post-fill abundances -> abundance_from_predictions_after_autofill.csv and abundance_from_predictions_weighted_after_autofill.csv
Confusion matrix empty -> skipping deconvolution.

SUMMARY:
 UNASSIGNED before (after BLAST step): 144
 Assigned from BLAST: 0
 Assigned from auto-fill: 0
 UNASSIGNED after process

In [133]:
# FINALIZE (fixed): produce publication-ready abundance CSV from files in extracted/
# - Fixes earlier pandas truth-value error.
# - Works only with files already in extracted/.
# - Non-destructive: original files remain untouched.
from pathlib import Path
import pandas as pd
import numpy as np

# ---------- Config ----------
APPLY_REDISTRIBUTE_UNASSIGNED = True
REDISTRIBUTE_ONLY_IF_UNASSIGNED_FRAC_GT = 0.0
TOP_N_PRINT = 20
# ----------------------------

def find_extracted():
    candidates = [
        Path.cwd() / "sih" / "ncbi_blast_db" / "extracted",
        Path.cwd() / "ncbi_blast_db" / "extracted",
        Path.cwd() / "extracted",
        Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"),
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
    ]
    for p in candidates:
        if p.exists() and p.is_dir():
            return p.resolve()
    for p in Path.cwd().rglob("**/extracted"):
        if p.is_dir():
            return p.resolve()
    return None

def load_if_exists(path_obj):
    if path_obj.exists():
        try:
            return pd.read_csv(path_obj)
        except Exception as e:
            print(f"Warning: failed to read {path_obj.name}: {e}")
    return None

extracted = find_extracted()
if extracted is None:
    print("Could not find an 'extracted/' folder. Place outputs in an 'extracted/' folder and re-run.")
else:
    print("Using extracted folder:", extracted)

    # locate the best-available raw/weighted/deconv/reconciled files (first match wins)
    raw_candidates = [
        "abundance_from_predictions.csv",
        "abundance_from_predictions_after_blast.csv",
        "abundance_from_predictions_after_autofill.csv",
        "abundance_from_predictions_after_blast.csv"
    ]
    weighted_candidates = [
        "abundance_from_predictions_weighted.csv",
        "abundance_from_predictions_weighted_after_blast.csv",
        "abundance_from_predictions_weighted_after_autofill.csv"
    ]
    deconv_candidates = [
        "abundance_from_predictions_deconvolved.csv",
        "abundance_from_predictions_deconvolved_after_blast.csv",
        "abundance_from_predictions_deconvolved_after_autofill.csv"
    ]
    reconciled_candidates = [
        "abundance_reconciled_species.csv",
        "abundance_reconciled_species_after_blast.csv",
        "abundance_reconciled_species_after_autofill.csv"
    ]

    raw = None
    for fn in raw_candidates:
        raw = load_if_exists(extracted / fn)
        if raw is not None:
            print("Loaded raw file:", fn)
            break

    weighted = None
    for fn in weighted_candidates:
        weighted = load_if_exists(extracted / fn)
        if weighted is not None:
            print("Loaded weighted file:", fn)
            break

    deconv = None
    for fn in deconv_candidates:
        deconv = load_if_exists(extracted / fn)
        if deconv is not None:
            print("Loaded deconv file:", fn)
            break

    reconciled = None
    for fn in reconciled_candidates:
        reconciled = load_if_exists(extracted / fn)
        if reconciled is not None:
            print("Loaded reconciled file:", fn)
            break

    # If reconciled missing but deconv is present, form reconciled from deconv if possible
    if reconciled is None and deconv is not None:
        if 'est_true_rel' in deconv.columns:
            reconciled = deconv[['species','est_true_rel']].rename(columns={'est_true_rel':'est_rel'})
            print("Built reconciled table from deconv (est_true_rel).")
        elif 'est_true_count' in deconv.columns:
            tmp = deconv[['species','est_true_count']].copy()
            s = tmp['est_true_count'].sum() if tmp['est_true_count'].sum()>0 else 1.0
            tmp['est_rel'] = tmp['est_true_count'] / s
            reconciled = tmp[['species','est_rel']]
            print("Built reconciled table from deconv (est_true_count).")
        else:
            reconciled = None

    # Collect all species present across available tables
    species_set = set()
    if raw is not None and 'species' in raw.columns:
        species_set.update(raw['species'].astype(str).tolist())
    if weighted is not None and 'species' in weighted.columns:
        species_set.update(weighted['species'].astype(str).tolist())
    if reconciled is not None and 'species' in reconciled.columns:
        species_set.update(reconciled['species'].astype(str).tolist())

    species_list = sorted(list(species_set), key=lambda x: (x.upper()=="UNASSIGNED", x))

    if not species_list:
        print("No abundance data files were found in extracted/. Nothing to build.")
    else:
        res = pd.DataFrame({'species': species_list})

        # map raw counts if available
        if raw is not None and 'species' in raw.columns and 'count' in raw.columns:
            raw_map = dict(zip(raw['species'].astype(str), raw['count'].astype(float)))
            res['pred_count'] = res['species'].map(raw_map).fillna(0.0).astype(float)
        else:
            res['pred_count'] = 0.0

        total_raw = res['pred_count'].sum() if res['pred_count'].sum()>0 else 1.0
        res['pred_count_rel'] = res['pred_count'] / total_raw

        # map weighted (conf_sum)
        if weighted is not None and 'species' in weighted.columns:
            if 'conf_sum' in weighted.columns:
                wmap = dict(zip(weighted['species'].astype(str), weighted['conf_sum'].astype(float)))
            else:
                # assume second numeric column is conf_sum
                second_col = weighted.columns[1] if len(weighted.columns)>1 else None
                if second_col:
                    wmap = dict(zip(weighted['species'].astype(str), pd.to_numeric(weighted[second_col], errors='coerce').fillna(0.0)))
                else:
                    wmap = {}
            res['conf_sum'] = res['species'].map(wmap).fillna(0.0).astype(float)
        else:
            res['conf_sum'] = 0.0
        total_conf = res['conf_sum'].sum() if res['conf_sum'].sum()>0 else 1.0
        res['pred_conf_rel'] = res['conf_sum'] / total_conf

        # map reconciled est_rel
        if reconciled is not None and 'species' in reconciled.columns and 'est_rel' in reconciled.columns:
            rmap = dict(zip(reconciled['species'].astype(str), reconciled['est_rel'].astype(float)))
            res['est_rel'] = res['species'].map(rmap).fillna(0.0).astype(float)
        else:
            # fallback to pred_count_rel to fill est_rel when none available
            res['est_rel'] = res['pred_count_rel'].copy()

        # normalize est_rel
        total_est = res['est_rel'].sum()
        if total_est > 0:
            res['est_rel'] = res['est_rel'] / total_est

        # handle UNASSIGNED redistribution at abundance-level only
        unassigned_mask = res['species'].astype(str).str.upper() == 'UNASSIGNED'
        unassigned_count = float(res.loc[unassigned_mask, 'pred_count'].sum()) if unassigned_mask.any() else 0.0
        frac_unassigned = unassigned_count / res['pred_count'].sum() if res['pred_count'].sum()>0 else 0.0

        if APPLY_REDISTRIBUTE_UNASSIGNED and frac_unassigned > REDISTRIBUTE_ONLY_IF_UNASSIGNED_FRAC_GT:
            # distribute UNASSIGNED count proportionally to est_rel over non-unassigned taxa
            target = res.loc[~unassigned_mask].copy()
            if target.empty:
                print("Only UNASSIGNED present; cannot redistribute.")
            else:
                est_vals = target['est_rel'].values
                if est_vals.sum() <= 0:
                    est_vals = target['pred_count_rel'].values
                    if est_vals.sum() <= 0:
                        est_vals = np.ones(len(target))
                weights = est_vals / est_vals.sum()
                added = weights * unassigned_count
                target = target.reset_index(drop=True)
                target['pred_count_after'] = target['pred_count'] + added
                redistributed = res.copy()
                redistributed.loc[~unassigned_mask, 'pred_count'] = target['pred_count_after'].values
                redistributed.loc[unassigned_mask, 'pred_count'] = 0.0
                total_new = redistributed['pred_count'].sum() if redistributed['pred_count'].sum()>0 else 1.0
                redistributed['pred_count_rel'] = redistributed['pred_count'] / total_new
                # renormalize est_rel
                if redistributed['est_rel'].sum() > 0:
                    redistributed['est_rel'] = redistributed['est_rel'] / redistributed['est_rel'].sum()
                redistributed['conf_sum'] = redistributed['conf_sum'].where(~unassigned_mask, 0.0)
                total_conf2 = redistributed['conf_sum'].sum() if redistributed['conf_sum'].sum()>0 else 1.0
                redistributed['pred_conf_rel'] = redistributed['conf_sum'] / total_conf2
                final_df = redistributed
                note = f"Redistributed UNASSIGNED mass at abundance level (unassigned_count={unassigned_count:.0f})."
        else:
            final_df = res.copy()
            note = "No redistribution performed."

        # cleanup and ensure numeric
        for c in ['pred_count','pred_count_rel','conf_sum','pred_conf_rel','est_rel']:
            if c in final_df.columns:
                final_df[c] = pd.to_numeric(final_df[c], errors='coerce').fillna(0.0)

        # add est_pct
        final_df['est_pct'] = final_df['est_rel'] * 100.0

        # write publication-ready CSV
        pub_out = extracted / "abundance_publication_ready.csv"
        cols_out = ['species','pred_count','pred_count_rel','conf_sum','pred_conf_rel','est_rel','est_pct']
        # some columns may be missing depending on inputs; select those present
        cols_present = [c for c in cols_out if c in final_df.columns]
        final_df[cols_present].to_csv(pub_out, index=False)
        print("Wrote publication-ready CSV ->", pub_out)
        print("Note:", note)

        # print top N
        top = final_df.sort_values('est_rel', ascending=False).reset_index(drop=True).head(TOP_N_PRINT)
        print(f"\nTop {TOP_N_PRINT} taxa by est_rel (species, pred_count, est_pct):")
        print(top[['species','pred_count','est_pct']].to_string(index=False))

        # summary file
        summary = {
            "pub_csv": pub_out.name,
            "raw_used": bool(raw is not None),
            "weighted_used": bool(weighted is not None),
            "reconciled_used": bool(reconciled is not None),
            "redistributed_unassigned": APPLY_REDISTRIBUTE_UNASSIGNED and frac_unassigned > REDISTRIBUTE_ONLY_IF_UNASSIGNED_FRAC_GT,
            "fraction_unassigned_before": frac_unassigned,
            "rows_in_pub_table": len(final_df)
        }
        summary_df = pd.DataFrame(list(summary.items()), columns=["key","value"])
        summary_out = extracted / "abundance_publication_summary.csv"
        summary_df.to_csv(summary_out, index=False)
        print("Wrote summary ->", summary_out)


Using extracted folder: C:\Users\Srijit\sih\ncbi_blast_db\extracted
Loaded raw file: abundance_from_predictions.csv
Loaded weighted file: abundance_from_predictions_weighted.csv
Loaded deconv file: abundance_from_predictions_deconvolved.csv
Loaded reconciled file: abundance_reconciled_species.csv
Wrote publication-ready CSV -> C:\Users\Srijit\sih\ncbi_blast_db\extracted\abundance_publication_ready.csv
Note: No redistribution performed.

Top 20 taxa by est_rel (species, pred_count, est_pct):
                        species  pred_count   est_pct
                     UNASSIGNED         0.0 52.783613
                Maylandia zebra         0.0 21.008403
               Chaetodon auriga         0.0  5.777311
          Arvicanthis niloticus         0.0  1.838235
                  Morchella sp.         0.0  1.313025
           Aonchotheca annulosa         0.0  1.050420
            Amanita fuscozonata         0.0  0.840336
       Pseudopestalotiopsis sp.         0.0  0.787815
Deuterostichococcu

In [135]:
# CELL: Apply BLAST assignments (strict) and recompute abundances (run only if blast_results.tsv is present)
from pathlib import Path
import pandas as pd, numpy as np, re

# Settings (conservative)
MIN_PID = 97.0
MIN_COV = 80.0
MAX_EVAL = 1e-6
APPLY_ASSIGNMENTS = True  # will only change rows that are currently UNASSIGNED

# locate extracted/
def find_extracted():
    cand = [
        Path.cwd()/"sih"/"ncbi_blast_db"/"extracted",
        Path.cwd()/"ncbi_blast_db"/"extracted",
        Path.cwd()/ "extracted",
        Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"),
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
    ]
    for p in cand:
        if p.exists() and p.is_dir(): return p.resolve()
    for p in Path.cwd().rglob("predictions_with_uncertainty.csv"):
        return p.parent.resolve()
    return None

extracted = find_extracted()
if extracted is None:
    print("Could not locate 'extracted/'. Stop.")
else:
    print("Using extracted:", extracted)
    blast_file = extracted / "blast_results.tsv"
    if not blast_file.exists():
        print("blast_results.tsv not found in extracted/. Run BLAST externally (use earlier printed command) and copy the file here, then re-run this cell.")
    else:
        print("Parsing BLAST file:", blast_file.name)
        cols = ["qseqid","sseqid","pident","length","qlen","qstart","qend","sstart","send","evalue","bitscore","stitle"]
        dfb = pd.read_csv(blast_file, sep="\t", names=cols, header=None, quoting=3, dtype={"stitle":str})
        dfb["pident"] = pd.to_numeric(dfb["pident"], errors="coerce").fillna(0.0)
        dfb["length"] = pd.to_numeric(dfb["length"], errors="coerce").fillna(0.0)
        dfb["qlen"] = pd.to_numeric(dfb["qlen"], errors="coerce").replace(0, np.nan)
        dfb["coverage_pct"] = dfb.apply(lambda r: (r["length"]/r["qlen"]*100.0) if (pd.notna(r["qlen"]) and r["qlen"]>0) else 0.0, axis=1)
        dfb["evalue"] = pd.to_numeric(dfb["evalue"], errors="coerce").fillna(1e6)
        dfb_high = dfb[(dfb["pident"]>=MIN_PID) & (dfb["coverage_pct"]>=MIN_COV) & (dfb["evalue"]<=MAX_EVAL)].copy()
        print("BLAST rows passing thresholds:", len(dfb_high))
        if dfb_high.empty:
            print("No high-confidence BLAST hits found. Nothing to apply.")
        else:
            dfb_top = dfb_high.sort_values(["qseqid","bitscore"], ascending=[True,False]).groupby("qseqid", as_index=False).first()
            def extract_species(title):
                if not isinstance(title,str) or not title.strip(): return None
                m = re.search(r'\b([A-Z][a-z]+(?:\s+[a-z][a-z\-]+))\b', title)
                if m: return m.group(1).strip()
                m2 = re.search(r'\b([A-Z][a-z]+)\s+sp\b', title)
                if m2: return m2.group(0).strip()
                return None
            assignments = {}
            for _, r in dfb_top.iterrows():
                q = str(r["qseqid"]); s = extract_species(r["stitle"])
                if s:
                    assignments[q] = {"species": s, "pident": float(r["pident"]), "coverage": float(r["coverage_pct"]), "sseqid": r["sseqid"], "stitle": r["stitle"]}
            # audit CSV
            assign_df = pd.DataFrame([{"query_id":q,**v} for q,v in assignments.items()])
            assign_df.to_csv(extracted/"blast_species_assignments.csv", index=False)
            print("Wrote blast_species_assignments.csv")
            # load predictions
            pred_file = extracted/"predictions_with_uncertainty.csv"
            if not pred_file.exists(): pred_file = extracted/"predictions.csv"
            if not pred_file.exists():
                print("Predictions CSV not found. Abort.")
            else:
                pred = pd.read_csv(pred_file)
                cols_lower = {c.lower():c for c in pred.columns}
                species_col = next((cols_lower[k] for k in cols_lower if "species_pred_label" in k or ("species" in k and "pred" in k)), None)
                if species_col is None:
                    species_col = next((c for c in pred.columns if "species" in c.lower()), pred.columns[-1])
                id_col = next((cols_lower[k] for k in cols_lower if k in ("id","accession","seqid","read_id","readid","accession_id","global_index")), None)
                if id_col is None:
                    print("No id column in predictions — cannot map BLAST queries. Abort.")
                else:
                    pred["_species_norm"] = pred[species_col].astype(str).fillna("UNASSIGNED").apply(lambda s: " ".join(str(s).split()))
                    applied = 0
                    for q,info in assignments.items():
                        exact_mask = pred[id_col].astype(str) == str(q)
                        if exact_mask.any():
                            for ri in pred.index[exact_mask]:
                                if pred.at[ri, "_species_norm"].upper()=="UNASSIGNED":
                                    pred.at[ri, species_col] = info["species"]
                                    pred.at[ri, "blast_assigned_species"] = info["species"]
                                    pred.at[ri, "blast_pident"] = info["pident"]
                                    pred.at[ri, "blast_coverage"] = info["coverage"]
                                    pred.at[ri, "blast_sseqid"] = info["sseqid"]
                                    pred.at[ri, "blast_stitle"] = info["stitle"]
                                    applied += 1
                            continue
                        substr_mask = pred[id_col].astype(str).str.contains(str(q), na=False)
                        if substr_mask.any():
                            for ri in pred.index[substr_mask]:
                                if pred.at[ri, "_species_norm"].upper()=="UNASSIGNED":
                                    pred.at[ri, species_col] = info["species"]
                                    pred.at[ri, "blast_assigned_species"] = info["species"]
                                    pred.at[ri, "blast_pident"] = info["pident"]
                                    pred.at[ri, "blast_coverage"] = info["coverage"]
                                    pred.at[ri, "blast_sseqid"] = info["sseqid"]
                                    pred.at[ri, "blast_stitle"] = info["stitle"]
                                    applied += 1
                    if applied==0:
                        print("No UNASSIGNED rows matched the high-confidence BLAST hits. Nothing changed.")
                    else:
                        pred.to_csv(extracted/"predictions_after_blast_forced.csv", index=False)
                        print(f"Applied BLAST assignments to {applied} rows -> predictions_after_blast_forced.csv")
                        # recompute abundances (simple)
                        pred["_species_norm_final"] = pred[species_col].astype(str).fillna("").apply(lambda s:" ".join(str(s).split()))
                        counts = pred["_species_norm_final"].value_counts().reset_index()
                        counts.columns = ["species","count"]
                        counts["rel"] = counts["count"]/counts["count"].sum()
                        counts.to_csv(extracted/"abundance_after_blast.csv", index=False)
                        print("Wrote abundance_after_blast.csv")


Using extracted: C:\Users\Srijit\sih\ncbi_blast_db\extracted
blast_results.tsv not found in extracted/. Run BLAST externally (use earlier printed command) and copy the file here, then re-run this cell.


In [137]:
# KNN-based assignment of UNASSIGNED reads using embeddings (no BLAST needed).
# Conservative, audited, uses only files in extracted/.
from pathlib import Path
import numpy as np
import pandas as pd
import math, sys

# --- settings (tweak if you want more/less aggressive) ---
K_NEIGH = 7               # number of neighbor votes to consider (auto-limited to assigned count)
VOTE_FRACTION_THRESH = 0.60   # fraction of neighbor-weighted vote needed to accept top species
SIM_REL_TO_ASSIGNED_MED = 0.85  # require top-sim >= median_assigned_sim * this factor (keeps conservative)
MIN_TOP_SIM_ABS = 0.70    # always require this absolute minimum top similarity
# ---------------------------------------------------------

def find_extracted():
    candidates = [
        Path.cwd() / "sih" / "ncbi_blast_db" / "extracted",
        Path.cwd() / "ncbi_blast_db" / "extracted",
        Path.cwd() / "extracted",
        Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"),
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
    ]
    for p in candidates:
        if p.exists() and p.is_dir(): return p.resolve()
    for p in Path.cwd().rglob("predictions_with_uncertainty.csv"):
        return p.parent.resolve()
    return None

extracted = find_extracted()
if extracted is None:
    raise SystemExit("Could not locate extracted/ folder with your predictions and embeddings. Place files in an 'extracted' folder and re-run.")

print("Using extracted folder:", extracted)

# load predictions
pred_file = extracted / "predictions_with_uncertainty.csv"
if not pred_file.exists():
    pred_file = extracted / "predictions.csv"
if not pred_file.exists():
    raise SystemExit("Predictions CSV not found in extracted/.")

pred = pd.read_csv(pred_file)
print("Loaded predictions:", pred.shape)

# detect species column / id / index mapping
species_col = next((c for c in pred.columns if "species_pred_label" in c.lower()), None)
if species_col is None:
    species_col = next((c for c in pred.columns if "species" in c.lower()), None)
id_col = next((c for c in pred.columns if c.lower() in ("id","accession","seqid","global_index","read_id","readid")), None)
# pick embedding index column
index_col = None
if "global_index" in pred.columns:
    index_col = "global_index"
elif "pca_index" in pred.columns:
    index_col = "pca_index"
else:
    # assume order corresponds to embeddings
    index_col = None

print("Detected columns -> species:", species_col, "| id:", id_col, "| index_col:", index_col)

# load embeddings
emb_path = extracted / "embeddings_pca.npy"
if not emb_path.exists():
    raise SystemExit("embeddings_pca.npy not found in extracted/. Cannot run KNN assignment without embeddings.")
X = np.load(emb_path)
print("Loaded embeddings shape:", X.shape)

# map prediction rows -> embedding indices
if index_col is not None:
    try:
        pred["_emb_index"] = pred[index_col].astype(int).values
    except Exception:
        pred["_emb_index"] = pred[index_col].astype(float).fillna(0).astype(int).values
else:
    # fallback: assume same order and lengths match
    if len(pred) == X.shape[0]:
        pred["_emb_index"] = np.arange(len(pred))
    else:
        raise SystemExit("Cannot determine mapping from prediction rows to embedding rows; predictions length != embeddings rows and no index column found.")

# normalize embeddings (for cosine similarity)
eps = 1e-12
norms = np.linalg.norm(X, axis=1, keepdims=True)
norms = np.maximum(norms, eps)
Xn = X / norms

# determine assigned vs unassigned
pred["_species_norm"] = pred[species_col].astype(str).fillna("UNASSIGNED").apply(lambda s: " ".join(str(s).split()))
mask_assigned = pred["_species_norm"].str.upper() != "UNASSIGNED"
mask_unassigned = pred["_species_norm"].str.upper() == "UNASSIGNED"
assigned_rows = pred[mask_assigned].copy()
unassigned_rows = pred[mask_unassigned].copy()
print("Assigned rows:", len(assigned_rows), "Unassigned rows:", len(unassigned_rows))

if len(assigned_rows) < 5:
    raise SystemExit("Too few assigned rows to run embedding-based assignment (need >=5). Use BLAST or supply more labeled data.")

# embedding indices arrays
assigned_idx = assigned_rows["_emb_index"].astype(int).values
unassigned_idx = unassigned_rows["_emb_index"].astype(int).values

# build assigned normalized matrix and species list
Xa = Xn[assigned_idx, :]
labels_assigned = assigned_rows["_species_norm"].values

# compute per-assigned nearest neighbor similarity (to get median within-assigned similarity)
# compute dot product matrix limited to assigned (small)
Sa = Xa.dot(Xa.T)  # shape (na, na)
# set self to -inf
np.fill_diagonal(Sa, -np.inf)
nearest_assigned_sim = np.max(Sa, axis=1)
median_assigned_sim = float(np.median(nearest_assigned_sim[np.isfinite(nearest_assigned_sim)]))
if not np.isfinite(median_assigned_sim):
    median_assigned_sim = 0.5
print("median assigned nearest-neighbor similarity:", round(median_assigned_sim,4))

# compute similarity of each unassigned to all assigned: S_u_a = X_un @ Xa.T
Xu = Xn[unassigned_idx, :]
S_u_a = Xu.dot(Xa.T)  # shape (n_unassigned, n_assigned)

# for each unassigned compute top-k neighbors and weighted vote
k = min(K_NEIGH, Xa.shape[0])
assigned_species_list = list(labels_assigned)
assigned_species_array = np.array(labels_assigned)

assignments = []
for i_un, sims in enumerate(S_u_a):
    if sims.max() <= 0:
        assignments.append(None); continue
    # top k indices by similarity
    topk_idx = np.argsort(-sims)[:k]
    topk_sims = sims[topk_idx]
    sum_sims = topk_sims.sum()
    if sum_sims <= 0:
        assignments.append(None); continue
    # weighted votes by similarity
    species_votes = {}
    for idx_local, sim_val in zip(topk_idx, topk_sims):
        sp = assigned_species_array[idx_local]
        species_votes[sp] = species_votes.get(sp, 0.0) + float(sim_val)
    # pick top species
    top_species, top_vote = max(species_votes.items(), key=lambda x: x[1])
    vote_fraction = top_vote / sum_sims
    top_similarity = float(topk_sims.max())
    # acceptance criteria
    accept = (vote_fraction >= VOTE_FRACTION_THRESH) and (top_similarity >= max(MIN_TOP_SIM_ABS, SIM_REL_TO_ASSIGNED_MED * median_assigned_sim))
    if accept:
        assignments.append({
            "assigned_species": top_species,
            "vote_fraction": float(vote_fraction),
            "top_similarity": float(top_similarity),
            "k": int(k)
        })
    else:
        assignments.append(None)

# apply assignments to a copy of predictions, but do NOT overwrite original
pred_knn = pred.copy()
applied_count = 0
audit_rows = []
for i_row, (ri, row) in enumerate(unassigned_rows.iterrows()):
    out = assignments[i_row]
    if out is not None:
        idx_global = int(row.name)  # original pred row index
        # write columns
        pred_knn.at[idx_global, "knn_assigned_species"] = out["assigned_species"]
        pred_knn.at[idx_global, "knn_vote_fraction"] = out["vote_fraction"]
        pred_knn.at[idx_global, "knn_top_similarity"] = out["top_similarity"]
        pred_knn.at[idx_global, "knn_k"] = out["k"]
        # apply assignment to species column (non-destructive; original file stays)
        pred_knn.at[idx_global, species_col] = out["assigned_species"]
        applied_count += 1
        audit_rows.append({
            "row_index": int(idx_global),
            "id": pred_knn.at[idx_global, id_col] if id_col in pred_knn.columns else "",
            "new_species": out["assigned_species"],
            "vote_fraction": out["vote_fraction"],
            "top_similarity": out["top_similarity"]
        })

print("KNN assignments applied to UNASSIGNED rows:", applied_count)

# save outputs
out_preds = extracted / "predictions_after_knn_assignments.csv"
pred_knn.to_csv(out_preds, index=False)
out_audit = extracted / "knn_assignments_audit.csv"
pd.DataFrame(audit_rows).to_csv(out_audit, index=False)

# recompute abundances from the updated predictions
pred_knn["_species_norm_final"] = pred_knn[species_col].astype(str).fillna("").apply(lambda s: " ".join(str(s).split()))
counts = pred_knn["_species_norm_final"].value_counts().reset_index()
counts.columns = ["species","count"]
counts["relative_abundance"] = counts["count"] / counts["count"].sum() if counts["count"].sum()>0 else 0.0
out_abund = extracted / "abundance_from_predictions_after_knn.csv"
counts.to_csv(out_abund, index=False)

# confidence-weighted: try to detect a confidence column
conf_col = next((c for c in pred_knn.columns if "species_pred_conf" in c.lower() or c.lower().endswith("_conf") or "topprob" in c.lower()), None)
if conf_col:
    pred_knn["_conf_num"] = pd.to_numeric(pred_knn[conf_col], errors="coerce").fillna(0.0)
else:
    pred_knn["_conf_num"] = 0.0
weighted = pred_knn.groupby("_species_norm_final")["_conf_num"].sum().reset_index().rename(columns={"_conf_num":"conf_sum"})
total_conf = weighted["conf_sum"].sum() if weighted["conf_sum"].sum()>0 else 1.0
weighted["relative_abundance_weighted"] = weighted["conf_sum"] / total_conf
out_weight = extracted / "abundance_from_predictions_weighted_after_knn.csv"
weighted.to_csv(out_weight, index=False)

print("Wrote files:")
print(" -", out_preds.name)
print(" -", out_audit.name)
print(" -", out_abund.name)
print(" -", out_weight.name)

# show top assigned species from audit
if applied_count>0:
    adf = pd.read_csv(out_audit)
    top_assigned = adf['new_species'].value_counts().reset_index().rename(columns={'index':'species','new_species':'count'})
    print("\nTop species assigned by KNN (audit):")
    print(top_assigned.head(15).to_string(index=False))
else:
    print("\nNo KNN assignments met the conservative criteria. You can loosen thresholds (increase VOTE_FRACTION_THRESH lower, or lower SIM_REL_TO_ASSIGNED_MED) and re-run the cell.")


Using extracted folder: C:\Users\Srijit\sih\ncbi_blast_db\extracted
Loaded predictions: (380, 54)
Detected columns -> species: species_pred_label | id: global_index | index_col: global_index
Loaded embeddings shape: (2555, 64)
Assigned rows: 236 Unassigned rows: 144
median assigned nearest-neighbor similarity: 1.0
KNN assignments applied to UNASSIGNED rows: 3
Wrote files:
 - predictions_after_knn_assignments.csv
 - knn_assignments_audit.csv
 - abundance_from_predictions_after_knn.csv
 - abundance_from_predictions_weighted_after_knn.csv

Top species assigned by KNN (audit):
                 count  count
       Maylandia zebra      2
Hysterothylacium fabri      1


In [143]:
# Robust: audit KNN, recompute abundances, safe before/after comparison (no dtype merge errors)
import pandas as pd, numpy as np, traceback
from pathlib import Path

# ---------- Config (no need to change) ----------
OUTPUT_PREFIX = "after_knn"   # used for filenames
# -------------------------------------------------

def find_extracted():
    candidates = [
        Path.cwd()/"sih"/"ncbi_blast_db"/"extracted",
        Path.cwd()/"ncbi_blast_db"/"extracted",
        Path.cwd()/ "extracted",
        Path(r"C:\Users\Srijit\sih\ncbi_blast_db\extracted"),
        Path(r"C:\Users\Srijit\OneDrive\Desktop\sihtaxa\sihabundance\ncbi_blast_db\extracted"),
    ]
    for p in candidates:
        if p.exists() and p.is_dir():
            return p.resolve()
    # fallback: first folder named 'extracted' under cwd
    for p in Path.cwd().rglob("**/extracted"):
        if p.is_dir():
            return p.resolve()
    return None

def safe_print_df(df, n=10, label=None):
    if df is None:
        print(f"{label or 'DataFrame'}: None")
        return
    try:
        from IPython.display import display
        if label: print(label)
        display(df.head(n))
    except Exception:
        print((df.head(n)).to_string())

def detect_species_col(df):
    # Prefer explicitly named prediction columns
    for c in df.columns:
        if "species_pred_label" in c.lower():
            return c
    # otherwise pick any column that contains 'species' (but not 'idx')
    for c in df.columns:
        cl = c.lower()
        if "species" in cl and "idx" not in cl:
            return c
    # fallback: first object (text) column
    text_cols = [c for c in df.columns if df[c].dtype == object]
    if text_cols:
        return text_cols[0]
    # final fallback: last column
    return df.columns[-1]

def detect_conf_col(df):
    for c in df.columns:
        cl = c.lower()
        if "species_pred_conf" in cl or cl.endswith("_conf") or "topprob" in cl or "mc_mean" in cl:
            return c
    return None

def normalize_species_series(s):
    # convert to string, strip whitespace, replace empty with 'UNASSIGNED'
    s2 = s.astype(str).fillna("").apply(lambda x: " ".join(x.split()))
    s2 = s2.replace({"": "UNASSIGNED", "nan": "UNASSIGNED"})
    return s2

def recompute_abundances(pred_df, species_col_hint=None):
    """
    Returns (abundance_df, species_col_used)
    abundance_df columns: species, count, relative_abundance, conf_sum, relative_abundance_weighted
    """
    try:
        df = pred_df.copy()
        species_col = species_col_hint or detect_species_col(df)
        if species_col not in df.columns:
            # create placeholder
            df[species_col] = "UNASSIGNED"
        df["_species_norm_final"] = normalize_species_series(df[species_col])

        # raw counts
        counts = df["_species_norm_final"].value_counts(dropna=False).reset_index()
        counts.columns = ["species", "count"]
        counts["species"] = counts["species"].astype(str)

        # confidence-weighted
        conf_col = detect_conf_col(df)
        if conf_col:
            df["_conf_num"] = pd.to_numeric(df[conf_col], errors="coerce").fillna(0.0)
        else:
            df["_conf_num"] = 0.0
        weighted = df.groupby("_species_norm_final")["_conf_num"].sum().reset_index().rename(columns={"_species_norm_final":"species", "_conf_num":"conf_sum"})
        weighted["species"] = weighted["species"].astype(str)

        # merge safely on species (both string)
        merged = pd.merge(counts, weighted, on="species", how="outer").fillna(0.0)
        # ensure numeric
        merged["count"] = pd.to_numeric(merged["count"], errors="coerce").fillna(0.0)
        merged["conf_sum"] = pd.to_numeric(merged["conf_sum"], errors="coerce").fillna(0.0)

        # relative columns
        total_count = merged["count"].sum() if merged["count"].sum() > 0 else 1.0
        merged["relative_abundance"] = merged["count"] / total_count
        total_conf = merged["conf_sum"].sum() if merged["conf_sum"].sum() > 0 else 1.0
        merged["relative_abundance_weighted"] = merged["conf_sum"] / total_conf

        # keep nice column order
        cols_keep = ["species","count","relative_abundance","conf_sum","relative_abundance_weighted"]
        for c in cols_keep:
            if c not in merged.columns:
                merged[c] = 0.0
        merged = merged[cols_keep]
        return merged, species_col
    except Exception as e:
        print("Error in recompute_abundances:", e)
        traceback.print_exc()
        return pd.DataFrame(columns=["species","count","relative_abundance","conf_sum","relative_abundance_weighted"]), species_col_hint

# ------------------ Main ------------------
extracted = find_extracted()
if extracted is None:
    raise SystemExit("Could not find an 'extracted' folder. Put your outputs into an 'extracted' directory and re-run.")
print("Using extracted folder:", extracted)

# pick predictions (prefer post-KNN file)
candidate_preds = [
    "predictions_after_knn_assignments.csv",
    "predictions_after_knn_rerun.csv",
    "predictions_with_uncertainty.csv",
    "predictions.csv"
]
pred_path = None
for fn in candidate_preds:
    p = extracted / fn
    if p.exists():
        pred_path = p
        break
if pred_path is None:
    raise SystemExit("No predictions CSV found in extracted/. Expected one of: " + ", ".join(candidate_preds))
print("Loaded predictions file:", pred_path.name)

pred = pd.read_csv(pred_path)
print("Predictions rows:", len(pred))

# load and display audit if present
audit_path = extracted / "knn_assignments_audit.csv"
if audit_path.exists():
    try:
        audit = pd.read_csv(audit_path)
        print("Found KNN audit:", audit_path.name, "rows =", len(audit))
        safe_print_df(audit, n=50, label="KNN audit (first rows):")
    except Exception as e:
        print("Failed to read audit CSV:", e)
        audit = None
else:
    print("No KNN audit file found.")

# recompute abundances from current predictions
after_abund, species_col_used = recompute_abundances(pred)
out_after_path = extracted / f"abundance_from_predictions_{OUTPUT_PREFIX}_recomputed.csv"
after_abund.to_csv(out_after_path, index=False)
print("Wrote recomputed abundance ->", out_after_path.name)
safe_print_df(after_abund.sort_values("count", ascending=False), n=20, label="Top taxa (post-KNN recomputed):")

# try load previous raw abundance for comparison (if present)
before_path = extracted / "abundance_from_predictions.csv"
if before_path.exists():
    try:
        before_abund = pd.read_csv(before_path)
        print("Loaded before-abundance file:", before_path.name)
        # normalize column names: try to find species & count columns
        if 'species' not in before_abund.columns:
            # assume first column is species
            before_abund = before_abund.rename(columns={before_abund.columns[0]:'species'})
        if 'count' not in before_abund.columns:
            # choose plausible count column
            if 'pred_count' in before_abund.columns:
                before_abund = before_abund.rename(columns={'pred_count':'count'})
            elif before_abund.shape[1] >= 2:
                before_abund = before_abund.rename(columns={before_abund.columns[1]:'count'})
            else:
                before_abund['count'] = 0
        # force species to string to avoid dtype mismatch during merge
        before_abund['species'] = before_abund['species'].astype(str)
        before_abund['count'] = pd.to_numeric(before_abund['count'], errors='coerce').fillna(0.0)
        # ensure after table species are strings as well
        after_abund['species'] = after_abund['species'].astype(str)
        after_abund['count'] = pd.to_numeric(after_abund['count'], errors='coerce').fillna(0.0)

        # merge safely on string species
        merged_comp = pd.merge(
            before_abund[['species','count']].rename(columns={'count':'count_before'}),
            after_abund[['species','count']].rename(columns={'count':'count_after'}),
            on='species', how='outer'
        ).fillna(0.0)
        merged_comp['delta'] = merged_comp['count_after'] - merged_comp['count_before']
        # safe pct change: NaN when before==0
        merged_comp['pct_change'] = merged_comp.apply(lambda r: (r['delta'] / r['count_before']*100.0) if r['count_before']>0 else np.nan, axis=1)
        merged_comp = merged_comp.sort_values('count_after', ascending=False).reset_index(drop=True)
        print("\nTop changes (before -> after) (top 30):")
        safe_print_df(merged_comp, n=30)
        changed = merged_comp[merged_comp['delta'] != 0.0]
        print("Number of species with count changes:", len(changed))
        if len(changed) > 0:
            print("Examples of changed species (first 20):")
            safe_print_df(changed.head(20), n=20)
        # save comparison for audit
        comp_out = extracted / f"abundance_before_after_comparison_{OUTPUT_PREFIX}.csv"
        merged_comp.to_csv(comp_out, index=False)
        print("Wrote comparison CSV ->", comp_out.name)
    except Exception as e:
        print("Error while comparing before/after abundance:", e)
        traceback.print_exc()
else:
    print("No prior 'abundance_from_predictions.csv' found; skipping before/after comparison.")

# Also save weighted counts if not already present (recompute weighted table)
try:
    # recompute weighted table from pred (uses conf col if present)
    _, _ = recompute_abundances(pred, species_col_hint=species_col_used)  # ensures columns present
    weighted_out = extracted / f"abundance_from_predictions_{OUTPUT_PREFIX}_weighted_recomputed.csv"
    # use after_abund (which already has conf_sum & relative_abundance_weighted)
    if 'conf_sum' not in after_abund.columns:
        after_abund['conf_sum'] = 0.0
        after_abund['relative_abundance_weighted'] = 0.0
    after_abund[['species','conf_sum','relative_abundance_weighted']].to_csv(weighted_out, index=False)
    print("Wrote weighted abundance CSV ->", weighted_out.name)
except Exception as e:
    print("Failed to write weighted abundance:", e)
    traceback.print_exc()

print("\nDone. Output files in:", extracted)
for fn in [
    out_after_path.name,
    f"abundance_from_predictions_{OUTPUT_PREFIX}_weighted_recomputed.csv",
    f"abundance_before_after_comparison_{OUTPUT_PREFIX}.csv",
    "knn_assignments_audit.csv",
    "predictions_after_knn_assignments.csv",
]:
    p = extracted / fn
    if p.exists():
        print(" -", fn)


Using extracted folder: C:\Users\Srijit\sih\ncbi_blast_db\extracted
Loaded predictions file: predictions_after_knn_assignments.csv
Predictions rows: 380
Found KNN audit: knn_assignments_audit.csv rows = 3
KNN audit (first rows):


Unnamed: 0,row_index,id,new_species,vote_fraction,top_similarity
0,20,2541,Hysterothylacium fabri,0.729104,0.864676
1,55,1572,Maylandia zebra,1.0,0.851294
2,226,1741,Maylandia zebra,1.0,0.851294


Wrote recomputed abundance -> abundance_from_predictions_after_knn_recomputed.csv
Top taxa (post-KNN recomputed):


Unnamed: 0,species,count,relative_abundance,conf_sum,relative_abundance_weighted
51,UNASSIGNED,141,0.371053,135.859698,0.406443
28,Maylandia zebra,82,0.215789,79.453327,0.237695
10,Chaetodon auriga,46,0.121053,25.948245,0.077628
33,Morchella sp.,10,0.026316,6.121686,0.018314
5,Arvicanthis niloticus,9,0.023684,7.17039,0.021451
6,Aspergillus costaricensis,7,0.018421,6.823369,0.020413
8,Callospermophilus lateralis,6,0.015789,4.341063,0.012987
20,Hysterothylacium fabri,6,0.015789,4.773331,0.01428
17,Entoloma sp.,5,0.013158,3.705058,0.011084
19,Eucoleus sp.,4,0.010526,3.86194,0.011554


Loaded before-abundance file: abundance_from_predictions.csv

Top changes (before -> after) (top 30):


Unnamed: 0,species,count_before,count_after,delta,pct_change
0,UNASSIGNED,0.0,141.0,141.0,
1,Maylandia zebra,0.0,82.0,82.0,
2,Chaetodon auriga,0.0,46.0,46.0,
3,Morchella sp.,0.0,10.0,10.0,
4,Arvicanthis niloticus,0.0,9.0,9.0,
5,Aspergillus costaricensis,0.0,7.0,7.0,
6,Hysterothylacium fabri,0.0,6.0,6.0,
7,Callospermophilus lateralis,0.0,6.0,6.0,
8,Entoloma sp.,0.0,5.0,5.0,
9,Eucoleus sp.,0.0,4.0,4.0,


Number of species with count changes: 104
Examples of changed species (first 20):


Unnamed: 0,species,count_before,count_after,delta,pct_change
0,UNASSIGNED,0.0,141.0,141.0,
1,Maylandia zebra,0.0,82.0,82.0,
2,Chaetodon auriga,0.0,46.0,46.0,
3,Morchella sp.,0.0,10.0,10.0,
4,Arvicanthis niloticus,0.0,9.0,9.0,
5,Aspergillus costaricensis,0.0,7.0,7.0,
6,Hysterothylacium fabri,0.0,6.0,6.0,
7,Callospermophilus lateralis,0.0,6.0,6.0,
8,Entoloma sp.,0.0,5.0,5.0,
9,Eucoleus sp.,0.0,4.0,4.0,


Wrote comparison CSV -> abundance_before_after_comparison_after_knn.csv
Wrote weighted abundance CSV -> abundance_from_predictions_after_knn_weighted_recomputed.csv

Done. Output files in: C:\Users\Srijit\sih\ncbi_blast_db\extracted
 - abundance_from_predictions_after_knn_recomputed.csv
 - abundance_from_predictions_after_knn_weighted_recomputed.csv
 - abundance_before_after_comparison_after_knn.csv
 - knn_assignments_audit.csv
 - predictions_after_knn_assignments.csv
