### 导入一些包以及系统路径

In [1]:
import time

import numpy as np

import random

from torch.utils.data import DataLoader

from tqdm import tqdm
import os

#### Local import
import sys
### 这样的操作通常用于解决模块导入的问题，特别是当你的代码和模块位于不同的目录中时
sys.path.append('../../src')
from utils.dataset import PDB_complex_training

### 定义训练集和测试集并且定义一些常数

In [18]:

DIM=16

DATA_DIR = os.getcwd()+'/data_preparation/'
print(DATA_DIR)

# 目前训练只采用训练集
TRAIN_LIST_FILE = '../data/lists/final_training.txt'
VAL_LIST_FILE = '../data/lists/validation.txt'

PATIENCE=10 # number of consequitive times when model don't improve before early stopping
SEED_ID = 7272
BATCH_SIZE=1
# 最大的训练次数
MAX_EPOCH = 200

# 加载piston的训练好的数据来最为我们模型训练的初始化参数
MODEL_NAME=f'PIsToN_multiAttn_contrast'
MODEL_DIR='./saved_models'
IMG_SIZE=64
##创建自己的参数模型
ANTIGEN_MODEL = f'antigen'
ANTIBODY_MODEL = f'antibody'

/mnt/Data1/23wxy/piston/training_example/data_preparation/


### 定义config配置文件

In [3]:
import os

config = {}

config['dirs'] = {}
config['dirs']['data_prepare'] = DATA_DIR
config['dirs']['grid'] = config['dirs']['data_prepare'] + '07-grid'
config['dirs']['docked'] = config['dirs']['data_prepare'] + 'docked/'
config['dirs']['tmp'] = '/aul/homes/vsteb002/tmp'

config['ppi_const'] = {}
config['ppi_const']['patch_r'] = 32 # 16

os.environ["TMP"] = config['dirs']['tmp']
os.environ["TMPDIR"] = config['dirs']['tmp']
os.environ["TEMP"] = config['dirs']['tmp']
print(config['dirs']['grid'])

/mnt/Data1/23wxy/piston/training_example/data_preparation/07-grid


In [4]:
train_list = [x.strip('\n') for x in open(TRAIN_LIST_FILE, 'r').readlines()]
for list in train_list:
    print(list)
    antigen = list.split(',')[0]
    antibody = list.split(',')[1]
    print(antigen)
    antigen_grid_path = f"{config['dirs']['grid']}/training/{antigen}.npy"
    antibody_grid_path = f"{config['dirs']['grid']}/training/{antibody}.npy"
    test = np.load(antibody_grid_path,allow_pickle=True)
    
print("loading validation lists")
val_list = [x.strip('\n') for x in open(VAL_LIST_FILE,'r').readlines()]
for val in val_list:
    print(val)
    antigen = val.split(',')[0]
    antibody = val.split(',')[1]
    print(antigen)
    antigen_grid_path = f"{config['dirs']['grid']}/validation/{antigen}.npy"
    antibody_grid_path = f"{config['dirs']['grid']}/validation/{antibody}.npy"
    val = np.load(antibody_grid_path,allow_pickle=True)

1a2y_C,1a2y_B
1a2y_C
1bgx_T,1bgx_L
1bgx_T
1e6j_P,1e6j_H
1e6j_P
1egj_A,1egj_H
1egj_A
1fj1_E,1fj1_A
1fj1_E
1fns_A,1fns_L
1fns_A
1fsk_A,1fsk_C
1fsk_A
1h0d_C,1h0d_B
1h0d_C
1iqd_C,1iqd_B
1iqd_C
1jrh_I,1jrh_H
1jrh_I
1kb5_A,1kb5_H
1kb5_A
1lk3_A,1lk3_L
1lk3_A
1mhp_B,1mhp_L
1mhp_B
1nca_N,1nca_H
1nca_N
1nfd_A,1nfd_E
1nfd_A
1nsn_S,1nsn_L
1nsn_S
1oaz_B,1oaz_H
1oaz_B
1ors_C,1ors_B
1ors_C
1SY6_A,1SY6_H
1SY6_A
1UJ3_C,1UJ3_B
1UJ3_C
1YJD_C,1YJD_H
1YJD_C
1YY9_A,1YY9_D
1YY9_A
2BDN_A,2BDN_H
2BDN_A
2H32_B,2H32_H
2H32_B
2OZ4_A,2OZ4_H
2OZ4_A
2Q8A_A,2Q8A_H
2Q8A_A
2QQK_A,2QQK_H
2QQK_A
2UZI_R,2UZI_H
2UZI_R
2XRA_A,2XRA_H
2XRA_A
3GI9_C,3GI9_H
3GI9_C
3GRW_A,3GRW_H
3GRW_A
3KR3_D,3KR3_H
3KR3_D
3LD8_A,3LD8_C
3LD8_A
3LEV_A,3LEV_H
3LEV_A
3MJ9_A,3MJ9_H
3MJ9_A
3MXW_A,3MXW_H
3MXW_A
3O2D_A,3O2D_H
3O2D_A
loading validation lists
3R1G_B,3R1G_H
3R1G_B
3RU8_X,3RU8_H
3RU8_X
3S35_X,3S35_H
3S35_X
3SOB_B,3SOB_H
3SOB_B
3TJE_F,3TJE_H
3TJE_F
3VG9_A,3VG9_C
3VG9_A
3W9E_C,3W9E_A
3W9E_C
4EDW_V,4EDW_H
4EDW_V
4F3F_C,4F3F_B
4F3F_C
4FQJ_A,4F

