## Data preprocessing

In [1]:
import pandas as pd
from tqdm import tqdm

In [2]:
def read_file(filepath):
    with open(filepath) as fp:
        content=fp.read();
    return content

In [3]:
file = read_file('./arg_v5.fasta').split('>')
# len(file)

In [4]:
# file[0] is null
del file[0]

In [5]:
data = pd.DataFrame()
anti_label = {}
mech_label = {}
type_label = {}
for f in tqdm(file):
    line = f.split('|')
    r6 = line[6].split('\n')
    seq = ''
    for i in range(1, len(r6)):
        seq += r6[i]
    if line[3] not in anti_label:
        anti_label[line[3]] = 0
    if line[5] not in mech_label:
        mech_label[line[5]] = 0
    if r6[0] not in type_label:
        type_label[r6[0]] = 0
    anti_label[line[3]] += 1
    mech_label[line[5]] += 1
    type_label[r6[0]] += 1
    data = data._append({'id':line[0], 'antibiotic':line[3],'arg':line[4],'mechanism':line[5],'type': r6[0], 'seq': seq}, ignore_index=True)
data

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17282/17282 [00:14<00:00, 1182.54it/s]


Unnamed: 0,id,antibiotic,arg,mechanism,type,seq
0,ACN58946.1,macrolide-lincosamide-streptogramin,macB,antibiotic target protection,0,MAEPVLSVKDLDIRFTTPDGNVHAVKKVSFDIAPGECLGVVGESGS...
1,ACN58871.1,macrolide-lincosamide-streptogramin,macB,antibiotic efflux,0,MADYLLEMKNIVKEFGGVRALNGIDIKLKAGECAGLCGENGAGKST...
2,ACN58991.1,multidrug,cmeB,antibiotic efflux,0,MKNDRGEMVPFSAFMTIKKKQGANEINRYNMYNTAAIRGGPATGYS...
3,ACN58776.1,macrolide-lincosamide-streptogramin,macA,antibiotic efflux,0,MGNLPRPTLSPSLSGIRPTMNRETTTRVDSSTPAARLGMRVPSTSR...
4,ACN58740.1,pleuromutilin,TaeA,antibiotic efflux,0,MRQAVMQGVGDAFKKLVRFNEISEKFAEPMSDDEMNALLEEQAKLQ...
...,...,...,...,...,...,...
17277,DQ993182_1,trimethoprim,dfrB7,antibiotic target replacement,0,MDQSSKEVSSPATDQFALPFRATFGLGDRVRKKSGAAWQGQVVGWY...
17278,CAH16951.1,sulfonamide,folp,antibiotic target alteration,0,MGVLNVTSDSFYDGGKYLSVDRACDQALKLIACGADLIDIGGESTK...
17279,WP_000918666.1,rifampin,rpoB,antibiotic target alteration,0,MAGQVVQYGRHRKRRNYARISEVLELPNLIEIQTKSYEWFLREGLI...
17280,CAJ66881.1,rifampin,rpoB,antibiotic target alteration,0,MPHPVTIGKRTRMSFSKIKEIADVPNLIEIQVDSYEWFLKEGLKEV...


In [6]:
maxlen = 0
for index, row in data.iterrows():
    l = len(row['seq'])
    maxlen = max(l, maxlen)
maxlen

1576

In [7]:
word_map = {'M':0, 'A':1,'E':2, 'P':3, 'V':4, 'L':5, 'S':6, 'K':7, 'D':8, 'I':9, 'R':10, 'F':11, 'T':12,'G':13, 
         'N':14, 'H':15, 'C':16, 'Q':17, 'Y':18, 'W':19, 'X':20, 'Z':21}
used_anti_label = {'beta_lactam': 0, 'bacitracin': 1,'multidrug': 2,'macrolide-lincosamide-streptogramin': 3,'aminoglycoside': 4,
                   'polymyxin': 5,'chloramphenicol': 6, 'tetracycline': 7,'fosfomycin': 8,'glycopeptide': 9,'quinolone': 10,
                   'trimethoprim': 11, 'sulfonamide': 12, 'rifampin': 13,'others': 14}

uesd_mech_label = {'antibiotic target protection': 0,  'antibiotic efflux': 1, 'antibiotic inactivation': 2, 
                   'antibiotic target alteration': 3,'antibiotic target replacement': 4, 'others': 5}
