In [None]:
!pip install --quiet --upgrade --force-reinstall "fair-esm[esmfold]" biopython==1.81 omegaconf einops ml-collections

In [None]:
%cd /kaggle/working
!git clone https://github.com/aqlaboratory/openfold.git
!pip install --quiet "biopython==1.81" omegaconf einops ml-collections
!pip install --quiet -r openfold/requirements.txt
!pip install --quiet -e openfold

In [None]:
import os, re
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
from Bio import SeqIO
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 (kích hoạt 3D proj)

# ==== ĐƯỜNG DẪN SỬA TÙY BẠN ====
ROOT = Path("/kaggle/input/cafa-6-protein-function-prediction")

P = {
    "train_fasta": ROOT/"Train"/"train_sequences.fasta",
    "test_fasta":  ROOT/"Test"/"testsuperset.fasta",
}

OUTDIR = "/kaggle/working/pred_struct"

# ==== GPU ====
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)
print("GPU count:", torch.cuda.device_count())

MIXED_PRECISION = True  # dùng fp16 autocast cho nhanh/đỡ VRAM
MAX_LEN = 1022          # ESMFold thường giới hạn ~1022 aa; quá dài thì bỏ qua/handle riêng
BATCH_SIZE = 2          # batch inference nhiều protein 1 lúc (tối đa tùy VRAM T4x2)
SKIP_IF_EXISTS = True   # không chạy lại protein đã fold rồi

In [None]:
def norm_uniprot_id(header_id: str) -> str:
    """
    Convert 'sp|P9WHI7|RECN_MYCT ...' -> 'P9WHI7'
    """
    rid = str(header_id).split()[0]
    if "|" in rid:
        parts = rid.split("|")
        if len(parts) >= 2 and parts[1]:
            return parts[1]
    return rid

def load_fasta_list(fasta_path: Path):
    """
    Trả về list dict:
      {
        "target": accession_id,
        "seq":    amino acid sequence (str),
        "len":    sequence length
      }
    Không lẫn taxon gì ở đây, vì DeepFRI chỉ cần seq → cấu trúc.
    """
    rows = []
    for rec in SeqIO.parse(str(fasta_path), "fasta"):
        acc = norm_uniprot_id(rec.id)
        seq = str(rec.seq)
        rows.append({
            "target": acc,
            "seq": seq,
            "len": len(seq),
        })
    return rows

train_list = load_fasta_list(P["train_fasta"])
test_list  = load_fasta_list(P["test_fasta"])

print("Train proteins:", len(train_list))
print("Test  proteins:", len(test_list))
print("Sample train entry:", train_list[0])

In [None]:
import esm

# Load pretrained ESMFold
esmfold_model = esm.pretrained.esmfold_v1()
esmfold_model = esmfold_model.eval()  # inference mode

# Đưa lên GPU
esmfold_model = esmfold_model.to(DEVICE)

# Dùng DataParallel nếu có >1 GPU
if torch.cuda.device_count() > 1:
    print("Using DataParallel across", torch.cuda.device_count(), "GPUs")
    esmfold_model = nn.DataParallel(esmfold_model)