### 统计有多少复合物被处理了

In [5]:
from utils.utils import get_processed
# 目前阶段只采用训练集的数据
train_list = [x.strip('\n') for x in open(TRAIN_LIST_FILE, 'r').readlines()]
print(train_list)
val_list = [x.strip('\n') for x in open(VAL_LIST_FILE, 'r').readlines()]

train_list_updated = get_processed(train_list, config)
val_list_updated = get_processed(val_list, config)

print(f"{len(train_list_updated)}/{len(train_list)} training complexes were processed for 12A")
print(f"{len(val_list_updated)}/{len(val_list)} validation complexes were processed for 12A")

['1a2y_C,1a2y_B', '1bgx_T,1bgx_L', '1e6j_P,1e6j_H', '1egj_A,1egj_H', '1fj1_E,1fj1_A', '1fns_A,1fns_L', '1fsk_A,1fsk_C', '1h0d_C,1h0d_B', '1iqd_C,1iqd_B', '1jrh_I,1jrh_H', '1kb5_A,1kb5_H', '1lk3_A,1lk3_L', '1mhp_B,1mhp_L', '1nca_N,1nca_H', '1nfd_A,1nfd_E', '1nsn_S,1nsn_L', '1oaz_B,1oaz_H', '1ors_C,1ors_B', '1SY6_A,1SY6_H', '1UJ3_C,1UJ3_B', '1YJD_C,1YJD_H', '1YY9_A,1YY9_D', '2BDN_A,2BDN_H', '2H32_B,2H32_H', '2OZ4_A,2OZ4_H', '2Q8A_A,2Q8A_H', '2QQK_A,2QQK_H', '2UZI_R,2UZI_H', '2XRA_A,2XRA_H', '3GI9_C,3GI9_H', '3GRW_A,3GRW_H', '3KR3_D,3KR3_H', '3LD8_A,3LD8_C', '3LEV_A,3LEV_H', '3MJ9_A,3MJ9_H', '3MXW_A,3MXW_H', '3O2D_A,3O2D_H']
37/37 training complexes were processed for 12A
14/14 validation complexes were processed for 12A


### 学习图像标准缩放的平均值和标准差

In [6]:
## get all antigen and antibody lists
grid_antigen_list = []
grid_antibody_list = []
for ppi in train_list:
     antigen = ppi.split(',')[0]
     antibody = ppi.split(',')[1]
     antigen_grid_path = f"{config['dirs']['grid']}/training/{antigen}.npy"
     antibody_grid_path = f"{config['dirs']['grid']}/training/{antibody}.npy"
     # if os.path.exists(antigen_grid_path ) and os.path.exists(antibody_grid_path):
     grid_antigen_list.append(np.load(antigen_grid_path,allow_pickle=True))
     grid_antibody_list.append(np.load(antibody_grid_path,allow_pickle=True))

print(f"Loaded {len(grid_antigen_list)} antigen complexes ")
print(f"Loaded {len(grid_antibody_list)} antibody complexes ")
# n个抗原，每个抗原的形状是(n,32,32,7)，抗体同理
antigen_all_grid = np.stack(grid_antigen_list,axis=0)
antibody_all_grid = np.stack(grid_antibody_list,axis=0)
radius = config['ppi_const']['patch_r']
antigen_std_array = np.ones(7)
antigen_mean_array = np.zeros(7)
antibody_std_array = np.ones(7)
antibody_mean_array = np.zeros(7)
antigen_feature_pairs = {
    'shape_index': (0,),
    'ddc': (1,),
    'electrostatics':(2,),
    'charge': (3,),
    'hydrophobicity': (4,),
    'patch_dist':(5,),
    'SASA': (6,)
    }
