In [2]:
import os
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [14]:
def getImgtSequence(filepath):
    letters = {'ALA':'A','ARG':'R','ASN':'N','ASP':'D','CYS':'C','GLU':'E','GLN':'Q','GLY':'G','HIS':'H','ILE':'I','LEU':'L','LYS':'K','MET':'M','PHE':'F','PRO':'P','SER':'S','THR':'T','TRP':'W','TYR':'Y','VAL':'V', 'UNK':'X'}
    maxlen=140
    seqres=''
    xyz=[]
    with open(filepath) as f:
        for line in f:
            toks = line.split()
            if toks[0]=='ATOM' and toks[2]=='CA' and len(seqres)<maxlen:
                try:
                    seqres+=letters[toks[3]]
                except:
                    seqres+='X'
                try:
                    xyz.append([float(toks[6]),float(toks[7]),float(toks[8])])
                except:
                    xyz.append([0,0,0])
    return seqres, xyz

def getImgtCaData(filepath='./data/imgt/', maxlen=140):
    for filename in os.listdir(filepath):
        if filename.endswith('.pdb'):
            seqres, xyz = getImgtSequence(filepath+filename)
            # 小于140的补全
            if len(seqres) < maxlen:
                for i in range(maxlen-len(seqres)):
                    seqres+='X'
                    xyz.append([0,0,0])
            with open('./data/imgtCaResult/'+filename+'.txt','w') as file:
                # 只取前140个
                for i in range(maxlen):
                    file.write(seqres[i]+' '+str(xyz[i][0])+' '+str(xyz[i][1])+' '+str(xyz[i][2])+'\n')

In [15]:
filepath='data/imgt/'
getImgtCaData(filepath)

以上代码仅用作处理imgt数据集，生成CA原子的坐标及对应氨基酸字母

In [None]:
def clean(filepath='./data/train_data/train_data/'):
    for filename in os.listdir(filepath):
        if filename.endswith('.txt'):
            # 判断文件行数是否大于141
            with open(filepath+filename,'r') as file:
                lines = file.readlines()
                # 如果len(lines) > 141，删除剩余行
                if len(lines) > 141:
                    with open(filepath+filename,'w') as file:
                        for i in range(141):
                            file.write(lines[i])
                # 如果len(lines) < 141，进行填充X,0,0,0
                elif len(lines) < 141:
                    with open(filepath+filename,'a') as file:
                        for i in range(141-len(lines)):
                            file.write('X 0 0 0\n')


以上代码仅用作处理train_data数据集，由于有表头，所以141行

In [16]:
def getSequence(filepath):
    letters = {'ALA':'A','ARG':'R','ASN':'N','ASP':'D','CYS':'C','GLU':'E','GLN':'Q','GLY':'G','HIS':'H','ILE':'I','LEU':'L','LYS':'K','MET':'M','PHE':'F','PRO':'P','SER':'S','THR':'T','TRP':'W','TYR':'Y','VAL':'V', 'UNK':'X', 'PCA':'X', 'MSE':'X', 'GYS':'X', 'CRO':'X', 'YCM':'X',
    'EDO':'X'}
    with open(filepath,'r') as file:
        seqres=''
        flag=0
        for line in file:
            toks=line.split()
            if toks[0] == 'SEQRES':
                if toks[1] == '1':
                    flag+=1
                if flag > 1:
                    break
                tempseq=toks[4:] 
                tempseq=[letters[i] for i in tempseq]
                tempseq=''.join(tempseq)
                seqres+=tempseq
    return seqres


def getXYZ(filepath, length):
    with open(filepath,'r') as file:
        xyz=[]
        for line in file:
            toks = line.split()
            if toks[0] == 'ATOM' and toks[2] == 'CA' and len(xyz) < length:  
                xyz.append([float(toks[6]),float(toks[7]),float(toks[8])])
    return xyz


def getCaData(filepath='./data/otherPdb/', maxlen=140):
    for filename in os.listdir(filepath):
        if filename.endswith('.pdb'):
            seqres = getSequence(filepath+filename)
            xyz = getXYZ(filepath+filename, len(seqres))
            # 小于140的补全
            if len(seqres) < maxlen:
                for i in range(maxlen-len(seqres)):
                    seqres+='X'
                    xyz.append([0,0,0])
            with open('./data/imgtCaResult/'+filename+'.txt','w') as file:
                # 只取前140个
                for i in range(maxlen):
                    file.write(seqres[i]+' '+str(xyz[i][0])+' '+str(xyz[i][1])+' '+str(xyz[i][2])+'\n')


