In [3]:
#pip install pyfaidx

In [52]:
import pandas as pd
import numpy as np
import torch
import json
import os
from pyfaidx import Fasta
from sklearn.metrics import confusion_matrix
from model import Pangolin
from tqdm import tqdm

# Import h√†m c·ªßa b·∫°n
from metrics import compute_metrics

# --- C·∫§U H√åNH H·ªÜ TH·ªêNG ---
DATA_FOLDER = r"D:\Bio_sequence_Research_AITALAB\train\task1_splicing_prediction\data_preparation\train_val"
RESULTS_FOLDER = "results"
MODEL_FOLDER = r"D:\Bio_sequence_Research_AITALAB\benchmark\task1_splicing_prediction\Pangolin\pretrained_model"
FASTA_PATH = r"D:\Homo_sapiens.GRCh38.dna.primary_assembly.fa"

# Th√¥ng s·ªë k·ªπ thu·∫≠t
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 256
OFFSET = -1      
WINDOW_SIZE = 5
DELTA_MARGIN = 0.000005  # Hi·ªáu s·ªë t·ªëi thi·ªÉu gi·ªØa D v√† A ƒë·ªÉ tin t∆∞·ªüng
BASE_THRESHOLD = 0.9 # Ng∆∞·ª°ng c∆° b·∫£n

# Ki·∫øn tr√∫c Pangolin v3
CONTEXT_WINDOW = 15000
HALF_WINDOW = CONTEXT_WINDOW // 2
N_CHANNELS = 32
W = np.asarray([11, 11, 11, 11, 11, 11, 11, 11, 21, 21, 21, 21, 41, 41, 41, 41])
AR = np.asarray([1, 1, 1, 1, 4, 4, 4, 4, 10, 10, 10, 10, 25, 25, 25, 25])

TEST_FILES = ['test_1_1_1.csv', 'test_2_1_1.csv', 'test_4_1_1.csv', 'test_10_1_1.csv', 'test_data.csv']
os.makedirs(RESULTS_FOLDER, exist_ok=True)

# --- H√ÄM H·ªñ TR·ª¢ ---
genome = Fasta(FASTA_PATH)

def get_sequence(chrom, pos):
    chrom_key = chrom if chrom in genome.keys() else (f"chr{chrom}" if f"chr{chrom}" in genome.keys() else chrom)
    start, end = pos - HALF_WINDOW, pos + HALF_WINDOW
    try:
        return genome[chrom_key][start:end].seq.upper()
    except:
        return "N" * CONTEXT_WINDOW 