antibody_feature_pairs = {
    'shape_index': (0,),
    'ddc': (1,),
    'electrostatics':(2,),
    'charge': (3,),
    'hydrophobicity': (4,),
    'patch_dist':(5,),
    'SASA': (6,)
}

Loaded 37 antigen complexes 
Loaded 37 antibody complexes 


In [7]:
## compute mean and std values of antigen 
for feature in antigen_feature_pairs.keys():
    print(f"Obtaining pixel values for {feature}")
    pixel_values = []
    for feature_i in antigen_feature_pairs[feature]:
        print(f"Index {feature_i}")
        for image_i in tqdm(range(antigen_all_grid.shape[0])):
            for row_i in range(antigen_all_grid.shape[1]):
                for column_i in range(antigen_all_grid.shape[2]):
                    # Check if coordinates are within the radius
                    x = column_i - radius
                    y = radius - row_i
                    if x**2 + y**2 < radius**2:
                        pixel_values.append(antigen_all_grid[image_i][row_i][column_i][feature_i])

    antigen_mean_value = np.mean(pixel_values)
    antigen_std_value = np.std(pixel_values)
    print(f"antigen  : Feature {feature}; Mean: {antigen_mean_value}; std: {antigen_std_value}")
    for feature_i in antigen_feature_pairs[feature]:
        antigen_mean_array[feature_i] = antigen_mean_value
        antigen_std_array[feature_i] = antigen_std_value
    
## compute mean and std values of antibody 
for feature in antibody_feature_pairs.keys():
    print(f"Obtaining pixel values for {feature}")
    pixel_values = []
    for feature_i in antibody_feature_pairs[feature]:
        print(f"Index {feature_i}")
        for image_i in tqdm(range(antibody_all_grid.shape[0])):
            for row_i in range(antibody_all_grid.shape[1]):
                for column_i in range(antibody_all_grid.shape[2]):
                    # Check if coordinates are within the radius
                    x = column_i - radius
                    y = radius - row_i
                    if x**2 + y**2 < radius**2:
                        pixel_values.append(antibody_all_grid[image_i][row_i][column_i][feature_i])

    antibody_mean_value = np.mean(pixel_values)
    antibody_std_value = np.std(pixel_values)
    print(f"antibody : Feature {feature}; Mean: {antibody_mean_value}; std: {antibody_std_value}")
    for feature_i in antigen_feature_pairs[feature]:
        antibody_mean_array[feature_i] = antibody_mean_value
        antibody_std_array[feature_i] = antibody_std_value

Obtaining pixel values for shape_index
Index 0


100%|██████████| 37/37 [00:00<00:00, 97.96it/s] 


antigen  : Feature shape_index; Mean: 0.08677603366047008; std: 0.4693663624419205
Obtaining pixel values for ddc
Index 1


100%|██████████| 37/37 [00:00<00:00, 115.52it/s]


antigen  : Feature ddc; Mean: 0.018924626890510887; std: 0.09985762926227935
Obtaining pixel values for electrostatics
Index 2


100%|██████████| 37/37 [00:00<00:00, 97.69it/s] 


antigen  : Feature electrostatics; Mean: -0.025685580756326612; std: 0.186244606531557
Obtaining pixel values for charge
Index 3


100%|██████████| 37/37 [00:00<00:00, 114.50it/s]


antigen  : Feature charge; Mean: -0.013862121896561776; std: 0.24762478927377363
Obtaining pixel values for hydrophobicity
Index 4


100%|██████████| 37/37 [00:00<00:00, 97.67it/s] 


antigen  : Feature hydrophobicity; Mean: -0.33676715630053944; std: 0.5386591687709986
Obtaining pixel values for patch_dist
Index 5


100%|██████████| 37/37 [00:00<00:00, 116.15it/s]


antigen  : Feature patch_dist; Mean: 55.04107114901713; std: 18.152144856456268
Obtaining pixel values for SASA
Index 6


100%|██████████| 37/37 [00:00<00:00, 96.61it/s] 


antigen  : Feature SASA; Mean: 0.4223890650525582; std: 0.2573474004041363
Obtaining pixel values for shape_index
Index 0