def seq2Onehot(row):
    seq = row['seq']
    seq_mat = []
    for w in seq:
        if w not in word_map:
            word_map[w] = len(word_map)
        one_hot = [0] * 23
        one_hot[word_map[w]] = 1
        seq_mat.append(one_hot)
#     zero-padding
    for i in range(len(seq_mat), maxlen):
        one_hot = [0] * 23
        seq_mat.append(one_hot)
    row['seq_map'] = seq_mat
    
    if row['antibiotic'] not in used_anti_label:
        row['anti_label'] = 14
    else:
        row['anti_label'] = used_anti_label[row['antibiotic']]
        
    if row['mechanism'] not in uesd_mech_label:
        row['mech_label'] = 5
        
    else:
        row['mech_label'] = uesd_mech_label[row['mechanism']]
        
    row['type_label'] = int(row['type'])
    return row
print('data processing...')
data = data.apply(seq2Onehot, axis=1)
print('finish')
data

data processing...
finish


Unnamed: 0,id,antibiotic,arg,mechanism,type,seq,seq_map,anti_label,mech_label,type_label
0,ACN58946.1,macrolide-lincosamide-streptogramin,macB,antibiotic target protection,0,MAEPVLSVKDLDIRFTTPDGNVHAVKKVSFDIAPGECLGVVGESGS...,"[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...",3,0,0
1,ACN58871.1,macrolide-lincosamide-streptogramin,macB,antibiotic efflux,0,MADYLLEMKNIVKEFGGVRALNGIDIKLKAGECAGLCGENGAGKST...,"[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...",3,1,0
2,ACN58991.1,multidrug,cmeB,antibiotic efflux,0,MKNDRGEMVPFSAFMTIKKKQGANEINRYNMYNTAAIRGGPATGYS...,"[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...",2,1,0
3,ACN58776.1,macrolide-lincosamide-streptogramin,macA,antibiotic efflux,0,MGNLPRPTLSPSLSGIRPTMNRETTTRVDSSTPAARLGMRVPSTSR...,"[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...",3,1,0
4,ACN58740.1,pleuromutilin,TaeA,antibiotic efflux,0,MRQAVMQGVGDAFKKLVRFNEISEKFAEPMSDDEMNALLEEQAKLQ...,"[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...",14,1,0
...,...,...,...,...,...,...,...,...,...,...
17277,DQ993182_1,trimethoprim,dfrB7,antibiotic target replacement,0,MDQSSKEVSSPATDQFALPFRATFGLGDRVRKKSGAAWQGQVVGWY...,"[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...",11,4,0
17278,CAH16951.1,sulfonamide,folp,antibiotic target alteration,0,MGVLNVTSDSFYDGGKYLSVDRACDQALKLIACGADLIDIGGESTK...,"[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...",12,3,0
17279,WP_000918666.1,rifampin,rpoB,antibiotic target alteration,0,MAGQVVQYGRHRKRRNYARISEVLELPNLIEIQTKSYEWFLREGLI...,"[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...",13,3,0
17280,CAJ66881.1,rifampin,rpoB,antibiotic target alteration,0,MPHPVTIGKRTRMSFSKIKEIADVPNLIEIQVDSYEWFLKEGLKEV...,"[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...",13,3,0


In [8]:
import os
os.mkdir("./data")

In [9]:
data.to_pickle('./data/arg_v5_processed.pickle')

## Dataset Dividing

In [10]:
from sklearn.model_selection import KFold
import pandas as pd

In [11]:
os.mkdir("./data/train_val")
os.mkdir("./data/test")

In [12]:
def save_KFold_data(data, K):
    kf = KFold(n_splits=K)
    cross = 1
    for train_index, val_index in kf.split(data):
        train_data = data.iloc[train_index]
        val_data = data.iloc[val_index]
        train_data.to_pickle('data/train_val/cross_' + str(cross) +'_train.pickle')
        val_data.to_pickle('data/train_val/cross_' + str(cross) +'_val.pickle')
        print('cross_' + str(cross) +' train_val data saved...')
        cross += 1

def save_test_data(test_data):
    test_data.to_pickle('./data/test/test.pickle')
    print('test data saved...')

