In [None]:
"""
The model is designed to analyze the consensus sequences in DNA sequences

The input of this model will be in the 2D matrix:    
    dimension: n x 1000, n is the amount of the input samples

The output of the model will be in the 3D matrix:
    dimension: n x 1000 x 4, n is the amoutn of the input samples
               the meaning of the output is the position-wise appearance for each alphabets(ATCG)

"""

In [2]:
import os
import torch
import numpy as np
import math
import json
import time
import random
from collections import Counter, defaultdict
from sklearn.preprocessing import normalize
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer,\
    TransformerDecoder, TransformerDecoderLayer
from torch.nn.functional import softmax
from itertools import compress

In [3]:
'''
Global variables
'''

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

alphabet_dict = {'A' : 1, 'T' : 2, 'C' : 3, 'G' : 4}

combi_letters = {'A' : 1, 'T' : 2, 'C' : 3, 'G' : 4, 
                 'AT' : 5, 'AC': 6, 'AG' : 7, 'TC' : 8, 'TG' : 9, 
                 'CG' : 10, 'ATC' : 11, 'ATG' : 12, 'ACG' : 13,
                 'TCG' : 14, 'ATCG' : 15}
val_to_char = {1 : 'A', 2 : 'T', 3 : 'C', 4 : 'G'}

fileDir = r'D:\motif data'
fileExt = r'.fa'
file_list = [f for f in os.listdir(fileDir) if f.endswith(fileExt)]

In [4]:
# read the fasta files and make the list of DNA sequences
def read_file(path):
    f = open(path,'r')
    seq = {}

    for line in f:
        if line.startswith('>'):
            name = line.replace('>', '').split()[0]
            seq[name] = ''
        else:
            seq[name] += line.replace('\n','').strip()

    f.close()

    seq_ls = []
    for name in seq:
        seq_ls.append(seq[name])
    return seq_ls

In [5]:
d_model = 512
dropout = 0.1
max_len = 500
nhead = 8
encoder_layer_nums = 6
decoder_layer_nums = 6
dim_ff = 512
# token: ATCG(1~4), <BOS or SOS>(0), <EOS>(16), <PAD>(17)
    # combinations: AT, AC, AG, TC, TG, CG, ATC, ATG, ACG, TCG, ATCG(5~15, total 11)
ntoken = 18

In [6]:
class TransformerModel(nn.Module): #done
    def __init__(
        self,
        d_model,
        dropout,
        max_len,
        nhead,
        encoder_layer_nums,
        decoder_layer_nums,
        dim_ff,
        ntoken,
    ):
        '''
        
        '''
        super(TransformerModel, self).__init__()
        
        self.d_model = d_model
        self.pos_encode = PositionalEmbedding(d_model, dropout, max_len)
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_ff, dropout, batch_first = True)
        self.encoder = TransformerEncoder(encoder_layer, encoder_layer_nums)
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_ff, dropout, batch_first = True)
        self.decoder = TransformerDecoder(decoder_layer, decoder_layer_nums)
        self.src_embedding = nn.Embedding(ntoken, d_model, scale_grad_by_freq = False)
        self.tgt_embedding = nn.Embedding(4, d_model, scale_grad_by_freq = False)
        self.out_embed = nn.Embedding(ntoken, d_model)
        self.output_linear = nn.Linear(d_model, ntoken)
        self.output_softmax = nn.Softmax(dim = -1)
        
        self.init_weights()

    
    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.output_linear.bias.data.zero_()
        self.output_linear.weight.data.uniform_(-initrange, initrange)
    
    def forward(
        self,
        src,
        tgt,
        src_mask = None,
        tgt_mask = None,
        memory_mask = None,
        src_key_padding_mask = None,
        tgt_key_padding_mask = None,
        memory_key_padding_mask = None,
    ):
        
        src = self.src_embedding(src)
        tgt = self.tgt_embedding(tgt)
        
        src = self.pos_encode(src)
        tgt = self.pos_encode(tgt)
        
        memory = self.encoder(src, mask = src_mask, src_key_padding_mask = src_key_padding_mask)
        output = self.decoder(tgt, memory, 
                              tgt_mask = tgt_mask, 
                              memory_mask = memory_mask, 
                              tgt_key_padding_mask = tgt_key_padding_mask,
                              memory_key_padding_mask = memory_key_padding_mask
                             )
        output = self.output_linear(output)
        output = self.output_softmax(output)
        
        return output

