In [3]:
# GCN model
# 환경설정을 PIP 기반으로 간략하게 변경하였음

import os
import sys
import torch

# Check the compatibility of the python and torch version (Comments are wriiten @2024.08.30)
print("Python version: {}".format(sys.version))                    # 3.10.12
print("PyTorch version:{}".format(torch.__version__))              # 2.4.0+cu124
print("cuda version:   {}".format(torch.version.cuda))             # 12.1
print("cudnn version:  {}".format(torch.backends.cudnn.version())) # 8700

# Install required python libraries
!pip install -q pyg_lib -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install -q torch_scatter -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install -q torch_sparse  -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install -q torch_cluster -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install -q torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install -q torch-geometric
!pip install -q e3nn vapeplot
!pip install rdkit

# Check the compatibilites of the torch-geometric
import torch_geometric as pyg
print("torch_geometric version:{}".format(pyg.__version__))        # 2.4.0

Python version: 3.9.23 (main, Jun  5 2025, 13:40:20) 
[GCC 11.2.0]
PyTorch version:2.8.0+cu128
cuda version:   12.8
cudnn version:  91002
[33m  DEPRECATION: Building 'vapeplot' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'vapeplot'. Discussion can be found at https://github.com/pypa/pip/issues/6334[0m[33m
[0mCollecting rdkit
  Downloading rdkit-2025.3.5-cp39-cp39-manylinux_2_28_x86_64.whl.metadata (4.1 kB)
Downloading rdkit-2025.3.5-cp39-cp39-manylinux_2_28_x86_64.whl (36.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.3/36.3 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-202

In [5]:
import os
import numpy as np
import pandas as pd
import json,pickle
import networkx as nx
from math import sqrt
from random import shuffle
from collections import OrderedDict
from scipy import stats
from IPython.display import SVG
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU

In [7]:
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import MolFromSmiles
from torch_geometric import data as DATA
from torch_geometric.data import InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_max_pool as gmp
from torch_geometric.nn import GCNConv, GATConv, GINConv, global_add_pool
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

# 시각화 라이브러리
from matplotlib import pyplot as plt
import seaborn as sns
%matplotlib inline

#### 원자 특성 인코딩

In [8]:
def feature_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input{0} not allowed in set {1}:".format(x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))

In [9]:
def feature_encoding_unk(x, allowable_set):
    # allowable set에 있지 않으면 마지막 요소로 매핑
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

In [10]:
def atom_features(atom):
    return np.array(feature_encoding_unk(atom.GetSymbol(),['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb','Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr','Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) +
                    feature_encoding(atom.GetDegree(), [0,1,2,3,4,5,6,7,8,9,10]) +
                    feature_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5,6,7,8,9,10]) +
                    feature_encoding_unk(atom.GetTotalNumHs(),[0,1,2,3,4,5,6,7,8,9,10]) +
                    feature_encoding_unk(atom.GetImplicitValence(),[0,1,2,3,4,5,6,7,8,9,10]) +
                    [atom.GetIsAromatic()]
                    )

#### SMILES to Graph

In [None]:
# returns: 원자 개수, 원자 특성 행렬, 인접 행렬
def smiles_to_graph(smiles):
    # 문자열 -> 그래프
    mol = Chem.MolFromSmiles(smiles)

    # 원자 개수 저장
    c_size = mol.GetNumAtoms()
    features = []

    for atom in mol.GetAtoms():
        feature = atom_features(atom)
        features.append(feature/sum(feature)) # 정규화

    # 엣지 - 시작 원자 정보와 끝 원자 정보
    edges = []
    for bond in mol.GetBonds():
        edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])

    # nx라이브러리를 이용해 데이터를 방향 그래프로 변환
    g = nx.Graph(edges).to_directed()
    edge_index = []
    for e1, e2 in g.edges:
        edge_index.append([e1, e2])

    return c_size, features, edge_index

#### Protein Representation

In [12]:
# 표적 염기서열을 이루는 알파벳(25자) vocabulary
seq_voc = "ABCDEFGHIKLMNOPQRSTUVWXYZ"

# 정수로 매핑
seq_dict = {v:{i+1} for i,v in enumerate(seq_voc)}
seq_dict_len = len(seq_dict)

# padding
max_seq_len = 1000

In [13]:
# protein representation
def seq_cat(prot):
    # 0행렬 생성
    x = np.zeros(max_seq_len)
    for i, ch in enumerate(prot[:max_seq_len]):
        x[i] = seq_dict[ch]
    return x

In [None]:
# # datasets
# all_prots = []
# datasets = ['kiba']

In [15]:
class ProteinESMLinear(nn.Module):
    def __init__(self, in_dim=1280, out=256, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim,1024),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(1024, out)
        )
    def forward(self, esm_repr): # [B, 1280] 형태로 넣어주기 (mean-pooled)
        return self.net(esm_repr)