# Gi·ªØ nguy√™n h√†m m√£ h√≥a c·ªßa b·∫°n
IN_MAP = np.asarray([[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
def one_hot_encode_torch(seq, strand):
    seq = seq.upper().replace('A', '1').replace('C', '2').replace('G', '3').replace('T', '4').replace('N', '0')
    if strand == '-':
        seq = np.asarray(list(map(int, list(seq[::-1]))))
        seq = np.where(seq == 0, 0, 5 - seq)
    else:
        seq = np.asarray(list(map(int, list(seq))))
    return IN_MAP[seq.astype('int8')]

# --- KH·ªûI T·∫†O ENSEMBLE ---
models = []
WEIGHT_FILES = [f"final.3.{i}.3.v2" for i in range(8)]
for weight_file in WEIGHT_FILES:
    m = Pangolin(L=N_CHANNELS, W=W, AR=AR)
    path = os.path.join(MODEL_FOLDER, weight_file)
    if os.path.exists(path):
        m.load_state_dict(torch.load(path, map_location=DEVICE))
        m.to(DEVICE).eval()
        models.append(m)

# --- V√íNG L·∫∂P CH√çNH ---
for file_name in TEST_FILES:
    input_path = os.path.join(DATA_FOLDER, file_name)
    if not os.path.exists(input_path): continue
        
    print(f"üöÄ Benchmarking: {file_name}")
    df = pd.read_csv(input_path)
    y_true = df['Splicing_types'].values
    
    # FIX L·ªñI 1: Kh·ªüi t·∫°o danh s√°ch d·ª± ƒëo√°n
    probs_list = []
    preds_list = []

    for i in tqdm(range(0, len(df), BATCH_SIZE), desc="Inference"):
        batch_df = df.iloc[i : i + BATCH_SIZE]
        batch_seqs = []
        
        for _, row in batch_df.iterrows():
            parts = row['id'].split('_')
            raw = get_sequence(parts[1], int(parts[2]) + OFFSET)
            encoded = one_hot_encode_torch(raw, parts[3])
            batch_seqs.append(torch.from_numpy(encoded).float().permute(1, 0))
        
        seq_tensor = torch.stack(batch_seqs).to(DEVICE)

        with torch.no_grad():
            # Kh·ªüi t·∫°o tensor l∆∞u gi√° tr·ªã MAX qua 8 m√¥ h√¨nh (FIX L·ªñI 4)
            max_d_ensemble = torch.zeros(len(batch_df)).to(DEVICE)
            max_a_ensemble = torch.zeros(len(batch_df)).to(DEVICE)
            
            for m in models:
                output = m(seq_tensor) 
                c_idx, n_t = output.shape[1] // 2, output.shape[-1] // 2
                
                # Sliding Window Max-pooling trong kh√¥ng gian (Window=5)
                half_w = WINDOW_SIZE // 2
                window_out = output[:, c_idx - half_w : c_idx + half_w + 1, :]
                
                # L·∫•y Max spatial v√† Max tissue cho t·ª´ng m√¥ h√¨nh
                p_d_m = torch.max(torch.max(window_out[:, :, :n_t], dim=1)[0], dim=1)[0]
                p_a_m = torch.max(torch.max(window_out[:, :, n_t:], dim=1)[0], dim=1)[0]
                
                # FIX L·ªñI 4: L·∫•y Max qua c√°c m√¥ h√¨nh thay v√¨ Mean
                max_d_ensemble = torch.max(max_d_ensemble, p_d_m)
                max_a_ensemble = torch.max(max_a_ensemble, p_a_m)
            
            d_final = max_d_ensemble.cpu().numpy()
            a_final = max_a_ensemble.cpu().numpy()
            
            # FIX L·ªñI 2: X·ª≠ l√Ω x√°c su·∫•t li√™n t·ª•c v√† Ph√¢n lo·∫°i
            for d, a in zip(d_final, a_final):
                diff = abs(d - a)
                
                # T√≠nh to√°n x√°c su·∫•t l·ªõp Neither (0) m·ªôt c√°ch t·ª± nhi√™n
                # N·∫øu d v√† a th·∫•p, p_neither s·∫Ω cao
                sum_da = d + a
                p_neither = max(0.0001, 1.0 - sum_da)
                
                # Normalize ƒë·ªÉ t·ªïng x√°c su·∫•t b·∫±ng 1 (Duy tr√¨ AUC chu·∫©n)
                norm_factor = p_neither + d + a
                prob_vec = [p_neither/norm_factor, a/norm_factor, d/norm_factor] # ƒê√£ swap nh√£n 1-2

                # Logic ph√¢n lo·∫°i d√πng Delta Margin v√† Threshold
                if diff < DELTA_MARGIN or (d < BASE_THRESHOLD and a < BASE_THRESHOLD):
                    preds_list.append(0)
                    probs_list.append([1.0, 0.0, 0.0]) # √âp nh√£n c·ª©ng cho Prediction
                else:
                    if d > a:
                        preds_list.append(2) # d (Donor model) -> Acceptor (2)
                        probs_list.append(prob_vec)
                    else:
                        preds_list.append(1) # a (Acceptor model) -> Donor (1)
                        probs_list.append(prob_vec)

    probs = np.array(probs_list)
    y_pred = np.array(preds_list)
    
    res = compute_metrics(y_true, y_pred, probs)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2])
    
    final = {"metrics": res, "confusion_matrix": cm.tolist()}
    output_path = os.path.join(RESULTS_FOLDER, file_name.replace('.csv', '_results.json'))
    with open(output_path, 'w') as f:
        json.dump(final, f, indent=4)

print("‚úÖ Benchmark ho√†n t·∫•t!")

üöÄ Benchmarking: test_1_1_1.csv


Inference: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 103/103 [05:25<00:00,  3.16s/it]


üöÄ Benchmarking: test_2_1_1.csv


Inference:  22%|‚ñà‚ñà‚ñè       | 31/138 [01:40<05:47,  3.25s/it]


KeyboardInterrupt: 

In [16]:
def diagnostic_check(file_name, num_samples=5):
    df = pd.read_csv(os.path.join(DATA_FOLDER, file_name))
    
    # L·∫•y m·∫´u Donor (l·ªõp 1) v√† Acceptor (l·ªõp 2)
    donor_samples = df[df['Splicing_types'] == 1].head(num_samples)
    acc_samples = df[df['Splicing_types'] == 2].head(num_samples)
    diagnose_df = pd.concat([donor_samples, acc_samples])

    print(f"{'ID (Pos)':<20} | {'Type':<6} | {'Motif (Center)':<15} | {'Avg_D':<8} | {'Avg_A':<8} | {'Pred'}")
    print("-" * 85)

    # Ng∆∞·ª°ng Threshold th·ª±c t·∫ø ƒë·ªÉ ch·∫©n ƒëo√°n
    HARD_THRESHOLD = 0.1 

    for _, row in diagnose_df.iterrows():
        parts = row['id'].split('_')
        raw_seq = get_sequence(parts[1], int(parts[2]))
        
        # 1. Ki·ªÉm tra Motif t·∫°i v·ªã tr√≠ trung t√¢m (7500)
        # L·∫•y 4nt quanh ƒëi·ªÉm c·∫Øt ƒë·ªÉ d·ªÖ quan s√°t (V·ªã tr√≠ 7498 ƒë·∫øn 7502)
        center_motif = raw_seq[HALF_WINDOW-2 : HALF_WINDOW+2] 
        
        # 2. S·ª≠a l·ªói AttributeError: Chuy·ªÉn NumPy sang Tensor tr∆∞·ªõc khi unsqueeze
        # V√¨ h√†m c·ªßa b·∫°n tr·∫£ v·ªÅ NumPy array (L, 4), ta c·∫ßn permute(1, 0) ƒë·ªÉ th√†nh (4, L)
        encoded_seq = one_hot_encode_torch(raw_seq, parts[3])
        seq_tensor = torch.from_numpy(encoded_seq).float().permute(1, 0).unsqueeze(0).to(DEVICE)
        
        sum_d, sum_a = 0, 0
        with torch.no_grad():
            for m in models:
                out = m(seq_tensor) # Output shape: [1, Out_Len, 2*Tissues]
                
                c_idx = out.shape[1] // 2
                n_t = out.shape[-1] // 2
                
                # T√°ch Donor v√† Acceptor t·ª´ tensor g·ªôp
                p_d = out[0, c_idx, :n_t]
                p_a = out[0, c_idx, n_t:]
                
                sum_d += torch.max(p_d).item()
                sum_a += torch.max(p_a).item()
        
        avg_d, avg_a = sum_d/len(models), sum_a/len(models)
        
        # 3. √Åp d·ª•ng Threshold th·ª±c t·∫ø ƒë·ªÉ x√°c ƒë·ªãnh nh√£n d·ª± ƒëo√°n
        if avg_d < HARD_THRESHOLD and avg_a < HARD_THRESHOLD:
            pred = 0 # Neither
        elif avg_d > avg_a:
            pred = 1 # Donor
        else:
            pred = 2 # Acceptor
            
        print(f"{parts[2]:<20} | {row['Splicing_types']:<6} | {center_motif:^15} | {avg_d:.4f} | {avg_a:.4f} | {pred}")

# G·ªçi h√†m ki·ªÉm tra
diagnostic_check(TEST_FILES[0])

ID (Pos)             | Type   | Motif (Center)  | Avg_D    | Avg_A    | Pred
-------------------------------------------------------------------------------------
63733353             | 1      |      GGTG       | 0.9945 | 0.9947 | 2
46150118             | 1      |      ACCA       | 0.9946 | 0.9942 | 1
26842338             | 1      |      ACCT       | 0.9946 | 0.9949 | 2
47786047             | 1      |      ACTC       | 0.9942 | 0.9947 | 2
34017506             | 1      |      GGTA       | 0.9942 | 0.9936 | 1
32743755             | 2      |      CCTA       | 0.9945 | 0.9945 | 2
46137074             | 2      |      TCTG       | 0.9945 | 0.9941 | 1
33819240             | 2      |      AGCA       | 0.9945 | 0.9944 | 1
3035587              | 2      |      AGAA       | 0.9946 | 0.9948 | 2
45455744             | 2      |      AGAG       | 0.9938 | 0.9939 | 2


In [17]:
def find_best_offset(file_name, sample_idx=0):
    df = pd.read_csv(os.path.join(DATA_FOLDER, file_name))
    row = df.iloc[sample_idx]
    parts = row['id'].split('_')
    chrom, pos, strand = parts[1], int(parts[2]), parts[3]
    
    print(f"--- Ki·ªÉm tra Offset cho {parts[0]} t·∫°i {pos} ({strand}) ---")
    # Ki·ªÉm tra trong kho·∫£ng offset t·ª´ -3 ƒë·∫øn +3
    for offset in range(-3, 4):
        actual_pos = pos + offset
        raw_seq = get_sequence(chrom, actual_pos)
        
        # Tr√≠ch xu·∫•t 2nt ngay t·∫°i ƒëi·ªÉm c·∫Øt (v·ªã tr√≠ 7500-7501)
        # L∆∞u √Ω: V·ªõi strand '-', ta l·∫•y reverse complement c·ªßa v√πng ƒë√≥
        center_2nt = raw_seq[HALF_WINDOW : HALF_WINDOW+2]
        
        # Donor chu·∫©n l√† GT, Acceptor chu·∫©n l√† AG
        status = "MATCH!" if (row['Splicing_types'] == 1 and center_2nt == "GT") or \
                             (row['Splicing_types'] == 2 and center_2nt == "AG") else ""
        
        print(f"Offset {offset:>2}: {center_2nt} {status}")

# Ch·∫°y th·ª≠ ƒë·ªÉ t√¨m offset ƒë√∫ng
find_best_offset(TEST_FILES[0], sample_idx=0) # Th·ª≠ v·ªõi m·∫´u Donor ƒë·∫ßu ti√™n

--- Ki·ªÉm tra Offset cho Neg t·∫°i 46105846 (+) ---
Offset -3: AT 
Offset -2: TG 
Offset -1: GC 
Offset  0: CA 
Offset  1: AA 
Offset  2: AA 
Offset  3: AA 


In [20]:
def find_best_offset(file_name, sample_idx=0):
    df = pd.read_csv(os.path.join(DATA_FOLDER, file_name))
    row = df.iloc[sample_idx]
    parts = row['id'].split('_')
    chrom, pos, strand = parts[1], int(parts[2]), parts[3]
    
    print(f"--- Ki·ªÉm tra Offset cho {parts[0]} t·∫°i {pos} ({strand}) ---")
    # Ki·ªÉm tra trong kho·∫£ng offset t·ª´ -3 ƒë·∫øn +3
    for offset in range(-3, 4):
        actual_pos = pos + offset
        raw_seq = get_sequence(chrom, actual_pos)
        
        # Tr√≠ch xu·∫•t 2nt ngay t·∫°i ƒëi·ªÉm c·∫Øt (v·ªã tr√≠ 7500-7501)
        # L∆∞u √Ω: V·ªõi strand '-', ta l·∫•y reverse complement c·ªßa v√πng ƒë√≥
        center_2nt = raw_seq[HALF_WINDOW : HALF_WINDOW+2]
        
        # Donor chu·∫©n l√† GT, Acceptor chu·∫©n l√† AG
        status = "MATCH!" if (row['Splicing_types'] == 1 and center_2nt == "GT") or \
                             (row['Splicing_types'] == 2 and center_2nt == "AG") else ""
        
        print(f"Offset {offset:>2}: {center_2nt} {status}")

# Ch·∫°y th·ª≠ ƒë·ªÉ t√¨m offset ƒë√∫ng
find_best_offset(TEST_FILES[0], sample_idx=2) # Th·ª≠ v·ªõi m·∫´u Donor ƒë·∫ßu ti√™n

--- Ki·ªÉm tra Offset cho Donor t·∫°i 63733353 (+) ---
Offset -3: CG 
Offset -2: GG 
Offset -1: GT MATCH!
Offset  0: TG 
Offset  1: GA 
Offset  2: AG 
Offset  3: GG 


In [21]:
def find_best_offset(file_name, sample_idx=0):
    df = pd.read_csv(os.path.join(DATA_FOLDER, file_name))
    row = df.iloc[sample_idx]
    parts = row['id'].split('_')
    chrom, pos, strand = parts[1], int(parts[2]), parts[3]
    
    print(f"--- Ki·ªÉm tra Offset cho {parts[0]} t·∫°i {pos} ({strand}) ---")
    # Ki·ªÉm tra trong kho·∫£ng offset t·ª´ -3 ƒë·∫øn +3
    for offset in range(-3, 4):
        actual_pos = pos + offset
        raw_seq = get_sequence(chrom, actual_pos)
        
        # Tr√≠ch xu·∫•t 2nt ngay t·∫°i ƒëi·ªÉm c·∫Øt (v·ªã tr√≠ 7500-7501)
        # L∆∞u √Ω: V·ªõi strand '-', ta l·∫•y reverse complement c·ªßa v√πng ƒë√≥
        center_2nt = raw_seq[HALF_WINDOW : HALF_WINDOW+2]
        
        # Donor chu·∫©n l√† GT, Acceptor chu·∫©n l√† AG
        status = "MATCH!" if (row['Splicing_types'] == 1 and center_2nt == "GT") or \
                             (row['Splicing_types'] == 2 and center_2nt == "AG") else ""
        
        print(f"Offset {offset:>2}: {center_2nt} {status}")

# Ch·∫°y th·ª≠ ƒë·ªÉ t√¨m offset ƒë√∫ng
find_best_offset(TEST_FILES[0], sample_idx=3) # Th·ª≠ v·ªõi m·∫´u Donor ƒë·∫ßu ti√™n

--- Ki·ªÉm tra Offset cho Acc t·∫°i 32743755 (-) ---
Offset -3: AC 
Offset -2: CC 
Offset -1: CT 
Offset  0: TA 
Offset  1: AA 
Offset  2: AA 
Offset  3: AA 