def getOneHot(filepath):
    with open(filepath,'r') as file:
        seqres=[]
        xyz=[]
        flag=0
        for line in file:
            if flag==0:
                flag=1
                continue
            toks = line.split()
            seqres.append(toks[0])
            xyz.append([float(toks[1]),float(toks[2]),float(toks[3])])
    letters = ['A','R','N','D','C','E','Q','G','H','I','L','K','M','F','P','S','T','W','Y','V','X']
    seqres = [[1 if i==j else 0 for i in letters] for j in seqres]
    return seqres, xyz


def getTrainData(filepath='./data/imgtCaResult/'):
    trainData=[]
    trainLabel=[]
    for filename in os.listdir(filepath):
        if filename.endswith('.txt'):
            onehot, xyz = getOneHot(filepath+filename)
            if len(onehot)!=len(xyz):
              continue
            trainData.append(onehot)
            trainLabel.append(xyz)
    
    trainData = torch.tensor(trainData)
    trainLabel = torch.tensor(trainLabel)
    return trainData, trainLabel

def main():
    getCaData()
    trainData, trainLabel = getTrainData()
    print(trainData.shape)
    print(trainLabel.shape)
    return trainData, trainLabel

In [4]:
class MyDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return len(self.data)

In [5]:
# 定义模型
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, dilation=dilation,
                               padding=(kernel_size - 1) * dilation // 2)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, dilation=dilation,
                               padding=(kernel_size - 1) * dilation // 2)
        self.conv3 = nn.Conv1d(out_channels, out_channels,kernel_size=kernel_size, dilation=dilation,
                               padding=(kernel_size - 1) * dilation // 2)
        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        out1 = self.conv1(x)
        # print('con1: ', out1.shape)
        out2 = self.conv2(out1)
        # print('con2: ', out2.shape)
        out3 = self.relu(out2)
        # print('relu: ', out3.shape)
        out4 = self.conv3(out3)
        # print('con3: ', out4.shape)
        out5 = self.bn(out4)
        # print('bn: ', out5.shape)
        out6 = self.relu(out5 + out1)
        # print('relu: ', out6.shape)
        return out6


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # ResNet 1: 3blocks, kernel_size(25*25), dilation=1
        self.ResNet1_block1 = ResidualBlock(21, 64, 25, 1)
        self.ResNet1_block2 = ResidualBlock(64, 64, 25, 1)
        self.ResNet1_block3 = ResidualBlock(64, 64, 25, 1)

        # ResNet 2: 5blocks, kernel_size(5*5), dilations: 1, 2, 4, 8, 16
        self.ResNet2_block1 = ResidualBlock(64, 140, 5, 1)
        self.ResNet2_block2 = ResidualBlock(140, 140, 5, 2)
        self.ResNet2_block3 = ResidualBlock(140, 140, 5, 4)
        self.ResNet2_block4 = ResidualBlock(140, 140, 5, 8)
        self.ResNet2_block5 = ResidualBlock(140, 140, 5, 16)

        self.dropout = nn.Dropout(0.25)
        self.conv70 = nn.Conv1d(140, 70, 1)
        self.conv3 = nn.Conv1d(70, 3, 1)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = x.float()
        # ResNet 1, 3blocks
        out1 = self.ResNet1_block1(x)
        out2 = self.ResNet1_block2(out1)
        out3 = self.ResNet1_block3(out2)
        # ResNet 2, 5blocks
        out4 = self.ResNet2_block1(out3)
        out5 = self.ResNet2_block2(out4)
        out6 = self.ResNet2_block3(out5)
        out7 = self.ResNet2_block4(out6)
        out8 = self.ResNet2_block5(out7)
        # dropout
        out9 = self.dropout(out8)
        out10 = self.conv70(out9)
        out11 = F.elu(out10)
        out12 = self.conv3(out11)
        out13 = out12.permute(0, 2, 1)
        return out13


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
trainData, trainLabel = main()

In [None]:
trainData = trainData.to(float).to(device)
trainLabel = trainLabel.to(float).to(device)
# 生成数据集
dataset = MyDataset(trainData, trainLabel)
dataloder = DataLoader(dataset, batch_size=16, shuffle=True)

In [None]:
model = Net()
model = model.to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
def train(epoch):
    losslist = []
    print('training on ', device)
    for epo in range(epoch):
        train_loss = 0
        for i, data in enumerate(dataloder):
            inputs, labels = data
            inputs = inputs.to(float)
            labels = labels.to(float)
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.to(float)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        losslist.append(train_loss/len(dataloder))
    losslist = [i/140 for i in losslist]
    plt.plot(np.arange(len(losslist)), losslist, label="train loss")
    plt.legend()
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()
    print('epoch: %d, loss: %.03f' % (epo + 1, train_loss / (140*len(dataloder))))
    print('Finished Training')

In [None]:
train(100)