In [None]:
# class DTIDataset(InMemoryDataset):
#     def __init__(self, csv_path, esm, transform = None, pre_transform=None):
#         super().__init__('.', transform, pre_transform)
#         self.df = pd.read_csv(csv_path)
#         if isinstance(esm, str):
#             if esm.endswith('.npy'):
#                 self.esm = np.load(esm, allow_pickle=True).item()

In [None]:
class GCN(torch.nn.Module):
    def __init__(self, n_output = 1, n_filters=32, embed_dim=128, num_features_xd=78, num_features_xt = 25, output_dim=128, dropout=0.2, esm_in_dim=1280):
        super(GCN, self).__init__()
        self.n_output = n_output # 모델의 출력은 숫자 1개

        # Drug Representation
        self.conv1 = GCNConv(num_features_xd, num_features_xd)
        self.conv2 = GCNConv(num_features_xd, num_features_xd*2)
        self.conv3 = GCNConv(num_features_xd*2, num_features_xd*4)

        # fully connected layer - 1024차원으로 변환
        self.fc_g1 = torch.nn.Linear(num_features_xd*4, 1024)
        self.fc_g2 = torch.nn.Linear(1024, output_dim)

        # activation function
        self.relu = nn.ReLU()
        #Dropout
        self.dropout = nn.Dropout(dropout)

        # Protein Representation(ESM -> MLP proj)
        self.protein_proj = nn.Sequential(
            nn.Linear(esm_in_dim, 1024),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(1024, output_dim)
        )

        # Drug + Protein Representation fusion
        self.fc1 = nn.Linear(2*output_dim, 1024)
        self.fc2 = nn.Linear(1024,512)

        self.out = nn.Linear(512, self.n_output)
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        # target = data.target

        # GCN Layer
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        
        x = self.conv2(x, edge_index)
        x = self.relu(x)

        x = self.conv3(x, edge_index)
        x = self.relu(x)

        x = gmp(x, batch)

        x = self.relu(self.fc_g1(x))
        x = self.dropout(x)
        x = self.fc_g2(x)
        x = self.dropout(x)

        xt = data.protein_esm.float()
        xt = self.protein_proj(xt)

        # Drug, Protein Representation을 torch.cat을 이용해 하나로 결합
        xc = torch.cat((x, xt), 1)

        # 하나의 출력값 구하기
        xc = self.fc1(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        xc = self.fc2(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        out = self.out(xc)
        return out

In [None]:
# def make_data_item(graph_x, edge_index, y, esm_vec):
#     d = Data(
#         x=torch.tensor(graph_x, dtype=torch.float),
#         edge_index=torch.tensor(edge_index, dtype=torch.long),
#         y=torch.tensor(y, dtype=torch.float)
#     )
#     d.protein_esm = torch.tensor(esm_vec, dtype=torch.float)
#     return d

In [None]:
def train(model, device, train_loader, optimizer, epoch):
    print('Training on {} samples...'.format(len(train_loader.dataset)))
    model.train()

    for batch_idx, data in enumerate(train_loader):
        data = data.to(device)

        optimizer.zero_grad()
        output = model(data)

        loss = loss_fn(output, data.y.view(-1,1).float().to(device))

        loss.backward()
        optimizer.step()

        if batch_idx%LOG_INTERVAL == 0:
            print('Train epoch: {}[{}/{} ({:.0f}%)]|tLoss: {:.6f}'.format(epoch,
                                                                          batch_idx*len(data.x),
                                                                          len(train_loader.dataset),
                                                                          100.*batch_idx/len(train_loader),
                                                                          loss.item()))

In [None]:
def predicting(model, device, loader):
    model.eval()
    total_preds = torch.Tensor()

    print('Make prediction for {} samples...'.format(len(loader.dataset)))
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            output = model(data)
            total_preds = torch.cat((total_preds, output.cpu()), 0)
            total_labels = torch.cat((total_labels, data.y.view(-1, 1).cpu()), 0)
    return total_labels.numpy().flatten(), total_preds.numpy().flatten()