In [1]:
"""
First version for training
"""
"""
targets
    size: batch_size(15) x fi(32 = 30 + eos + bos) x fj(4)
"""

'\ntargets\n    size: batch_size(15) x fi(32 = 30 + eos + bos) x fj(4)\n'

In [2]:
import os
import torch
import numpy as np
import math
import json
import time
import random
import matplotlib.pyplot as plt
from collections import Counter, defaultdict
from sklearn.preprocessing import normalize
from torch import nn, tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer,\
    TransformerDecoder, TransformerDecoderLayer
from torch.nn.functional import softmax
from itertools import compress
from fns import masked_dist, get_padding_mask, get_batch, indicator, init_lambda, EM_criterion

In [3]:
d_model = 512
dropout = 0.1
max_len = 500
nhead = 8
encoder_layer_nums = 6
decoder_layer_nums = 6
dim_ff = 2048
# src_token: ATCG(1~4), <BOS or SOS>(0), <EOS>(16), <PAD>(17)
    # combinations: AT, AC, AG, TC, TG, CG, ATC, ATG, ACG, TCG, ATCG(5~15, total 11)
# tgt_token: <BOS or SOS>(0), find motif position(1) and not(2), <EOS>(3)
src_ntoken = 18
tgt_ntoken = 4
alphabet_dict = {'A' : 1, 'T' : 2, 'C' : 3, 'G' : 4}

In [4]:
class PositionalEmbedding(nn.Module): #done
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000):
        super(PositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p = dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000) / d_model) )
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, x:tensor):
        x = x + self.pe[0, :x.size(1), :]#.requires_grad_(False)
        output = self.dropout(x)
        return output
    

In [5]:
class TransformerModel(nn.Module): #done
    def __init__(
        self,
        d_model,
        dropout,
        max_len,
        nhead,
        encoder_layer_nums,
        decoder_layer_nums,
        dim_ff,
        src_ntoken,
        tgt_ntoken
    ):
        '''
        
        '''
        super(TransformerModel, self).__init__()
        
        self.d_model = d_model
        self.pos_encode = PositionalEmbedding(d_model, dropout, max_len)
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_ff, dropout, batch_first = True)
        self.encoder = TransformerEncoder(encoder_layer, encoder_layer_nums)
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_ff, dropout, batch_first = True)
        self.decoder = TransformerDecoder(decoder_layer, decoder_layer_nums)
        self.src_embedding = nn.Embedding(src_ntoken, d_model)
        self.tgt_embedding = nn.Embedding(tgt_ntoken, d_model)
        #self.out_embed = nn.Embedding(tgt_ntoken, d_model)
        self.output_linear = nn.Linear(d_model, 4)#ntoken)
        self.output_softmax = nn.Softmax(dim = -1)
        
        self.init_weights()

    
    def init_weights(self):
        initrange = 0.1
        self.src_embedding.weight.data.uniform_(-initrange, initrange)
        self.tgt_embedding.weight.data.uniform_(-initrange, initrange)
        self.output_linear.bias.data.zero_()
        self.output_linear.weight.data.uniform_(-initrange, initrange)
    
    def forward(
        self,
        src,
        tgt,
        src_mask = None,
        tgt_mask = None,
        memory_mask = None,
        src_key_padding_mask = None,
        tgt_key_padding_mask = None,
        memory_key_padding_mask = None,
    ):
        
        src = self.src_embedding(src)
        tgt = self.tgt_embedding(tgt)
        
        src = self.pos_encode(src)
        tgt = self.pos_encode(tgt)
        
        memory = self.encoder(src, mask = src_mask, src_key_padding_mask = src_key_padding_mask)
        output = self.decoder(tgt, memory, 
                              tgt_mask = tgt_mask, 
                              memory_mask = memory_mask, 
                              tgt_key_padding_mask = tgt_key_padding_mask,
                              memory_key_padding_mask = memory_key_padding_mask
                             )
        output = self.output_linear(output)
        output = self.output_softmax(output)
        
        return output

