In [1]:
import os
import sys
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import json
import importlib.util
from torch.utils.data import Dataset, DataLoader

# ==========================================
# 1. C·∫§U H√åNH ƒê∆Ø·ªúNG D·∫™N (ƒê√É C·∫¨P NH·∫¨T T·ª™ K·∫æT QU·∫¢ QU√âT)
# ==========================================

# ƒê∆∞·ªùng d·∫´n file ch·ª©a code model (L·∫•y t·ª´ k·∫øt qu·∫£ scan c·ªßa b·∫°n)
MODEL_CODE_FILE = r"D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\SpliceTransformer-main\model\model.py"

# T√™n Class Model (L·∫•y t·ª´ k·∫øt qu·∫£ scan)
MODEL_CLASS_NAME = "SpTransformer"

# ƒê∆∞·ªùng d·∫´n file weights (.ckpt)
CKPT_PATH = r"D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\model.ckpt"

# Th∆∞ m·ª•c ch·ª©a 4 file csv prepared
DATA_DIR = r"D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\data_test"

# Th∆∞ m·ª•c xu·∫•t k·∫øt qu·∫£
RESULT_DIR = r"D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\results"

BATCH_SIZE = 16
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ratios = ["1_1_1", "2_1_1", "4_1_1", "10_1_1", "100_1_1"]

# ==========================================
# 2. H√ÄM LOAD MODEL T·ª™ FILE B·∫§T K·ª≤
# ==========================================
def load_model_class(file_path, class_name):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"‚ùå Kh√¥ng t√¨m th·∫•y file code t·∫°i: {file_path}")
    
    # Load module t·ª´ ƒë∆∞·ªùng d·∫´n file
    spec = importlib.util.spec_from_file_location("dynamic_model_module", file_path)
    module = importlib.util.module_from_spec(spec)
    sys.modules["dynamic_model_module"] = module
    spec.loader.exec_module(module)
    
    if hasattr(module, class_name):
        print(f"‚úÖ ƒê√£ import th√†nh c√¥ng class '{class_name}' t·ª´ {os.path.basename(file_path)}")
        return getattr(module, class_name)
    else:
        raise AttributeError(f"‚ùå Trong file '{file_path}' kh√¥ng c√≥ class '{class_name}'")

# ==========================================
# 3. H√ÄM LOAD CHECKPOINT (.ckpt / .pth)
# ==========================================
def load_weights(model, ckpt_path):
    print(f"üîÑ ƒêang load weights t·ª´: {ckpt_path}")
    
    # Load checkpoint l√™n CPU/GPU
    checkpoint = torch.load(ckpt_path, map_location=DEVICE)
    
    # Tr√≠ch xu·∫•t state_dict
    if isinstance(checkpoint, dict):
        if 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        elif 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint
    else:
        state_dict = checkpoint

    # X·ª≠ l√Ω t√™n keys (x√≥a prefix th·ª´a)
    clean_state_dict = {}
    for key, val in state_dict.items():
        new_key = key
        # X√≥a c√°c ti·ªÅn t·ªë ph·ªï bi·∫øn do PyTorch Lightning ho·∫∑c DataParallel sinh ra
        for prefix in ["model.", "net.", "module.", "backbone."]:
            if new_key.startswith(prefix):
                new_key = new_key[len(prefix):]
                break
        clean_state_dict[new_key] = val

    # Load v√†o model
    try:
        model.load_state_dict(clean_state_dict, strict=False)
        print("‚úÖ Load weights ho√†n t·∫•t!")
    except Exception as e:
        print(f"‚ö†Ô∏è C·∫£nh b√°o khi load weights: {e}")

# ==========================================
# 4. DATASET & METRICS
# ==========================================
# Fallback metrics n·∫øu thi·∫øu file metrics.py
try:
    from metrics import compute_metrics
except ImportError:
    print("‚ö†Ô∏è Kh√¥ng t√¨m th·∫•y file 'metrics.py'. S·∫Ω d√πng h√†m metrics ƒë∆°n gi·∫£n.")
    def compute_metrics(labels, preds, probs=None, k=2):
        # T√≠nh accuracy ƒë∆°n gi·∫£n ƒë·ªÉ kh√¥ng b·ªã l·ªói code
        acc = (labels == preds).mean()
        return {"accuracy": float(acc), "note": "metrics.py not found"}

