In [210]:
import torch.nn as nn
import torch
import numpy as np
import random
from torch.nn import functional as F
import sys
torch.backends.cudnn.enabled = False
class Causal_ARG(nn.Module):
    def __init__(self, X_dim, G_dim, z1_dim, z2_dim, transfer_count, mechanism_count,
                 antibiotic_count):
        super(Causal_ARG, self).__init__()
        self.feature = nn.Sequential(
            # (batch * 1 * 1576 * 23) -> (batch * 32 * 1537 * 20)
            nn.Conv2d(1, 32, kernel_size=(40, 4), ),
            nn.LeakyReLU(),
            # (batch * 32 * 1537 * 20) -> (batch * 32 * 1533 * 19)
            nn.MaxPool2d(kernel_size=(5, 2), stride=1),
            # (batch * 32 * 1533 * 19) -> (batch * 64 * 1504 * 16)
            nn.Conv2d(32, 64, kernel_size=(30, 4)),
            nn.LeakyReLU(),
            # (batch * 64 * 1504 * 16) -> (batch * 128 * 1475 * 13)
            nn.Conv2d(64, 128, kernel_size=(30, 4)),
            nn.LeakyReLU(),
            # (batch * 128 * 1475 * 13) -> (batch * 128 * 1471 * 12)
            nn.MaxPool2d(kernel_size=(5, 2), stride=1),
            # (batch * 128 * 1471, 12) -> (batch * 256 * 1452 * 10)
            nn.Conv2d(128, 256, kernel_size=(20, 3)),
            nn.LeakyReLU(),
            # (batch * 256 * 1452 * 10) -> (batch * 256 * 1433 * 8)
            nn.Conv2d(256, 256, kernel_size=(20, 3)),
            nn.LeakyReLU(),
            # (batch * 256 * 1433 * 8) -> (batch * 256 * 1430 * 8)
            nn.MaxPool2d(kernel_size=(4, 1), stride=1),
            # (batch * 256 * 1430 * 8) -> (batch * 1 * 1411 * 6)
            nn.Conv2d(256, 1, kernel_size=(20, 3)),
            nn.LeakyReLU(),
            # (batch * 1 * 1411 * 6) -> (batch * 1 * 1410 * 6)
            nn.MaxPool2d(kernel_size=(2, 1), stride=1)
        )
        self.fc = nn.Sequential(
            nn.Linear(8460, 1024),
            nn.LeakyReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(1024, X_dim),
            nn.LeakyReLU()
        )

        self.gauus = Gauussion(X_dim + 3, G_dim)
        self.hidden = Hidden(X_dim, G_dim, z1_dim, z2_dim)
        self.causal = Causal(z1_dim + z2_dim, transfer_count, mechanism_count, antibiotic_count)

    def forward(self, seq_map, transfer_label, mechanism_label, antibiotic_label):
        x = self.feature(seq_map)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        transfer_label.unsqueeze(-1)
        mechanism_label.unsqueeze(-1)
        antibiotic_label.unsqueeze(-1)

        labels = [transfer_label, mechanism_label, antibiotic_label]
        labels = torch.cat(labels, dim=1)
        mean_logvar1, mean_logvar2, mean_logvar3, mean_logvar4, prob, G = self.gauus(torch.cat((x, labels), dim=1))

        z1, z2 = self.hidden(x, G)
        hidden_representation = torch.cat((z1, z2), dim=1)

        transfer_pre, mechanism_pre, antibiotic_pre = self.causal(hidden_representation)

        # return transfer_pre, mechanism_pre, antibiotic_pre, mean_logvar1, mean_logvar2, mean_logvar3, mean_logvar4, prob
        return transfer_pre, mechanism_pre, antibiotic_pre

