In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'


from pathlib import Path
import sys
import logging
import typing as T
from timeit import default_timer as timer
import numpy as np
import torch
import esm
from esm.data import read_fasta
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Logging setup
logger = logging.getLogger()
logger.setLevel(logging.INFO)

formatter = logging.Formatter(
    "%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%y/%m/%d %H:%M:%S",
)

console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

# Helper functions
def enable_cpu_offloading(model):
    from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel
    from torch.distributed.fsdp.wrap import enable_wrap, wrap

    wrapper_kwargs = dict(cpu_offload=CPUOffload(offload_params=True))

    with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
        for layer_name, layer in model.layers.named_children():
            wrapped_layer = wrap(layer)
            setattr(model.layers, layer_name, wrapped_layer)
        model = wrap(model)

    return model


def init_model_on_gpu_with_cpu_offloading(model):
    model = model.eval()
    model_esm = enable_cpu_offloading(model.esm)
    del model.esm
    model.cuda()
    model.esm = model_esm
    return model


def create_batched_sequence_datasest(
    sequences: T.List[T.Tuple[str, str]], max_tokens_per_batch: int = 1024
) -> T.Generator[T.Tuple[T.List[str], T.List[str]], None, None]:
    batch_headers, batch_sequences, num_tokens = [], [], 0
    for header, seq in sequences:
        if (len(seq) + num_tokens > max_tokens_per_batch) and num_tokens > 0:
            yield batch_headers, batch_sequences
            batch_headers, batch_sequences, num_tokens = [], [], 0
        batch_headers.append(header)
        batch_sequences.append(seq)
        num_tokens += len(seq)

    yield batch_headers, batch_sequences


def run(args):
    if not args.fasta.exists():
        raise FileNotFoundError(args.fasta)

    args.pdb.mkdir(exist_ok=True)

    # Read fasta and sort sequences by length
    print(f"Reading sequences from {args.fasta}")
    all_sequences = sorted(read_fasta(args.fasta), key=lambda header_seq: len(header_seq[1]))
    print(f"Loaded {len(all_sequences)} sequences from {args.fasta}")

    print("Loading model")

    # Use pre-downloaded ESM weights from model_pth.
    if args.model_dir is not None:
        torch.hub.set_dir(args.model_dir)

    model = esm.pretrained.esmfold_v1()

    model = model.eval()
    model.set_chunk_size(args.chunk_size)

    if args.cpu_only:
        model.esm.float()  # convert to fp32 as ESM-2 in fp16 is not supported on CPU
        model.cpu()
    elif args.cpu_offload:
        model = init_model_on_gpu_with_cpu_offloading(model)
    else:
        model.cuda()
    print("Starting Predictions")
    batched_sequences = create_batched_sequence_datasest(all_sequences, args.max_tokens_per_batch)

    num_completed = 0
    num_sequences = len(all_sequences)
    for headers, sequences in batched_sequences:
        start = timer()
        try:
            output = model.infer(sequences, num_recycles=args.num_recycles)
        except RuntimeError as e:
            if e.args[0].startswith("CUDA out of memory"):
                if len(sequences) > 1:
                    print(
                        f"Failed (CUDA out of memory) to predict batch of size {len(sequences)}. "
                        "Try lowering `--max-tokens-per-batch`."
                    )
                else:
                    print(
                        f"Failed (CUDA out of memory) on sequence {headers[0]} of length {len(sequences[0])}."
                    )

                continue
            raise

        output = {key: value.cpu() for key, value in output.items()}
        pdbs = model.output_to_pdb(output)
        tottime = timer() - start
        time_string = f"{tottime / len(headers):0.1f}s"
        if len(sequences) > 1:
            time_string = time_string + f" (amortized, batch size {len(sequences)})"
        for header, seq, pdb_string, mean_plddt, ptm in zip(
            headers, sequences, pdbs, output["mean_plddt"], output["ptm"]
        ):
            output_file = args.pdb / f"{header}.pdb"
            output_file.write_text(pdb_string)
            num_completed += 1
            print(
                f"Predicted structure for {header} with length {len(seq)}, pLDDT {mean_plddt:0.1f}, "
                f"pTM {ptm:0.3f} in {time_string}. "
                f"{num_completed} / {num_sequences} completed."
            )


class Args:
    fasta = Path("")
    pdb = Path("")
    model_dir = Path("")  #
    num_recycles = 6
    max_tokens_per_batch = 256
    chunk_size = 128
    cpu_only = False
    cpu_offload = False

args = Args()


run(args)


In [None]:
import os
import torch
import pandas as pd

# Set directory paths (modify to your actual local paths)
pt_dir = ''
surface_dir = ''
output_dir = ''
os.makedirs(output_dir, exist_ok=True)

