In [1]:
from torch import nn
import torch
import numpy as np
import utils
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy,softmax

In [3]:
class Seq2Seq(nn.Module):
    def __init__(self,enc_v_dim, dec_v_dim, emb_dim, units, max_pred_len, start_token, end_token):
        super().__init__()
        self.units = units
        self.dec_v_dim = dec_v_dim

        # encoder
        self.enc_embeddings = nn.Embedding(enc_v_dim,emb_dim)
        self.enc_embeddings.weight.data.normal_(0,0.1)
        self.encoder = nn.LSTM(emb_dim,units,1,batch_first=True)
    
        # decoder
        self.dec_embeddings = nn.Embedding(dec_v_dim,emb_dim)
        self.dec_embeddings.weight.data.normal_(0,0.1)
        self.decoder_cell = nn.LSTMCell(emb_dim,units)
        self.decoder_dense = nn.Linear(units,dec_v_dim)

        self.opt = torch.optim.Adam(self.parameters(),lr=0.001)
        self.max_pred_len = max_pred_len
        self.start_token = start_token
        self.end_token = end_token

    
    def encode(self,x):
        embedded = self.enc_embeddings(x)   # [n, step, emb]
        hidden = (torch.zeros(1,x.shape[0],self.units),torch.zeros(1,x.shape[0],self.units))
        o,(h,c) = self.encoder(embedded,hidden)
        return h,c
    
    def inference(self,x):
        self.eval()
        hx,cx = self.encode(x)
        hx,cx = hx[0],cx[0]
        start = torch.ones(x.shape[0],1)
        start[:,0] = torch.tensor(self.start_token)
        start= start.type(torch.LongTensor)
        dec_emb_in = self.dec_embeddings(start)
        dec_emb_in = dec_emb_in.permute(1,0,2)
        dec_in = dec_emb_in[0]
        output = []
        for i in range(self.max_pred_len):
            hx, cx = self.decoder_cell(dec_in, (hx, cx))
            o = self.decoder_dense(hx)
            o = o.argmax(dim=1).view(-1,1)
            dec_in=self.dec_embeddings(o).permute(1,0,2)[0]
            output.append(o)
        output = torch.stack(output,dim=0)
        self.train()

        return output.permute(1,0,2).view(-1,self.max_pred_len)

    
    def train_logit(self,x,y):
        hx,cx = self.encode(x)
        hx,cx = hx[0],cx[0]
        dec_in = y[:,:-1]
        dec_emb_in = self.dec_embeddings(dec_in)
        dec_emb_in = dec_emb_in.permute(1,0,2)
        output = []
        for i in range(dec_emb_in.shape[0]):
            hx, cx = self.decoder_cell(dec_emb_in[i], (hx, cx))
            o = self.decoder_dense(hx)
            output.append(o)
        output = torch.stack(output,dim=0)
        return output.permute(1,0,2)
    
    def step(self,x,y):
        self.opt.zero_grad()
        batch_size = x.shape[0]
        logit = self.train_logit(x,y)    
        dec_out = y[:,1:]
        loss = cross_entropy(logit.reshape(-1,self.dec_v_dim),dec_out.reshape(-1))
        loss.backward()
        self.opt.step()
        return loss.detach().numpy()


In [4]:
def train():
    dataset = utils.DateData(4000)
    print("Chinese time order: yy/mm/dd ",dataset.date_cn[:3],"\nEnglish time order: dd/M/yyyy", dataset.date_en[:3])
    print("Vocabularies: ", dataset.vocab)
    print(f"x index sample:  \n{dataset.idx2str(dataset.x[0])}\n{dataset.x[0]}",
    f"\ny index sample:  \n{dataset.idx2str(dataset.y[0])}\n{dataset.y[0]}")
    loader = DataLoader(dataset,batch_size=32,shuffle=True)
    model = Seq2Seq(dataset.num_word,dataset.num_word,emb_dim=16,units=32,max_pred_len=11,start_token=dataset.start_token,end_token=dataset.end_token)
    for i in range(100):
        for batch_idx , batch in enumerate(loader):
            bx, by, decoder_len = batch
            bx = bx.type(torch.LongTensor)
            by = by.type(torch.LongTensor)
            loss = model.step(bx,by)
            if batch_idx % 70 == 0:
                target = dataset.idx2str(by[0, 1:-1].data.numpy())
                pred = model.inference(bx[0:1])
                res = dataset.idx2str(pred[0].data.numpy())
                src = dataset.idx2str(bx[0].data.numpy())
                print(
                    "Epoch: ",i,
                    "| t: ", batch_idx,
                    "| loss: %.3f" % loss,
                    "| input: ", src,
                    "| target: ", target,
                    "| inference: ", res,
                )

