In [None]:
import time
import pandas as pd
import random
import argparse
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import  DataLoader
from torch.utils.data import TensorDataset
from torch.autograd import Variable

from sklearn.metrics import confusion_matrix,accuracy_score,roc_auc_score,f1_score, matthews_corrcoef
from sklearn.metrics import precision_score, recall_score
from sklearn.metrics import precision_recall_curve, auc

torch.cuda.set_device(1)

import model

In [None]:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=200):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return self.dropout(x)


class Discriminator(nn.Module):
    def __init__(self, num_classes = 2, vocab_size = 4, embedding_dim = 4, nhead =2, dropout  = 0.2, seq_len = 200 ):
        super(Discriminator, self).__init__()
        self.hidden_dim = 64
        self.embed = nn.Embedding(vocab_size , embedding_dim)
        ######## positioal Embedding ########
        self.pe = PositionalEncoding(d_model = embedding_dim, dropout = dropout, max_len = seq_len)
        self.ln = nn.Linear(embedding_dim, self.hidden_dim )
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.hidden_dim, nhead = nhead, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
        ######## layer CNN ########
        self.conv1 = nn.Sequential(
                    nn.Conv1d(self.hidden_dim, 128, 5, 1, 2),  
                    nn.ReLU(),
                )
        self.conv2 = nn.Sequential(
                    nn.Conv1d(128, 256, 5, 1, 2), 
                    nn.ReLU(),
                )
        self.conv3 = nn.Sequential(
                    nn.Conv1d(256, 256, 5, 1, 2),  
                    nn.ReLU(),
                )
        ######## layer norm ########
        self.normlayer1 = nn.LayerNorm(self.hidden_dim)
        self.normlayer2 = nn.LayerNorm(256)
        ######## fc ########
        self.fc = nn.Linear(256 * seq_len , num_classes)
        self.activation = nn.LogSoftmax(dim=1)
    def forward(self,x):
        ######## Embedding ########
        x = self.embed(x)  # batch_size, seq_len, emb_dim =  batch_size*200*2
        ######## Pe ########
        se = x        
        x = self.pe(x)  # batch_size, seq_len, emb_dim
        x = x + se
        x = self.ln(x)
        se = x
        x = self.transformer_encoder(x)        
        x = x + se
        ######## conv and lstm ########
        x = x.permute(0,2,1)  # batch_size, hidden_dim, seq_len
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.permute(0,2,1)  # batch_size, hidden_dim, seq_len
        ####### layer norm ########
        x = self.normlayer2(x)
        x = self.normlayer2(x)
        ######## fc ########
        x =  x.contiguous().view((x.size()[0], -1)) # batch_size, seq_len*hidden_dim*2
        x = self.fc(x)
        x = self.activation(x)
        return x     

setting parameters and vocab

In [None]:
# Arguemnts
parser = argparse.ArgumentParser(description='SeqGAN')
parser.add_argument('--hpc', action='store_true', default=False,
                    help='set to hpc mode')
parser.add_argument('--data_path', type=str, default='.seq_gan/', metavar='PATH',
                    help='data path to save files (default: /scratch/zc807/seq_gan/)')
parser.add_argument('--vocab_size', type=int, default=4, metavar='N',
                    help='vocabulary size (default: 10)')
parser.add_argument('--no_cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')


In [None]:
args = parser.parse_args(args=[])
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
if not args.hpc:
    args.data_path = ''

In [None]:


