In [1]:
import numpy as np
from Bio.PDB import PDBParser
from collections import defaultdict

In [3]:
VOXEL_SIZE = 1.0  # Å
GRID_SIZE = 32  # 32x32x32
ATOM_TYPES = ['C', 'N', 'O', 'S']  # 주요 원자 종류 (채널)
CHANNELS = len(ATOM_TYPES)

In [4]:
def get_atom_channel(atom_name):
    for i, t in enumerate(ATOM_TYPES):
        if atom_name.startswith(t):
            return i
        return None

In [13]:
def get_structure_coords(pdb_file):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('mol', pdb_file)
    atoms = []
    for atom in structure.get_atoms():
        pos = atom.get_coord()
        name = atom.element.strip()
        atoms.append((pos, name))
    return atoms

In [81]:
def get_ligand_center(ligand_pdb_path, ligand_resname='UNK'):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('ligand', ligand_pdb_path)
    
    coords = []
    
    for model in structure:
        for chain in model:
            for residue in chain:
                if residue.get_resname() == ligand_resname:
                    for atom in residue:
                        coords.append(atom.coord)

    if len(coords) == 0:
        raise ValueError(f"No ligand atoms found with resname {ligand_resname}")
    
    coords = np.array(coords)
    center = coords.mean(axis=0)
    
    return center

