In [1]:
"""
The model is designed to analyze the consensus sequences in DNA sequences

The input of this model will be in the 2D matrix:    
    dimension: n x 1000, n is the amount of the input samples

The output of the model will be in the 3D matrix:
    dimension: n x 1000 x 4, n is the amoutn of the input samples
               the meaning of the output is the position-wise appearance for each alphabets(ATCG)

"""

'\nThe model is designed to analyze the consensus sequences in DNA sequences\n\nThe input of this model will be in the 2D matrix:    \n    dimension: n x 1000, n is the amount of the input samples\n\nThe output of the model will be in the 3D matrix:\n    dimension: n x 1000 x 4, n is the amoutn of the input samples\n               the meaning of the output is the position-wise appearance for each alphabets(ATCG)\n\n'

In [2]:
import math
import torch
import time
import os
import numpy as np
from collections import Counter, defaultdict
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer,\
    TransformerDecoder, TransformerDecoderLayer
from torch.nn.functional import softmax

In [3]:
'''
Global variables
'''

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

alphabet_dict = {'A' : 1, 'T' : 2, 'C' : 3, 'G' : 4}

In [None]:
f = open('','r')
seq = {}

for line in f:
    if line.startswith('>'):
        name = line.replace('>', '').split()[0]
        seq[name] = ''
    else:
        seq[name] += line.replace('\n','').strip()
        
f.close()

seq_ls = []
for name in seq:
    seq_ls.append(seq[name])

In [None]:
class TransformerModel(nn.Module): #done
    def __init__(
        self,
        d_model,
        dropout,
        max_len,
        nhead,
        encoder_layer_nums,
        decoder_lyaer_nums,
        dim_ff,
        ntoken,
    ):
        '''
        
        '''
        super(TransformerModel, self).__init__()
        
        self.d_model = d_model
        self.pos_encode = PositionalEmbedding(d_model, dropout, max_len)
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_ff, dropout, batch_first = True)
        self.encoder = TransformerEncoder(encoder_layer, encoder_layer_nums)
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_ff, dropout, batch_first = True)
        self.decoder = TransformerDecoder(decoder_layer, decoder_layer_nums)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.out_embed = nn.Embedding(ntoken, d_model)
        self.output_linear = nn.Linear(d_model, ntoken)
        self.output_softmax = nn.Softmax(dim = -1)
        
        self.init_weights()

    
    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.output_linear.bias.data.zero_()
        self.output_linear.weight.data.uniform_(-initrange, initrange)
    
    def forward(
        self,
        src,
        tgt,
        src_mask = None,
        tgt_mask = None,
        memory_mask = None,
        src_key_padding_mask = None,
        tgt_key_padding_mask = None,
        memory_key_padding_mask = None,
    ):
        
        src = self.embedding(src)
        tgt = self.embedding(tgt)
        
        src = self.pos_encode(src)
        tgt = self.pos_encode(tgt)
        
        memory = self.encoder(src, mask = src_mask, src_key_padding_mask = src_key_padding_mask)
        output = self.decoder(tgt, memory, 
                              tgt_mask = tgt_mask, 
                              memory_mask = memory_mask, 
                              tgt_key_padding_mask = tgt_key_padding_mask,
                              memory_key_padding_mask = memory_key_padding_mask
                             )
        output = self.output_linear(output)
        output = self.output_softmax(output)
        
        return output

In [None]:
class PositionalEmbedding(nn.Module): #done
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000):
        super(PositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p = dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000) / d_model) )
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, x: Tensor):
        x = x + self.pe[0, :x.size(1), :].requires_grad_(False)
        output = self.dropout(x)
        return output
    