def load_data():
    data = pd.read_pickle('./data/arg_v5_processed.pickle')
    anti_count, mech_count, type_count = 15, 6, 2
    return data, anti_count, mech_count, type_count

def init_data(data, train_rate):
    train_data = data.sample(int(len(data) * train_rate))
    test_data = data.drop(labels=train_data.index)
    return train_data, test_data

### The dataset dividing takes a longer time, please be patient.

In [13]:
data, anti_count, mech_count, type_count = load_data()
train_data, test_data = init_data(data, 0.8)
save_test_data(test_data)
save_KFold_data(train_data, 5)

test data saved...
cross_1 train_val data saved...
cross_2 train_val data saved...
cross_3 train_val data saved...
cross_4 train_val data saved...
cross_5 train_val data saved...


## Model architecture

In [14]:
import torch.nn as nn
import torch
import random

torch.backends.cudnn.enabled = False

In [15]:
class Causal_ARG(nn.Module):
    def __init__(self, X_dim, G_dim, z1_dim, z2_dim, transfer_count, mechanism_count,
                 antibiotic_count):
        super(Causal_ARG, self).__init__()
        self.feature = nn.Sequential(
            # (batch * 1 * 1576 * 23) -> (batch * 32 * 1537 * 20)
            nn.Conv2d(1, 32, kernel_size=(40, 4), ),
            nn.LeakyReLU(),
            # (batch * 32 * 1537 * 20) -> (batch * 32 * 1533 * 19)
            nn.MaxPool2d(kernel_size=(5, 2), stride=1),
            # (batch * 32 * 1533 * 19) -> (batch * 64 * 1504 * 16)
            nn.Conv2d(32, 64, kernel_size=(30, 4)),
            nn.LeakyReLU(),
            # (batch * 64 * 1504 * 16) -> (batch * 128 * 1475 * 13)
            nn.Conv2d(64, 128, kernel_size=(30, 4)),
            nn.LeakyReLU(),
            # (batch * 128 * 1475 * 13) -> (batch * 128 * 1471 * 12)
            nn.MaxPool2d(kernel_size=(5, 2), stride=1),
            # (batch * 128 * 1471, 12) -> (batch * 256 * 1452 * 10)
            nn.Conv2d(128, 256, kernel_size=(20, 3)),
            nn.LeakyReLU(),
            # (batch * 256 * 1452 * 10) -> (batch * 256 * 1433 * 8)
            nn.Conv2d(256, 256, kernel_size=(20, 3)),
            nn.LeakyReLU(),
            # (batch * 256 * 1433 * 8) -> (batch * 256 * 1430 * 8)
            nn.MaxPool2d(kernel_size=(4, 1), stride=1),
            # (batch * 256 * 1430 * 8) -> (batch * 1 * 1411 * 6)
            nn.Conv2d(256, 1, kernel_size=(20, 3)),
            nn.LeakyReLU(),
            # (batch * 1 * 1411 * 6) -> (batch * 1 * 1410 * 6)
            nn.MaxPool2d(kernel_size=(2, 1), stride=1)
        )
        self.fc = nn.Sequential(
            nn.Linear(8460, 1024),
            nn.LeakyReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(1024, X_dim),
            nn.LeakyReLU()
        )

        self.gauus = Gauussion(X_dim + 3, G_dim)
        self.hidden = Hidden(X_dim, G_dim, z1_dim, z2_dim)
        self.causal = Causal(z1_dim + z2_dim, transfer_count, mechanism_count, antibiotic_count)

    def forward(self, seq_map, transfer_label, mechanism_label, antibiotic_label):
        x = self.feature(seq_map)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        transfer_label.unsqueeze(-1)
        mechanism_label.unsqueeze(-1)
        antibiotic_label.unsqueeze(-1)

        labels = [transfer_label, mechanism_label, antibiotic_label]
        labels = torch.cat(labels, dim=1)
        mean_logvar1, mean_logvar2, mean_logvar3, mean_logvar4, prob, G = self.gauus(torch.cat((x, labels), dim=1))

        z1, z2 = self.hidden(x, G)
        hidden_representation = torch.cat((z1, z2), dim=1)

        transfer_pre, mechanism_pre, antibiotic_pre = self.causal(hidden_representation)

        return transfer_pre, mechanism_pre, antibiotic_pre, mean_logvar1, mean_logvar2, mean_logvar3, mean_logvar4, prob


