In [5]:
# ============================================================
# SpliceTransformer FINAL INFERENCE (ALL-IN BEST PRACTICE)
# ============================================================

import os
import json
import torch
import numpy as np
import pandas as pd
import importlib.util

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from collections import Counter

# ================= CONFIG =================
MODEL_MAX_LEN = 8192
SEQ_LEN = 601
BATCH_SIZE = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- INFERENCE HYPERPARAMS ----
WINDOW_RADIUS = 25          # pooling ¬±25
TEMPERATURE = 0.7           # soften / sharpen probs
TOP_SPLICE_RATIO = 0.30     # percentile-based decision (0.2‚Äì0.4)

# ================= PATHS =================
DATA_DIR = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\data"
RESULT_DIR = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result"
CKPT_PATH = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\SpTransformer_pytorch.ckpt"
MODEL_CODE = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\SpliceTransformer-main\model\model.py"
METRICS_FILE = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\metrics.py"

os.makedirs(RESULT_DIR, exist_ok=True)

# ================= UTILS =================
def load_module(path, name):
    spec = importlib.util.spec_from_file_location(name, path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

# ================= DATASET =================
class SpliceInferenceDataset(Dataset):
    def __init__(self, csv_path):
        self.df = pd.read_csv(csv_path)
        self.map = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
        self.pad_left = (MODEL_MAX_LEN - SEQ_LEN) // 2

    def __len__(self):
        return len(self.df)

    def encode_onehot(self, seq):
        onehot = np.zeros((4, len(seq)), dtype=np.float32)
        for i, c in enumerate(seq):
            if c in self.map:
                onehot[self.map[c], i] = 1.0
        return onehot

    def __getitem__(self, idx):
        seq = self.df.iloc[idx]["sequence"].upper().strip()
        label = int(self.df.iloc[idx]["Splicing_types"])

        # ---- CASE 1: already full length ----
        if len(seq) == MODEL_MAX_LEN:
            x = self.encode_onehot(seq)

        # ---- CASE 2: short ‚Üí center pad ----
        elif len(seq) == SEQ_LEN:
            # üî• non-zero pad to avoid background bias
            x = np.full((4, MODEL_MAX_LEN), 0.25, dtype=np.float32)
            onehot = self.encode_onehot(seq)
            x[:, self.pad_left:self.pad_left + SEQ_LEN] = onehot

        else:
            raise ValueError(f"Unexpected sequence length {len(seq)}")

        return torch.tensor(x), torch.tensor(label, dtype=torch.long)

# ================= LOAD MODEL & METRICS =================
print("üì¶ Loading model & metrics...")
metrics_mod = load_module(METRICS_FILE, "metrics")
model_mod = load_module(MODEL_CODE, "model")

model = model_mod.SpTransformer(
    dim=128,
    tissue_num=15,
    attn_depth=6,
    max_seq_len=MODEL_MAX_LEN
).to(DEVICE)

ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
state_dict = ckpt.get("state_dict", ckpt)
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
model.eval()

print("‚úÖ Model loaded")

# ================= SORT TEST FILES =================
def parse_ratio(fname):
    return int(fname.split("_")[1])

test_files = sorted(
    [f for f in os.listdir(DATA_DIR) if f.startswith("test_")],
    key=parse_ratio
)

print("üß™ Test files:", test_files)

# ================= INFERENCE =================
for fname in test_files:
    ratio = fname.replace("test_", "").replace(".csv", "")
    print(f"\nüöÄ Inference for ratio {ratio}")

    dataset = SpliceInferenceDataset(os.path.join(DATA_DIR, fname))
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

    all_labels, all_probs = [], []

    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc=f"Inference {ratio}"):
            inputs = inputs.to(DEVICE)

            outputs = model(inputs)             # (B, C, L)
            splice_logits = outputs[:, :3, :]   # bg, acc, donor

            # ---- CENTER WINDOW POOLING ----
            center = splice_logits.shape[-1] // 2
            window = splice_logits[:, :, center-WINDOW_RADIUS:center+WINDOW_RADIUS+1]
            pooled_logits = torch.max(window, dim=2)[0]

            # ---- TEMPERATURE SCALING ----
            probs = torch.softmax(pooled_logits / TEMPERATURE, dim=1)

            all_labels.extend(labels.numpy())
            all_probs.extend(probs.cpu().numpy())

    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    # =====================================================
    # üî• RANK-BASED + SPLICE-CONFIDENCE DECISION (CORE)
    # =====================================================
    splice_scores = np.max(all_probs[:, 1:], axis=1)

    threshold = np.percentile(
        splice_scores,
        100 * (1 - TOP_SPLICE_RATIO)
    )

    all_preds = []
    for p, s in zip(all_probs, splice_scores):
        if s < threshold:
            all_preds.append(0)
        else:
            all_preds.append(1 if p[1] > p[2] else 2)

    all_preds = np.array(all_preds)

    print("üìä Prediction distribution:", Counter(all_preds))

    # ================= METRICS =================
    metrics = metrics_mod.compute_metrics(
        labels=all_labels,
        preds=all_preds,
        probs=all_probs,
        k=2
    )

    # ================= SAVE =================
    output = {
        "test_file": fname,
        "ratio": ratio,
        "window_radius": WINDOW_RADIUS,
        "temperature": TEMPERATURE,
        "top_splice_ratio": TOP_SPLICE_RATIO,
        "num_samples": len(all_labels),
        "metrics": metrics
    }

    out_path = os.path.join(RESULT_DIR, f"metrics_{ratio}.json")
    with open(out_path, "w") as f:
        json.dump(output, f, indent=4)

    print(f"üíæ Saved ‚Üí {out_path}")