In [7]:
class PositionalEmbedding(nn.Module): #done
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000):
        super(PositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p = dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000) / d_model) )
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, x: Tensor):
        x = x + self.pe[0, :x.size(1), :].requires_grad_(False)
        output = self.dropout(x)
        return output
    

In [9]:
def preprocess_seq(input_seq, IO, pad_num = 17):
    '''
    adjust the length for single sequence
    tokenize the sequence
    return:
    res: result for the single sequence adjustment and return in ndarray type
    '''
    
    L = len(input_seq)
    
    if IO:
        if L > 498:
            mid = L//2
            seq = input_seq[mid - 249: mid + 249]
            res = np.array(list(map(lambda x: combi_letters[x], seq)))
            res = np.concatenate((np.array([0]), res, np.array[16]))
        else:
            res = np.array(list(map(lambda x: combi_letters[x], input_seq)))
            res = np.concatenate((np.array([0]), res, np.array([16]), np.array([pad_num] * (498 - L) )))
        assert len(res) == 500, 'Preprocess the input sequence not equal to 500'
    
    else:
        assert L <= 48, 'Preprocess the output sequence get the length error, > 48'
        res = np.array(list(map(lambda x: combi_letters[x], input_seq)))
        res = np.concatenate((np.array([0]), res, np.array([16]), np.array([pad_num] * (48 - L))))
    
    return res.astype('int')

def make_outputs(data, row_size):
    '''
    make the outputs for the model, i.e. make the tgt
    return:
    tgt: the target array, has transfered the string data into tokens, np.array
    '''
    tgt = list()
    l = list()
    for seq in data:
        l.append(len(seq))
    assert min(l) > 3, 'The minimum length of the inputs are less than 3'
    if min(l) < row_size:
        row_size = min(l)
        print("Mention: the target size is greater than the minimum length of the data, it will set the size as the minimum length of the data")

    for motif_len in range(row_size - 3):
        #print('At step ', motif_len)
        #print(f'making the length {motif_len+3}\'s MEME output')
        m = MEME(data, motif_len + 3)
        start_time = time.time()
        m.iter()
        print(f'making the length {motif_len+3}\'s MEME output use the time {time.time()-start_time}')
        motif, logo = m.prob_out()
        tgt.append(preprocess_seq(logo, 0))
    tgt = np.array(tgt)
    assert len(tgt.shape) == 2, f'tgt dimension error in make outputs and the shape is {tgt.shape}'
    return tgt

def expand_dim(data, IO, pad_num = 17, row_size = 40):
    '''
    To expand input and output data into specific dimensions
    The input variable of this function is transfered matrix, that is the matrix with tokenized elements
    return:
    res: the results of expanding the dimension
    '''

    row, col = data.shape
    assert row < row_size, 'Row size out of range.'

    if IO:# expand input to row_size x 500
        assert col <= 500, 'Column size out of range(Input)'
        col_expand = pad_num * np.ones((row, 500 - col))
        row_expand = pad_num * np.ones((row_size - row, 500))

    elif not IO:# expand output to row_size x 50
        assert col <= 50, 'Column size out of range(Output)'
        col_expand = pad_num * np.ones((row, 50 - col))
        row_expand = pad_num * np.ones((row_size - row, 50))

    res = np.concatenate( (np.concatenate((data, col_expand), axis = 1), row_expand), axis = 0)

    return res


def get_padding_mask(data, pad_num = 17):
    '''
    create the padding mask for data
    return:
    mask: the padding mask in Tensor, with size same as the input data for the function
    '''
    mask = torch.zeros(data.shape)
    mask[data == pad_num] = -torch.inf
    
    return mask

#tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(-1))

In [11]:
src = torch.LongTensor([
    [0,3,2,4,3,3,2,1,5,6,6],
    [0,4,3,1,2,3,1,2,1,5,6]
])

tgt = torch.LongTensor([
    [0,3,2,4,3,3,2,1,5,6],
    [0,4,3,1,2,3,1,2,1,5]    
])

In [10]:
mdl = TransformerModel(d_model, dropout, max_len, nhead, encoder_layer_nums, decoder_layer_nums, dim_ff, ntoken)

In [13]:
mdl(src, tgt).shape

torch.Size([2, 10, 18])

In [11]:
criterion = nn.CrossEntropyLoss()
lr = 5.0
#optimizer = torch.optim.SGD(model.parameters(), lr = lr)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1,0, gamma = 0.8)