class Hidden(nn.Module):
    def __init__(self, X_dim, G_dim, z1_dim, z2_dim):
        super(Hidden, self).__init__()
        self.concat_dim = X_dim + G_dim
        self.hidden1 = nn.Sequential(
            nn.Linear(self.concat_dim, self.concat_dim),
            nn.LeakyReLU(),
            nn.Linear(self.concat_dim, z1_dim),
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(G_dim, G_dim),
            nn.LeakyReLU(),
            nn.Linear(G_dim, z2_dim),
        )

    def forward(self, X, G):
        input = torch.cat((X, G), dim=1)

        z1 = self.hidden1(input)
        z2 = self.hidden2(G)
        return z1, z2


class Causal(nn.Module):
    def __init__(self, input_dim, transfer_count, mechanism_count, antibiotic_count):
        super(Causal, self).__init__()
        self.transfer_layer = nn.Linear(input_dim, transfer_count)
        self.softmax = nn.Softmax(dim=1)

        self.mechanism_layer = nn.Linear(input_dim + transfer_count, mechanism_count)
        self.antibiotic_layer = nn.Linear(input_dim + transfer_count + mechanism_count, antibiotic_count)

    def forward(self, input):
        transfer_pre = self.softmax(self.transfer_layer(input))
        mechanism_pre = self.softmax(self.mechanism_layer(torch.cat((input, transfer_pre), dim=1)))
        antibiotic_pre = self.softmax(self.antibiotic_layer(torch.cat((input, transfer_pre, mechanism_pre), dim=1)))

        return transfer_pre, mechanism_pre, antibiotic_pre


class Gauussion(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Gauussion, self).__init__()
        self.hidden = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
        )
        self.mean_logvar1 = nn.Linear(hidden_dim, 2 * hidden_dim)
        self.mean_logvar2 = nn.Linear(hidden_dim, 2 * hidden_dim)
        self.mean_logvar3 = nn.Linear(hidden_dim, 2 * hidden_dim)
        self.mean_logvar4 = nn.Linear(hidden_dim, 2 * hidden_dim)

        self.softmax = nn.Softmax(dim=1)
        self.prob = nn.Linear(hidden_dim, 4)

    def forward(self, x):
        hidden = self.hidden(x)
        mean_logvar1 = self.mean_logvar1(hidden)
        mean_logvar2 = self.mean_logvar2(hidden)
        mean_logvar3 = self.mean_logvar3(hidden)
        mean_logvar4 = self.mean_logvar4(hidden)
        prob = self.softmax(self.prob(hidden))
        values = [0, 1, 2, 3]
        g_list = []
        mid = hidden.size()[1]
        for i, pro in enumerate(prob):

            value = random.choices(values, pro.tolist())[0]

            if (value == 0):
                g = mean_logvar1[i][:mid] + torch.rand_like(mean_logvar1[i][mid:]) * mean_logvar1[i][mid:]
            elif (value == 1):
                g = mean_logvar2[i][:mid] + torch.rand_like(mean_logvar2[i][mid:]) * mean_logvar2[i][mid:]
            elif (value == 2):
                g = mean_logvar3[i][:mid] + torch.rand_like(mean_logvar3[i][mid:]) * mean_logvar3[i][mid:]
            else:
                g = mean_logvar4[i][:mid] + torch.rand_like(mean_logvar4[i][mid:]) * mean_logvar4[i][mid:]

            g_list.append(g)


        final_g = torch.stack(g_list)
        return mean_logvar1, mean_logvar2, mean_logvar3, mean_logvar4, prob, final_g


## Dataloader

In [16]:
import torch.utils.data as tud
import pandas as pd
import torch

In [17]:
class ARGDataSet(tud.Dataset):
    def __init__(self, data):
        super(ARGDataSet, self).__init__()
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return torch.FloatTensor(self.data.iloc[item]['seq_map']), self.data.iloc[item]['type_label'], self.data.iloc[item]['mech_label'], self.data.iloc[item]['anti_label']

