In [None]:
"""
The model is designed to analyze the consensus sequences in DNA sequences

The input of this model will be in the 2D matrix:    
    dimension: n x 1000, n is the amount of the input samples

The output of the model will be in the 3D matrix:
    dimension: n x 1000 x 4, n is the amoutn of the input samples
               the meaning of the output is the position-wise appearance for each alphabets(ATCG)

"""

In [None]:
import math
import torch
import time
import os
import numpy as np
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer,\
    TransformerDecoder, TransformerDecoderLayer
from torch.nn.functional import softmax

In [None]:
'''
Global variables
'''

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

alphabet_dict = {'A' : 1, 'T' : 2, 'C' : 3, 'G' : 4}

In [None]:
f = open('','r')
seq = {}

for line in f:
    if line.startswith('>'):
        name = line.replace('>', '').split()[0]
        seq[name] = ''
    else:
        seq[name] += line.replace('\n','').strip()
        
f.close()

seq_ls = []
for name in seq:
    seq_ls.append(seq[name])

In [None]:
class TransformerModel(nn.Module):
    def __init__(
        self,
        d_model,
        dropout,
        max_len,
        nhead,
        encoder_layer_nums,
        decoder_lyaer_nums,
        dim_ff,
        ntoken,
    ):
        '''
        
        '''
        super(TransformerModel, self).__init__()
        
        self.d_model = d_model
        self.pos_encode = PositionalEmbedding(d_model, dropout, max_len)
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_ff, dropout)
        self.encoder = TransformerEncoder(encoder_layer, encoder_layer_nums)
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_ff, dropout)
        self.decoder = TransformerDecoder(decoder_layer, decoder_layer_nums)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.out_embed = nn.Embedding(ntoken, d_model)
        self.output_linear = nn.Linear(d_model, ntoken)
        self.output_softmax = nn.Softmax(dim = -1)
        
        self.init_weights()

    
    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.output_linear.bias.data.zero_()
        self.output_linear.weight.data.uniform_(-initrange, initrange)
    
    def forward(
        self,
        src,
        tgt,
        src_mask = None,
        tgt_mask = None,
        memory_mask = None,
        src_key_padding_mask = None,
        tgt_key_padding_mask = None,
        memory_key_padding_mask = None,
    ):
        
        src = self.embedding(src)
        tgt = self.embedding(tgt)
        
        src = self.pos_encode(src)
        tgt = self.pos_encode(tgt)
        
        memory = self.encoder(src, mask = src_mask, src_key_padding_mask = src_key_padding_mask)
        output = self.decoder(tgt, memory, 
                              tgt_mask = tgt_mask, 
                              memory_mask = memory_mask, 
                              tgt_key_padding_mask = tgt_key_padding_mask,
                              memory_key_padding_mask = memory_key_padding_mask
                             )
        output = self.output_linear(output)
        output = self.output_softmax(output)
        
        return output

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000):
        super(PositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p = dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000) / d_model) )
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, x: Tensor):
        x = x + self.pe[0, :x.size(1), :].requires_grad_(False)
        output = self.dropout(x)
        return output
    

In [None]:
def condi_distribution(freq_letter, indicator, n):
    '''
    Calculate the conditional distribution p(Xi | theta_j)
    eq(7),(8) in the MEME article
    To avoid the computation error for the digits, it will use ln() to make it being summation
    
    Arguments:
    
    freq_letter: the frequences for each letter in each position, size: 1001 x 4
                 background ( 1 x 4 ) + motif ( 1000 x 4 )
    indicator: indicator function, size: n x 1000
    n: amount of the input sequences
    ====================================================
    return:
    
    p_Xi_1: the conditional distribution of motif sequence, size: n x 1
    p_Xi_2: the conditional distribution of background, size: n x 1
    '''
    
    p_Xi_1 = torch.zeros(n)
    p_Xi_2 = torch.zeros(n)
    f_0 = freq_letter[0]
    f_j = freq_letter[1:]
    
    for subseq in range(self.n):
        for pos in range(self.W):
            p_Xi_1[subseq] += torch.log( f_j[ indicator[subseq][pos] - 1 ] )
            p_Xi_2[subseq] += torch.log( f_0[ indicator[subseq][pos] - 1 ] )
        
    return p_Xi_1, p_Xi_2

In [None]:
def indicator_function(data):
    '''
    In article, it is the I(k,a) function for eq(7), (8)
    There will transfer the alphabets to the index
    
    Arguments
    data: input data, dtype is string
    
    return
    indicator: transfer input string into tensor indicator, size is 1 x 1000
    '''
    
    n = len(data)
    indicator = torch.zeros(1000).int()
    for i in range(1000):
        if i < n:
            indicator[i] = alphabet_dict[data[i]]
        else:
            indicator[i] = 0
    
    return indicator.reshape(1,1000)

In [None]:
def group_membership(condi_dis, lamb):
    '''
    calculate the Z_ij, in article's eq(4)
    
    Arguments:
    condi_dis: conditional distribution, from the defined function, size: 1 x 
    
    '''
    
    lamb1 = lamb[0]
    lamb2 = lamb[1]
    
    
    

In [None]:
class trainer(object):
    def __init__(self, model):
        
        self.model = model
        
        
    def posterior_computation(self):
        
        
    
    def train(self):
        
        
        

In [None]:
if __name__ == '__main__':
    '''
    data_info = {amount of sequence, }
    '''
    # read data
    data_info
    # build model and train