class Hidden(nn.Module):
    def __init__(self, X_dim, G_dim, z1_dim, z2_dim):
        super(Hidden, self).__init__()
        self.concat_dim = X_dim + G_dim
        self.hidden1 = nn.Sequential(
            nn.Linear(self.concat_dim, self.concat_dim),
            nn.LeakyReLU(),
            nn.Linear(self.concat_dim, z1_dim),
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(G_dim, G_dim),
            nn.LeakyReLU(),
            nn.Linear(G_dim, z2_dim),
        )

    def forward(self, X, G):
        input = torch.cat((X, G), dim=1)

        z1 = self.hidden1(input)
        z2 = self.hidden2(G)
        return z1, z2


class Causal(nn.Module):
    def __init__(self, input_dim, transfer_count, mechanism_count, antibiotic_count):
        super(Causal, self).__init__()
        self.transfer_layer = nn.Linear(input_dim, transfer_count)
        self.softmax = nn.Softmax(dim=1)

        self.mechanism_layer = nn.Linear(input_dim + transfer_count, mechanism_count)
        self.antibiotic_layer = nn.Linear(input_dim + transfer_count + mechanism_count, antibiotic_count)

    def forward(self, input):
        transfer_pre = self.softmax(self.transfer_layer(input))
        mechanism_pre = self.softmax(self.mechanism_layer(torch.cat((input, transfer_pre), dim=1)))
        antibiotic_pre = self.softmax(self.antibiotic_layer(torch.cat((input, transfer_pre, mechanism_pre), dim=1)))

        return transfer_pre, mechanism_pre, antibiotic_pre


class Gauussion(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Gauussion, self).__init__()
        self.hidden = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
        )
        self.mean_logvar1 = nn.Linear(hidden_dim, 2 * hidden_dim)
        self.mean_logvar2 = nn.Linear(hidden_dim, 2 * hidden_dim)
        self.mean_logvar3 = nn.Linear(hidden_dim, 2 * hidden_dim)
        self.mean_logvar4 = nn.Linear(hidden_dim, 2 * hidden_dim)

        self.softmax = nn.Softmax(dim=1)
        self.prob = nn.Linear(hidden_dim, 4)

    def forward(self, x):
        hidden = self.hidden(x)
        mean_logvar1 = self.mean_logvar1(hidden)
        mean_logvar2 = self.mean_logvar2(hidden)
        mean_logvar3 = self.mean_logvar3(hidden)
        mean_logvar4 = self.mean_logvar4(hidden)
        prob = self.softmax(self.prob(hidden))
        values = [0, 1, 2, 3]
        g_list = []
        mid = hidden.size()[1]
        for i, pro in enumerate(prob):

            value = random.choices(values, pro.tolist())[0]

            if (value == 0):
                g = mean_logvar1[i][:mid] + torch.rand_like(mean_logvar1[i][mid:]) * mean_logvar1[i][mid:]
            elif (value == 1):
                g = mean_logvar2[i][:mid] + torch.rand_like(mean_logvar2[i][mid:]) * mean_logvar2[i][mid:]
            elif (value == 2):
                g = mean_logvar3[i][:mid] + torch.rand_like(mean_logvar3[i][mid:]) * mean_logvar3[i][mid:]
            else:
                g = mean_logvar4[i][:mid] + torch.rand_like(mean_logvar4[i][mid:]) * mean_logvar4[i][mid:]

            g_list.append(g)


        final_g = torch.stack(g_list)
        return mean_logvar1, mean_logvar2, mean_logvar3, mean_logvar4, prob, final_g

In [211]:
# 加载训练好的模型
import torch   

In [212]:
model = Causal_ARG(64, 64, 64, 64, 2, 6, 15)
model.load_state_dict(torch.load('./model0.pth',map_location=torch.device('cpu')))

<All keys matched successfully>

In [213]:
model