100%|██████████| 37/37 [00:00<00:00, 114.88it/s]


antibody : Feature shape_index; Mean: 0.12825944028080027; std: 0.4367312007653308
Obtaining pixel values for ddc
Index 1


100%|██████████| 37/37 [00:00<00:00, 97.01it/s] 


antibody : Feature ddc; Mean: 0.040910091799603164; std: 0.08927263061103818
Obtaining pixel values for electrostatics
Index 2


100%|██████████| 37/37 [00:00<00:00, 111.90it/s]


antibody : Feature electrostatics; Mean: -0.023686035975559825; std: 0.18450419334408213
Obtaining pixel values for charge
Index 3


100%|██████████| 37/37 [00:00<00:00, 97.82it/s] 


antibody : Feature charge; Mean: 0.029256066937372913; std: 0.1969441834975774
Obtaining pixel values for hydrophobicity
Index 4


100%|██████████| 37/37 [00:00<00:00, 116.06it/s]


antibody : Feature hydrophobicity; Mean: -0.2552092752213807; std: 0.46755264919753026
Obtaining pixel values for patch_dist
Index 5


100%|██████████| 37/37 [00:00<00:00, 96.95it/s] 


antibody : Feature patch_dist; Mean: 55.04107114901713; std: 18.152144856456268
Obtaining pixel values for SASA
Index 6


100%|██████████| 37/37 [00:00<00:00, 115.94it/s]

antibody : Feature SASA; Mean: 0.42173882297545834; std: 0.2668331917681385





In [8]:
print("Antigen Mean array:")
print(antigen_mean_array)
print("")
print("Antigen Standard deviation array:")
print(antigen_std_array)
print("Antibody Mean array:")
print(antibody_mean_array)
print("")
print("Antibody Standard deviation array:")
print(antibody_std_array)

Antigen Mean array:
[ 8.67760337e-02  1.89246269e-02 -2.56855808e-02 -1.38621219e-02
 -3.36767156e-01  5.50410711e+01  4.22389065e-01]

Antigen Standard deviation array:
[ 0.46936636  0.09985763  0.18624461  0.24762479  0.53865917 18.15214486
  0.2573474 ]
Antibody Mean array:
[ 1.28259440e-01  4.09100918e-02 -2.36860360e-02  2.92560669e-02
 -2.55209275e-01  5.50410711e+01  4.21738823e-01]

Antibody Standard deviation array:
[ 0.4367312   0.08927263  0.18450419  0.19694418  0.46755265 18.15214486
  0.26683319]


## apply module

In [22]:
import torch
from torch import nn
from networks.ViT_pytorch import Encoder
from networks.ViT_hybrid import ViT_Hybrid_encoder

class PISToN_proto(nn.Module):
    
    #定义模型，从父类方法中继承
    def __init__(self,config,img_size=24,zero_head=False):
        #构造调用模型的父类
        super(PISToN_proto, self).__init__()
        self.index_dict = {
            'all features' : (0,1,2,3,4,5,6)
        }
        
        self.img_size = img_size
        self.zero_head = zero_head
        ### 创建一个ModuleList，将多个模块连接起来，方便管理
        self.spatial_transformers_list = nn.ModuleList()
        for feature in self.index_dict.keys():
            self.spatial_transformers_list.append(self.init_transformer(config, channels=len(self.index_dict[feature])))
        
        self.feature_transformer = Encoder(config, vis=True) 
        
        
        
        
    def init_transformer(self, config, channels):
        """
        Initialize Transformer Network for a given tupe of features
        :param model_config:
        :param channels:
        :param n_individual:
        :return:
        """
        return ViT_Hybrid_encoder(config, img_size=self.img_size, channels=channels, vis=True)
        
    def forward(self,img):
        
         for i, feature in enumerate(self.index_dict.keys()):
             # 从图片中得到我们所提取的特征
             img_tmp = img[:,self.index_dict[feature],:,:]
             x, attn = self.spatial_transformers_list[i](img_tmp)
        
         # 把所有的特征综合起来（其实这里只有一个，因为我们只有一组）
         all_x = x
         all_attn = attn
         x = nn.functional.normalize(all_x)
         
         return x

In [23]:
### 定义模型的参数
params = {'dim_head': DIM,
          'hidden_size': DIM,
          'dropout': 0,
          'attn_dropout': 0,
          'lr': 0.0001,
          'n_heads': 8,
          'neg_pos_ratio': 5,
          'patch_size': 4,
          'transformer_depth': 8,
          'weight_decay': 0.0001,
          }