# Traverse all .pt files
for file in os.listdir(pt_dir):
    if file.endswith('.pt'):
        base = file.replace('.pt', '')  # e.g., "0_positive_training"
        pt_path = os.path.join(pt_dir, file)

        # Find the corresponding surface CSV file
        surface_csv_name = f"{base} dataset_surface.csv"
        surface_path = os.path.join(surface_dir, surface_csv_name)
        
        # Skip if the corresponding surface file is not found
        if not os.path.exists(surface_path):
            print(f"⚠️ Surface file not found: {surface_csv_name}")
            continue

        # Load graph features
        data = torch.load(pt_path)
        features = data['x']  # [num_nodes, 1217]

        # Load exposed residue indices
        surface_df = pd.read_csv(surface_path)
        indices = torch.tensor(surface_df['Residue_Number'].astype(int).tolist()) - 1
        surface_features = features[indices]

        # Create output DataFrame with Residue_Number
        output_df = pd.DataFrame(surface_features.numpy())
        output_df.insert(0, 'Residue_Number', surface_df['Residue_Number'].values)

        # Save the extracted features
        out_csv_path = os.path.join(output_dir, f"{base}_surface_node_features.csv")
        output_df.to_csv(out_csv_path, index=False)
        print(f"✅ Saved: {out_csv_path}")


In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
import numpy as np
from scipy.spatial.distance import cdist
from torch_geometric.nn import radius_graph
import os
from Bio import pairwise2

# Process PDB file
def get_pdb_xyz(pdb_file):
    current_pos = -1000
    X = []
    current_aa = {}  # N, CA, C, O, R
    for line in pdb_file:
        if (line[0:4].strip() == "ATOM" and int(line[22:26].strip()) != current_pos) or line[0:4].strip() == "TER":
            if current_aa != {}:
                R_group = [current_aa[atom] for atom in current_aa if atom not in ["N", "CA", "C", "O"]]
                R_group = np.array(R_group).mean(0) if R_group else current_aa["CA"]
                X.append([current_aa["N"], current_aa["CA"], current_aa["C"], current_aa["O"], R_group])
                current_aa = {}
            if line[0:4].strip() != "TER":
                current_pos = int(line[22:26].strip())

        if line[0:4].strip() == "ATOM":
            atom = line[13:16].strip()
            if atom != "H":
                xyz = np.array([line[30:38].strip(), line[38:46].strip(), line[46:54].strip()]).astype(np.float32)
                current_aa[atom] = xyz
    return np.array(X)

# Process DSSP file
def process_dssp(dssp_file):
    aa_type = "ACDEFGHIKLMNPQRSTVWY"
    SS_type = "HBEGITSC"
    rASA_std = [115, 135, 150, 190, 210, 75, 195, 175, 200, 170,
                185, 160, 145, 180, 225, 115, 140, 155, 255, 230]

    with open(dssp_file, "r") as f:
        lines = f.readlines()

    seq = ""
    dssp_feature = []
    
    # Skip header
    if lines[0].strip() == "No.\tAA\tSS\tASA (Angstrom**2)\tRSA (%)":
        lines = lines[1:]

    for i, line in enumerate(lines):
        parts = line.split()
        if len(parts) < 4:
            continue
        aa = parts[1]  # amino acid
        SS = parts[2] if parts[2] != " " else "C"  # process secondary structure

        if aa not in aa_type:
            continue  # skip unknown amino acids

        seq += aa
        SS_vec = np.zeros(len(SS_type))
        SS_vec[SS_type.find(SS)] = 1
        
        ASA = float(parts[3])
        RSA = min(1, ASA / rASA_std[aa_type.find(aa)])
        
        dssp_feature.append(np.concatenate((np.array([RSA]), SS_vec)))

    return seq, dssp_feature

# Match DSSP sequence to reference sequence
def match_dssp(seq, dssp, ref_seq):
    alignments = pairwise2.align.globalxx(ref_seq, seq)
    ref_seq = alignments[0].seqA
    seq = alignments[0].seqB

    padded_item = np.zeros(9)  # ensure same dimension as DSSP features

    new_dssp = []
    dssp_idx = 0
    for aa in seq:
        if aa == "-":
            new_dssp.append(padded_item)
        else:
            new_dssp.append(dssp[dssp_idx])
            dssp_idx += 1

    matched_dssp = []
    for i in range(len(ref_seq)):
        if ref_seq[i] == "-":
            continue
        matched_dssp.append(new_dssp[i])

    return matched_dssp

# Compute geometric features
def get_geo_feat(X, edge_index):
    pos_embeddings = _positional_embeddings(edge_index)
    node_angles = _get_angle(X)
    node_dist, edge_dist = _get_distance(X, edge_index)
    node_direction, edge_direction, edge_orientation = _get_direction_orientation(X, edge_index)

    geo_node_feat = torch.cat([node_angles, node_dist, node_direction], dim=-1)
    geo_edge_feat = torch.cat([pos_embeddings, edge_orientation, edge_dist, edge_direction], dim=-1)

    return geo_node_feat, geo_edge_feat