# Hàm tiện: model nhận list[str] (các chuỗi aa), trả batch output
def fold_batch_sequences(seq_list):
    """
    Input:
      seq_list: list of aa strings length B
    Output:
      coords_ca_list: list of np.array [L,3] (C-alpha coords)
    """
    with torch.no_grad():
        if MIXED_PRECISION and DEVICE == "cuda":
            autocast_ctx = torch.cuda.amp.autocast(dtype=torch.float16)
        else:
            # dummy context manager
            class Dummy: 
                def __enter__(self): pass
                def __exit__(self,a,b,c): pass
            autocast_ctx = Dummy()

        with autocast_ctx:
            # ESMFold forward
            # esmfold_model expects a list of sequences
            output = esmfold_model(seq_list)

            # output["positions"]: shape [B, L, n_atoms, 3]
            # L có thể khác nhau giữa các seq trong batch => ESMFold pad nội bộ.
            # Với DataParallel, output có thể là dict of tensors hoặc list of dict,
            # tuỳ version fair-esm. Chúng ta handle 2 case phổ biến:

            if isinstance(output, dict):
                pos = output["positions"]  # [B, L, n_atoms, 3] tensor
                coords_all = pos.detach().float().cpu().numpy()
            elif isinstance(output, list):
                # một số bản DataParallel trả list dict per device, ta gộp tay
                # đơn giản nhất: concat theo batch dim
                coords_list = []
                for out_i in output:
                    coords_list.append(out_i["positions"].detach().float().cpu().numpy())
                coords_all = np.concatenate(coords_list, axis=0)
            else:
                raise RuntimeError("Unexpected ESMFold output type")

    # Lấy C-alpha (CA) = index 1 trong n_atoms
    # coords_all shape [B, L, n_atoms, 3] -> [B, L, 3]
    coords_ca_list = []
    for b in range(coords_all.shape[0]):
        coords_ca = coords_all[b, :, 1, :]  # lấy atom index=1 => CA
        coords_ca_list.append(coords_ca)
    return coords_ca_list

In [None]:
def save_coords_npz(prot_id, seq, coords, out_dir: Path):
    out_path = out_dir / f"{prot_id}.npz"
    np.savez_compressed(out_path, coords=coords.astype(np.float32), seq=np.array(seq))
    return out_path

def generate_structures_for_list(protein_list, out_dir: Path, split_name="train"):
    """
    protein_list: list of dict {"target","seq","len"}
    """
    miss_too_long = 0
    already = 0
    done = 0

    # tạo danh sách các protein mà ta sẽ thực sự fold
    jobs = []
    for p in protein_list:
        prot_id = p["target"]
        seq     = p["seq"]
        L       = p["len"]

        out_path = out_dir / f"{prot_id}.npz"
        if SKIP_IF_EXISTS and out_path.exists():
            already += 1
            continue

        if L > MAX_LEN:
            # skip protein quá dài
            miss_too_long += 1
            continue

        jobs.append((prot_id, seq))

    print(f"[{split_name}] total={len(protein_list)}, to_run={len(jobs)}, already={already}, too_long={miss_too_long}")

    # chạy theo batch
    for i in tqdm(range(0, len(jobs), BATCH_SIZE), desc=f"Folding {split_name}"):
        batch_jobs = jobs[i:i+BATCH_SIZE]
        batch_ids  = [j[0] for j in batch_jobs]
        batch_seqs = [j[1] for j in batch_jobs]

        coords_batch = fold_batch_sequences(batch_seqs)  # list np[L,3]
        for pid, seq, coords in zip(batch_ids, batch_seqs, coords_batch):
            save_coords_npz(pid, seq, coords, out_dir)
            done += 1

    print(f"[{split_name}] generated new {done} structures")

In [None]:
generate_structures_for_list(train_list, OUTDIR, split_name="train")
generate_structures_for_list(test_list,  OUTDIR, split_name="test")

In [None]:
def plot_protein_coords(coords, title="Protein 3D"):
    xs = coords[:,0]
    ys = coords[:,1]
    zs = coords[:,2]
    fig = plt.figure(figsize=(5,5))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(xs, ys, zs, s=8)
    ax.plot(xs, ys, zs, linewidth=1, alpha=0.4)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    ax.set_title(title)
    plt.show()

# ví dụ: chọn 1 protein đã sinh
example_pid = train_list[0]["target"]
example_path = OUTDIR / f"{example_pid}.npz"
if example_path.exists():
    data = np.load(example_path)
    coords_ex = data["coords"]  # (L,3)
    plot_protein_coords(coords_ex, title=f"{example_pid} (len={coords_ex.shape[0]})")
else:
    print("Chưa có file ví dụ để plot. Chạy Cell 6 trước đã.")