os.makedirs(MODEL_DIR, exist_ok=True)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [21]:
"this code meant to compute ground_truth matrix. When using ,just load the matrix with ppi "

import numpy as np

## compute patch
def split_into_patches(array, patch_size):
    patches = []
    for i in range(0, array.shape[0], patch_size):
        for j in range(0, array.shape[1], patch_size):
            patch = array[i:i+patch_size, j:j+patch_size]
            patches.append(patch.reshape(-1, 3))
    return np.array(patches)  # 转换成 NumPy 数组


distance = {}
for list in train_list:
    antigen = list.split(",")[0]
    antibody = list.split(",")[1]
    ## load npy data
    antigen_grid_path = f"{config['dirs']['grid']}/training/{antigen}.npy"
    antibody_grid_path = f"{config['dirs']['grid']}/training/{antibody}.npy"
    antigen_true_grid = np.load(antigen_grid_path, allow_pickle=True)
    antibody_true_grid = np.load(antibody_grid_path, allow_pickle=True)
    antigen_grid = antigen_true_grid[:, :, -3:]
    antibody_grid = antibody_true_grid[:, :, -3:]
    ## compute patches of antigen and antibody
    antigen_patch = split_into_patches(antigen_grid,params['patch_size'])
    antibody_patch = split_into_patches(antibody_grid,params['patch_size'])
    ## compute distance matrix dimension should be (64,64,16,16)
    distance_matrix = np.linalg.norm(antigen_patch[:, None, :, None, :] - antibody_patch[None, :, None, :, :], axis=-1)
    ## set threshold
    result_matrix = np.any(distance_matrix < 4.5, axis=(2, 3)).astype(float)
    distance[list] = result_matrix

for list in val_list:
    antigen = list.split(",")[0]
    antibody = list.split(",")[1]
    ## load npy data
    antigen_grid_path = f"{config['dirs']['grid']}/validation/{antigen}.npy"
    antibody_grid_path = f"{config['dirs']['grid']}/validation/{antibody}.npy"
    antigen_true_grid = np.load(antigen_grid_path, allow_pickle=True)
    antibody_true_grid = np.load(antibody_grid_path, allow_pickle=True)
    antigen_grid = antigen_true_grid[:, :, -3:]
    antibody_grid = antibody_true_grid[:, :, -3:]
    ## compute patches of antigen and antibody
    antigen_patch = split_into_patches(antigen_grid,params['patch_size'])
    antibody_patch = split_into_patches(antibody_grid,params['patch_size'])
    ## compute distance matrix dimension should be (64,64,16,16)
    distance_matrix = np.linalg.norm(antigen_patch[:, None, :, None, :] - antibody_patch[None, :, None, :, :], axis=-1)
    ## set threshold
    result_matrix = np.any(distance_matrix < 4.5, axis=(2, 3)).astype(float)
    distance[list] = result_matrix







In [1]:
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import time
import json
from tqdm import tqdm
from sklearn import metrics
from datetime import datetime


def get_date():
    return datetime.now().strftime("%d/%m/%Y %H:%M:%S")


def set_device(model, device_ids, device):
    if device_ids is None and device is None:
        if not torch.cuda.is_available():
            device = torch.device("cpu")
        else:
            device = torch.device("cuda:0")
        model = model.to(device, non_blocking=False)
    elif device is not None:
        model = model.to(device, non_blocking=False)
    elif device_ids is not None:
        print("Setting up the following GPUs: {}".format(device_ids))
        device=torch.device("cuda:{}".format(device_ids[0]))
        model = nn.DataParallel(model, device_ids=device_ids).to(device, non_blocking=False)
    return model,device

def plot_metrics(history, saved_model_dir, model_name="training_plot"):
    plt.style.use('ggplot')

    figures_dir = saved_model_dir + '/' + model_name +'_figs/'
    if not os.path.exists(figures_dir):
        os.mkdir(figures_dir)
    # plot losses:
    plt.plot(history['train_loss'], marker='.', color='b', label='Train loss')
    plt.plot(history['val_loss'], marker='.', color='r', label='Validation loss')
    plt.legend(loc="upper right")
    plt.savefig(figures_dir+'/loss_{}.png'.format(model_name))
    plt.clf()
    plt.plot(history['train_auc'], marker='.', color='b', label='Train AUC')
    plt.plot(history['val_auc'], marker='.', color='r', label='Validation AUC')
    plt.legend(loc="lower right")
    plt.savefig(figures_dir+'/auc_{}.png'.format(model_name))