print("\nüéâ FINAL INFERENCE FINISHED (ALL TRICKS APPLIED)")


üì¶ Loading model & metrics...
‚úÖ Model loaded
üß™ Test files: ['test_1_1_1.csv', 'test_2_1_1.csv', 'test_4_1_1.csv', 'test_10_1_1.csv', 'test_100_1_1.csv']

üöÄ Inference for ratio 1_1_1


Inference 1_1_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 412/412 [06:29<00:00,  1.06it/s]


üìä Prediction distribution: Counter({0: 18417, 1: 4166, 2: 3727})
üíæ Saved ‚Üí D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_1_1_1.json

üöÄ Inference for ratio 2_1_1


Inference 2_1_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 549/549 [08:33<00:00,  1.07it/s]


üìä Prediction distribution: Counter({0: 24592, 1: 5470, 2: 5070})
üíæ Saved ‚Üí D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_2_1_1.json

üöÄ Inference for ratio 4_1_1


Inference 4_1_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 825/825 [12:30<00:00,  1.10it/s]


üìä Prediction distribution: Counter({0: 36943, 2: 7921, 1: 7912})
üíæ Saved ‚Üí D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_4_1_1.json

üöÄ Inference for ratio 10_1_1


Inference 10_1_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1652/1652 [24:13<00:00,  1.14it/s]


üìä Prediction distribution: Counter({0: 73995, 2: 16948, 1: 14765})
üíæ Saved ‚Üí D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_10_1_1.json

üöÄ Inference for ratio 100_1_1


Inference 100_1_1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 13502/13502 [3:22:27<00:00,  1.11it/s] 


üìä Prediction distribution: Counter({0: 604862, 2: 144236, 1: 114991})
üíæ Saved ‚Üí D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_100_1_1.json

üéâ FINAL INFERENCE FINISHED (ALL TRICKS APPLIED)


In [1]:
# ============================================================
# SpliceTransformer FINAL INFERENCE (MANUAL THRESHOLD)
# ============================================================

import os
import json
import torch
import numpy as np
import pandas as pd
import importlib.util

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from collections import Counter
from sklearn.metrics import confusion_matrix