# Internal: positional encoding
def _positional_embeddings(edge_index, num_embeddings=16):
    d = edge_index[0] - edge_index[1]

    frequency = torch.exp(
        torch.arange(0, num_embeddings, 2, dtype=torch.float32, device=edge_index.device)
        * -(np.log(10000.0) / num_embeddings)
    )
    angles = d.unsqueeze(-1) * frequency
    PE = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
    return PE

# Internal: get dihedral and bond angles
def _get_angle(X, eps=1e-7):
    X = torch.reshape(X[:, :3], [3 * X.shape[0], 3])
    dX = X[1:] - X[:-1]
    U = F.normalize(dX, dim=-1)
    u_2 = U[:-2]
    u_1 = U[1:-1]
    u_0 = U[2:]

    n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
    n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1)

    cosD = torch.sum(n_2 * n_1, -1)
    cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
    D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD)
    D = F.pad(D, [1, 2])
    D = torch.reshape(D, [-1, 3])
    dihedral = torch.cat([torch.cos(D), torch.sin(D)], 1)

    cosD = (u_2 * u_1).sum(-1)
    cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
    D = torch.acos(cosD)
    D = F.pad(D, [1, 2])
    D = torch.reshape(D, [-1, 3])
    bond_angles = torch.cat((torch.cos(D), torch.sin(D)), 1)

    node_angles = torch.cat((dihedral, bond_angles), 1)
    return node_angles

# Internal: radial basis function for distances
def _rbf(D, D_min=0., D_max=20., D_count=16):
    D_mu = torch.linspace(D_min, D_max, D_count, device=D.device).view([1, -1])
    D_sigma = (D_max - D_min) / D_count
    D_expand = torch.unsqueeze(D, -1)
    RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2)
    return RBF

# Internal: compute distances
def _get_distance(X, edge_index):
    atom_N = X[:, 0]
    atom_Ca = X[:, 1]
    atom_C = X[:, 2]
    atom_O = X[:, 3]
    atom_R = X[:, 4]

    node_list = ['Ca-N', 'Ca-C', 'Ca-O', 'N-C', 'N-O', 'O-C', 'R-N', 'R-Ca', "R-C", 'R-O']
    node_dist = []
    for pair in node_list:
        atom1, atom2 = pair.split('-')
        E_vectors = vars()['atom_' + atom1] - vars()['atom_' + atom2]
        rbf = _rbf(E_vectors.norm(dim=-1))
        node_dist.append(rbf)
    node_dist = torch.cat(node_dist, dim=-1)

    atom_list = ["N", "Ca", "C", "O", "R"]
    edge_dist = []
    for atom1 in atom_list:
        for atom2 in atom_list:
            E_vectors = vars()['atom_' + atom1][edge_index[0]] - vars()['atom_' + atom2][edge_index[1]]
            rbf = _rbf(E_vectors.norm(dim=-1))
            edge_dist.append(rbf)
    edge_dist = torch.cat(edge_dist, dim=-1)

    return node_dist, edge_dist

# Internal: direction and orientation features
def _get_direction_orientation(X, edge_index):
    X_N = X[:, 0]
    X_Ca = X[:, 1]
    X_C = X[:, 2]
    u = F.normalize(X_Ca - X_N, dim=-1)
    v = F.normalize(X_C - X_Ca, dim=-1)
    b = F.normalize(u - v, dim=-1)
    n = F.normalize(torch.cross(u, v), dim=-1)
    local_frame = torch.stack([b, n, torch.cross(b, n)], dim=-1)

    node_j, node_i = edge_index

    t = F.normalize(X[:, [0, 2, 3, 4]] - X_Ca.unsqueeze(1), dim=-1)
    node_direction = torch.matmul(t, local_frame).reshape(t.shape[0], -1)

    t = F.normalize(X[node_j] - X_Ca[node_i].unsqueeze(1), dim=-1)
    edge_direction_ji = torch.matmul(t, local_frame[node_i]).reshape(t.shape[0], -1)
    t = F.normalize(X[node_i] - X_Ca[node_j].unsqueeze(1), dim=-1)
    edge_direction_ij = torch.matmul(t, local_frame[node_j]).reshape(t.shape[0], -1)
    edge_direction = torch.cat([edge_direction_ji, edge_direction_ij], dim=-1)

    r = torch.matmul(local_frame[node_i].transpose(-1, -2), local_frame[node_j])
    edge_orientation = _quaternions(r)

    return node_direction, edge_direction, edge_orientation