def add_to_history(history, train_loss, val_loss,train_auc, val_auc):
    history['train_loss'].append(float(train_loss))
    history['val_loss'].append(float(val_loss))
    history['train_auc'].append(float(train_auc))
    history['val_auc'].append(float(val_auc))

    return history

def compute_performance(output, label):
    pred_probabilities = output.cpu().detach().numpy()
    label = label.cpu().detach().numpy()
    flat_true_labels = label.flatten()
    flat_pred_probabilities = pred_probabilities.flatten()
    auc = metrics.roc_auc_score(flat_true_labels, flat_pred_probabilities)
    return auc

def train_one_epoch(antigen_model,antibody_model,train_loader,device,criterion,optimizer,disable_tqdm = False):
     running_loss = 0
     running_auc = 0
     antigen_model.train()
     antibody_model.train()
     for i, data in tqdm(enumerate(train_loader), total=len(train_loader), position=0, leave=True, disable=disable_tqdm):
         antigen_grid, antibody_grid,_ , _, ppi = data
         print(antigen_grid.shape)
         antigen_grid = antigen_grid.to(device,dtype=torch.float)
         antibody_grid = antibody_grid.to(device,dtype=torch.float)
         ppi = ppi[0]
         label = distance[ppi]
         label = torch.tensor(label,dtype=torch.float)
         label = label.to(device)
         antigen_output = antigen_model(antigen_grid)
         antibody_output = antibody_model(antibody_grid)
         antigen_vector = antigen_output
         antibody_vector = antibody_output
         ## compute predict matrix
         antibody_vector = antibody_vector.transpose(1, 2)
         predict = torch.bmm(antigen_vector, antibody_vector)
         predict = predict[0]
         loss = criterion(predict,label)
         auc = compute_performance(predict, label)
         
         loss.backward()
         optimizer.step()  # update weight
         optimizer.zero_grad()
         optimizer.step()
         
         running_loss += loss
         running_auc += auc      
         
         

     antigen_model.eval()
     antibody_model.eval()
     train_loss = running_loss / len(train_loader)
     train_auc = running_auc / len(train_loader)
     print("Average training loss: {}; train AUC: {};".format(train_loss, train_auc))
     return antigen_model,antibody_model, train_loss, train_auc


def evaluate_val(loader, antigen_model, antibody_model,device, criterion=None, include_attn=False, inside_loss=False):
     running_loss, running_auc = 0, 0
     antigen_model.eval()
     antibody_model.eval()
     with torch.no_grad():
        for i, data in enumerate(loader):
            antigen_grid, antibody_grid,_ , _, ppi = data
            antigen_grid = antigen_grid.to(device,dtype=torch.float)
            antibody_grid = antibody_grid.to(device,dtype=torch.float)
            ppi = ppi[0]
            label = distance[ppi]
            label = torch.tensor(label,dtype=torch.float)
            label = label.to(device)
            antigen_output = antigen_model(antigen_grid)
            antibody_output = antibody_model(antibody_grid)
            antigen_vector = antigen_output
            antibody_vector = antibody_output
            ## compute predict matrix
            antibody_vector = antibody_vector.transpose(1, 2)
            predict = torch.bmm(antigen_vector, antibody_vector)
            predict = predict[0]
            loss = criterion(predict,label)
            auc = compute_performance(predict, label)
            running_loss += loss
            running_auc += auc 
             
     val_loss = running_loss / len(loader)
     val_auc = running_auc / len(loader)
     
     return val_loss,val_auc
    
     
    

