In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import gc

# --- CẤU HÌNH TỐI ƯU CHO T4 x2 ---
CFG = {
    'model_name': 'facebook/esm2_t33_650M_UR50D',
    'input_dim': 1280,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    
    # Cấu hình Extraction
    'extract_batch_size': 16,
    'max_len': 1024,
    
    'data_path': '/kaggle/input/cafa-6-protein-function-prediction'
}

print(f">>> ĐANG CHẠY TRÊN: {CFG['device']}")
if torch.cuda.device_count() > 1:
    print(f"{torch.cuda.device_count()}")
else:
    print("Chỉ tìm thấy 1 GPU")

# ====================================================
# PHẦN 1: TRÍCH XUẤT EMBEDDINGS (ESM-2 650M)
# ====================================================
print("\n[1/5] CHECKING EMBEDDINGS...")

TRAIN_EMB_FILE = 'X_train_650M.npy'
TEST_EMB_FILE = 'X_test_650M.npy'

def extract_embeddings_heavy(sequences, save_path):
    if os.path.exists(save_path):
        print(f" -> File {save_path} đã tồn tại. Bỏ qua bước này.")
        return

    print(f" -> Đang khởi tạo model {CFG['model_name']}...")
    tokenizer = AutoTokenizer.from_pretrained(CFG['model_name'])
    model = AutoModel.from_pretrained(CFG['model_name'])
    
    # KÍCH HOẠT 2 GPU
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    
    model.to(CFG['device'])
    model.eval()
    
    embeddings = []
    print(f" -> Bắt đầu xử lý {len(sequences)} sequences...")
    
    with torch.no_grad():
        for i in tqdm(range(0, len(sequences), CFG['extract_batch_size']), desc="Extracting"):
            batch_seqs = sequences[i : i + CFG['extract_batch_size']]
            
            # Tokenize
            inputs = tokenizer(
                batch_seqs, 
                return_tensors="pt", 
                padding=True, 
                truncation=True, 
                max_length=CFG['max_len']
            )
            inputs = {k: v.to(CFG['device']) for k, v in inputs.items()}
            
            # Forward pass (Mixed Precision để tăng tốc & giảm VRAM)
            if torch.cuda.is_available():
                with torch.cuda.amp.autocast():
                    outputs = model(**inputs)
            else:
                outputs = model(**inputs)
            
            # Lấy Mean Pooling (Bỏ qua padding)
            last_hidden_state = outputs.last_hidden_state
            attention_mask = inputs['attention_mask'].unsqueeze(-1)
            
            # Tính mean thủ công để chính xác
            masked_hidden = last_hidden_state * attention_mask
            sum_hidden = masked_hidden.sum(dim=1)
            sum_mask = attention_mask.sum(dim=1)
            batch_emb = sum_hidden / sum_mask
            
            # Lưu dạng float16
            embeddings.append(batch_emb.cpu().numpy().astype(np.float16))

    # Gộp và Lưu
    full_emb = np.vstack(embeddings)
    np.save(save_path, full_emb)
    print(f" -> Đã lưu {save_path}: {full_emb.shape}")
    
    # Dọn dẹp GPU
    del model, tokenizer, embeddings, full_emb
    gc.collect()
    torch.cuda.empty_cache()

# --- HÀM ĐỌC DATA ---
def get_seqs(path):
    seqs = []
    ids = [] 
    with open(path, 'r') as f:
        current_seq = []
        prev_header = ""
        for line in f:
            if line.startswith('>'):
                if current_seq: 
                    seqs.append("".join(current_seq))
                    # Lấy ID chuẩn
                    ids.append(prev_header.strip()[1:].split('|')[1] if '|' in prev_header else prev_header.strip()[1:].split()[0])
                current_seq = []
                prev_header = line
            else:
                current_seq.append(line.strip())
        if current_seq: 
            seqs.append("".join(current_seq))
            ids.append(prev_header.strip()[1:].split('|')[1] if '|' in prev_header else prev_header.strip()[1:].split()[0])
    return ids, seqs

# --- THỰC THI EXTRACT ---
print(" -> Reading FASTA files...")
# Đọc file Train và Test từ đường dẫn Kaggle
train_ids, train_seqs = get_seqs(f"{CFG['data_path']}/Train/train_sequences.fasta")
test_ids, test_seqs = get_seqs(f"{CFG['data_path']}/Test/testsuperset.fasta")

# Chạy Extract
extract_embeddings_heavy(train_seqs, TRAIN_EMB_FILE)
extract_embeddings_heavy(test_seqs, TEST_EMB_FILE)

print("\n>>> HOÀN THÀNH EXTRACT EMBEDDINGS!")