# ================= CONFIG (T√ôY CH·ªàNH T·∫†I ƒê√ÇY) =================
MODEL_MAX_LEN = 8192
SEQ_LEN = 601
BATCH_SIZE = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- INFERENCE HYPERPARAMS ----
WINDOW_RADIUS = 25          # pooling ¬±25
TEMPERATURE = 0.7           # soften / sharpen probs

# üî• C·∫§U H√åNH NG∆Ø·ª†NG (QUAN TR·ªåNG)
# N·∫øu True: S·ª≠ d·ª•ng ng∆∞·ª°ng c·ªë ƒë·ªãnh do b·∫°n ƒë·∫∑t (SPLICE_THRESHOLD).
# N·∫øu False: S·ª≠ d·ª•ng Top % nh∆∞ code c≈© (TOP_SPLICE_RATIO).
USE_FIXED_THRESHOLD = True  

# Ng∆∞·ª°ng x√°c su·∫•t ƒë·ªÉ ch·∫•p nh·∫≠n l√† Splice site (0.0 -> 1.0)
# V√≠ d·ª•: 0.5 nghƒ©a l√† x√°c su·∫•t ph·∫£i > 50% m·ªõi ƒë∆∞·ª£c coi l√† 1 ho·∫∑c 2. N·∫øu th·∫•p h∆°n s·∫Ω v·ªÅ 0.
# 0.00001
SPLICE_THRESHOLD = 0.00001 

# C·∫•u h√¨nh c≈© (ch·ªâ d√πng khi USE_FIXED_THRESHOLD = False)
TOP_SPLICE_RATIO = 0.30     

# üî• ƒê·∫£o ng∆∞·ª£c nh√£n (True: 1->2, 2->1)
FIX_REVERSED_LABELS = True  

# ================= PATHS =================
DATA_DIR = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\data"
RESULT_DIR = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result"
CKPT_PATH = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\SpTransformer_pytorch.ckpt"
MODEL_CODE = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\SpliceTransformer-main\model\model.py"
METRICS_FILE = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\metrics.py"

os.makedirs(RESULT_DIR, exist_ok=True)