In [29]:
"this code meant to train model "
from torchsummaryX import summary
def fit_training(epochs,antigen_model,antibody_model,train_loader,val_loader,optimer,criterion = None, model_name = 'default', image_size = 64,channels = 7,device_ids = None,device = None ,saved_model_dir = ',/savedModels/',print_summary = True,patience = 10,antigen_model_name = "antigen",antibody_model_name = 'antibody'):
    start = time.time()
    if not os.path.exists(saved_model_dir):
        os.mkdir(saved_model_dir)
    ## load the pre-trained model
    elif os.path.exists(saved_model_dir + '/{}.pth'.format(model_name)):
        antigen_model.load_state_dict(torch.load(saved_model_dir + '/{}.pth'.format(model_name)))
        antibody_model.load_state_dict(torch.load(saved_model_dir + '/{}.pth'.format(model_name)))
        
    history = {'train_loss': [], 'val_loss': [],
               'train_auc':[], 'val_auc':[]
               }
    antigen_model,device = set_device(antigen_model, device_ids, device)
    antibody_model,device = set_device(antibody_model, device_ids, device)
    print("Start training our own model")
    
    if print_summary:
         summary(antigen_model, torch.rand((1, channels, image_size, image_size)).to(device, non_blocking=False))
         summary(antibody_model, torch.rand((1, channels, image_size, image_size)).to(device, non_blocking=False))
        
    max_auc = 0
    decrease = 0
    not_improved = 0
    saved_index = 0
    for e in range(epochs):
        print("[{}] Starting training for epoch {}...".format(get_date(), e))
        antigen_model, antibody_model ,train_loss, train_auc = train_one_epoch(antigen_model,antibody_model,train_loader, device, criterion, optimizer=optimer)
        print("Evaluating the model on validation set...".format(e))
        val_loss, val_auc = evaluate_val(val_loader, antigen_model,antibody_model, device, criterion)
        print("Average val loss: {}; val AUC: {};".format(val_loss, val_auc))
        if val_auc>max_auc:
            print('AUC increasing.. {:.4f} >> {:.4f} '.format(max_auc, val_auc))
            max_auc = val_auc
            decrease += 1
            print('saving model on epoch {}...'.format(e))
            torch.save(antigen_model.state_dict(), saved_model_dir + '/{}.pth'.format(antigen_model_name))
            torch.save(antibody_model.state_dict(), saved_model_dir + '/{}.pth'.format(antibody_model_name))
            saved_index = e
            not_improved=0
        else:
            not_improved +=1
        
        history = add_to_history(history, train_loss, val_loss, train_auc, val_auc)  
        if not_improved==patience:
            print("Stopping training...")
            break
    antigen_model.load_state_dict(torch.load(saved_model_dir + '/{}.pth'.format(antigen_model_name)))
    antibody_model.load_state_dict(torch.load(saved_model_dir + '/{}.pth'.format(antibody_model_name)))
    print("[{}] Done with training.".format(get_date()))
    plot_metrics(history, saved_model_dir)
    with open(saved_model_dir+'/history.json', 'w') as outfile:
        json.dump(history, outfile)
    
    return antigen_model,antibody_model, history, saved_index

In [30]:
from networks.ViT_pytorch import get_ml_config
import torch.optim 

def train_piston(search_space, train_list,val_list,SEED_ID, IMG_SIZE, 
                  antigen_mean, antigen_std, antibody_mean, antibody_std,
                  MAX_EPOCHS=50, N_FEATURES=7, feature_subset=None,data_prepare_dir='./data_preparation/',MODEL_NAME=MODEL_NAME, MODEL_DIR=MODEL_DIR,threshold = 0 ,antigen_name='antigen',antibody_name='antibody'):
     assert len(antigen_mean) == N_FEATURES
     if feature_subset is not None:
        assert len(antigen_mean) == len(feature_subset)
     # print(saved_model_dir_antigen)
     ## load all test data and validation data 
     train_db = PDB_complex_training(train_list,
                                  training_mode=True,
                                  feature_subset=feature_subset,
                                  data_prepare_dir=data_prepare_dir,
                                  neg_pos_ratio=search_space['neg_pos_ratio'],
                                  antigen_mean=antigen_mean,
                                  antigen_std=antigen_std,
                                  antibody_mean=antibody_mean, 
                                  antibody_std=antibody_std)
     print(len(train_db))
     val_pdb = PDB_complex_training(val_list,
                                  training_mode=True,
                                  feature_subset=feature_subset,
                                  data_prepare_dir=data_prepare_dir,
                                  neg_pos_ratio=search_space['neg_pos_ratio'],
                                  antigen_mean=antigen_mean,
                                  antigen_std=antigen_std,
                                  antibody_mean=antibody_mean, 
                                  antibody_std=antibody_std)
     ##### Initialize data loaders
     def worker_init_fn(worker_id):
        random.seed(SEED_ID + worker_id)
        
     ###load the test data 
     trainloader = DataLoader(train_db, batch_size=1, shuffle=False, pin_memory=True)
     valiloader = DataLoader(val_pdb, batch_size=1, shuffle=False, pin_memory=True)
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     ## get a model 
     model_config = get_ml_config(search_space)
     print(model_config)
     model_antigen = PISToN_proto(model_config, img_size=IMG_SIZE).float()
     model_antibody = PISToN_proto(model_config, img_size=IMG_SIZE).float()
     optimizer = torch.optim.Adam([
        {'params': model_antigen.parameters(), 'lr': 0.0001,'weight_decay':0.0001}, 
        {'params': model_antibody.parameters(), 'lr': 0.0001,'weight_decay':0.0001}
        ])
     critertion = nn.CrossEntropyLoss()
     antigen_model, antibody_model, history, saved_index =fit_training(MAX_EPOCHS,model_antigen,model_antibody,trainloader,valiloader,optimizer,criterion = critertion, model_name = MODEL_NAME, image_size =IMG_SIZE,channels = 7,device_ids = None,device =device ,saved_model_dir = './savedModels/',print_summary = True,patience = 10,antigen_model_name = "antigen",antibody_model_name = 'antibody')

     
     return antigen_model,antibody_model, history, saved_index
     
             