In [4]:
class MEME(object):
    def __init__(self, source_data, W):
        '''
        source_data will be list of strings
        '''
        
        self.source_data = source_data
        self.W = W
        self.L = 4
        
        # fixed variables
        self.N = None
        self.l = None
        self.X = None
        self.n = None
        self.I = None
        self.letter_counts = None
        
        # mutable variables, i.e. the value will be changed
        self.p_X1 = None
        self.p_X2 = None
        self.Z = None
        self.z = None
        self.f_ij = None
        self.e_ij = None
        self.lambda1 = 0
        self.lambda2 = 0
        
        self.init_variables()
        
        
    def init_variables(self):
        '''
        initialize the followings: X, f_i, z, I, lambda, N, small L, 
        '''
        
        # record N and small L
        self.N = len(self.source_data)
        self.l = list()
        for seq in self.source_data:
            self.l.append(len(seq))

        # init small z
        self.z = np.ones((self.N, max(self.l)))
        for idx in range(self.N):
            self.z[idx][(self.l[idx]- self.W + 1):] = 0
        self.e_ij = self.z.copy()
        
        # init W-mer set X
        # X will be the list of strings
        X_ls = list()
        for i in range(self.N):
            X_ls += [ self.source_data[i][j:j+self.W] for j in range( self.l[i] - self.W + 1 ) ]
        self.X = X_ls
        self.n = len(self.X)
        
        # init indicator
        self.I = self.indicator_function()
        
        # init lambda
        lamb_range_min = min( np.sqrt(self.N)/self.n, 1/(2*self.W) )
        lamb_range_max = max( np.sqrt(self.N)/self.n, 1/(2*self.W) )
        self.lambda1 = np.random.uniform(lamb_range_min, lamb_range_max)
        self.lambda2 = 1 - self.lambda1
        
        # init f_ij, size: ( 1 + W ) x L
        self.f_ij = np.zeros( ( (self.W + 1), self.L ) )
        
        # init letter_counts
        self.letter_counts = self.count_letter_appearance()
        
        # f_i parts
        for pos in range(self.W):
            C = Counter(self.I.transpose()[pos])
            length = len(self.I)
            for val in range(4):
                self.f_ij[pos + 1][val] = C[val + 1] / length
        
        # f_0 part
        '''count_num = defaultdict(int)
        for subseq in self.X:
            for s in subseq:
                count_num[s.upper()] += 1
        total = np.concatenate(self.letter_counts).sum()
        for letter in alphabet_dict:
            self.f_ij[0][alphabet_dict[letter] - 1] = count_num[letter]/total
        '''
        mps = np.sum(self.f_ij.transpose(), axis = 1)
        for a in range(4):
            if mps[a] >= 1:
                for j in range(4):
                    self.f_ij[a][j] /= mps[a]
        
        
        
        # conditional probabilities p_X1, p_X2
        
        self.p_X1, self.p_X2 = self.condi_distribution(self.f_ij)
        
        return
        
    def indicator_function(self): #done
        '''
        In article, it is the I(k,a) function for eq(7), (8)
        There will transfer the alphabets to the index
        Thus, the results look like 
        [
        [1,2,3,4],
        [4,3,2,1]
        ]

        return
        indicator: transfer input string into tensor indicator, size is n x W
        '''
        assert isinstance(self.X, list), 'Type of X is not list'
        indicator = list()
        for seq in self.X:
            indicator.append(list(map(lambda x: alphabet_dict[x], seq)))

        return np.array(indicator, dtype = 'int')
        
    
    def condi_distribution(self, freq_letter):
        '''
        Calculate the conditional distribution p(Xi | theta_j)
        eq(7),(8) in the MEME article
        To avoid the computation error for the digits, it will use ln() to make it being summation
    
        Arguments:
    
        freq_letter: the frequences for each letter in each position, size: (W + 1) x L
                     background ( 1 x L ) + motif ( W x L )
                     dtype: np.array
        ====================================================
        return: it will be the log form output
    
        p_Xi_1: conditional distribution of motif sequence, size: n x 1
        p_Xi_2: conditional distribution of background, size: n x 1
        '''
    
        p_Xi_1 = np.zeros(self.n)
        p_Xi_2 = np.zeros(self.n)
        f_0 = freq_letter[0]
        f_j = freq_letter[1:]
    
        for subseq in range(self.n):
            for pos in range(self.W):
                p_Xi_1[subseq] += np.log( f_j[pos][self.I[subseq][pos] - 1] )
                p_Xi_2[subseq] += np.log( f_0[self.I[subseq][pos]-1])
        
        return np.exp(p_Xi_1), np.exp(p_Xi_2)
    
    def count_letter_appearance(self): #done
        '''
        count the total appearance times for each alphabets
        Arguments:
        data: input source W-mer data, size: n x W
        return:
        count: counting results, size: n x 4
        '''

        count = np.zeros( (self.n, 4) , dtype = 'int')

        for i in range(self.n):
            C = Counter(self.X[i])
            count[i][0] = C['A']
            count[i][1] = C['T']
            count[i][2] = C['C']
            count[i][3] = C['G']

        return count
    
    def update_erasing(self):
        '''
        compute the erasing values e_ij
        '''
        
        for i in range(self.N):
            for j in range(self.l[i]):
                lower_bound = j - self.W + 1 if j - self.W + 1 >= 0 else 0
                vals = self.z[i][lower_bound:j]
                if sum(vals) > 1:
                    vals /= sum(vals)
                if any(vals > 9.99999e-1):
                    self.e_ij[i][j] *= 1e-6
                else:
                    self.e_ij[i][j] *= np.exp( sum( np.log( np.ones(len(vals)) - vals ) ) )
                
        return
    
    def update_z(self):
        '''
        update the small z values
        '''
        
        assert sum(self.l)-self.N*(self.W - 1) == self.n, 'Error: index division not equal for updating z.'
        idx = 0
        for i in range(self.N):
            for j in range(self.l[i] - self.W + 1):
                self.z[i][j] = self.Z[idx][0]
                idx += 1
        
        return                
    
    def E_step(self, condi_dis, lamb): #done
        '''
        calculate the Z_ij, in article's eq(4)
        also update small z depends on the results of Z

        Arguments:
        condi_dis: conditional distribution, from the defined function, size: n x 2
        lamd: probability for using models, size: 1 x 2
        return:
        Z: membership probability, size: n x 2
        '''
        
        # regulate the shape
        p = condi_dis.reshape(self.n, 2)
        lamb = lamb.reshape(1, 2)

        multi_results = p * np.tile(lamb, (self.n, 1))
        summation = np.sum(multi_results, axis = 1, keepdims = True)
        Z = multi_results / summation
        self.Z = Z.copy()
        self.update_z()
        
        return Z
    
    def M_step(self):
        '''
        calculate the lambda and f_ij
        
        Arguments:
        Z: membership from E-step, size: n x 2
        I: indicator function, size: n x W
        count: count the appearance time for each alphabet in every sequences, size: n x 4
        
        return:
        
        '''
        
        Z = self.Z.transpose()
        count = self.letter_counts.transpose()

        # update lambda, eq(5)
        lamb = np.mean(Z, axis = 1)
        
        # update f_ij
        # calculate c_0k and c_jk
        c_0k = np.zeros((1, self.L))
        c_jk = np.zeros((self.W, self.L))
        
        # calculate the c_0k
        
        for k in range(self.L):
            for i in range(self.n):
                for j in range(self.W):
                    c_0k[0][k] += Z[1][i] if self.I[i][j] == (k + 1) else 0
        
        '''for i in range(4):
            c_0k[0][i] = np.sum( Z[1] * count[i] )'''
        
        # calculate the c_jk
        # first make erasing being 1 x n
        E = np.hstack([self.e_ij[i][:(self.l[i] - self.W + 1)] for i in range(self.N)])
        EZ = E * Z[0]
        
        for pos in range(self.W):
            for i in range(self.n):
                c_jk[pos][self.I[i][pos] - 1] += EZ[i]
        
        # start updating f
        # won't directly update self.f since it will use the previous values to check the disparity
        c_0k /= np.sum(c_0k, axis = 1)
        for pos in range(self.W):
            c_jk[pos] /= sum(c_jk[pos])
        f_ij = np.vstack((c_0k,c_jk))
        
        self.update_erasing()
        
        return lamb, f_ij
    
    def update_variables(self, lamb, f_ij):
        '''
        update the variables' values after M step
        it will update: f_ij, lambda, conditional probabilities p_X1 and p_X2, erasing
        
        Inputs:
        lamb: lambda, size: 1 x 2
        f_ij: the updated values for position-wise letter frequencies, size: (W+1) x L
        '''
        
        self.lambda1 = lamb[0]
        self.lambda2 = lamb[1]
        self.f_ij = f_ij.copy()
        self.p_X1, self.p_X2 = self.condi_distribution(self.f_ij)
        self.update_erasing()
        
        return
    
    def iter(self, epsilon = 1e-6):
        '''
        iterate the E and M steps
        '''
        
        err_ls = list()
        f_ls = list()
        f_ls.append(self.f_ij)
        
        for step in range(100):
            condi_dis = np.array([self.p_X1, self.p_X2])
            lamb = np.array([self.lambda1, self.lambda2])
            new_Z = self.E_step(condi_dis, lamb)
            new_lambda, new_f_ij = self.M_step()
            f_ls.append(new_f_ij)
            
            err = np.sqrt(np.sum( (self.f_ij - new_f_ij) ** 2  , axis = 1))
            err_ls.append(err)
            
            if any(err < epsilon):
                #print('At step {}, EM converges to a solution.'.format(step))
                #print('err is ', err)
                self.update_variables(new_lambda, new_f_ij)
                break
            
            '''if step % 10 == 0 and step > 0:
                print('in step 10n err is ', err)
            '''
            self.update_variables(new_lambda, new_f_ij)
        
        return #err_ls, new_f_ij, f_ls
    

