In [2]:
# Developer Lab Notebook
# Experiment: Measure Inference Cost and GPU Usage

In [3]:
%matplotlib inline
import sys, os
sys.path.append(os.path.abspath(".."))

In [None]:
"""
4. The current result analysis does not fully address the risk of overfitting. Achieving 99.43% accuracy in CDS classification might be influenced 
by data imbalances or the characteristics of negative sampling (e.g., downsampling of non-CDS regions). It remains unclear whether the model relies 
on genuine sequence features or on confounding factors like ORF length. To improve confidence in the results, it would be helpful to include a 
confusion matrix, detailed performance on short ORFs (<300 bp), and further exploration of the model’s decision criteria. 
"""

In [None]:
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from Bio import SeqIO
import matplotlib.pyplot as plt
from api.core import AnnotatorPipeline
import orfipy_core as oc

# Initialize
annotator = AnnotatorPipeline()
data_dir = Path('./data3')
results = []

# Helper: parse .gff
def parse_gff(gff_file):
    cds_coords = set()
    with open(gff_file, 'r') as f:
        for line in f:
            if line.startswith("#"): continue
            parts = line.strip().split("\t")
            if len(parts) > 3 and parts[2] == "CDS":
                cds_coords.add((int(parts[3]), int(parts[4]), parts[6]))
    return cds_coords

# Main experiment
for fa_file in sorted(data_dir.glob("*.fa")):
    gff_file = fa_file.with_suffix(".gff")
    if not gff_file.exists(): continue

    print(f"📄 Processing {fa_file.name}")
    cds_truth = parse_gff(gff_file)

    for record in SeqIO.parse(fa_file, "fasta"):
        seq = str(record.seq)
        seq_rc = str(record.seq.reverse_complement())
        orfs_pos, orfs_neg = annotator._parse_orfs(
            oc.start_search(seq, seq_rc, record.id, 10, 1000000, 'b',
                            ['TTG', 'CTG', 'ATG', 'GTG'],
                            ['TAA', 'TAG', 'TGA'], '1', True, False,
                            False, False, True,
                            [False, False, True, False, False])[2]
        )

        inputs = annotator._cds_input_parser(orfs_pos, "+") + annotator._cds_input_parser(orfs_neg, "-")
        sequences = [x["sequence"] for x in inputs]
        meta = [{"start": x["start"], "end": x["end"], "strand": x["strand"], "len": x["end"] - x["start"]} for x in inputs]
        preds = annotator._prediction(annotator.model_cds, annotator.tokenizer, sequences)

        for p, m in zip(preds, meta):
            label = 1 if (m["start"], m["end"], m["strand"]) in cds_truth else 0
            results.append({
                "Genome": fa_file.name,
                "Length": m["len"],
                "TrueLabel": label,
                "PredLabel": p
            })

# To CSV and visualization
df = pd.DataFrame(results)
bins = [0, 100, 200, 300, 400, 500, 1000, 2000, 10000]
labels = ["<100", "100-200", "200-300", "300-400", "400-500", "500-1000", "1000-2000", "≥2000"]
df["LengthBin"] = pd.cut(df["Length"], bins=bins, labels=labels, include_lowest=True)

summary = df.groupby("LengthBin").apply(
    lambda g: pd.Series({
        "Count": len(g),
        "Precision": np.mean((g["PredLabel"] == 1) & (g["TrueLabel"] == 1)) / max(np.mean(g["PredLabel"] == 1), 1e-6),
        "Recall": np.mean((g["PredLabel"] == 1) & (g["TrueLabel"] == 1)) / max(np.mean(g["TrueLabel"] == 1), 1e-6),
        "Accuracy": np.mean(g["PredLabel"] == g["TrueLabel"])
    })
).reset_index()

summary.to_csv("cds_performance_by_orf_length.csv", index=False)

# Plot
summary.plot(x="LengthBin", y=["Precision", "Recall", "Accuracy"], kind="bar", figsize=(10, 6))
plt.title("CDS Classifier Performance by ORF Length")
plt.ylabel("Metric Score")
plt.xticks(rotation=45)
plt.grid(True)
plt.tight_layout()
plt.show()*