class ARGDataLoader(object):
    def __init__(self):
        print("loading data...")
        self.antibiotic_count, self.mechanism_count, self.transfer_count = 15, 6, 2

    def load_test_dataSet(self, batch_size):
        print('loading test data...')
        test_data = pd.read_pickle('./data/test/test.pickle')
        test_data = tud.DataLoader(ARGDataSet(test_data), batch_size=batch_size, shuffle=True, num_workers=0)
        return test_data

    def load_n_cross_data(self, k, batch_size):
        print('loading cross_' + str(k) + ' train_val data ...')
        train_data = pd.read_pickle('data/train_val/cross_' + str(k) + '_train.pickle')
        val_data = pd.read_pickle('data/train_val/cross_' + str(k) + '_val.pickle')
        train_data = tud.DataLoader(ARGDataSet(train_data), batch_size=batch_size, shuffle=True, num_workers=0)
        val_data = tud.DataLoader(ARGDataSet(val_data), batch_size=batch_size, shuffle=True, num_workers=0)
        return train_data, val_data

    def get_data_shape(self):
        return self.transfer_count, self.mechanism_count, self.antibiotic_count

## Evaluation Functions

In [18]:
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score
import numpy as np

def arr2hot(arr, N):
    res = [0] * N
    for e in arr:
        res[e - 1] = 1
    return res

def evaluate(pred, trues, classes):
    count = pred.shape[0]
    preds = []
    lebels = [i for i in range(classes)]
    for i in range(count):
        preds.append(np.argmax(pred[i]))
    acc = accuracy_score(trues, preds)
    # micro-precision
    micro_p = precision_score(trues, preds, labels=lebels, average='micro')
    # micro-recall
    micro_r = recall_score(trues, preds, labels=lebels, average='micro')
    # micro f1-score
    micro_f1 = f1_score(trues, preds, labels=lebels, average='micro')

    # macro-precision
    macro_p = precision_score(trues, preds, average='macro')
    # macro-recall
    macro_r = recall_score(trues, preds, average='macro')
    # macro f1-score
    macro_f1 = f1_score(trues, preds, average='macro')

    return acc, macro_p, macro_r, macro_f1

## Train

In [19]:
import torch
import numpy as np
import pandas as pd
import torch.optim as optim
import torch.nn as nn
from argparse import ArgumentParser
torch.backends.cudnn.enabled = False
import warnings
import random
from random import randint
warnings.filterwarnings("ignore")

In [24]:
device = torch.device("cuda:0")

train_rate = 0.8
batch_size = 16
lr = 0.0001
epoch = 20
K = 1
X_dim, G_dim = 64, 64
z1_dim, z2_dim = 64, 64
# print(str(args))
dataloader = ARGDataLoader()

transfer_count, mechanism_count, antibiotic_count = dataloader.get_data_shape()
os.mkdir("./res")
alpha, beta, yita, tao= 1, 0.2, 0.2, 0.2

loading data...


In [25]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_num = randint(1,1000)
# seed_num = 456
setup_seed(493)
print(seed_num)

493


In [26]:
def gauss_loss(mean_logvar1, mean_logvar2, mean_logvar3, mean_logvar4 ,
               batch_mean1, batch_mean2, batch_mean3, batch_mean4):
    loss1 = torch.sum((mean_logvar1 - batch_mean1) ** 2, dim=0)
    loss2 = torch.sum((mean_logvar2 - batch_mean2) ** 2, dim=0)
    loss3 = torch.sum((mean_logvar3 - batch_mean3) ** 2, dim=0)
    loss4 = torch.sum((mean_logvar4 - batch_mean4) ** 2, dim=0)

    loss = torch.sum(loss1) + torch.sum(loss2) + torch.sum(loss3) + torch.sum(loss4)
    return loss

# Add constraint for mean differences
def add_mean_constraint(batch_mean_list, lambda_value):
    mean_diff_loss = 0.0
    num_distributions = len(batch_mean_list)

    # Calculate mean differences
    for i in range(num_distributions - 1):
        for j in range(i + 1, num_distributions):
            mean_diff_loss -= lambda_value * torch.sum((10 * (batch_mean_list[i] - batch_mean_list[j])) ** 2)

    return mean_diff_loss

In [None]:
t_transfer_acc, t_transfer_precision, t_transfer_recall, t_transfer_f1 = 0, 0, 0, 0
t_antibiotic_acc, t_antibiotic_precision, t_antibiotic_recall, t_antibiotic_f1 = 0, 0, 0, 0
t_mechanism_acc, t_mechanism_precision, t_mechanism_recall, t_mechanism_f1 = 0, 0, 0, 0