vocab =['Bac', 'Ent', 'Esc','Lis','Pse', 'Sal', 'Sta','Sac','chr1','chr2','chr3','chr4','chr5','chr6','chr7','chr8','chr9','chr10','chr11','chr12','chr13','chr14','chr15','chr16','chr17','chr18','chr19','chr20','chr21','chr22','chrX','chrY','human','WH','Alpha','Beta','Delta','Gamma','Omicron','yeast',">1",	">2",	">4",	">3",	">6",	">5",	">X",	">7",	">8",	">12",	">10",	">11",	">9",	">16",	">13",	">17",	">14",	">15",	">18",	">19",	">20",	">21",	">22",	">KI270438.1",	">KI270733.1",	">Y",	">GL000220.1",	">GL000225.1",	">MT",	">GL000224.1",	">KI270709.1",	">KI270728.1",	">KI270729.1",	">KI270736.1",	">GL000216.2",	">GL000218.1",	">KI270442.1",	">KI270467.1",	">GL000195.1",	">GL000205.2",	">GL000221.1",	">GL000226.1",	">KI270333.1",	">KI270538.1",	">KI270589.1",	">KI270715.1",	">KI270720.1",	">KI270722.1",	">KI270723.1",	">KI270735.1",	">KI270749.1",	">KI270750.1",	">KI270754.1",	">Staphylococcus_aureus_chromosome",	">Pseudomonas_aeruginosa_complete_genome",	">Listeria_monocytogenes_complete_genome",	">Enterococcus_faecalis_complete_genome",	">BS.pilon.polished.v3.ST170922",	">Escherichia_coli_chromosome",	">Salmonella_enterica_complete_genome",	">Escherichia_coli_plasmid",	">tig00000001",	">tig00000003",	">tig00000018",	">Staphylococcus_aureus_plasmid1",	">tig00000023",	">tig00000308",	">tig00000036",	">tig00000136",	">tig00000031",	">tig00000051",	">tig00000071",	">tig00000072",	">tig00000104",	">tig00000109",	">tig00000011",	">tig00000055",	">tig00000306",	">tig00000063",	">tig00000069",	">tig00000139",	">tig00000140",	">tig00000006",	">tig00000042",	">tig00000080",	">tig00000094",	">tig00000105",	">tig00000307",'fly','omicron','phage','sars','ebola','mouse','tick','HAdV']

class label2int:
    def __init__(self, vocab =vocab):
        self.int_map = {}
        self.base_map = {}
        for ind, base in enumerate(vocab):
            self.int_map[base] = ind
            self.base_map[ind] = base

    def text_to_int(self, text):
        """ Use a character map and convert text to an integer sequence """
        int_sequence = []
        ch = self.int_map[text]
        int_sequence.append(ch)
        return int_sequence
    
    def int_to_text(self, labels):
        """ Use a character map and convert integer labels to an text sequence """
        string = []
        for i in labels:
            string.append(self.base_map[i])
        return ''.join(string)
    
    
convert = label2int()

def get_words_from_indices(indices, vocab):
    result = [vocab[index] for index in indices]
    return result

def compare_tensors(t1, t2, t3):
    mask = t1 != t2
    result = t3[mask]
    return result

def count_words(words):
    word_counts = {}
    for word in words:
        if word in word_counts:
            word_counts[word] += 1
        else:
            word_counts[word] = 1
    return word_counts

def words_fun(err_labels):
    new_lst1  = [element for sublist in err_labels for element in sublist]
    err_labels_ls  = [element for sublist in new_lst1 for element in sublist]
    err_labels_result = get_words_from_indices(err_labels_ls, vocab)
    word_counts = count_words(err_labels_result)
    return word_counts


def add_first_elements(data, data_list):
    first_elements = [row[0] for row in data]
    data_list.extend(first_elements)
def add_chrom_elements(data, target_list):
    chrom_elements = [row[1] for row in data]
    target_list.extend(chrom_elements)


def read_test_file(data_file):
    loadData = np.load(data_file, allow_pickle=True)
    data_l = loadData.tolist()
    
    data_list = []
    target_list = []
    add_first_elements(data_l, data_list)
    add_chrom_elements(data_l, target_list)

    lis = []
    for line in data_list:
        l = [int(s) for s in list(line.strip().split())]
        lis.append(l)  
    
    return lis , target_list

In [None]:
test_file = '/public/data1/zhangyx/Projects/Project_seqGAN/data/test_data/human/human_200000_Zymo_200000.npy'
num = 200000  # the number of host