# ================= UTILS & DATASET =================
def load_module(path, name):
    spec = importlib.util.spec_from_file_location(name, path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

class SpliceInferenceDataset(Dataset):
    def __init__(self, csv_path):
        self.df = pd.read_csv(csv_path)
        self.map = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
        self.pad_left = (MODEL_MAX_LEN - SEQ_LEN) // 2

    def __len__(self):
        return len(self.df)

    def encode_onehot(self, seq):
        onehot = np.zeros((4, len(seq)), dtype=np.float32)
        for i, c in enumerate(seq):
            if c in self.map:
                onehot[self.map[c], i] = 1.0
        return onehot

    def __getitem__(self, idx):
        seq = self.df.iloc[idx]["sequence"].upper().strip()
        label = int(self.df.iloc[idx]["Splicing_types"])
        if len(seq) == MODEL_MAX_LEN:
            x = self.encode_onehot(seq)
        elif len(seq) == SEQ_LEN:
            x = np.full((4, MODEL_MAX_LEN), 0.25, dtype=np.float32)
            onehot = self.encode_onehot(seq)
            x[:, self.pad_left:self.pad_left + SEQ_LEN] = onehot
        else:
            raise ValueError(f"Unexpected sequence length {len(seq)}")
        return torch.tensor(x), torch.tensor(label, dtype=torch.long)

# ================= LOAD MODEL =================
print("üì¶ Loading model & metrics...")
metrics_mod = load_module(METRICS_FILE, "metrics")
model_mod = load_module(MODEL_CODE, "model")

model = model_mod.SpTransformer(dim=128, tissue_num=15, attn_depth=6, max_seq_len=MODEL_MAX_LEN).to(DEVICE)
ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
state_dict = ckpt.get("state_dict", ckpt)
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
model.eval()
print("‚úÖ Model loaded")

test_files = sorted([f for f in os.listdir(DATA_DIR) if f.startswith("test_")], key=lambda x: int(x.split("_")[1]))

# ================= INFERENCE LOOP =================
for fname in test_files:
    ratio = fname.replace("test_", "").replace(".csv", "")
    print(f"\nüöÄ Inference for ratio {ratio} | Threshold: {SPLICE_THRESHOLD if USE_FIXED_THRESHOLD else 'Top %'}")

    dataset = SpliceInferenceDataset(os.path.join(DATA_DIR, fname))
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

    all_labels, all_probs = [], []

    # 1. Ch·∫°y model l·∫•y x√°c su·∫•t
    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc=f"Running Model"):
            inputs = inputs.to(DEVICE)
            outputs = model(inputs)
            splice_logits = outputs[:, :3, :]
            
            # Pooling & Softmax
            center = splice_logits.shape[-1] // 2
            window = splice_logits[:, :, center-WINDOW_RADIUS:center+WINDOW_RADIUS+1]
            pooled_logits = torch.max(window, dim=2)[0]
            probs = torch.softmax(pooled_logits / TEMPERATURE, dim=1)

            all_labels.extend(labels.numpy())
            all_probs.extend(probs.cpu().numpy())

    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    # -------------------------------------------------------------
    # üî• FIX QUAN TR·ªåNG: ƒê·∫£o c·ªôt x√°c su·∫•t NGAY T·ª™ ƒê·∫¶U
    # L√∫c n√†y C·ªôt 1 = Acceptor Ch√≠nh X√°c, C·ªôt 2 = Donor Ch√≠nh X√°c
    # -------------------------------------------------------------
    if FIX_REVERSED_LABELS:
        print("üîÑ Swapping Probability Columns (1 <-> 2) to match Labels...")
        all_probs[:, [1, 2]] = all_probs[:, [2, 1]]
    
    # 2. X·ª≠ l√Ω logic g√°n nh√£n (Thresholding)
    all_preds = []
    
    # T√≠nh ng∆∞·ª°ng (n·∫øu d√πng ch·∫ø ƒë·ªô Percentile c≈©)
    splice_scores = np.max(all_probs[:, 1:], axis=1) 
    
    if USE_FIXED_THRESHOLD:
        decision_threshold = SPLICE_THRESHOLD
    else:
        decision_threshold = np.percentile(splice_scores, 100 * (1 - TOP_SPLICE_RATIO))
        print(f"   ‚ÑπÔ∏è  Dynamic Threshold (Percentile {TOP_SPLICE_RATIO}): {decision_threshold:.4f}")

    # V√≤ng l·∫∑p quy·∫øt ƒë·ªãnh nh√£n t·ª´ng m·∫´u
    for probs in all_probs:
        p_bg, p_acc, p_don = probs[0], probs[1], probs[2] 
        max_splice_prob = max(p_acc, p_don)

        # B∆Ø·ªöC 1: So s√°nh v·ªõi ng∆∞·ª°ng (Background vs Splice)
        if max_splice_prob < decision_threshold:
            pred = 0
        else:
            # B∆Ø·ªöC 2: So s√°nh gi·ªØa Acc v√† Don
            # ‚ö†Ô∏è L∆ØU √ù: V√¨ all_probs ƒë√£ ƒë∆∞·ª£c ƒë·∫£o c·ªôt ·ªü tr√™n r·ªìi,
            # n√™n ·ªü ƒë√¢y p_acc > p_don th√¨ ch·∫Øc ch·∫Øn l√† l·ªõp 1 (Acc).
            # KH√îNG c·∫ßn ƒë·∫£o nh√£n ·ªü ƒë√¢y n·ªØa!
            if p_acc > p_don:
                pred = 1 
            else:
                pred = 2
        
        all_preds.append(pred)

    all_preds = np.array(all_preds)

    print(f"üìä Prediction Stats: {Counter(all_preds)}")

    # ================= METRICS & CONFUSION MATRIX =================
    # T√≠nh Confusion Matrix 3x3
    cm = confusion_matrix(all_labels, all_preds, labels=[0, 1, 2])
    print("\nConfusion Matrix:\n", cm)

    metrics = metrics_mod.compute_metrics(
        labels=all_labels,
        preds=all_preds,
        probs=all_probs,
        k=2
    )

    # ================= SAVE =================
    output = {
        "test_file": fname,
        "ratio": ratio,
        "config": {
            "window_radius": WINDOW_RADIUS,
            "temperature": TEMPERATURE,
            "use_fixed_threshold": USE_FIXED_THRESHOLD,
            "threshold_value": SPLICE_THRESHOLD if USE_FIXED_THRESHOLD else decision_threshold,
            "fix_reversed_labels": FIX_REVERSED_LABELS
        },
        "confusion_matrix": cm.tolist(),
        "metrics": metrics
    }

    out_path = os.path.join(RESULT_DIR, f"metrics_{ratio}.json")
    with open(out_path, "w") as f:
        json.dump(output, f, indent=4)

    print(f"üíæ Saved ‚Üí {out_path}")