test_dataloader = dataloader.load_test_dataSet(batch_size)

for k in range(K):
    print('Cross ', k + 1, ' of ', K)

    model = Causal_ARG(X_dim, G_dim, z1_dim, z2_dim, transfer_count, mechanism_count, antibiotic_count)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    transfer_loss_function = nn.NLLLoss()
    antibiotic_loss_function = nn.NLLLoss()
    mechanism_loss_function = nn.NLLLoss()

    train_dataloader, val_dataloader = dataloader.load_n_cross_data(k + 1, batch_size)

    running_loss = 0.0
    for e in range(epoch):
        mean1_list = []
        mean2_list = []
        mean3_list = []
        mean4_list = []
        prob_list = []
        df = pd.DataFrame()
        model.train()
        print('train batch: ', len(train_dataloader))
        for index, (seq_map, transfer_label, mechanism_label, antibiotic_label) in enumerate(train_dataloader):


            seq_map, transfer_label, mechanism_label, antibiotic_label = seq_map.view(-1, 1, 1576, 23).to(device), \
                transfer_label.to(device), mechanism_label.to(device), antibiotic_label.to(device)


            optimizer.zero_grad()
            transfer_output, mechanism_output, antibiotic_output, mean_logvar1, mean_logvar2, mean_logvar3,mean_logvar4, prob = model.forward(seq_map, transfer_label.view(-1, 1), mechanism_label.view(-1, 1), antibiotic_label.view(-1, 1))

            loss_transfer = transfer_loss_function(torch.log(transfer_output + 0.000001), transfer_label)
            loss_mechanism = mechanism_loss_function(torch.log(mechanism_output + 0.000001), mechanism_label)
            loss_antibiotic = antibiotic_loss_function(torch.log(antibiotic_output + 0.000001), antibiotic_label)

            batch_mean1 = (torch.sum(prob[:, 0, None] * mean_logvar1, dim=0)) / torch.sum(prob[:, 0], dim=0)
            mean1_list.append(batch_mean1)

            # batch_mean2 = torch.mean(outputs[1], dim=0)
            batch_mean2 = (torch.sum(prob[:, 1, None] * mean_logvar2, dim=0)) / torch.sum(prob[:, 1], dim=0)
            mean2_list.append(batch_mean2)

            # batch_mean3 = torch.mean(outputs[2], dim=0)
            batch_mean3 = (torch.sum(prob[:, 2, None] * mean_logvar3, dim=0)) / torch.sum(prob[:, 2], dim=0)
            mean3_list.append(batch_mean3)

            # batch_mean4 = torch.mean(outputs[2], dim=0)
            batch_mean4 = (torch.sum(prob[:, 3, None] * mean_logvar4, dim=0)) / torch.sum(prob[:, 3], dim=0)
            mean4_list.append(batch_mean4)

            batch_prob = torch.mean(prob, dim=0)
            prob_list.append(batch_prob)

            batch_mean_list = [batch_mean1[:len(batch_mean1)], batch_mean2[:len(batch_mean2)],
                               batch_mean3[:len(batch_mean3)], batch_mean4[:len(batch_mean4)]]

            loss_gauss = gauss_loss(mean_logvar1, mean_logvar2, mean_logvar3, mean_logvar4,
                                    batch_mean1, batch_mean2, batch_mean3, batch_mean4)

            loss = alpha * loss_antibiotic + beta * loss_mechanism + yita * loss_transfer + tao * loss_gauss
            +add_mean_constraint(batch_mean_list, 0.2)
            

            running_loss += loss.item()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            df = df._append({'loss_transfer': loss_transfer.item(), 'loss_antibiotic': loss_antibiotic.item(), 'loss_mechanism': loss_mechanism.item(), 'loss': loss.item(), 'running_loss': running_loss}, ignore_index=True)
            if index % 50 == 49:
                print('[%d, %2d, %5d] loss: %.3f' % (k + 1, e + 1, index + 1, running_loss / 50))
                running_loss = 0.0

        df.to_csv('./res/loss_cross' + str(k + 1) + '_epoch' + str(e) + '.csv')
        model.eval()
        val_transfer_pred, val_transfer_label = np.empty(shape=[0, transfer_count]), np.array([])
        val_mechanism_pred, val_mechanism_label = np.empty(shape=[0, mechanism_count]), np.array([])
        val_antibiotic_pred, val_antibiotic_label = np.empty(shape=[0, antibiotic_count]), np.array([])

        for index, (seq_map, transfer_label, mechanism_label, antibiotic_label) in enumerate(val_dataloader):
            seq_map, transfer_label, mechanism_label, antibiotic_label = seq_map.view(-1, 1, 1576, 23).to(device), transfer_label.to(device), mechanism_label.to(device), antibiotic_label.to(device)
            transfer_output, mechanism_output, antibiotic_output, mean_logvar1, mean_logvar2, \
                mean_logvar3, mean_logvar4, prob = model.forward(seq_map,
                torch.full((transfer_label.view(-1, 1).shape), -1).to(device),
                torch.full((transfer_label.view(-1, 1).shape), -1).to(device),
                torch.full((transfer_label.view(-1, 1).shape), -1).to(device))

            transfer_output, transfer_label = transfer_output.cpu().detach().numpy(), transfer_label.cpu().numpy()
            val_transfer_pred = np.append(val_transfer_pred, transfer_output, axis=0)
            val_transfer_label = np.concatenate((val_transfer_label, transfer_label))

            antibiotic_output, antibiotic_label = antibiotic_output.cpu().detach().numpy(), antibiotic_label.cpu().numpy()
            val_antibiotic_pred = np.append(val_antibiotic_pred, antibiotic_output, axis=0)
            val_antibiotic_label = np.concatenate((val_antibiotic_label, antibiotic_label))

            mechanism_output, mechanism_label = mechanism_output.cpu().detach().numpy(), mechanism_label.cpu().numpy()
            val_mechanism_pred = np.append(val_mechanism_pred, mechanism_output, axis=0)
            val_mechanism_label = np.concatenate((val_mechanism_label, mechanism_label))

        print('-------------Val: epoch ' + str(e + 1) + '-----------------')
        acc, macro_p, macro_r, macro_f1 = evaluate(val_transfer_pred, val_transfer_label, transfer_count)
        print('transfer -> acc: {}, precision: {}, recall: {}, f1: {}'.format(acc, macro_p, macro_r, macro_f1))
        acc, macro_p, macro_r, macro_f1 = evaluate(val_mechanism_pred, val_mechanism_label, mechanism_count)
        print('mechanism -> acc: {}, precision: {}, recall: {}, f1: {}'.format(acc, macro_p, macro_r, macro_f1))
        acc, macro_p, macro_r, macro_f1 = evaluate(val_antibiotic_pred, val_antibiotic_label, antibiotic_count)
        print('antibiotic -> acc: {}, precision: {}, recall: {}, f1: {}'.format(acc, macro_p, macro_r, macro_f1))

    model.eval()
    test_transfer_pred, test_transfer_label = np.empty(shape=[0, transfer_count]), np.array([])
    test_antibiotic_pred, test_antibiotic_label = np.empty(shape=[0, antibiotic_count]), np.array([])
    test_mechanism_pred, test_mechanism_label = np.empty(shape=[0, mechanism_count]), np.array([])
    for index, (seq_map, transfer_label, mechanism_label, antibiotic_label) in enumerate(test_dataloader):
        seq_map, transfer_label, mechanism_label, antibiotic_label = seq_map.view(-1, 1, 1576, 23).to(device), transfer_label.to(device), mechanism_label.to(device), antibiotic_label.to(device)

        transfer_output, mechanism_output, antibiotic_output, mean_logvar1, mean_logvar2, \
            mean_logvar3, mean_logvar4, prob = model.forward(seq_map,
            torch.full((transfer_label.view(-1, 1).shape), -1).to(device),
            torch.full((transfer_label.view(-1, 1).shape), -1).to(device),
            torch.full((transfer_label.view(-1, 1).shape), -1).to(device))

        transfer_output, transfer_label = transfer_output.cpu().detach().numpy(), transfer_label.cpu().numpy()
        test_transfer_pred = np.append(test_transfer_pred, transfer_output, axis=0)
        test_transfer_label = np.concatenate((test_transfer_label, transfer_label))

        antibiotic_output, antibiotic_label = antibiotic_output.cpu().detach().numpy(), antibiotic_label.cpu().numpy()
        test_antibiotic_pred = np.append(test_antibiotic_pred, antibiotic_output, axis=0)
        test_antibiotic_label = np.concatenate((test_antibiotic_label, antibiotic_label))

        mechanism_output, mechanism_label = mechanism_output.cpu().detach().numpy(), mechanism_label.cpu().numpy()
        test_mechanism_pred = np.append(test_mechanism_pred, mechanism_output, axis=0)
        test_mechanism_label = np.concatenate((test_mechanism_label, mechanism_label))

    print('========Test: Cross ' + str(k + 1) + '===============')
    acc, macro_p, macro_r, macro_f1 = evaluate(test_transfer_pred, test_transfer_label, transfer_count)
    print('transfer -> acc: {}, precision: {}, recall: {}, f1: {}'.format(acc, macro_p, macro_r, macro_f1))
    t_transfer_acc += acc
    t_transfer_precision += macro_p
    t_transfer_recall += macro_r
    t_transfer_f1 += macro_f1

    acc, macro_p, macro_r, macro_f1 = evaluate(test_mechanism_pred, test_mechanism_label, mechanism_count)
    print('mechanism -> acc: {}, precision: {}, recall: {}, f1: {}'.format(acc, macro_p, macro_r, macro_f1))
    t_mechanism_acc += acc
    t_mechanism_precision += macro_p
    t_mechanism_recall += macro_r
    t_mechanism_f1 += macro_f1

    acc, macro_p, macro_r, macro_f1 = evaluate(test_antibiotic_pred, test_antibiotic_label, antibiotic_count)
    print('antibiotic -> acc: {}, precision: {}, recall: {}, f1: {}'.format(acc, macro_p, macro_r, macro_f1))
    t_antibiotic_acc += acc
    t_antibiotic_precision += macro_p
    t_antibiotic_recall += macro_r
    t_antibiotic_f1 += macro_f1



    torch.save(model.state_dict(), './res/model{}.pth'.format(k))