class InferenceDataset(Dataset):
    def __init__(self, csv_path):
        self.df = pd.read_csv(csv_path)
        self.mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # L·∫•y sequence v√† label
        seq_str = str(self.df.iloc[idx]['sequence']).upper()
        label = int(self.df.iloc[idx]['Splicing_types'])
        
        # One-hot encoding
        seq_enc = np.zeros((len(seq_str), 4), dtype=np.float32)
        for i, char in enumerate(seq_str):
            if char in self.mapping:
                seq_enc[i, self.mapping[char]] = 1.0
            else:
                seq_enc[i] = 0.25
        
        # Transpose (L, 4) -> (4, L) cho Model
        return torch.tensor(seq_enc).transpose(0, 1), torch.tensor(label, dtype=torch.long)

# ==========================================
# 5. MAIN RUN
# ==========================================
def main():
    if not os.path.exists(RESULT_DIR):
        os.makedirs(RESULT_DIR)
        
    try:
        # 1. Load Class Model t·ª´ file model/model.py
        ModelClass = load_model_class(MODEL_CODE_FILE, MODEL_CLASS_NAME)
        
        # 2. Kh·ªüi t·∫°o Model
        # L∆∞u √Ω: N·∫øu __init__ c·ªßa model y√™u c·∫ßu tham s·ªë, c·∫ßn truy·ªÅn v√†o ƒë√¢y. 
        # Th∆∞·ªùng SpTransformer m·∫∑c ƒë·ªãnh params chu·∫©n.
        try:
            model = ModelClass().to(DEVICE)
        except TypeError as e:
            print(f"‚ö†Ô∏è Model y√™u c·∫ßu tham s·ªë kh·ªüi t·∫°o: {e}")
            print("üëâ ƒêang th·ª≠ kh·ªüi t·∫°o v·ªõi tham s·ªë m·∫∑c ƒë·ªãnh (n·∫øu c·∫ßn)...")
            # V√≠ d·ª•: model = ModelClass(maxlen=...)
            model = ModelClass().to(DEVICE)

        # 3. Load Weights
        if os.path.exists(CKPT_PATH):
            load_weights(model, CKPT_PATH)
        else:
            print(f"‚ùå Kh√¥ng t√¨m th·∫•y file weights t·∫°i: {CKPT_PATH}")
            return
            
        model.eval()
        
        # 4. Ch·∫°y v√≤ng l·∫∑p Inference
        for ratio in ratios:
            csv_path = os.path.join(DATA_DIR, f"prepared_inference_{ratio}.csv")
            if not os.path.exists(csv_path):
                print(f"‚è© B·ªè qua {ratio}: File kh√¥ng t·ªìn t·∫°i")
                continue
                
            print(f"\n>>> üèÉ ƒêang ch·∫°y: {ratio}")
            dataset = InferenceDataset(csv_path)
            loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
            
            all_labels, all_preds, all_probs = [], [], []
            
            with torch.no_grad():
                for inputs, labels in loader:
                    inputs = inputs.to(DEVICE)
                    
                    # Forward pass
                    outputs = model(inputs)
                    
                    # X·ª≠ l√Ω output n·∫øu l√† tuple/list
                    if isinstance(outputs, (tuple, list)):
                        outputs = outputs[0]
                    
                    # T√≠nh x√°c su·∫•t
                    probs = torch.softmax(outputs, dim=1)
                    preds = torch.argmax(probs, dim=1)
                    
                    all_labels.extend(labels.cpu().numpy())
                    all_preds.extend(preds.cpu().numpy())
                    all_probs.extend(probs.cpu().numpy())
            
            # T√≠nh v√† l∆∞u k·∫øt qu·∫£
            results = compute_metrics(np.array(all_labels), np.array(all_preds), np.array(all_probs))
            
            out_file = os.path.join(RESULT_DIR, f"results_{ratio}.json")
            with open(out_file, 'w') as f:
                json.dump(results, f, indent=4)
                
            print(f"   ‚úÖ Ho√†n t·∫•t. Accuracy: {results.get('accuracy', 0):.4f}")
            print(f"   üìÇ K·∫øt qu·∫£ l∆∞u t·∫°i: {out_file}")

    except Exception as e:
        print(f"\n‚ùå L·ªñI TRONG QU√Å TR√åNH CH·∫†Y: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()

‚ö†Ô∏è Kh√¥ng t√¨m th·∫•y file 'metrics.py'. S·∫Ω d√πng h√†m metrics ƒë∆°n gi·∫£n.
‚úÖ ƒê√£ import th√†nh c√¥ng class 'SpTransformer' t·ª´ model.py
‚ùå Kh√¥ng t√¨m th·∫•y file weights t·∫°i: D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\model.ckpt


In [1]:
import os
import sys
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import json
import importlib.util
from torch.utils.data import Dataset, DataLoader

# ==========================================
# 1. C·∫§U H√åNH ƒê∆Ø·ªúNG D·∫™N (ƒê√É GI·ªÆ NGUY√äN C·ª¶A B·∫†N)
# ==========================================

METRICS_FILE_PATH = r"D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\metrics.py"
MODEL_CODE_FILE = r"D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\SpliceTransformer-main\model\model.py"
MODEL_CLASS_NAME = "SpTransformer"
CKPT_PATH = r"D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\SpTransformer_pytorch.ckpt"
DATA_DIR = r"D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\data"
RESULT_DIR = r"D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\results"

# --- C·∫§U H√åNH M·ªöI: ƒê·ªò D√ÄI CHU·ªñI ---
# SpliceTransformer th∆∞·ªùng y√™u c·∫ßu 5000 ho·∫∑c 10000. 
# N·∫øu ch·∫°y v·∫´n l·ªói k√≠ch th∆∞·ªõc, h√£y th·ª≠ ƒë·ªïi s·ªë n√†y th√†nh 10000.
MAX_SEQ_LEN = 10000  

BATCH_SIZE = 16
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ratios = ["1_1_1", "2_1_1", "4_1_1", "10_1_1", "100_1_1"]

# ==========================================
# 2. H√ÄM LOAD MODULE
# ==========================================
def load_module_from_path(file_path, module_name="custom_module"):
    """Load m·ªôt file python nh∆∞ m·ªôt module t·ª´ ƒë∆∞·ªùng d·∫´n tuy·ªát ƒë·ªëi"""
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"‚ùå Kh√¥ng t√¨m th·∫•y file t·∫°i: {file_path}")
    
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    if spec is None:
        raise ImportError(f"‚ùå Kh√¥ng th·ªÉ ƒë·ªçc specs t·ª´ file: {file_path}")
        
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module

