In [1]:
import math
import json
import os
import pickle
import copy
import einops

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.modules.container import ModuleList
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import (GATConv,
                                SAGPooling,
                                LayerNorm,
                                global_mean_pool,
                                max_pool_neighbor_x,
                                global_add_pool)
import pandas as pd
import numpy as np

import pytorch_lightning as pl
# from params import N_CHEM_NODE_FEAT, N_CHEM_EDGE_FEAT, N_PROT_EDGE_FEAT, N_PROT_NODE_FEAT
from sklearn.utils import shuffle
from dataset import pygDataSet

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

from metrics import get_metrics_reg
from dataset import testFpDataModule

import json
import pickle

Using backend: pytorch


## model

In [2]:
class SSI_DDI_Block(nn.Module):
    def __init__(self, n_heads, in_features, head_out_feats, ifConv1d = False, window_size = 5, dropout_1 = 0.4, dropout_2 = 0.1, layernorm = True):
        super().__init__()
        self.n_heads = n_heads
        self.in_features = in_features
        self.out_features = head_out_feats
        
        
        self.linear = nn.Linear(in_features, n_heads * head_out_feats)
        self.ifConv1d = ifConv1d
        if self.ifConv1d:
            self.conv1d = nn.Conv1d(in_features, in_features, kernel_size=window_size, padding=window_size//2)
        self.conv = GATConv(in_features, head_out_feats, n_heads, dropout = dropout_1)
        self.readout = SAGPooling(n_heads * head_out_feats, min_score=-1)
        if layernorm:
            self.norm = LayerNorm(n_heads * head_out_feats)
        else:
            self.norm = BatchNorm(n_heads * head_out_feats)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_2)
    def forward(self, data):
        res_x =  self.linear(data.x)
        if self.ifConv1d:
            data.x = self.conv1d(data.x.t().unsqueeze(0)).squeeze(0).t()
        data.x, attention_weights = self.conv(data.x, data.edge_index,return_attention_weights=True)
        att_x, att_edge_index, att_edge_attr, att_batch, att_perm, att_scores= self.readout(data.x, data.edge_index, batch=data.batch)
        global_graph_emb = global_add_pool(att_x, att_batch)
                
        data.x = data.x + res_x
        data.x = self.relu(self.norm(data.x))
        data.x = self.dropout(data.x)
        return data, global_graph_emb, attention_weights, att_scores

In [3]:
class CoAttentionLayer(nn.Module):
    def __init__(self, n_features, n1, n2, n3, dropout=0.5):
        super().__init__()
        self.h_dim = n_features

        self.mlp = nn.Sequential(nn.Linear(self.h_dim * 3, 1024),
                                 nn.BatchNorm3d(n1),
                                 nn.ReLU(),
                                 nn.Dropout(dropout),
                                 nn.Linear(1024, 512),
                                 nn.BatchNorm3d(n1),
                                 nn.ReLU(),
                                 nn.Dropout(dropout),
                                 nn.Linear(512, 256),
                                 nn.BatchNorm3d(n1),
                                 nn.ReLU(),
                                 nn.Dropout(dropout),
                                 nn.Linear(256, 1)
                                )
        
    
    def forward(self, v1, v2, v3):
        batch_size = v1.shape[0]
        c1 = v1.shape[1]
        c2 = v2.shape[1]
        c3 = v3.shape[1]

        
        e_activations = torch.cat(
            [einops.repeat(v1, 'b c1 h -> b c1 c2 c3 h', c2=c2, c3=c3),
            einops.repeat(v2, 'b c2 h -> b c1 c2 c3 h', c1=c1, c3=c3),
            einops.repeat(v3, 'b c3 h -> b c1 c2 c3 h', c1=c1, c2=c2)],
            dim = -1)
        y = self.mlp(e_activations).squeeze(-1)
        return y

