In [1]:
# pip install pyfaidx biopython

### Inference

In [21]:
import torch
import pandas as pd
import numpy as np
import os
import json
from tqdm import tqdm
from pyfaidx import Fasta
from sklearn.metrics import confusion_matrix
from data_preparation import get_splam_official_logic_seq, one_hot_encode_splam
from metrics import compute_metrics

# --- C·∫§U H√åNH ---
GENOME_FA = r"D:\my_project\Bio_paper\Homo_sapiens.GRCh38.dna.primary_assembly.fa"
MODEL_PT = r"D:\my_project\Bio_paper\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SPLAM\pretrained_model\splam_script.pt"
DATA_FOLDER = r"D:\my_project\Bio_paper\Bio_sequence_Research_AITALAB\train\task1_splicing_prediction\data_preparation\train_val"
OUTPUT_FOLDER = r"D:\my_project\Bio_paper\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SPLAM\results"

BATCH_SIZE = 256 
THRESHOLD = 0.21195 
TEST_FILES = ['test_1_1_1.csv', 'test_2_1_1.csv', 'test_4_1_1.csv', 'test_10_1_1.csv', 'test_data.csv']

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load(MODEL_PT).to(device)
model.eval()
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
genome = Fasta(GENOME_FA, sequence_always_upper=True)

for csv_file in TEST_FILES:
    input_path = os.path.join(DATA_FOLDER, csv_file)
    file_base = os.path.splitext(csv_file)[0]
    if not os.path.exists(input_path): continue
        
    print(f"\nüöÄ Ch·∫°y Benchmark Final: {csv_file}")
    df = pd.read_csv(input_path)
    all_y_true, all_y_probs = [], []
    
    for i in tqdm(range(0, len(df), BATCH_SIZE)):
        batch_df = df.iloc[i : i + BATCH_SIZE]
        batch_seqs, batch_labels = [], []
        
        for _, row in batch_df.iterrows():
            parts = row['id'].split('_')
            chrom, pos, strand = parts[1], int(parts[2]), parts[3]
            label = int(row['Splicing_types'])
            seq = get_splam_official_logic_seq(chrom, pos, strand, label, genome)
            batch_seqs.append(one_hot_encode_splam(seq))
            batch_labels.append(label)
        
        input_tensor = torch.from_numpy(np.stack(batch_seqs)).to(device)
        # --- TRONG V√íNG L·∫∂P INFERENCE ---
        with torch.no_grad():
            outputs = model(input_tensor)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            
            for idx_in_batch, label in enumerate(batch_labels):
                # 1. TR√çCH XU·∫§T T√çN HI·ªÜU D·ª∞A TR√äN TH·ª∞C NGHI·ªÜM CM:
                # Class 1 = Acceptor | Class 2 = Donor
                center = 200 if (label == 1 or label == 0) else 600

                window = range(center - 10, center + 11) # Ki·ªÉm tra +- 5bp
                
                # Max-pooling quanh 200 (Donor) v√† 600 (Acceptor)
                # Ch√∫ng ta g√°n ƒë√∫ng: p_don l·∫•y t·ª´ l·ªõp 2, p_acc l·∫•y t·ª´ l·ªõp 1
                p_don_raw = np.max(probs[idx_in_batch, 2, window]) 
                p_acc_raw = np.max(probs[idx_in_batch, 1, window])
                
                # L·∫•y x√°c su·∫•t n·ªÅn t·∫°i ch√≠nh t√¢m c√°c site
                p_null_raw = np.max(probs[idx_in_batch, 0, window])
                
                # 2. CHU·∫®N H√ìA: M·∫£ng y_probs ph·∫£i theo th·ª© t·ª± y_true [Null, Donor, Acceptor]
                # T·ª©c l√†: [Index 0, Index 1, Index 2]
                combined = np.array([p_null_raw, p_don_raw, p_acc_raw])
                combined /= (combined.sum() + 1e-9)
                
                all_y_true.append(label)
                all_y_probs.append(combined)

        # --- SAU V√íNG L·∫∂P (PH·∫¶N T√çNH TO√ÅN) ---
        y_true = np.array(all_y_true)
        y_probs = np.array(all_y_probs)
        y_preds = np.zeros(len(y_probs))

        for i in range(len(y_probs)):
            p_null, p_don, p_acc = y_probs[i] # p_don ·ªü ƒë√¢y ƒë√£ l√† Donor, p_acc ƒë√£ l√† Acceptor
            
            if p_don > THRESHOLD or p_acc > THRESHOLD:
                # 1 l√† Donor, 2 l√† Acceptor (Kh·ªõp ho√†n to√†n v·ªõi file CSV c·ªßa b·∫°n)
                y_preds[i] = 1 if p_don > p_acc else 2
            else:
                y_preds[i] = 0

        # T√≠nh metrics
        results = compute_metrics(y_true, y_preds, probs=y_probs)
        results['confusion_matrix'] = confusion_matrix(y_true, y_preds).tolist()
    
    with open(os.path.join(OUTPUT_FOLDER, f"{file_base}_results.json"), 'w') as f:
        json.dump(results, f, indent=4)
        
    print(f"‚úÖ Ho√†n t·∫•t {file_base} | AUC: {results['auc']:.4f}")


üöÄ Ch·∫°y Benchmark Final: test_1_1_1.csv


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 103/103 [00:20<00:00,  4.96it/s]


‚úÖ Ho√†n t·∫•t test_1_1_1 | AUC: 0.9475

üöÄ Ch·∫°y Benchmark Final: test_2_1_1.csv


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 138/138 [00:26<00:00,  5.24it/s]


‚úÖ Ho√†n t·∫•t test_2_1_1 | AUC: 0.9426

üöÄ Ch·∫°y Benchmark Final: test_4_1_1.csv


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 207/207 [00:40<00:00,  5.07it/s]


‚úÖ Ho√†n t·∫•t test_4_1_1 | AUC: 0.9381

üöÄ Ch·∫°y Benchmark Final: test_10_1_1.csv


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 413/413 [01:31<00:00,  4.50it/s]


‚úÖ Ho√†n t·∫•t test_10_1_1 | AUC: 0.9353

üöÄ Ch·∫°y Benchmark Final: test_data.csv


  0%|          | 2/3666 [00:00<13:59,  4.37it/s]



  0%|          | 4/3666 [00:00<10:31,  5.80it/s]



  0%|          | 6/3666 [00:01<09:37,  6.34it/s]



  0%|          | 8/3666 [00:01<09:16,  6.58it/s]



  0%|          | 10/3666 [00:01<09:09,  6.66it/s]



  0%|          | 12/3666 [00:01<09:09,  6.65it/s]



  0%|          | 14/3666 [00:02<09:08,  6.65it/s]



  0%|          | 16/3666 [00:02<09:11,  6.62it/s]



  0%|          | 18/3666 [00:02<09:14,  6.58it/s]



  1%|          | 20/3666 [00:03<09:17,  6.54it/s]



  1%|          | 22/3666 [00:03<09:20,  6.50it/s]



  1%|          | 24/3666 [00:03<09:21,  6.49it/s]



  1%|          | 26/3666 [00:04<09:26,  6.43it/s]



  1%|          | 28/3666 [00:04<09:26,  6.42it/s]



  1%|          | 30/3666 [00:04<09:29,  6.39it/s]



  1%|          | 32/3666 [00:05<09:32,  6.35it/s]



  1%|          | 34/3666 [00:05<09:35,  6.31it/s]



  1%|          | 36/3666 [00:05<09:39,  6.27it/s]



  1%|          | 38/3666 [00:06<09:41,  6.24it/s]



  1%|          | 40/3666 [00:06<09:46,  6.18it/s]



  1%|          | 42/3666 [00:06<09:49,  6.14it/s]



  1%|          | 44/3666 [00:07<09:54,  6.09it/s]



  1%|‚ñè         | 46/3666 [00:07<09:56,  6.07it/s]



  1%|‚ñè         | 48/3666 [00:07<09:57,  6.05it/s]



  1%|‚ñè         | 50/3666 [00:08<10:00,  6.02it/s]



  1%|‚ñè         | 52/3666 [00:08<10:04,  5.98it/s]



  1%|‚ñè         | 54/3666 [00:08<10:05,  5.96it/s]



  2%|‚ñè         | 56/3666 [00:09<10:11,  5.90it/s]



  2%|‚ñè         | 58/3666 [00:09<10:11,  5.90it/s]



  2%|‚ñè         | 60/3666 [00:09<10:15,  5.86it/s]



  2%|‚ñè         | 62/3666 [00:10<10:22,  5.79it/s]



  2%|‚ñè         | 64/3666 [00:10<10:21,  5.80it/s]



  2%|‚ñè         | 66/3666 [00:10<10:26,  5.75it/s]



  2%|‚ñè         | 68/3666 [00:11<10:27,  5.73it/s]



100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3666/3666 [51:51<00:00,  1.18it/s]

‚úÖ Ho√†n t·∫•t test_data | AUC: 0.9331





In [8]:
import torch
import pandas as pd
import numpy as np
import os
import json
from tqdm import tqdm
from pyfaidx import Fasta
from collections import Counter
from data_preparation import get_splam_official_logic_seq, one_hot_encode_splam

# --- C·∫§U H√åNH ---
GENOME_FA = r"D:\my_project\Bio_paper\Homo_sapiens.GRCh38.dna.primary_assembly.fa"
MODEL_PT = r"D:\my_project\Bio_paper\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\SPLAM\pretrained_model\splam_script.pt"
DATA_FOLDER = r"D:\my_project\Bio_paper\Bio_sequence_Research_AITALAB\train\task1_splicing_prediction\data_preparation\train_val"
TEST_FILES = ['test_1_1_1.csv', 'test_2_1_1.csv', 'test_4_1_1.csv', 'test_10_1_1.csv', 'test_data.csv']
BATCH_SIZE = 128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load(MODEL_PT).to(device)
model.eval()
genome = Fasta(GENOME_FA, sequence_always_upper=True)

peak_report = {}

for csv_file in TEST_FILES:
    input_path = os.path.join(DATA_FOLDER, csv_file)
    if not os.path.exists(input_path): continue
    
    print(f"\nüîç ƒêang qu√©t Peak cho file: {csv_file}")
    df = pd.read_csv(input_path)
    
    # Ch·ªâ l·∫•y c√°c m·∫´u l√† Donor (1) ho·∫∑c Acceptor (2) ƒë·ªÉ t√¨m peak
    df_sites = df[df['Splicing_types'].isin([1, 2])].copy()
    
    results = {"Donor_Peaks": [], "Acceptor_Peaks": []}
    
    for i in tqdm(range(0, len(df_sites), BATCH_SIZE)):
        batch_df = df_sites.iloc[i : i + BATCH_SIZE]
        batch_seqs = []
        batch_info = []
        
        for _, row in batch_df.iterrows():
            parts = row['id'].split('_')
            chrom, pos, strand = parts[1], int(parts[2]), parts[3]
            label = int(row['Splicing_types'])
            
            seq = get_splam_official_logic_seq(chrom, pos, strand, label, genome)
            batch_seqs.append(one_hot_encode_splam(seq))
            batch_info.append(label)
            
        input_tensor = torch.from_numpy(np.stack(batch_seqs)).to(device)
        
        with torch.no_grad():
            outputs = model(input_tensor)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            
            for idx_in_batch, label in enumerate(batch_info):
                # L·ªõp 2 cho Donor, L·ªõp 1 cho Acceptor (theo swap c·ªßa b·∫°n)
                model_class = 2 if label == 1 else 1
                
                # T√¨m v·ªã tr√≠ cao nh·∫•t tr√™n to√†n b·ªô 800nt
                peak_pos = np.argmax(probs[idx_in_batch, model_class, :])
                
                if label == 1:
                    results["Donor_Peaks"].append(int(peak_pos))
                else:
                    results["Acceptor_Peaks"].append(int(peak_pos))
                    
    # T·ªïng h·ª£p th·ªëng k√™ cho file n√†y
    don_counts = Counter(results["Donor_Peaks"])
    acc_counts = Counter(results["Acceptor_Peaks"])
    
    peak_report[csv_file] = {
        "Donor": {
            "Top_3_Peaks": don_counts.most_common(3),
            "On_Target_200": don_counts.get(200, 0),
            "Total_Samples": len(results["Donor_Peaks"])
        },
        "Acceptor": {
            "Top_3_Peaks": acc_counts.most_common(3),
            "On_Target_600": acc_counts.get(600, 0),
            "Total_Samples": len(results["Acceptor_Peaks"])
        }
    }