def get_batch(batch_size, ith):
    
    dna_seq = read_file(os.path.join(fileDir, file_list[ith]))
    assert len(dna_seq) < batch_size, f' The (row) size of the input data is out of range(>{batch_size}).'
    row = len(dna_seq)
    target = make_outputs(dna_seq, batch_size)
    data = np.zeros((row, 500))
    for idx in range(len(dna_seq)):
        print(f'doing the {idx} step in for loop')
        data[idx] = preprocess_seq(dna_seq[idx], 1)
        
    data = expand_dim(data, 1)
    return data, target

def save_data():
    '''
    save the preprocessed data
    since the 
    '''

def train(model):
    model.train()
    total_loss = 0.
    start_time = time.time()
    
    for step in range(1):
        data, target = get_batch(40, step)
        data_key_padding_mask = get_padding_mask(data)
        target_key_padding_mask = get_padding_mask(target)
        target_mask = nn.Transformer.generate_square_subsequent_mask(target.size(-1))
        output = model(data, target, 
                       src_key_padding_mask = data_key_padding_mask,
                       tgt_key_padding_mask = target_key_padding_mask,
                       tgt_mask = target_mask
                      )
        output_flat = output.view(-1, ntokens)
        loss = criterion(output_flat, target.veiw(-1))
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), 0.7)
        optimizer.step()
        
        if step% 20 == 0:
            print(f'Step {step}, total loss: {total_loss}')
            total_loss = 0
    
    return

def evaluate(model, eval_data):
    model.eval()
    total_loss = 0.
    
    with torch.no_grad():
        for i in range(1):
            data, targets = get_batch()
            seq_len = data.size(0)
            output = model(data)
            output_flat = output.view(-1, ntoken)
            total_loss += seq_len * criterion(output_flat, targets.view(-1)).item()
        
    return total_loss / (len(eval_data) - 1)
        
            

In [13]:
m = MEME(dna_seq, 20)
start_time = time.time()
m.iter()
print(time.time()-start_time)

NameError: name 'dna_seq' is not defined

In [72]:
if __name__ == '__main__':
    
    mdl = TransformerModel(d_model, dropout, max_len, nhead, encoder_layer_nums, decoder_layer_nums, dim_ff, ntoken)
    
    # read data
    #data = read_file(path)
    # build model and train
    
    train(mdl)

Mention: the target size is greater than the minimum length of the data, it will set the size as the minimum length of the data
making the length 3's MEME output use the time 10.256240367889404
making the length 4's MEME output use the time 11.786333799362183
making the length 5's MEME output use the time 1.0089032649993896
making the length 6's MEME output use the time 14.54584813117981
making the length 7's MEME output use the time 15.272658824920654
making the length 8's MEME output use the time 0.17101764678955078
making the length 9's MEME output use the time 17.363539934158325
making the length 10's MEME output use the time 18.07251811027527
making the length 11's MEME output use the time 0.16188907623291016
making the length 12's MEME output use the time 0.10732913017272949
making the length 13's MEME output use the time 0.2673914432525635
making the length 14's MEME output use the time 0.19844818115234375
making the length 15's MEME output use the time 0.3223719596862793
making

IndexError: too many indices for tensor of dimension 2

In [75]:
data = ['CAGTCAGCAC','AGTTACGTAG']
data1 = ['AAAAA','AAAAA']
m = MEME(data, 3)

In [76]:
mdl = TransformerModel(d_model, dropout, max_len, nhead,
                       encoder_layer_nums, decoder_layer_nums,
                       dim_ff, ntoken)

In [78]:
mdl.cuda()

TransformerModel(
  (pos_encode): PositionalEmbedding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLi

In [None]:
src = torch.LongTensor([
    [0,3,2,4,3,3,2,1,5,6,6],
    [0,4,3,1,2,3,1,2,1,5,6]
])

tgt = torch.LongTensor([
    [0,3,2,4,3,3,2,1,5,6],
    [0,4,3,1,2,3,1,2,1,5]    
])

mdl(src,tgt).shape

In [None]:
emb = mdl.embedding
mdl.decoder(emb(tgt),emb(src)).shape

In [77]:
data=['TTCGATGCCGT','CTTAATGCTAC']
m=MEME(data,4)
m.iter()
m.prob_out()

(array([[0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]]),
 ['T', 'T', 'C', 'G'])

In [76]:
m.prob_out()

(array([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]]),
 ['A', 'T', 'C', 'G'])

In [20]:
dna_seq = read_file(os.path.join(fileDir, file_list[0]))
print(dna_seq)