# ==========================================
# 3. H√ÄM LOAD WEIGHTS TH√îNG MINH
# ==========================================
def load_weights_smart(model, ckpt_path):
    print(f"üîÑ ƒêang load weights t·ª´: {ckpt_path}")
    checkpoint = torch.load(ckpt_path, map_location=DEVICE)
    
    # 1. L·∫•y state_dict
    if isinstance(checkpoint, dict):
        if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict']
        elif 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict']
        else: state_dict = checkpoint
    else:
        state_dict = checkpoint

    # 2. X√≥a prefix th·ª´a
    clean_state_dict = {}
    for key, val in state_dict.items():
        new_key = key
        for prefix in ["model.", "net.", "module.", "backbone."]:
            if new_key.startswith(prefix):
                new_key = new_key[len(prefix):]
                break
        clean_state_dict[new_key] = val

    # 3. Load v√†o model
    try:
        model.load_state_dict(clean_state_dict, strict=False)
        print("‚úÖ Load weights th√†nh c√¥ng!")
    except Exception as e:
        print(f"‚ö†Ô∏è C·∫£nh b√°o load weight: {e}")

# ==========================================
# 4. DATASET (ƒê√É S·ª¨A: TH√äM PADDING/CROP)
# ==========================================
class InferenceDataset(Dataset):
    def __init__(self, csv_path, max_len=MAX_SEQ_LEN):
        self.df = pd.read_csv(csv_path)
        self.mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
        self.max_len = max_len
        
    def __len__(self): return len(self.df)
    
    def __getitem__(self, idx):
        # L·∫•y chu·ªói v√† chu·∫©n h√≥a
        seq = str(self.df.iloc[idx]['sequence']).upper().strip()
        label = int(self.df.iloc[idx]['Splicing_types'])
        
        curr_len = len(seq)
        
        # --- X·ª¨ L√ù ƒê·ªò D√ÄI (PADDING HO·∫∂C CROP) ---
        if curr_len < self.max_len:
            # N·∫øu ng·∫Øn h∆°n -> Th√™m N v√†o 2 ƒë·∫ßu (ƒë·ªÉ t√¢m ·ªü gi·ªØa)
            pad_total = self.max_len - curr_len
            pad_left = pad_total // 2
            pad_right = pad_total - pad_left
            seq = "N" * pad_left + seq + "N" * pad_right
        elif curr_len > self.max_len:
            # N·∫øu d√†i h∆°n -> C·∫Øt l·∫•y ƒëo·∫°n gi·ªØa
            start = (curr_len - self.max_len) // 2
            seq = seq[start : start + self.max_len]
            
        # --- ONE-HOT ENCODING ---
        seq_enc = np.zeros((len(seq), 4), dtype=np.float32)
        for i, base in enumerate(seq):
            if base in self.mapping: 
                seq_enc[i, self.mapping[base]] = 1.0
            else: 
                seq_enc[i] = 0.25 # Gi√° tr·ªã cho N
        
        # Transpose (L, 4) -> (4, L)
        return torch.tensor(seq_enc).transpose(0, 1), torch.tensor(label, dtype=torch.long)

