In [1]:
import numpy as np
import torch
import torch.nn.functional as F

In [2]:
# Explore with the Embedding layer
x = torch.tensor([[0, 2], [1, 1], [2, 0], [0, 0]], dtype=torch.long)
input_cls_num = 3
target_hid_dim = 16
embed = torch.nn.Embedding(input_cls_num, target_hid_dim)
embeded_seq = embed(x)
print(embeded_seq.shape)

torch.Size([4, 2, 16])


In [3]:
x_tp = torch.tensor([[0, 1, 2, 0]], dtype=torch.long)
embeded_tp = embed(x_tp)
print(embeded_tp)

tensor([[[-1.8788, -0.0331,  1.1378, -0.9246, -0.2182,  0.6418, -0.2694,
          -0.4144, -1.2269,  1.4005,  0.1025,  0.9638,  0.6893,  1.1649,
           1.9673, -1.3042],
         [ 0.0096,  1.4777, -0.7478,  1.8785,  0.5272,  0.5485, -1.6951,
          -0.3182, -0.4823, -0.0222, -1.4294,  1.0353,  0.8258,  1.1885,
          -0.2497,  1.1072],
         [-0.8538,  1.2686, -0.7242, -0.1459, -0.2994,  0.5838, -0.0340,
           2.0798, -0.1340,  0.0988, -0.3270, -0.1567, -1.5562, -0.5270,
           0.4882, -0.3418],
         [-1.8788, -0.0331,  1.1378, -0.9246, -0.2182,  0.6418, -0.2694,
          -0.4144, -1.2269,  1.4005,  0.1025,  0.9638,  0.6893,  1.1649,
           1.9673, -1.3042]]], grad_fn=<EmbeddingBackward>)


In [6]:
# Playing with GRU layer (bidirectional or not)
# the input is in shape [seq len, batch size, embedding dim]
input_seq = embeded_seq
input_size = target_hid_dim
hidden_size = 32
# define the gru layers
bi_gru = torch.nn.GRU(input_size=input_size, hidden_size=hidden_size, bidirectional=True)
ow_gru = torch.nn.GRU(input_size=input_size, hidden_size=hidden_size, bidirectional=False)
# every batch and rnn will only correspond to one hidden state, the length of sequence (num of tokens) and num of rnns decide the num of hidden state
bi_output, bi_hidden_states = bi_gru(input_seq)
ow_output, ow_hidden_states = ow_gru(input_seq)
# the output will be in shape [seq len, batch size, hidden dim * num of networks]
print("The shape of bidirectional gru layer's output: ", bi_output.shape)
print("The shape of simple gru layer's output: ", ow_output.shape)
# the hidden state is [num of networks, batch size, hidden dim]
print("When bidirectional, the shape of hidden states will be: ", bi_hidden_states.shape)
print("Thus the simple one is: ", ow_hidden_states.shape)

The shape of bidirectional gru layer's output:  torch.Size([4, 2, 64])
The shape of simple gru layer's output:  torch.Size([4, 2, 32])
When bidirectional, the shape of hidden states will be:  torch.Size([2, 2, 32])
Thus the simple one is:  torch.Size([1, 2, 32])


In [3]:
# Some tensor transform tricks
tensor_a = torch.tensor([[[1, 1, 6, 6], [1, 1, 6, 6]], [[1, 1, 6, 6], [2, 4, 3, 9]]])
print(tensor_a.shape)
tensor_a_cat = torch.cat([tensor_a[0], tensor_a[1]], dim=-1)
print(tensor_a_cat.shape)

torch.Size([2, 2, 4])
torch.Size([2, 8])


In [4]:
tensor_b = tensor_a_cat.repeat(2, 1, 1)
print(tensor_b.shape)

torch.Size([2, 2, 8])


In [19]:
tensor_c.reshape((1,1,4,1))

tensor([[[[1],
          [8],
          [6],
          [2]]]])

In [7]:
tensor_c = torch.tensor([[[1], [8], [6], [2]]])
print(tensor_c.shape)
tensor_c_seq = tensor_c.squeeze(-1)
print(tensor_c_seq.shape)
tensor_c_unseq = tensor_c_seq.unsqueeze(0)
print(tensor_c_unseq.shape)

torch.Size([1, 4, 1])
torch.Size([1, 4])
torch.Size([1, 1, 4])


In [55]:
# the usage of some matrix functions of pytorch
attn_tensor = torch.tensor([[[.3]], [[.2]], [[.4]], [[.1]]], dtype=torch.float32)
print(attn_tensor.shape)
attn_tensor = attn_tensor.squeeze(-1)
print(attn_tensor)
print(attn_tensor.shape)

torch.Size([4, 1, 1])
tensor([[0.3000],
        [0.2000],
        [0.4000],
        [0.1000]])
torch.Size([4, 1])


In [56]:
# prepare two matrices
logits = F.softmax(attn.t())
logits = torch.cat([logits, logits]).unsqueeze(-1)
logits = logits.reshape(-1, logits.shape[-1], logits.shape[-2])
print(logits)
print(logits.shape)