Causal_ARG(
  (feature): Sequential(
    (0): Conv2d(1, 32, kernel_size=(40, 4), stride=(1, 1))
    (1): LeakyReLU(negative_slope=0.01)
    (2): MaxPool2d(kernel_size=(5, 2), stride=1, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(30, 4), stride=(1, 1))
    (4): LeakyReLU(negative_slope=0.01)
    (5): Conv2d(64, 128, kernel_size=(30, 4), stride=(1, 1))
    (6): LeakyReLU(negative_slope=0.01)
    (7): MaxPool2d(kernel_size=(5, 2), stride=1, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(20, 3), stride=(1, 1))
    (9): LeakyReLU(negative_slope=0.01)
    (10): Conv2d(256, 256, kernel_size=(20, 3), stride=(1, 1))
    (11): LeakyReLU(negative_slope=0.01)
    (12): MaxPool2d(kernel_size=(4, 1), stride=1, padding=0, dilation=1, ceil_mode=False)
    (13): Conv2d(256, 1, kernel_size=(20, 3), stride=(1, 1))
    (14): LeakyReLU(negative_slope=0.01)
    (15): MaxPool2d(kernel_size=(2, 1), stride=1, padding=0, dilation=1, ceil_mode=Fals

Input your protein sequence here

The seq below is an example, you can modify it

In [242]:
seq="""
MLATLPLAVHASPQPLEQIKLSESQLSGRVGMIEMDLASGRTLTAWRADERFPMMSTFKV
VLCGAVLARVDAGDEQLERKIHYRQQDLVDYSPVSEKHLADGMTVGELCAAAITMSDNSA
ANLLLATVGGPAGLTAFLRQIGDNVTRLDRWETELNEALPGDARATTTPASMAATLRKLL
TSQRLSARSQRQLLQWMVDDRVAGPLIRSVLPAGWFIADKTGAGERGARGIVALLGPNNK
AERIVVIYLRDTPASMAERN
"""

In [243]:
used_anti_label = {'beta_lactam': 0, 'bacitracin': 1,'multidrug': 2,'macrolide-lincosamide-streptogramin': 3,'aminoglycoside': 4,
                   'polymyxin': 5,'chloramphenicol': 6, 'tetracycline': 7,'fosfomycin': 8,'glycopeptide': 9,'quinolone': 10,
                   'trimethoprim': 11, 'sulfonamide': 12, 'rifampin': 13,'others': 14}
anti_list = list(used_anti_label.keys())  
uesd_mech_label = {'antibiotic target protection': 0,  'antibiotic efflux': 1, 'antibiotic inactivation': 2, 
                   'antibiotic target alteration': 3,'antibiotic target replacement': 4, 'others': 5}
uesd_transfer_label = {'intrinsic':0,'acquired':1}
tranfer_list = list(uesd_transfer_label.keys())
mech_list = list(uesd_mech_label.keys())
anti_list,mech_list
word_map = {'M':0, 'A':1,'E':2, 'P':3, 'V':4, 'L':5, 'S':6, 'K':7, 'D':8, 'I':9, 'R':10, 'F':11, 'T':12,'G':13, 
         'N':14, 'H':15, 'C':16, 'Q':17, 'Y':18, 'W':19, 'X':20, 'Z':21}
seq_mat = []
for w in seq:

    seq_mat = []
    for w in seq:
        if w not in word_map:
            word_map[w] = len(word_map)
        one_hot = [0] * 23
        one_hot[word_map[w]] = 1
        seq_mat.append(one_hot)
#     zero-padding
    for i in range(len(seq_mat), 1576):
        one_hot = [0] * 23
        seq_mat.append(one_hot)

seq_map = torch.tensor(seq_mat, dtype=torch.float32).view(-1, 1, 1576, 23)
transfer_output, mechanism_output, antibiotic_output= model.forward(seq_map,
                torch.full((1, 1), -1),torch.full((1, 1), -1),torch.full((1, 1), -1))

index = np.argmax(antibiotic_output.detach().numpy(), axis = 1)
print(anti_list[index[0]])
mech_index = np.argmax(mechanism_output.detach().numpy(), axis = 1)
print(mech_list[mech_index[0]])
transfer_index = np.argmax(transfer_output.detach().numpy(), axis = 1)
print(tranfer_list[transfer_index[0]])

beta_lactam
antibiotic inactivation
acquired