In [6]:
model = TransformerModel(d_model, dropout, max_len, nhead, encoder_layer_nums, decoder_layer_nums, dim_ff, src_ntoken, tgt_ntoken)
lr = 1e-6
optimizer = torch.optim.SGD(model.parameters(), lr = lr)

In [7]:
use_gpu = 0
if use_gpu:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")

model = model.to(device)
model.train(True)
total_loss = 0.
loss_list = list()
times = 1
motif_mask = torch.zeros((15, 30, 4), requires_grad = False)
for w in range(15):
    motif_mask[w][:w+10] = 1
epsilon = 1e-6
mdl_states = list()
lamb_ls = list()
nll_ls = list()
time_set = {'t1' : 0 , 't2' : 0 , 't3' : 0, 't4' : 0}
iter_count = 0
file_count = 0
loss_log = r"D:\DNA data\Artificial\Loss.txt"

In [8]:
fileDir = r"D:\DNA data\Artificial"
fileExt = r".fa"
file_list = [f for f in os.listdir(fileDir) if f.endswith(fileExt)]
file_idx = [i for i in range(len(file_list))]
random.shuffle(file_idx)

In [9]:
for idx in file_idx:

    T = time.time()

    pwd = os.path.join(fileDir, file_list[idx])
    data, src, target = get_batch(pwd)
    I = indicator(data)
    data_key_padding_mask = get_padding_mask(src, 17)
    target_key_padding_mask = get_padding_mask(target, 2)
    target_mask = nn.Transformer.generate_square_subsequent_mask(target.size(-1))
    last_f = torch.ones((15,30,4))
    for w in range(15):
        last_f[w][:w+10] = 1
    subseq_len = [ len(seq) for seq in data]
    last_lambda = init_lambda(subseq_len)
    last_lambda.requires_grad_(False)
    loss = tensor(float('nan'))
    
    time_set['t1'] += time.time() - T

    T = time.time()

    if use_gpu:
        src = src.to(device)
        target = target.to(device)
        data_key_padding_mask = data_key_padding_mask.to(device)
        target_key_padding_mask = target_key_padding_mask.to(device)
        target_mask = target_mask.to(device)
        last_f = last_f.to(device)
        last_lambda = last_lambda.to(device)
        motif_mask = motif_mask.to(device)
        loss = loss.to(device)
    
    time_set['t2'] += time.time() - T

    file_count +=1
    with open(loss_log, 'w') as f:
        f.write(f'file {file_count}')
        f.write('\n')
    # the EM training process for 100 times
    for step in range(50):###
        output = model(src, target,
                       src_key_padding_mask = data_key_padding_mask,
                       tgt_key_padding_mask = target_key_padding_mask,
                       tgt_mask = target_mask
                      )
        if masked_dist(last_f, output, motif_mask) < epsilon:
            break
        
        T = time.time()
        nll_value, new_lambda= EM_criterion(output, I, last_lambda, device)
        time_set['t3'] += time.time() - T
        
        lamb_ls.append(last_lambda)
        #loss = torch.sum(tensor(nll))
        loss = nll_value
        with open(loss_log, 'w') as f:
            f.write(str(loss.item()))
            f.write('\n')
        last_f = output
        #total_loss += loss.item()
        #nll_ls.append(nll)
        #loss_list.append(loss.item())
        #if step%5 == 0:
        #    mdl_states.append(model.state_dict())
        
        optimizer.zero_grad()
        #loss.requires_grad_(True)

        T = time.time()

        loss.backward()
        torch.nn.utils.clip_grad_value_(model.parameters(), 1)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()

        time_set['t4'] += time.time() - T

        # update the lambda value manually
        if (step+1)%5==0:
            #print('Update lambda {} times!'.format(iter_count + 1))
            #print('-' * 87)
            iter_count += 1
            with torch.no_grad():
                  last_lambda = new_lambda.clone().detach()
                  #last_lambda = last_lambda.to(device)
    with open(loss_log, 'w') as f:
        f.write('='*100)
        f.write('\n')
    if (file_count)%10 ==0:
        print('Pass 10 times')


In [None]:
#torch.save(model, 'Transformo_100x100.pt')

In [None]:
time_set

In [None]:
model.state_dict()