In [15]:
def make_voxel_grid(atom_list, center, grid_size=32, voxel_size=1.0):
    # grid 0으로 초기화
    grid = np.zeros((CHANNELS, grid_size, grid_size, grid_size), dtype=np.float32)
    half = grid_size * voxel_size / 2
    for pos, name in atom_list:
        x, y, z = pos - center + half
        i, j, k = (int(x//voxel_size), int(y//voxel_size), int(z//voxel_size))    #위치 좌표를 grid index로 변환
        ch = get_atom_channel(name)
        if 0 <= i < grid_size and 0 <= j < grid_size and 0 <= k < grid_size and ch is not None:
            grid[ch, i, j, k] = 1.0
    return grid

In [8]:
def make_label_grid(water_list, center, grid_size=32, voxel_size=1.0):
    label = np.zeros((grid_size, grid_size, grid_size), dtype=np.uint8)
    half = grid_size * voxel_size / 2
    for pos, _ in water_list:
        x, y, z = pos - center + half
        i, j, k = (int(x // voxel_size), int(y // voxel_size), int(z // voxel_size))
        if 0 <= i < grid_size and 0 <= j < grid_size and 0 <= k < grid_size:
            label[i, j, k] = 1
    return label

In [9]:
def prepare_sample_with_ligand(prot_pdb, hs_pdb, lig_pdb):
    prot_atoms = get_structure_coords(prot_pdb)
    water_atoms = get_structure_coords(hs_pdb)
    center = get_center_ligand(lig_pdb)

    voxel_X = make_voxel_grid(prot_atoms, center)
    voxel_Y = make_label_grid(water_atoms, center)

    return voxel_X, voxel_Y

In [17]:
prot_file = "/Users/yeonji/Dropbox/myfolder_data/wbp_last/ahr_eq/avg/xiap_ahr_eq_l20_avg.pdb"
hs_file = "/Users/yeonji/Dropbox/myfolder_data/wbp_last/ahr_eq/cc/xiap_ahr_eq_cc.pdb"
lig_file = "/Users/yeonji/Dropbox/myfolder_data/Binding_Site_Reorganization/min_ahr_aligned/ahr_aligned_complex/lig/xiap_lig_min_ahr_aligned.pdb"
X, Y = prepare_sample_with_ligand(prot_file, hs_file, lig_file)
print(X.shape)
print(Y.shape)

(4, 32, 32, 32)
(32, 32, 32)


3. PyTorch Dataset Class

In [19]:
import torch
from torch.utils.data import Dataset

class HydrationDataset(Dataset):
    def __init__(self, prot_list, hs_list, lig_list):
        self.prot_list = prot_list
        self.hs_list = hs_list
        self.lig_list = lig_list

    def __len__(self):
        return len(self.prot_list)

    def __getitem__(self, idx):
        X, Y = prepare_sample_with_ligand(self.prot_list[idx], self.hs_list[idx], self.lig_list[idx])
        X = torch.tensor(X, dtype=torch.float32)
        Y = torch.tensor(Y, dtype=torch.float32).unsqueeze(0)    # (1, D, H, W)

        return X, Y

4. 3D CNN Model

In [29]:
import torch.nn as nn
class HydrationCNN(nn.Module):
    def __init__(self, in_channels=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(in_channels, 32, 3, padding=1),      # 5 -> 32 channels
            nn.ReLU(),
            nn.MaxPool3d(2),                               # downsampling 32 to 16 (voxel not channel)

            nn.Conv3d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(2),

            nn.Conv3d(64, 128, 3, padding=1),
            nn.ReLU(),

            nn.ConvTranspose3d(128, 64, 2, stride=2),      # Upsampling (8->16)
            nn.ReLU(),

            nn.ConvTranspose3d(64, 32, 2, stride=2),       # 16->32
            nn.ReLU(),

            nn.Conv3d(32, 1, 1),                           # 1 channel: possibility of HS each voxel
            nn.Sigmoid()                                   # 0-1 possibility
            
        )

    def forward(self, x):
        return self.net(x)

5. Train Loop

In [41]:
from torch.utils.data import DataLoader
device = torch.device("cpu")
def train_model(model, dataloader, epochs=10, lr=1e-3, device="cpu"):
    model = model.to(device)
    criterion = nn.BCELoss()                                 # Binary Cross Entropy, HS or not-> binary
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):
        total_loss = 0.0                                     # initialize loss as 0 each epoch
        for X, Y in dataloader:
            X = X.to(device)
            Y = Y.to(device)
            
            pred = model(X)
            loss = criterion(pred, Y)
            
            optimizer.zero_grad()                            # initialize prev gradient 
            loss.backward()                                  # calculate gradient for current loss (back propagation)
            optimizer.step()                                 # update model parameters
            
            total_loss += loss.item()
            
        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss / len(dataloader):.4f}")

In [46]:
import matplotlib.pyplot as plt
def vis_predictions(model, dataset, index=0):
    model.eval()
    X, Y = dataset[index]
    with torch.no_grad():                                    # not training, grad off
        pred = model(X.unsqueeze(0)).squeeze(0).squeeze(0).numpy()        # add bach dim (unsqueeze), squeeze & squeeze -> 3D
        true = Y.squeeze(0).numpy()
        
    z = pred.shape[2] // 2

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.title("Prediction")
    plt.imshow(pred[:, :, z], cmap='hot')
    plt.subplot(1, 2, 2)
    plt.title("Ground Truth")
    plt.imshow(true[:, :, z], cmap='hot')
    plt.show

In [25]:
import glob
prot_file = sorted(glob.glob("/Users/yeonji/Dropbox/myfolder_data/wbp_last/ahr_eq/avg/*.pdb"))
hs_file = sorted(glob.glob("/Users/yeonji/Dropbox/myfolder_data/wbp_last/ahr_eq/cc/*.pdb"))
lig_file = sorted(glob.glob("/Users/yeonji/Dropbox/myfolder_data/Binding_Site_Reorganization/min_ahr_aligned/ahr_aligned_complex/lig/*.pdb"))

In [71]:
dataset = HydrationDataset(prot_file, hs_file, lig_file)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [72]:
model = HydrationCNN(in_channels=4)

In [73]:
train_model(model, dataloader)

Epoch 1/10 | Loss: 0.2219
Epoch 2/10 | Loss: 0.0516
Epoch 3/10 | Loss: 0.0540
Epoch 4/10 | Loss: 0.0540
Epoch 5/10 | Loss: 0.0539
Epoch 6/10 | Loss: 0.0536
Epoch 7/10 | Loss: 0.0533
Epoch 8/10 | Loss: 0.0530
Epoch 9/10 | Loss: 0.0526
Epoch 10/10 | Loss: 0.0523


In [90]:
def save_prediction_as_pdb_real(pred, center, file_name, threshold=0.5, spacing=1.0):

    output_path = "/Users/yeonji/Desktop/ComputerProject/CNN_HS_Predict/3d_CNN_HS_preds/threshold0.5/"
    grid_size = np.array(pred.shape)  # (32, 32, 32)
    half_grid = grid_size / 2.0

    coords = np.argwhere(pred > threshold)  # (i, j, k) 인덱스 리스트

    with open(output_path + file_name, 'w') as f:
        for i, (x_idx, y_idx, z_idx) in enumerate(coords):
            # voxel index를 real-world coordinate로 변환
            fx = center[0] + (x_idx - half_grid[0]) * spacing
            fy = center[1] + (y_idx - half_grid[1]) * spacing
            fz = center[2] + (z_idx - half_grid[2]) * spacing

            f.write(
                f"HETATM{i:5d}  O   HOH A{i%10000:4d}    {fx:8.3f}{fy:8.3f}{fz:8.3f}  1.00 20.00           O\n"
            )
    print(f"PDB saved to {output_path} with real-world coordinates.")

In [89]:
for i, prot in enumerate(prot_file):
    prot_name = prot.split("/")[-1].split("_")[0]
    file_name = prot_name + "_hs_pred.pdb"
    for lig in lig_file:
        if prot_name in lig:
            center = get_ligand_center(lig)
            X, Y = dataset[i]
            
            model.eval()
            with torch.no_grad():
                pred = model(X.unsqueeze(0)).squeeze(0).squeeze(0).numpy()
    save_prediction_as_pdb_real(pred, center, file_name, threshold=0.1)

PDB saved to /Users/yeonji/Desktop/ComputerProject/CNN_HS_Predict/3d_CNN_HS_preds/threshold0.5/ with real-world coordinates.
PDB saved to /Users/yeonji/Desktop/ComputerProject/CNN_HS_Predict/3d_CNN_HS_preds/threshold0.5/ with real-world coordinates.
PDB saved to /Users/yeonji/Desktop/ComputerProject/CNN_HS_Predict/3d_CNN_HS_preds/threshold0.5/ with real-world coordinates.
PDB saved to /Users/yeonji/Desktop/ComputerProject/CNN_HS_Predict/3d_CNN_HS_preds/threshold0.5/ with real-world coordinates.
PDB saved to /Users/yeonji/Desktop/ComputerProject/CNN_HS_Predict/3d_CNN_HS_preds/threshold0.5/ with real-world coordinates.
PDB saved to /Users/yeonji/Desktop/ComputerProject/CNN_HS_Predict/3d_CNN_HS_preds/threshold0.5/ with real-world coordinates.
PDB saved to /Users/yeonji/Desktop/ComputerProject/CNN_HS_Predict/3d_CNN_HS_preds/threshold0.5/ with real-world coordinates.
PDB saved to /Users/yeonji/Desktop/ComputerProject/CNN_HS_Predict/3d_CNN_HS_preds/threshold0.5/ with real-world coordinates.