['TGCTGGATAAGAATGTTTTAGCAATCTCTTT', 'TCAGCGAAAAAAATTAAAGCGCAAGATTGTT', 'GCGACAACCGGAATATGAAAGCAAAGCGCAG', 'TGCTGGATAAGAATGTTTTAGCAATCTCTTT', 'TCAGCGAAAAAAATTAAAGCGCAAGATTGTT', 'GCGACAACCGGAATATGAAAGCAAAGCGCAG', 'TGCTGGATAAGAATGTTTTAGCAATCTCTTT', 'TCAGCGAAAAAAATTAAAGCGCAAGATTGTT', 'GCGACAACCGGAATATGAAAGCAAAGCGCAG']


31

In [44]:
m = MEME(dna_seq, 18)

In [65]:
stt = time.time()
record = m.iter()
print(time.time() - stt)

15.269115686416626


In [66]:
print(record)

[0.0, 0.010507583618164062, 0.0065097808837890625, 0.0, 0.010503530502319336, 0.006506919860839844, 0.0, 0.009590625762939453, 0.004101276397705078, 0.0, 0.006003856658935547, 0.002595186233520508, 0.0, 0.006006479263305664, 0.0025844573974609375, 0.0, 0.0059604644775390625, 0.002000093460083008, 0.0, 0.00650477409362793, 0.0029997825622558594, 0.0, 0.005504608154296875, 0.0030002593994140625, 0.0, 0.005504131317138672, 0.0035047531127929688, 0.0, 0.004999876022338867, 0.0032253265380859375, 0.0, 0.00599980354309082, 0.002000570297241211, 0.0, 0.00621485710144043, 0.003000020980834961, 0.0, 0.006071567535400391, 0.003000020980834961, 0.0, 0.005505561828613281, 0.0029997825622558594, 0.0, 0.005125999450683594, 0.003000020980834961, 0.0, 0.00599980354309082, 0.0031998157501220703, 0.0, 0.008270740509033203, 0.00599980354309082, 0.0, 0.010504484176635742, 0.0060002803802490234, 0.0009999275207519531, 0.010010242462158203, 0.006038188934326172, 0.0, 0.010000228881835938, 0.0059998035430908

In [67]:
rec = np.array(record).reshape(-1,3)

In [68]:
np.sum(rec, axis = 0)

array([0.04378748, 9.54441357, 5.65681934])

In [71]:
rec.argmax()

2479

In [73]:
m.prob_out()

(array([[1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.]]),
 ['A',
  'G',
  'C',
  'G',
  'A',
  'A',
  'A',
  'A',
  'A',
  'A',
  'A',
  'T',
  'T',
  'A',
  'A',
  'A',
  'G',
  'C'])

In [25]:
m.beta_i.reshape(1,-1)

array([[0.03984674, 0.02145594, 0.0164751 , 0.02222222]])

In [26]:
m.beta_i

array([0.03984674, 0.02145594, 0.0164751 , 0.02222222])

In [28]:
src, tgt = get_batch(40,0)

Mention: the target size is greater than the minimum length of the data, it will set the size as the minimum length of the data
making the length 3's MEME output use the time 0.1671743392944336
making the length 4's MEME output use the time 0.17180752754211426
making the length 5's MEME output use the time 0.25025105476379395
making the length 6's MEME output use the time 0.3084275722503662
making the length 7's MEME output use the time 0.7732706069946289
making the length 8's MEME output use the time 0.5844018459320068
making the length 9's MEME output use the time 0.9474921226501465
making the length 10's MEME output use the time 0.16924285888671875
making the length 11's MEME output use the time 0.1607062816619873
making the length 12's MEME output use the time 0.17380309104919434
making the length 13's MEME output use the time 0.19881176948547363
making the length 14's MEME output use the time 0.13724112510681152
making the length 15's MEME output use the time 0.1459197998046875
ma

  freq_letter = np.log(freq_letter)


making the length 23's MEME output use the time 0.1393427848815918
making the length 24's MEME output use the time 0.12112307548522949
making the length 25's MEME output use the time 0.10052275657653809
making the length 26's MEME output use the time 0.08383536338806152
making the length 27's MEME output use the time 0.08089590072631836
making the length 28's MEME output use the time 0.07557892799377441
making the length 29's MEME output use the time 0.05318450927734375
making the length 30's MEME output use the time 0.047606706619262695
doing the 0 step in for loop
doing the 1 step in for loop
doing the 2 step in for loop
doing the 3 step in for loop
doing the 4 step in for loop
doing the 5 step in for loop
doing the 6 step in for loop
doing the 7 step in for loop
doing the 8 step in for loop