# Internal: convert rotation matrix to quaternion
def _quaternions(R):
    diag = torch.diagonal(R, dim1=-2, dim2=-1)
    Rxx, Ryy, Rzz = diag.unbind(-1)
    magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([
        Rxx - Ryy - Rzz,
        -Rxx + Ryy - Rzz,
        -Rxx - Ryy + Rzz
    ], -1)))
    _R = lambda i, j: R[:, i, j]
    signs = torch.sign(torch.stack([
        _R(2, 1) - _R(1, 2),
        _R(0, 2) - _R(2, 0),
        _R(1, 0) - _R(0, 1)
    ], -1))
    xyz = signs * magnitudes
    w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2.
    Q = torch.cat((xyz, w), -1)
    Q = F.normalize(Q, dim=-1)
    return Q

# Dataset class
class GPSite_Dataset(torch.utils.data.Dataset):
    def __init__(self, pdb_path, dssp_path, transpro_path, radius=15):
        super(GPSite_Dataset, self).__init__()
        self.pdb_path = pdb_path
        self.dssp_path = dssp_path
        self.transpro_path = transpro_path
        self.radius = radius
        self.IDs = self._get_ids()
        self.failed_samples = []  # store failed samples and reasons

    def _get_ids(self):
        return [f.split('.')[0] for f in os.listdir(self.pdb_path) if f.endswith('.pdb')]

    def __len__(self): 
        return len(self.IDs)

    def __getitem__(self, idx): 
        try:
            return self._featurize_graph(idx)
        except Exception as e:
            sample_id = self.IDs[idx]
            self.failed_samples.append((sample_id, str(e)))
            print(f"⚠️ Skipped sample {sample_id}, reason: {e}")
            return None

    def _featurize_graph(self, idx):
        name = self.IDs[idx]
        with torch.no_grad():
            # Load PDB file
            with open(os.path.join(self.pdb_path, f"{name}.pdb"), 'r') as pdb_file:
                X = torch.tensor(get_pdb_xyz(pdb_file.readlines()), dtype=torch.float32)

            # Load DSSP file
            dssp_seq, dssp_feat = process_dssp(os.path.join(self.dssp_path, f"{name}.dssp"))
            dssp_feat = torch.tensor(dssp_feat, dtype=torch.float32)

            # Load TransPro file
            transpro_feat = torch.load(os.path.join(self.transpro_path, f"{name}.tensor")).clone().detach()
            if transpro_feat.shape[1] != 1024:
                raise ValueError(f"Expected TransPro feature dimension to be 1024, but got {transpro_feat.shape[1]}.")

            # Compute geometry features
            edge_index = radius_graph(X[:, 1], r=self.radius)
            geo_node_feat, geo_edge_feat = get_geo_feat(X, edge_index)

            # Truncate all node features to the shortest length
            min_nodes = min(dssp_feat.shape[0], geo_node_feat.shape[0], transpro_feat.shape[0])
            dssp_feat = dssp_feat[:min_nodes]
            geo_node_feat = geo_node_feat[:min_nodes]
            transpro_feat = transpro_feat[:min_nodes]

            pre_computed_node_feat = torch.cat([dssp_feat, geo_node_feat, transpro_feat], dim=-1)

            # Normalize node features
            node_mean = pre_computed_node_feat.mean(dim=0, keepdim=True)
            node_std = pre_computed_node_feat.std(dim=0, keepdim=True)
            pre_computed_node_feat = (pre_computed_node_feat - node_mean) / (node_std + 1e-6)

            # Normalize edge features
            edge_features = geo_edge_feat
            edge_mean = edge_features.mean(dim=0, keepdim=True)
            edge_std = edge_features.std(dim=0, keepdim=True)
            edge_features = (edge_features - edge_mean) / (edge_std + 1e-6)

        graph_data = Data(name=name, x=pre_computed_node_feat, edge_index=edge_index, edge_attr=edge_features)
        return graph_data

    def save_graphs(self, output_dir):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        for idx in range(len(self)):
            graph_data = self[idx]
            if graph_data is None:
                continue  # Skip failed samples
            
            file_path = os.path.join(output_dir, f"{graph_data.name}.pt")
            torch.save(graph_data, file_path)
            print(f"✅ Saved graph: {file_path}")

        # Print failed samples
        if self.failed_samples:
            print("\n❌ The following samples failed:")
            for sample_id, error in self.failed_samples:
                print(f"- {sample_id}: {error}")

# Example usage
pdb_folder = '/home/shliu/odorant/yanzhen/PDB/'
dssp_folder = '/home/shliu/odorant/yanzhen/dssp/'
transpro_folder = '/home/shliu/odorant/yanzhen/prot/ProtTrans/'
output_folder = '/home/shliu/odorant/yanzhen/graph/'

dataset = GPSite_Dataset(pdb_folder, dssp_folder, transpro_folder)
dataset.save_graphs(output_folder)