In [4]:
class DTAProtGraphChemGraph(torch.nn.Module):
    def __init__(self, **param_dict):
        super().__init__()
        
        self.chem_initial_norm = LayerNorm(param_dict["chem_in_features"])
        self.prot_initial_norm = LayerNorm(param_dict["prot_in_features"])
        
        self.chem_blocks = ModuleList()
        
        chem_in_features = param_dict["chem_in_features"]
        prot_in_features = param_dict["prot_in_features"]
        

        for i, (head_out_feats, n_heads) in enumerate(zip(param_dict["chem_heads_out_feat_params"], param_dict["chem_blocks_params"])):
            block = SSI_DDI_Block(n_heads, chem_in_features, head_out_feats, dropout_1=param_dict["dropout_1"], dropout_2=param_dict["dropout_2"])
            self.add_module(f"block{i}", block)
            self.chem_blocks.append(block)
            chem_in_features = head_out_feats * n_heads
            
        self.fp_linear = nn.Sequential(
            nn.Linear(2048, chem_in_features),
            nn.BatchNorm1d(chem_in_features),
            nn.ReLU()
        )
        
        
        self.prot_blocks =  ModuleList()
#         self.prot_net_norms = ModuleList()
        for i, (head_out_feats, n_heads, windows_size) in enumerate(zip(param_dict["prot_heads_out_feat_params"], param_dict["prot_blocks_params"],  param_dict["prot_windows_params"] )):
            block = SSI_DDI_Block(n_heads, prot_in_features, head_out_feats, True, windows_size, dropout_1=param_dict["dropout_1"], dropout_2=param_dict["dropout_2"])
            self.add_module(f"block{i}", block)
            self.prot_blocks.append(block)
            prot_in_features = head_out_feats * n_heads
#             self.prot_net_norms.append(LayerNorm(prot_in_features))
            
            
        self.co_attention = CoAttentionLayer(prot_in_features, len(param_dict["chem_blocks_params"]), len(param_dict["prot_blocks_params"]), 1, dropout=param_dict["dropout_3"])

        
        self.rel = nn.Parameter(torch.ones(len(param_dict["chem_blocks_params"]), len(param_dict["prot_blocks_params"]), 1)/(len(param_dict["chem_blocks_params"])* len(param_dict["prot_blocks_params"])))

       

        
    def forward(self, chem_fp, chem_graph, prot_graph):

       
        chem_graph.x = self.chem_initial_norm(chem_graph.x)
        prot_graph.x = self.prot_initial_norm(prot_graph.x)

        repr_fp = self.fp_linear(chem_fp).unsqueeze(dim = 1)
        
        repr_chem = []
        repr_prot = []

        # new_dict = {}
#             new_dict['edge_index'] = attention_weights[0].numpy().tolist()
#             new_dict['attention_weights'] = attention_weights[1].numpy().tolist()

#             new_dict['att_scores'] = att_scores.numpy().tolist()
#             json.dump(new_dict,open("output.txt","a+"))    # 将数据写入json文件中
        for i, block in enumerate(self.chem_blocks):
            chem_graph, r_chem, attention_weights, att_scores = block(chem_graph)
            repr_chem.append(r_chem)
#             chem_graph.x = F.elu(self.chem_net_norms[i](chem_graph.x))
#             subdict = {}
#             subdict['edge_index'] = attention_weights[0]
#             subdict['attention_weights'] = attention_weights[1]
#             subdict['att_scores'] = att_scores
#             new_dict[i] = subdict
#         global cIndex
#         pickle.dump(new_dict, open(f"output_data/{str(cIndex)}.pkl","wb"))  
#         cIndex += 1
        repr_chem = torch.stack(repr_chem, dim=-2)
            
        
        for i, block in enumerate(self.prot_blocks):
            prot_graph, r_prot, attention_weights, att_scores = block(prot_graph)
            repr_prot.append(r_prot)
#             prot_graph.x = F.elu(self.prot_net_norms[i](prot_graph.x))
            
        repr_prot = torch.stack(repr_prot, dim=-2)
    
        
        fusion = self.co_attention(repr_chem, repr_prot, repr_fp)
        y = fusion * self.rel
        y= y.sum(dim=(-1, -2, -3))

        return y


In [5]:

# res = []
class testModel(pl.LightningModule):
    def __init__(self, **param_dict): 
        super().__init__()
        self.save_hyperparameters(param_dict)
        
        self.model = DTAProtGraphChemGraph(**param_dict)
        self.criterion = param_dict["criterion"]
        self.lr = param_dict["lr"]
        self.batch_size = param_dict["batch_size"]
    def forward(self,  chem_fp, ligand_graph, protein_graph):