# ==========================================
# 5. CH∆Ø∆†NG TR√åNH CH√çNH
# ==========================================
def main():
    if not os.path.exists(RESULT_DIR): os.makedirs(RESULT_DIR)

    print("üöÄ B·∫ÆT ƒê·∫¶U CH·∫†Y INFERENCE...")

    # A. LOAD METRICS
    try:
        metrics_mod = load_module_from_path(METRICS_FILE_PATH, "metrics_mod")
        compute_metrics = metrics_mod.compute_metrics
        print(f"‚úÖ ƒê√£ load metrics t·ª´: {METRICS_FILE_PATH}")
    except Exception as e:
        print(f"‚ö†Ô∏è L·ªói load metrics.py: {e}")
        print("üëâ Chuy·ªÉn sang ch·∫ø ƒë·ªô t√≠nh Accuracy ƒë∆°n gi·∫£n (kh√¥ng c·∫ßn sklearn)")
        def compute_metrics(labels, preds, probs=None, k=2):
            return {"accuracy": float((labels == preds).mean()), "note": "fallback mode"}

    # B. LOAD MODEL & WEIGHTS
    try:
        model_mod = load_module_from_path(MODEL_CODE_FILE, "model_mod")
        
        if not hasattr(model_mod, MODEL_CLASS_NAME):
            print(f"‚ùå Kh√¥ng t√¨m th·∫•y class '{MODEL_CLASS_NAME}'")
            return
            
        ModelClass = getattr(model_mod, MODEL_CLASS_NAME)
        
        # Kh·ªüi t·∫°o model (Th·ª≠ v·ªõi tham s·ªë ƒë·ªô d√†i n·∫øu c·∫ßn)
        try:
            model = ModelClass().to(DEVICE)
        except:
            print(f"‚ÑπÔ∏è Init m·∫∑c ƒë·ªãnh l·ªói, th·ª≠ init v·ªõi L={MAX_SEQ_LEN}...")
            model = ModelClass(L=MAX_SEQ_LEN).to(DEVICE)

        load_weights_smart(model, CKPT_PATH)
        model.eval()

    except Exception as e:
        print(f"‚ùå L·ªói kh·ªüi t·∫°o Model: {e}")
        return

    # C. CH·∫†Y V√íNG L·∫∂P
    for ratio in ratios:
        csv_path = os.path.join(DATA_DIR, f"prepared_inference_{ratio}.csv")
        if not os.path.exists(csv_path):
            print(f"‚è© B·ªè qua {ratio}: File kh√¥ng t·ªìn t·∫°i")
            continue
            
        print(f"\n>>> üèÉ ƒêang x·ª≠ l√Ω: {ratio} (Len: {MAX_SEQ_LEN})")
        
        # Truy·ªÅn max_len v√†o Dataset
        dataset = InferenceDataset(csv_path, max_len=MAX_SEQ_LEN)
        loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        all_labels, all_preds, all_probs = [], [], []

        with torch.no_grad():
            for inputs, labels in loader:
                inputs = inputs.to(DEVICE)
                
                outputs = model(inputs)
                if isinstance(outputs, (tuple, list)): outputs = outputs[0]
                
                probs = torch.softmax(outputs, dim=1)
                preds = torch.argmax(probs, dim=1)

                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())

        # T√≠nh Metrics
        try:
            results = compute_metrics(np.array(all_labels), np.array(all_preds), np.array(all_probs))
            
            out_file = os.path.join(RESULT_DIR, f"results_{ratio}.json")
            with open(out_file, 'w') as f:
                json.dump(results, f, indent=4)
            print(f"   ‚úÖ Xong. Acc: {results.get('accuracy', 0):.4f} | Saved: {out_file}")
            
        except Exception as e:
            print(f"‚ùå L·ªói t√≠nh metrics cho {ratio}: {e}")