In [None]:
time_start = time.time()
real_data_lis, target_list = read_test_file(test_file)
datas = real_data_lis
data_label = [convert.text_to_int(i) for i in target_list]
targets= [0 for _ in range(num)] + [1 for _ in range(len(datas)-num)]
tensor_dat= TensorDataset(torch.tensor(datas), torch.tensor(targets).long() , torch.tensor(data_label) )
test_data_loader = DataLoader(dataset=tensor_dat, batch_size =500, shuffle=True)

print('Finish load data..')


nll_loss = nn.NLLLoss()
discriminator = Discriminator(num_classes = 2, vocab_size = 4, embedding_dim = 4, nhead =2, dropout  = 0.2, seq_len = seqlen)


time1 = time.time()
dis_path = '../../model/human_model.pt'
dis_dict = torch.load(dis_path, map_location='cpu')# 先加载参数
discriminator.load_state_dict(dis_dict,False)  # 再让模型加载参数, 恢复得到模型

if args.cuda:
    discriminator = discriminator.cuda()
    nll_loss = nll_loss.cuda()
    cudnn.benchmark = True


correct = 0
error = 0
total_loss = 0.
flag = 0
neg_flag= 0
pos_flag= 0
y_true = []
y_pred = []
y_score = []
err_labels =[]
err_datas = []
pred_lst=[]

with torch.no_grad():
    for data, target ,data_label in test_data_loader: 
        data, target, data_label= data.cuda(), target.cuda(), data_label.cuda()

        target = target.contiguous().view(-1)
        output = discriminator(data)

        result= torch.exp(output)
        score = output[:,0]
        pred =[]
        for x in result[:,1]: 
            if x > 0.1: 
                pred.append(0) 
            else: 
                pred.append(1) 
        pred = torch.tensor(pred).cuda()

        error_label = compare_tensors(pred.data, target.data, data_label.data)
        error_label = error_label.tolist()
        error_data = compare_tensors(pred.data, target.data, data.data)
        error_data = error_data.tolist()
        correct += pred.eq(target.data).cpu().sum()
        flag+=1

        y_true += target.cpu().tolist()
        y_pred += pred.cpu().tolist()
        y_score += score.cpu().tolist()
        err_labels.append(error_label)
        err_datas.append(error_data)
                    
    cm=confusion_matrix(y_true, y_pred)
    cm_dict = {'tn': cm[0, 0], 'fp': cm[0, 1],'fn': cm[1, 0], 'tp': cm[1, 1]}
    acc = accuracy_score(y_true, y_pred)
    roc_auc = roc_auc_score(y_true, y_score)
    f1 = f1_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    mcc=matthews_corrcoef(y_true, y_pred)
    specificity = cm_dict['tn']/(cm_dict['tn']+cm_dict['fp'])
    NPV = cm_dict['tn']/(cm_dict['tn']+cm_dict['fn'])

    p, r, t = precision_recall_curve(y_true, y_score)
    pr_auc = auc(r, p)
    # enrich_ratio = (cm_dict['tn']/ (cm_dict['tn'] + cm_dict['fn'])) / (( cm_dict['tn'] + cm_dict['fp'] ) / ( cm_dict['tn'] + cm_dict['tp'] + cm_dict['fn'] + cm_dict['fp']))   ## if you want to cal host enrichment
    enrich_ratio = (cm_dict['tp']/ (cm_dict['tp'] + cm_dict['fp'])) / (( cm_dict['tp'] + cm_dict['fn'] ) / ( cm_dict['tp'] + cm_dict['tn'] + cm_dict['fp'] + cm_dict['fn']))    ## pathogen enrichment 


    print("dis eval acc:{:.4f},dis auc: {:.4f}: ,dis pr-auc: {:.4f}: ,  f1 {:.4f}, mcc {:.4f}, recall {:.4f}, precision {:.4f}, specificity: {:.4f} ,NPV,{:.4f} enrich_ratio,{:.2f}\n".format(acc,roc_auc , pr_auc, f1, mcc, recall, precision, specificity, NPV, enrich_ratio))
    pred_lst.append([ 'human & Zymo', cm_dict['tp'], cm_dict['tn'], cm_dict['fp'], cm_dict['fn'], acc, roc_auc,pr_auc, f1, mcc, recall, precision, specificity, NPV, enrich_ratio])

    print('time',time.time()-time1)
    print('time per read',(time.time()-time1)/len(targets))

    print('time_with_data_process',time.time()-time_start)
    print('time per read',(time.time()-time1)/len(targets))