In [None]:
class trainer(object):
    def __init__(self, model, data_in, W):
        
        self.model = model
        self.data = data_in
        self.W = W
    
    def train(self):
        
        
        
        

In [None]:
if __name__ == '__main__':
    '''
    data_info = {amount of sequence, }
    '''
    # read data
    data_info
    # build model and train

In [None]:
def expand_dim(data, IO):
    '''
    To expand input and output data into specific dimensions
    '''

    row, col = data.size()
    assert row < 50 and col < 500, 'Index out of range.'

    if IO:# expand input to 50 x 500
        col_expand = 7 * torch.ones(row, 500 - col)
        row_expand = 7 * torch.ones(50 - row, 500)

    elif not IO:# expand output to 50 x 50
        col_expand = 7 * torch.ones(row, 50 - col)
        row_expand = 7 * torch.ones(50 - row, 50)

    res = torch.cat( (torch.cat((data, col_expand), axis = 1), row_expand), axis = 0)

    return res

def get_padding_mask(data):
    mask = torch.zeros(data.size())
    mask[data == 7] = -torch.inf
    return mask

tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(-1))

In [10]:
data = ['CAGTCAGCAC','AGTTACGTAG']
data1 = ['AAAAA','AAAAA']
m = MEME(data1,3)

In [24]:
np.sum(m.f_ij, axis = 0)

array([4., 0., 0., 0.])

In [23]:
m.f_ij

array([[1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.]])