if __name__ == "__main__":
    main()

üöÄ B·∫ÆT ƒê·∫¶U CH·∫†Y INFERENCE...
‚úÖ ƒê√£ load metrics t·ª´: D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\metrics.py
üîÑ ƒêang load weights t·ª´: D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\SpTransformer_pytorch.ckpt
‚ö†Ô∏è C·∫£nh b√°o load weight: Error(s) in loading state_dict for SpTransformer:
	size mismatch for conv1.0.weight: copying a param with shape torch.Size([128, 4, 1]) from checkpoint, the shape in current model is torch.Size([32, 4, 1]).
	size mismatch for conv1.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for conv1.1.weight: copying a param with shape torch.Size([128, 128, 1]) from checkpoint, the shape in current model is torch.Size([32, 32, 1]).
	size mismatch for conv1.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for conv2.weight: co

  checkpoint = torch.load(ckpt_path, map_location=DEVICE)


KeyboardInterrupt: 

In [3]:
import os
import ast

# ƒê∆∞·ªùng d·∫´n th∆∞ m·ª•c g·ªëc c·ªßa repo b·∫°n ƒë√£ t·∫£i
SEARCH_DIR = r"D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\SpliceTransformer-main"

print(f"üîç ƒêang qu√©t t√¨m Model trong: {SEARCH_DIR}...\n")

found_any = False

for root, dirs, files in os.walk(SEARCH_DIR):
    for file in files:
        if file.endswith(".py"):
            full_path = os.path.join(root, file)
            try:
                with open(full_path, "r", encoding="utf-8", errors="ignore") as f:
                    file_content = f.read()
                    
                    # D√πng AST ƒë·ªÉ ph√¢n t√≠ch code m√† kh√¥ng c·∫ßn import (tr√°nh l·ªói thi·∫øu th∆∞ vi·ªán)
                    try:
                        tree = ast.parse(file_content)
                    except SyntaxError:
                        continue # B·ªè qua file l·ªói c√∫ ph√°p

                    for node in ast.walk(tree):
                        if isinstance(node, ast.ClassDef):
                            is_model = False
                            # Ki·ªÉm tra xem class c√≥ k·∫ø th·ª´a nn.Module hay LightningModule kh√¥ng
                            for base in node.bases:
                                # Case 1: class Model(nn.Module) -> base l√† Attribute
                                if isinstance(base, ast.Attribute) and base.attr in ['Module', 'LightningModule']:
                                    is_model = True
                                # Case 2: class Model(Module) -> base l√† Name
                                elif isinstance(base, ast.Name) and base.id in ['Module', 'LightningModule']:
                                    is_model = True
                            
                            if is_model:
                                print(f"‚úÖ T√åM TH·∫§Y ·ª®NG VI√äN!")
                                print(f"   üìÇ File: {full_path}")
                                print(f"   üß© Class Name: {node.name}")
                                print("-" * 50)
                                found_any = True
            except Exception as e:
                pass

if not found_any:
    print("‚ùå V·∫´n kh√¥ng t√¨m th·∫•y class n√†o k·∫ø th·ª´a Module. B·∫°n c√≥ ch·∫Øc ƒë√£ t·∫£i ƒë·ªß code kh√¥ng?")

üîç ƒêang qu√©t t√¨m Model trong: D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\SpliceTransformer-main...

‚úÖ T√åM TH·∫§Y ·ª®NG VI√äN!
   üìÇ File: D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\SpliceTransformer-main\model\model.py
   üß© Class Name: ResBlock
--------------------------------------------------
‚úÖ T√åM TH·∫§Y ·ª®NG VI√äN!
   üìÇ File: D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\SpliceTransformer-main\model\model.py
   üß© Class Name: SpEncoder
--------------------------------------------------
‚úÖ T√åM TH·∫§Y ·ª®NG VI√äN!
   üìÇ File: D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\SpliceTransformer-main\model\model.py
   üß© Class Name: SpEncoder_4tis
--------------------------------------------------
‚úÖ T√åM TH·∫§Y ·ª®NG VI√äN!
   üìÇ File: D:\Study\5-FA25\AiTa_Lab_Research\Code\Inference_Model\SpliceTransformer\SpliceTransformer-main\model\model.py
  