In [5]:
if __name__ == "__main__":
    train()
    

Chinese time order: yy/mm/dd  ['31-04-25', '04-07-17', '33-06-06'] 
English time order: dd/M/yyyy ['25/Apr/2031', '17/Jul/2004', '06/Jun/2033']
Vocabularies:  {'/', '-', '8', '<PAD>', '<EOS>', 'Jan', 'Mar', '9', '7', '6', '5', 'May', 'Dec', 'Apr', '0', 'Nov', '4', '2', 'Jul', 'Jun', 'Oct', 'Aug', 'Sep', 'Feb', '1', '3', '<GO>'}
x index sample:  
31-04-25
[6 4 1 3 7 1 5 8] 
y index sample:  
<GO>25/Apr/2031<EOS>
[14  5  8  2 15  2  5  3  6  4 13]
Epoch:  0 | t:  0 | loss: 3.312 | input:  01-03-15 | target:  15/Mar/2001 | inference:  44444444444
Epoch:  0 | t:  70 | loss: 2.501 | input:  02-04-28 | target:  28/Apr/2002 | inference:  0//////////
Epoch:  1 | t:  0 | loss: 2.177 | input:  01-10-20 | target:  20/Oct/2001 | inference:  1/////2<EOS>
Epoch:  1 | t:  70 | loss: 1.792 | input:  02-08-24 | target:  24/Aug/2002 | inference:  22///200<EOS>
Epoch:  2 | t:  0 | loss: 1.530 | input:  34-04-08 | target:  08/Apr/2034 | inference:  00//20000<EOS>
Epoch:  2 | t:  70 | loss: 1.282 | input: 

Epoch:  37 | t:  0 | loss: 0.036 | input:  96-01-02 | target:  02/Jan/1996 | inference:  02/Jan/1996<EOS>
Epoch:  37 | t:  70 | loss: 0.029 | input:  06-02-03 | target:  03/Feb/2006 | inference:  03/Feb/2006<EOS>
Epoch:  38 | t:  0 | loss: 0.030 | input:  16-12-19 | target:  19/Dec/2016 | inference:  19/Dec/2016<EOS>
Epoch:  38 | t:  70 | loss: 0.024 | input:  85-08-12 | target:  12/Aug/1985 | inference:  12/Aug/1985<EOS>
Epoch:  39 | t:  0 | loss: 0.031 | input:  25-04-01 | target:  01/Apr/2025 | inference:  01/Apr/2025<EOS>
Epoch:  39 | t:  70 | loss: 0.030 | input:  30-03-13 | target:  13/Mar/2030 | inference:  13/Mar/2030<EOS>
Epoch:  40 | t:  0 | loss: 0.027 | input:  09-02-10 | target:  10/Feb/2009 | inference:  10/Feb/2009<EOS>
Epoch:  40 | t:  70 | loss: 0.023 | input:  04-04-12 | target:  12/Apr/2004 | inference:  12/Apr/2004<EOS>
Epoch:  41 | t:  0 | loss: 0.027 | input:  18-12-10 | target:  10/Dec/2018 | inference:  10/Dec/2018<EOS>
Epoch:  41 | t:  70 | loss: 0.018 | input:

Epoch:  75 | t:  70 | loss: 0.001 | input:  00-09-25 | target:  25/Sep/2000 | inference:  25/Sep/2000<EOS>
Epoch:  76 | t:  0 | loss: 0.001 | input:  03-07-06 | target:  06/Jul/2003 | inference:  06/Jul/2003<EOS>
Epoch:  76 | t:  70 | loss: 0.001 | input:  18-02-20 | target:  20/Feb/2018 | inference:  20/Feb/2018<EOS>
Epoch:  77 | t:  0 | loss: 0.001 | input:  33-05-02 | target:  02/May/2033 | inference:  02/May/2033<EOS>
Epoch:  77 | t:  70 | loss: 0.001 | input:  88-08-23 | target:  23/Aug/1988 | inference:  23/Aug/1988<EOS>
Epoch:  78 | t:  0 | loss: 0.001 | input:  09-04-10 | target:  10/Apr/2009 | inference:  10/Apr/2009<EOS>
Epoch:  78 | t:  70 | loss: 0.001 | input:  11-04-08 | target:  08/Apr/2011 | inference:  08/Apr/2011<EOS>
Epoch:  79 | t:  0 | loss: 0.001 | input:  26-02-03 | target:  03/Feb/2026 | inference:  03/Feb/2026<EOS>
Epoch:  79 | t:  70 | loss: 0.001 | input:  82-02-11 | target:  11/Feb/1982 | inference:  11/Feb/1982<EOS>
Epoch:  80 | t:  0 | loss: 0.001 | input: