In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext



class Decoder(nn.Module):
    def __init__(
        self,
        trg_dim,
        output_dim,
        d_model, 
        nhead, 
        num_dec_layers, 
        dim_feedforward, 
        dropout, 
        device,
        src_pad_index,
        trg_pad_index,
        max_len=100
    ):
        
        super().__init__()
        
        self.scale = torch.sqrt(torch.FloatTensor([d_model])).to(device)
        
        self.trg_embedding = nn.Embedding(trg_dim, d_model)
        self.trg_pos_decoder = nn.Embedding(max_len, d_model)
        
        transformer_dec_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,)
        
        dec_norm = nn.LayerNorm(d_model)
        
        # for generating mask
        self.generate_square_subsequent_mask = nn.Transformer().generate_square_subsequent_mask
        
        self.transformer_dec = nn.TransformerDecoder(transformer_dec_layer, num_dec_layers, dec_norm)
        
        self.output_layer = nn.Linear(d_model, output_dim)
        
        self.device = device
        
        self.dropout = nn.Dropout(p=dropout)
        
        self.src_pad_index = src_pad_index
        self.trg_pad_index = trg_pad_index
        
        
    def forward(self, src, memory, trg):
        
        # src = [src sent len, batch size]
        # trg = [trg sent len, batch size]
        # memory = [src sent len, batch_size, d_model]
        
        trg_embedded = self.trg_embedding(trg) * self.scale
        # trg_embedded = [trg sent len, batch size, d_model]
        trg_pos = torch.arange(0, trg.shape[0]).unsqueeze(0).repeat(trg.shape[1], 1)
        # trg_pos = [batch size, trg sent len]
        trg_pos = trg_pos.permute(1, 0).to(self.device)
        # trg_pos = [trg sent len, batch size]
        trg_embedded = self.dropout(self.trg_pos_decoder(trg_pos) + trg_embedded)
        # trg_embedded = [trg sent len, batch size, d_model]
        
        memory_padding_mask = (src.cpu() == self.src_pad_index)
        memory_padding_mask = memory_padding_mask.permute(1, 0).to(self.device)
        # memory_padding_mask = [batch_size, src_len]
        
        trg_padded_mask = (trg.cpu() == self.trg_pad_index)
        trg_padded_mask = trg_padded_mask.permute(1, 0).to(self.device)
        # trg_padded_mask = [batch_size, trg_len]
        
        trg_mask = self.generate_square_subsequent_mask(trg_embedded.size(0)).to(self.device)
        # trg_mask = [trg_len, trg_len]
        
        output = self.transformer_dec( 
            trg_embedded,
            memory,
            tgt_mask=trg_mask, 
            tgt_key_padding_mask=trg_padded_mask,
            memory_key_padding_mask=memory_padding_mask
        )
        # output = [trg_len, batch_size, d_model]
        
        output = self.output_layer(output)
        # output = [trg_len, batch_size, output_dim]
        
        return output

In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [22]:
enc = Decoder(trg_dim = 100,
        output_dim = 100,
        d_model = 100, 
        nhead = 100, 
        num_dec_layers = 100, 
        dim_feedforward = 100, 
        dropout = 0.1, 
        device = device,
        src_pad_index = 0,
        trg_pad_index = 0,
        max_len=100)