#         print(ligand_graph.device)
        pred_y = self.model(chem_fp, ligand_graph, protein_graph)
    
        return pred_y
    
    def training_step(self, batch, batch_idx):
        y, chem_fp, ligand_graph, protein_graph= batch
        pred_y= self(chem_fp, ligand_graph, protein_graph)
        
        loss = self.criterion(y, pred_y) 
        
        np_pred_y = pred_y.detach().cpu().numpy()
        np_y = y.detach().cpu().numpy()
        
        metrics = get_metrics_reg(np_y, np_pred_y, "train")
        
        metrics["loss"] = loss
        
        self.log_dict(metrics, batch_size=self.batch_size)
        
        return metrics
    
    
    def validation_step(self, batch, batch_idx):
            
        y, chem_fp, ligand_graph, protein_graph = batch
        pred_y = self(chem_fp, ligand_graph, protein_graph)
        
        loss = self.criterion(y, pred_y) 

        np_pred_y = pred_y.detach().cpu().numpy()
        np_y = y.detach().cpu().numpy()

        metrics = get_metrics_reg(np_y, np_pred_y, "valid", with_rm2=True, with_ci=True)

        metrics["valid_loss"] = loss
        
        self.log_dict(metrics, prog_bar =True, batch_size=self.batch_size)

        return metrics
    
    def test_step(self, batch, batch_idx):
            
        y, chem_fp, ligand_graph, protein_graph = batch
        pred_y = self(chem_fp, ligand_graph, protein_graph)

#         loss = self.criterion(y, pred_y)

        np_pred_y = pred_y.detach().cpu().numpy()
        np_y = y.detach().cpu().numpy()
#         print(np_pred_y - np_y)
        global res
        res = res + (np_pred_y - np_y).tolist()
#         json.dump((np_pred_y - np_y).tolist(), open("res.txt", "a"))
#         metrics = get_metrics_reg(np_y, np_pred_y, "test", with_rm2=True, with_ci=True)

#         metrics["valid_loss"] = loss
#         print(metrics)
        
#         self.log_dict(metrics, prog_bar =True, batch_size=self.batch_size)

#         return metrics
    
      
    def configure_optimizers(self):
        #         weight_decay=
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

## 参数

In [6]:
param_dict={
    "chem_in_features": 23,
    "prot_in_features": 41,
  
    "hidden_dim" :256,
     "chem_heads_out_feat_params": [32, 32, 32, 32, 32, 32], 
    "chem_blocks_params": [8, 8, 8, 8, 8, 8],
    "dropout_1":0.4,
    "dropout_2":0.1,
    "dropout_2":0.5,
    "prot_heads_out_feat_params": [32, 32, 32, 32], 
    "prot_blocks_params": [8, 8, 8, 8],
    "prot_windows_params": [7, 7, 7, 7],
    "batch_size": 512,
    "lr": 5e-4,
    "dataset_name": "kiba_cold_drug",
    "criterion": nn.MSELoss(),
    "model_name":"ssr-dta",
#     "T_max": 2000
}


In [7]:
dirpath = os.getcwd()+'/lightning_logs/checkpoints'
checkpoint_callback = ModelCheckpoint(
    monitor ='valid_mse',
    dirpath = dirpath,
    filename = '-{epoch:03d}-{valid_mse:.4f}--{valid_ci:.4f}',
    save_top_k=1,
    mode= 'min',
    save_last=True,
)
# c_path = "./lightning_logs/version_32/checkpoints/epoch=36-step=6845.ckpt"

In [8]:
trainer = pl.Trainer(accelerator="gpu", devices=[0], max_epochs=2000, check_val_every_n_epoch = 1)


trainer.callbacks.append(checkpoint_callback)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
model = testModel(**param_dict)
dm =  testFpDataModule(**param_dict)

trainer.fit(model=model, datamodule=dm)
# json.dump(res, open("res.txt", "w"))
# trainer.fit(model=model, datamodule=dm)

  rank_zero_warn(
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                  | Params
----------------------------------------------------
0 | model     | DTAProtGraphChemGraph | 4.5 M 
1 | criterion | MSELoss               | 0     
----------------------------------------------------
4.5 M     Trainable params
0         Non-trainable params
4.5 M     Total params
17.839    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

### 