input_tensor = torch.tensor([[[1, 2, 4]], [[2, 2, 3]], [[4, 3, 6]], [[3, 5, 6]]], dtype=torch.float32)
input_tensor = input_tensor.reshape(input_tensor.shape[1], input_tensor.shape[0], -1)
input_tensor = torch.cat([input_tensor, input_tensor])
print(input_tensor)
print(input_tensor.shape)

tensor([[[0.2612, 0.2363, 0.2887, 0.2138]],

        [[0.2612, 0.2363, 0.2887, 0.2138]]])
torch.Size([2, 1, 4])
tensor([[[1., 2., 4.],
         [2., 2., 3.],
         [4., 3., 6.],
         [3., 5., 6.]],

        [[1., 2., 4.],
         [2., 2., 3.],
         [4., 3., 6.],
         [3., 5., 6.]]])
torch.Size([2, 4, 3])


In [57]:
# it should transform dim [a, b, c] and [a, c, d] to [a, b, d]
answer = torch.bmm(logits, input_tensor)
print(answer.shape)
print(answer)

torch.Size([2, 1, 3])
tensor([[[2.5300, 2.9302, 4.7687]],

        [[2.5300, 2.9302, 4.7687]]])


In [None]:
class Seq2SeqModel(torch.nn.Module):
    def __init__(self, seq_len, input_dim, output_dim, gru_num_layers=1, bidirectional=False, dropout=.3, hidden_layer=[64, 32, 64], max_seq_len=128):
        super(Seq2SeqModel, self).__init__()
        # store the variables
        self.seq_len = seq_len
        self.input_dim = input_dim
        # thus the dim of ouputs can be aligned
        self.hidden_dim = input_dim // 2 if bidirectional else input_dim
        self.output_dim = output_dim
        self.gru_num_layers = gru_num_layers
        self.max_seq_len = max_seq_len
        # Encoder GRU
        self.encoder = torch.nn.GRU(
            input_size=self.input_dim, 
            hidden_size=self.hidden_dim, 
            num_layers=self.gru_num_layers, 
            bidirectional=bidirectional, 
            batch_first=True,
            dropout=dropout
        )
        # Decoder GRU
        self.decoder = torch.nn.GRU(
            input_size=self.input_dim, 
            hidden_size=self.hidden_dim, 
            num_layers=self.gru_num_layers, 
            bidirectional=bidirectional, 
            batch_first=True,
            dropout=dropout
        )
        # The FFN to adjust the outputs
        if hidden_layer and not len(hidden_layer) == 0:
            # the dim is not changed through the two GRU layer
            hidden_list = [torch.nn.Linear(self.input_dim, hidden_layer[0])]
            for idx in range(len(hidden_layer) - 1):
                hidden_list.append(torch.nn.Linear(hidden_layer[idx], hidden_layer[idx + 1]))
            self.hidden_layer_list = torch.nn.ModuleList(hidden_list)
            # init the weights
            for layer in self.hidden_layer_list: 
                torch.nn.init.kaiming_normal_(layer.weight.data)
            self.hidden_out_dim = hidden_layers[-1]
        else:
            self.hidden_layer_list = []
            self.hidden_out_dim = self.input_dim
        # Output layer
        self.output = torch.nn.Linear(self.hidden_out_dim, self.output_dim)
        torch.nn.init.kaiming_normal_(self.output.weight.data)
        # Other functions
        self.activate = torch.nn.ReLu()
        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, x, prev_y, teacher=True):
        # encode process
        encode_x, encoder_hidden_info = self.encoder(x)
        # decode process
        if teacher:
            decode_y, decoder_hidden_info = self.decoder(prev_y, encoder_hidden_info)
        output = self.activate(decode_y)
        # ffn process
        for layer in self.hidden_layer_list:
            output = layer(output)
            output = self.activate(output)
            output = self.dropout(output)
        # output layer, get logits
        output = self.output(decode_x)
        return output

In [None]:
class Seq2SqeEncoder(torch.nn.Module):
    def __init__(self, seq_len, input_dim, hidden_dim, bidirectional=False, dropout=.3):
        super(Seq2SqeEncoder, self).__init__()
        # keep the parameters
        self.seq_len = seq_len
        self.input_dim = input_dim
        self.output_dim = hidden_dim * 2 if bidirectional else hidden_dim
        self.hidden_dim = hidden_dim
        # define encoder
        self.encoder = torch.nn.GRU(
            input_size=self.input_dim, 
            hidden_size=hidden_dim,
            bidirectional=bidirectional, 
            batch_first=True,
            dropout=dropout
        )
        
    def forward(self, x):
        encoder_output, hidden_states = self.encoder(x)
        # if bidirectional, the dim of outputs will be doubled and we will get 2 hidden state
        return encoder_output, hidden_states

In [None]:
class Seq2SeqAttention(torch.nn.Module):
    def __init__(self, encoder_output_dim, decoder_hidden_dim):
        super(Seq2SeqAttention, self).__init__()
        # define attention mapping
        self.attn = torch.nn.Linear(encoder_output_dim + decoder_hidden_dim, decoder_hidden_dim)
        self.logits = torch.nn.Linear(decoder_hidden_dim, 1, bias=False)
    
    def forward(self, encoder_outputs, decoder_hidden_states):
        # if the encoder & decoder are bidirectional, there will be two hidden states (for two RNN in different directions)
        bidirectional = True if decoder_hidden_state.shape[0] == 2 else False
        if bidirectional:
            decoder_hidden_states = torch.cat([decoder_hidden_states[0], decoder_hidden_states[1]], dim=-1)
        # attach the decoder's current hidden state to each encoder's output
        decoder_hidden_states = decoder_hidden_states.repeat(encoder_outputs.shape[0], encoder_outputs.shape[1], 1)
        attn_input = torch.cat([encoder_outputs, decoder_hidden_states], dim=-1)
        # thus the attention input dim is [seq len, batch size, encoder_output_dim + decoder_hidden_state_dim]
        attn = torch.tanh(self.attn(attn_input))
        # for each seq, the logits layer ouputs one figure as its attention(weight) [seq len, batch size], then softmax will turn them into logits
        attn_logits = F.softmax(self.logits(attn).squeeze(-1), dim=1)
        return attn_logits

In [None]:
class Seq2SeqDecoder(torch.nn.Module):
    def __init__(self, seq_len, input_dim, hidden_dim, output_dim, attention_layer, bidirectional=False, dropout=.3):
        super(Seq2SeqDecoder, self).__init__()
        # keep the parameters
        self.seq_len = seq_len
        # for each input, the attention weight c is appended
        self.input_dim = input_dim + hidden_dim
        self.output_dim = output_dim
        self.decoder_output_dim = hidden_dim * 2 if bidirectional else hidden_dim
        self.hidden_dim = hidden_dim
        # create the decoder
        self.attention = attention_layer
        self.decoder = torch.nn.GRU(
            input_size=self.input_dim, 
            hidden_size=self.hidden_dim,
            bidirectional=bidirectional, 
            batch_first=True,
            dropout=dropout
        )
        self.output_layer = torch.nn.Linear(self.decoder_output_dim * 2, self.output_dim)

    def forward(self, x, prev_hidden_states, encoder_outputs):
        # in teacher mode, x can be prev_y
        attn = self.attention(encoder_outputs=encoder_outputs, decoder_hidden_states=prev_hidden_states)
        # transform the dim into [batch size, 1, seq len]
        attn = attn.t().unsqueeze(1)
        # encoder outputs [seq len, batch size, hidden dim]
        weight_vec = torch.bmm(attn, encoder_outputs.reshape(encoder_outputs.shape[1], encoder_outputs.shape[0], -1))
        # now weight vec [batch size, 1, hidden dim]
        # input x, the embedding of prev output or ground truth [1, batch size, input dim]
        weight_vec = weight_vec.reshape(weight_vec.shape[1], weight_vec.shape[0], -1)
        # concate the embedding and weights with the batch size dim: [1, batch size, input dim + hidden dim]
        decoder_input = torch.cat((x, weight_vec), dim = -1)
        output, hidden_states = self.decoder(decoder_input, prev_hidden_states)
        # deal with the output [1, batch size, output dim] -> [batch size, output dim + hidden dim]
        output = self.output_layer(torch.cat([output, weight_vec], dim=-1).squeeze(0))
        return output, hidden_states

In [None]:
class Seq2SeqAttnModel(torch.nn.Module):
    def __init__(self, seq_len, num_class, embedding_dim=32, hidden_dim=32, bidirectional=False, dropout=.3):
        super(Seq2SeqAttnModel, self).__init__()
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.encoder_hidden_dim = hidden_dim // 2 if bidirectional else hidden_dim
        self.decoder_hidden_dim = encoder_hidden_dim
        self.encoder = Seq2SqeEncoder(
            seq_len=seq_len, 
            input_dim=embedding_dim, 
            hidden_dim=self.encoder_hidden_dim, 
            bidirectional=bidirectional, 
            dropout=dropout
        ).to(device)
        self.attention = Seq2SeqAttention(self.encoder_hidden_dim, self.decoder_hidden_dim).to(device)
        self.decoder = Seq2SeqDecoder(
            seq_len=seq_len,
            input_dim=embedding_dim, 
            hidden_dim=self.decoder_hidden_dim, 
            output_dim=num_class, 
            attention_layer=self.attention, 
            bidirectional=bidirectional, 
            dropout=dropout
        ).to(device)
    
    def forward(self, x, y, teacher=True):
        # teacher mode
        encoder_outputs, hidden_states = self.encoder(x)
        predictions = []
        if teacher:
            for prev_label in y:
                pred_y, decoder_hidden_states = self.decoder(x=prev_label, prev_hidden_states=hidden_states, encoder_outputs=encoder_outputs)
                hidden_states = decoder_hidden_states
                predictioins.append(pred_y)
        return predictions