In [5]:
# ============================================================
# SpliceTransformer FINAL INFERENCE (ALL-IN BEST PRACTICE)
# ============================================================

import os
import json
import torch
import numpy as np
import pandas as pd
import importlib.util

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from collections import Counter

# ================= CONFIG =================
MODEL_MAX_LEN = 8192
SEQ_LEN = 601
BATCH_SIZE = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- INFERENCE HYPERPARAMS ----
WINDOW_RADIUS = 25          # pooling Â±25
TEMPERATURE = 0.7           # soften / sharpen probs
TOP_SPLICE_RATIO = 0.30     # percentile-based decision (0.2â€“0.4)

# ================= PATHS =================
DATA_DIR = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\data"
RESULT_DIR = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result"
CKPT_PATH = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\SpTransformer_pytorch.ckpt"
MODEL_CODE = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\SpliceTransformer-main\model\model.py"
METRICS_FILE = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\metrics.py"

os.makedirs(RESULT_DIR, exist_ok=True)

# ================= UTILS =================
def load_module(path, name):
    spec = importlib.util.spec_from_file_location(name, path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

# ================= DATASET =================
class SpliceInferenceDataset(Dataset):
    def __init__(self, csv_path):
        self.df = pd.read_csv(csv_path)
        self.map = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
        self.pad_left = (MODEL_MAX_LEN - SEQ_LEN) // 2

    def __len__(self):
        return len(self.df)

    def encode_onehot(self, seq):
        onehot = np.zeros((4, len(seq)), dtype=np.float32)
        for i, c in enumerate(seq):
            if c in self.map:
                onehot[self.map[c], i] = 1.0
        return onehot

    def __getitem__(self, idx):
        seq = self.df.iloc[idx]["sequence"].upper().strip()
        label = int(self.df.iloc[idx]["Splicing_types"])

        # ---- CASE 1: already full length ----
        if len(seq) == MODEL_MAX_LEN:
            x = self.encode_onehot(seq)

        # ---- CASE 2: short â†’ center pad ----
        elif len(seq) == SEQ_LEN:
            # ðŸ”¥ non-zero pad to avoid background bias
            x = np.full((4, MODEL_MAX_LEN), 0.25, dtype=np.float32)
            onehot = self.encode_onehot(seq)
            x[:, self.pad_left:self.pad_left + SEQ_LEN] = onehot

        else:
            raise ValueError(f"Unexpected sequence length {len(seq)}")

        return torch.tensor(x), torch.tensor(label, dtype=torch.long)

# ================= LOAD MODEL & METRICS =================
print("ðŸ“¦ Loading model & metrics...")
metrics_mod = load_module(METRICS_FILE, "metrics")
model_mod = load_module(MODEL_CODE, "model")

model = model_mod.SpTransformer(
    dim=128,
    tissue_num=15,
    attn_depth=6,
    max_seq_len=MODEL_MAX_LEN
).to(DEVICE)

ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
state_dict = ckpt.get("state_dict", ckpt)
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
model.eval()

print("âœ… Model loaded")

# ================= SORT TEST FILES =================
def parse_ratio(fname):
    return int(fname.split("_")[1])

test_files = sorted(
    [f for f in os.listdir(DATA_DIR) if f.startswith("test_")],
    key=parse_ratio
)

print("ðŸ§ª Test files:", test_files)

# ================= INFERENCE =================
for fname in test_files:
    ratio = fname.replace("test_", "").replace(".csv", "")
    print(f"\nðŸš€ Inference for ratio {ratio}")

    dataset = SpliceInferenceDataset(os.path.join(DATA_DIR, fname))
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

    all_labels, all_probs = [], []

    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc=f"Inference {ratio}"):
            inputs = inputs.to(DEVICE)

            outputs = model(inputs)             # (B, C, L)
            splice_logits = outputs[:, :3, :]   # bg, acc, donor

            # ---- CENTER WINDOW POOLING ----
            center = splice_logits.shape[-1] // 2
            window = splice_logits[:, :, center-WINDOW_RADIUS:center+WINDOW_RADIUS+1]
            pooled_logits = torch.max(window, dim=2)[0]

            # ---- TEMPERATURE SCALING ----
            probs = torch.softmax(pooled_logits / TEMPERATURE, dim=1)

            all_labels.extend(labels.numpy())
            all_probs.extend(probs.cpu().numpy())

    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    # =====================================================
    # ðŸ”¥ RANK-BASED + SPLICE-CONFIDENCE DECISION (CORE)
    # =====================================================
    splice_scores = np.max(all_probs[:, 1:], axis=1)

    threshold = np.percentile(
        splice_scores,
        100 * (1 - TOP_SPLICE_RATIO)
    )

    all_preds = []
    for p, s in zip(all_probs, splice_scores):
        if s < threshold:
            all_preds.append(0)
        else:
            all_preds.append(1 if p[1] > p[2] else 2)

    all_preds = np.array(all_preds)

    print("ðŸ“Š Prediction distribution:", Counter(all_preds))

    # ================= METRICS =================
    metrics = metrics_mod.compute_metrics(
        labels=all_labels,
        preds=all_preds,
        probs=all_probs,
        k=2
    )

    # ================= SAVE =================
    output = {
        "test_file": fname,
        "ratio": ratio,
        "window_radius": WINDOW_RADIUS,
        "temperature": TEMPERATURE,
        "top_splice_ratio": TOP_SPLICE_RATIO,
        "num_samples": len(all_labels),
        "metrics": metrics
    }

    out_path = os.path.join(RESULT_DIR, f"metrics_{ratio}.json")
    with open(out_path, "w") as f:
        json.dump(output, f, indent=4)

    print(f"ðŸ’¾ Saved â†’ {out_path}")

print("\nðŸŽ‰ FINAL INFERENCE FINISHED (ALL TRICKS APPLIED)")


ðŸ“¦ Loading model & metrics...
âœ… Model loaded
ðŸ§ª Test files: ['test_1_1_1.csv', 'test_2_1_1.csv', 'test_4_1_1.csv', 'test_10_1_1.csv', 'test_100_1_1.csv']

ðŸš€ Inference for ratio 1_1_1


Inference 1_1_1: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 412/412 [06:29<00:00,  1.06it/s]


ðŸ“Š Prediction distribution: Counter({0: 18417, 1: 4166, 2: 3727})
ðŸ’¾ Saved â†’ D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_1_1_1.json

ðŸš€ Inference for ratio 2_1_1


Inference 2_1_1: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 549/549 [08:33<00:00,  1.07it/s]


ðŸ“Š Prediction distribution: Counter({0: 24592, 1: 5470, 2: 5070})
ðŸ’¾ Saved â†’ D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_2_1_1.json

ðŸš€ Inference for ratio 4_1_1


Inference 4_1_1: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 825/825 [12:30<00:00,  1.10it/s]


ðŸ“Š Prediction distribution: Counter({0: 36943, 2: 7921, 1: 7912})
ðŸ’¾ Saved â†’ D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_4_1_1.json

ðŸš€ Inference for ratio 10_1_1


Inference 10_1_1: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1652/1652 [24:13<00:00,  1.14it/s]


ðŸ“Š Prediction distribution: Counter({0: 73995, 2: 16948, 1: 14765})
ðŸ’¾ Saved â†’ D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_10_1_1.json

ðŸš€ Inference for ratio 100_1_1


Inference 100_1_1: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 13502/13502 [3:22:27<00:00,  1.11it/s] 


ðŸ“Š Prediction distribution: Counter({0: 604862, 2: 144236, 1: 114991})
ðŸ’¾ Saved â†’ D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_100_1_1.json

ðŸŽ‰ FINAL INFERENCE FINISHED (ALL TRICKS APPLIED)