In [None]:
seqlen=200

In [None]:
pred_df = pd.DataFrame( pred_lst, columns = [ 'epoch','t-target', 't-host', 'f-target', 'f-nontarget', 'acc', 'roc_auc', 'pr-auc', 'f1', 'mcc', 'recall', 'precision', 'specificity', 'NPV','enrich_ratio'] )
pred_df


In [None]:
# npy_file = '/public/data1/zhangyx/Projects/Project_seqGAN/data/test_data/human/200kNA12878/NA12878_200000_zymo_2_1.npy'

npy_file = '/public/data1/zhangyx/Projects/Project_seqGAN/data/test_data/human/NA12878_200_ebola_200.npy'
# npy_file = '/public/data1/zhangyx/Projects/Project_seqGAN/data/test_data/human/hg002_200_ebola_200.npy'
num =200

time_start = time.time()

real_data_lis, target_list = read_test_file(npy_file)
# print('real_data_lis',real_data_lis)
datas = real_data_lis
print(len(datas))

data_label = [convert.text_to_int(i) for i in target_list]


# targets= [0 for _ in range(len(datas)-num)] +  [1 for _ in range(num)]
# targets= [1 for _ in range(num)] + [0 for _ in range(len(datas)-num)]
targets= [0 for _ in range(num)] + [1 for _ in range(len(datas)-num)]


# print(f"物种: {sp}, 标签长度: {len(targets)}")

tensor_dat= TensorDataset(torch.tensor(datas), torch.tensor(targets).long() , torch.tensor(data_label) )
test_data_loader = DataLoader(dataset=tensor_dat, batch_size =500, shuffle=True)

print('Finish load data..')
print('len(test_data_loader)',len(test_data_loader))

args = parser.parse_args(args=[])
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
if not args.hpc:
    args.data_path = ''
nll_loss = nn.NLLLoss()

discriminator = Discriminator(num_classes = 2, vocab_size = 4, embedding_dim = 4, nhead =2, dropout  = 0.2, seq_len = seqlen)
# discriminator = Discriminator()
# 
# for epoch in range(2,3):
for t_i in np.arange(0.1,0.2,0.1):
# for t_i in np.arange(0.9,1,0.1):

    time1 = time.time()