# --- IN B√ÅO C√ÅO T·ªîNG H·ª¢P ---
print("\n" + "="*50)
print("B√ÅO C√ÅO T·ªîNG H·ª¢P V·ªä TR√ç PEAK (ALIGNMENT REPORT)")
print("="*50)
for file, data in peak_report.items():
    print(f"\nüìÇ File: {file}")
    d = data["Donor"]
    a = data["Acceptor"]
    print(f"  [Donor]    - On-Target (200): {d['On_Target_200']}/{d['Total_Samples']} ({d['On_Target_200']/d['Total_Samples']*100:.1f}%)")
    print(f"             - Top Peaks: {d['Top_3_Peaks']}")
    print(f"  [Acceptor] - On-Target (600): {a['On_Target_600']}/{a['Total_Samples']} ({a['On_Target_600']/a['Total_Samples']*100:.1f}%)")
    print(f"             - Top Peaks: {a['Top_3_Peaks']}")

# L∆∞u b√°o c√°o v√†o file JSON
with open("peak_alignment_report.json", "w") as f:
    json.dump(peak_report, f, indent=4)


üîç ƒêang qu√©t Peak cho file: test_1_1_1.csv


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 137/137 [00:13<00:00, 10.50it/s]



üîç ƒêang qu√©t Peak cho file: test_2_1_1.csv


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 137/137 [00:12<00:00, 11.16it/s]



üîç ƒêang qu√©t Peak cho file: test_4_1_1.csv


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 137/137 [00:09<00:00, 15.17it/s]



üîç ƒêang qu√©t Peak cho file: test_10_1_1.csv


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 137/137 [00:08<00:00, 15.40it/s]



üîç ƒêang qu√©t Peak cho file: test_data.csv


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 137/137 [00:08<00:00, 15.48it/s]


B√ÅO C√ÅO T·ªîNG H·ª¢P V·ªä TR√ç PEAK (ALIGNMENT REPORT)

üìÇ File: test_1_1_1.csv
  [Donor]    - On-Target (200): 5002/8822 (56.7%)
             - Top Peaks: [(200, 5002), (600, 1652), (410, 130)]
  [Acceptor] - On-Target (600): 6308/8666 (72.8%)
             - Top Peaks: [(600, 6308), (200, 1596), (398, 32)]

üìÇ File: test_2_1_1.csv
  [Donor]    - On-Target (200): 5002/8822 (56.7%)
             - Top Peaks: [(200, 5002), (600, 1652), (410, 130)]
  [Acceptor] - On-Target (600): 6308/8666 (72.8%)
             - Top Peaks: [(600, 6308), (200, 1596), (398, 32)]

üìÇ File: test_4_1_1.csv
  [Donor]    - On-Target (200): 5002/8822 (56.7%)
             - Top Peaks: [(200, 5002), (600, 1652), (410, 130)]
  [Acceptor] - On-Target (600): 6308/8666 (72.8%)
             - Top Peaks: [(600, 6308), (200, 1596), (398, 32)]

üìÇ File: test_10_1_1.csv
  [Donor]    - On-Target (200): 5002/8822 (56.7%)
             - Top Peaks: [(200, 5002), (600, 1652), (410, 130)]
  [Acceptor] - On-Target (600):