In [27]:
### 训练模型
antigen_model,antibody_model, history, saved_index=train_piston(params,train_list = train_list,val_list=val_list,SEED_ID=SEED_ID,IMG_SIZE=IMG_SIZE,antigen_mean=antigen_mean_array,antigen_std=antigen_std_array,antibody_mean=antibody_mean_array,antibody_std=antibody_std_array,data_prepare_dir=config['dirs']['grid'],MODEL_NAME=MODEL_NAME, MODEL_DIR=MODEL_DIR,threshold= 4.5, antigen_name = ANTIGEN_MODEL, antibody_name = ANTIBODY_MODEL)

Length of PPI list: 37
37
Length of PPI list: 14
classifier: token
hidden_size: 16
patches:
  size: !!python/tuple
  - 4
  - 4
representation_size: null
transformer:
  attention_dropout_rate: 0
  dropout_rate: 0
  mlp_dim: 16
  num_heads: 8
  num_layers: 8

Start training our own model
                                                     Kernel Shape  \
Layer                                                               
0_spatial_transformers_list.0.transformer.embed...  [7, 16, 4, 4]   
1_spatial_transformers_list.0.transformer.embed...              -   
2_spatial_transformers_list.0.transformer.encod...           [16]   
3_spatial_transformers_list.0.transformer.encod...       [16, 16]   
4_spatial_transformers_list.0.transformer.encod...       [16, 16]   
5_spatial_transformers_list.0.transformer.encod...       [16, 16]   
6_spatial_transformers_list.0.transformer.encod...              -   
7_spatial_transformers_list.0.transformer.encod...              -   
8_spatial_transformers_

  0%|          | 0/37 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 1

In [None]:
 # for epoch in range(MAX_EPOCHS):
 #         for data in trainloader:
 #             antigen_grid, antibody_grid,_ , _, ppi = data
 #             ## compute ground_truth matrix  should be (64*64)
 #             ppi = ppi[0]
 #             distance_matrix = distance[ppi]
 #             ## convert to tensor and put it to device 
 #             distance_matrix = torch.tensor(distance_matrix)
 #             distance_matrix = distance_matrix.to(device)
 #             ## convert to tensor and load it tpo device
 #             antigen = antigen_grid.to(device=device, dtype=torch.float)
 #             antibody = antibody_grid.to(device=device, dtype=torch.float)
 #             ## input dimension (1,7,32,32)
 #             ## output dimension (1,64,16)
 #             antigen_output = model_antigen(antigen)
 #             antibody_output = model_antibody(antibody)
 #             antigen_vector = antigen_output
 #             print(antigen_output.shape)
 #             antibody_vector = antibody_output
 #             ## calculate predict matrix 
 #             antibody_vector = antibody_vector.transpose(1, 2)
 #             predict = torch.bmm(antigen_vector, antibody_vector)
 #             predict = predict[0]
 #             print(predict)
 #             ## using CrossEntropy loss to modify model 
 #             crossentropy = nn.CrossEntropyLoss()
 #             loss = crossentropy(predict,distance_matrix)
 #             ## bp
 #             optimizer.zero_grad()
 #             loss.backward()
 #             optimizer.step()
 #         print("Epoch:", epoch+1 , "Loss:", loss.item())
 #         torch.save(model_antigen.state_dict(), 'antigen.pth')
 #         torch.save(model_antibody.state_dict(), "antibody.pth")
 #     
 #     print('training over')