# for t_i in np.arange(0.1,1,0.1):
    # epoch =2
    dis_path = '../../save/20230924_human_step200_200bp_es10/model_dis10.pt'
    # dis_path = '../../save/20240516_human_step7000_200bp_400k_es15_lstm_5_4_32_test_ce/model_dis2.pt'

    # dis_path = '/public/data1/zhangyx/Projects/Project_seqGAN/save/20250106_tick_200_200/model_dis100.pt'
    # dis_path = '/public/data1/zhangyx/Projects/Project_seqGAN/save/20250106_tick_200_200/model_dis200.pt'
    # dis_path = '/public/data1/zhangyx/Projects/Project_seqGAN/save/20250102_mouse/model_dis19.pt'
    # dis_path = '../../save/20240516_human_step7000_200bp_400k_es15_lstm_5_4_32_test_ce/model_dis'+str(epoch)+'.pt'


    dis_dict = torch.load(dis_path, map_location='cpu')# 先加载参数
    discriminator.load_state_dict(dis_dict,False)  # 再让模型加载参数, 恢复得到模型

    if args.cuda:
        discriminator = discriminator.cuda()
        nll_loss = nll_loss.cuda()
        cudnn.benchmark = True


    correct = 0
    error = 0
    total_loss = 0.
    flag = 0
    neg_flag= 0
    pos_flag= 0
    y_true = []
    y_pred = []
    y_score = []
    err_labels =[]
    err_datas = []

    with torch.no_grad():
        for data, target ,data_label in test_data_loader: 
            data, target, data_label= data.cuda(), target.cuda(), data_label.cuda()

            target = target.contiguous().view(-1)
            output = discriminator(data)

            # print(torch.exp(output[0]))
            result= torch.exp(output)
            # result= output
            
            # pred = probs_to_prediction(result, 0.5)
            # pred = output.data.max(1)[1]
            score = output[:,0]
            # score = output[:,1]
            pred =[]
            for x in result[:,1]: 
                if x > t_i: 
                    # pred.append(1) 
                    pred.append(0) 
                else: 
                    # pred.append(0) 
                    pred.append(1) 
            pred = torch.tensor(pred).cuda()

            error_label = compare_tensors(pred.data, target.data, data_label.data)
            error_label = error_label.tolist()
            
            error_data = compare_tensors(pred.data, target.data, data.data)
            error_data = error_data.tolist()
            

                

            correct += pred.eq(target.data).cpu().sum()
            # loss = ce_loss(output, target)
            # total_loss += loss.item()
            flag+=1

            y_true += target.cpu().tolist()
            y_pred += pred.cpu().tolist()
            y_score += score.cpu().tolist()
            err_labels.append(error_label)
            err_datas.append(error_data)
                        
            # print(err_labels)

        
        cm=confusion_matrix(y_true, y_pred)
        cm_dict = {'tn': cm[0, 0], 'fp': cm[0, 1],'fn': cm[1, 0], 'tp': cm[1, 1]}
        # print(cm_dict)
        acc = accuracy_score(y_true, y_pred)
        roc_auc = roc_auc_score(y_true, y_score)
        f1 = f1_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred)
        recall = recall_score(y_true, y_pred)
        mcc=matthews_corrcoef(y_true, y_pred)
        specificity = cm_dict['tn']/(cm_dict['tn']+cm_dict['fp'])
        NPV = cm_dict['tn']/(cm_dict['tn']+cm_dict['fn'])

        p, r, t = precision_recall_curve(y_true, y_score)
        pr_auc = auc(r, p)
        # enrich_ratio = (cm_dict['tn']/ (cm_dict['tn'] + cm_dict['fn'])) / (( cm_dict['tn'] + cm_dict['fp'] ) / ( cm_dict['tn'] + cm_dict['tp'] + cm_dict['fn'] + cm_dict['fp']))
        enrich_ratio = (cm_dict['tp']/ (cm_dict['tp'] + cm_dict['fp'])) / (( cm_dict['tp'] + cm_dict['fn'] ) / ( cm_dict['tp'] + cm_dict['tn'] + cm_dict['fp'] + cm_dict['fn']))


        print("dis eval acc:{:.4f},dis auc: {:.4f}: ,dis pr-auc: {:.4f}: ,  f1 {:.4f}, mcc {:.4f}, recall {:.4f}, precision {:.4f}, specificity: {:.4f} ,NPV,{:.4f} enrich_ratio,{:.2f}\n".format(acc,roc_auc , pr_auc, f1, mcc, recall, precision, specificity, NPV, enrich_ratio))
        pred_lst.append([ 'ebola', cm_dict['tp'], cm_dict['tn'], cm_dict['fp'], cm_dict['fn'], acc, roc_auc,pr_auc, f1, mcc, recall, precision, specificity, NPV, enrich_ratio])

        print('time',time.time()-time1)
        print('time per read',(time.time()-time1)/len(targets))

        print('time_with_data_process',time.time()-time_start)
        print('time per read',(time.time()-time1)/len(targets))






In [None]:
pred_df = pd.DataFrame( pred_lst, columns = [ 'host & target','t-target', 't-host', 'f-target', 'f-nontarget', 'acc', 'roc_auc', 'pr-auc', 'f1', 'mcc', 'recall', 'precision', 'specificity', 'NPV','enrich_ratio'] )
pred_df