print("\nüéâ DONE.")

üì¶ Loading model & metrics...
‚úÖ Model loaded

üöÄ Inference for ratio 1_1_1 | Threshold: 1e-05


Running Model: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 412/412 [06:31<00:00,  1.05it/s]


üîÑ Swapping Probability Columns (1 <-> 2) to match Labels...
üìä Prediction Stats: Counter({0: 9465, 1: 8520, 2: 8325})

Confusion Matrix:
 [[7815  577  430]
 [ 892 7858   72]
 [ 758   85 7823]]
üíæ Saved ‚Üí D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_1_1_1.json

üöÄ Inference for ratio 2_1_1 | Threshold: 1e-05


Running Model: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 549/549 [08:38<00:00,  1.06it/s]


üîÑ Swapping Probability Columns (1 <-> 2) to match Labels...
üìä Prediction Stats: Counter({0: 17315, 1: 9089, 2: 8728})

Confusion Matrix:
 [[15666  1146   832]
 [  892  7858    72]
 [  757    85  7824]]
üíæ Saved ‚Üí D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_2_1_1.json

üöÄ Inference for ratio 4_1_1 | Threshold: 1e-05


Running Model: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 825/825 [12:56<00:00,  1.06it/s]


üîÑ Swapping Probability Columns (1 <-> 2) to match Labels...
üìä Prediction Stats: Counter({0: 32929, 1: 10275, 2: 9572})

Confusion Matrix:
 [[31279  2332  1677]
 [  892  7858    72]
 [  758    85  7823]]
üíæ Saved ‚Üí D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_4_1_1.json

üöÄ Inference for ratio 10_1_1 | Threshold: 1e-05


Running Model: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1652/1652 [25:58<00:00,  1.06it/s]


üîÑ Swapping Probability Columns (1 <-> 2) to match Labels...
üìä Prediction Stats: Counter({0: 80042, 1: 13602, 2: 12064})

Confusion Matrix:
 [[78393  5659  4168]
 [  892  7858    72]
 [  757    85  7824]]
üíæ Saved ‚Üí D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_10_1_1.json

üöÄ Inference for ratio 100_1_1 | Threshold: 1e-05


Running Model: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 13502/13502 [3:28:52<00:00,  1.08it/s] 


üîÑ Swapping Probability Columns (1 <-> 2) to match Labels...
üìä Prediction Stats: Counter({0: 756328, 1: 61099, 2: 46662})

Confusion Matrix:
 [[754817  53744  39364]
 [   819   7279     63]
 [   692     76   7235]]
üíæ Saved ‚Üí D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SpliceTransformer\result\metrics_100_1_1.json

üéâ DONE.