print('transfer => final acc: {}, final precision: {}, final recall: {}, final f1: {}\n'.format(t_transfer_acc / K, t_transfer_precision / K,
                                                                                             t_transfer_recall / K,
                                                                                             t_transfer_f1 / K))
print('mechanism => final acc: {}, final precision: {}, final recall: {}, final f1: {}\n'.format(t_mechanism_acc / K, t_mechanism_precision / K,
                                                                                             t_mechanism_recall / K,
                                                                                             t_mechanism_f1 / K))
print('antibiotic => final acc: {}, final precision: {}, final recall: {}, final f1: {}\n'.format(t_antibiotic_acc / K, t_antibiotic_precision / K,
                                                                                             t_antibiotic_recall / K,
                                                                                             t_antibiotic_f1 / K))

loading test data...
Cross  1  of  1
loading cross_1 train_val data ...
train batch:  692
[1,  1,    50] loss: 7.443
[1,  1,   100] loss: 5.881
[1,  1,   150] loss: 5.271
[1,  1,   200] loss: 5.077
[1,  1,   250] loss: 4.866
[1,  1,   300] loss: 4.880
[1,  1,   350] loss: 4.853
[1,  1,   400] loss: 4.856
[1,  1,   450] loss: 4.681
[1,  1,   500] loss: 4.654
[1,  1,   550] loss: 4.473
[1,  1,   600] loss: 4.401
[1,  1,   650] loss: 4.420
-------------Val: epoch 1-----------------
transfer -> acc: 0.49005424954792043, precision: 0.4513135419350529, recall: 0.49466073368342395, f1: 0.34789306560259486
mechanism -> acc: 0.6264014466546112, precision: 0.2942309708441038, recall: 0.2950448513898634, f1: 0.2938222824327348
antibiotic -> acc: 0.4835443037974684, precision: 0.08807422518435136, recall: 0.14264141841485362, f1: 0.10747968675498125
train batch:  692
[1,  2,    50] loss: 7.377
[1,  2,   100] loss: 3.676
[1,  2,   150] loss: 3.646
[1,  2,   200] loss: 3.592
[1,  2,   250] loss: 3.5