In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")
# focus on 2 layer single direction lstms before fucking around

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=2, bidirectional = False):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.lstm = nn.LSTM(input_size=input_size, 
                            hidden_size=hidden_size, 
                            num_layers=num_layers, 
                            batch_first=True,
                            bidirectional=bidirectional,)
        self.fc = nn.Linear(2*hidden_size,hidden_size)

    def forward(self, x):
        # x: (seq_len, batch_size, input_size)      seq len would be the number of touchpoints
        outputs, (hidden, cell) = self.lstm(x)
        print(next(self.parameters()).device)
        # if self.bidirectional:
        #     outputs = self.fc(outputs)
        # outputs: (seq_len, batch_size, hidden_size)
        # hidden, cell: (num_layers, batch_size, hidden_size)
        return outputs, hidden, cell

In [None]:
test_encoder = Encoder(input_size=6,
                       hidden_size=32,
                       num_layers=2,
                       bidirectional=False).to(device)
test_en_in = torch.rand((8, 60, 6)).to(device)
en_out, en_hidden, en_cell = test_encoder(test_en_in)

In [None]:
tt = torch.Tensor().to(device)

In [None]:
new = torch.rand((8,1, 27))
tt = torch.cat((tt, new), dim = 1)

In [2]:
class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size, force_ratio = 0.7, num_layers = 2, bidirectional = False):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.force_ratio = force_ratio

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.lstm = nn.LSTM(input_size=2*hidden_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            bidirectional=bidirectional,
                            batch_first=True)
        # the output size doubles if the lstm is bidirectional
        fc_in_size = 2*hidden_size if bidirectional else hidden_size
        self.fc = nn.Linear(fc_in_size,output_size)

    def forward(self, encoder_output, encoder_hidden, encoder_cell, word_tensor):
        """
        Args:
            encoder_output: output from encoder (N, L, decoder hidden)
            encoder_hidden: hidden state from encoder (D * Layers, N, Hidden) for encoder
            encoder_cell: cell state from encoder (D * Layers, N, Hidden) for encoder
            word_tensor: word tensors containing indicies of letters (N, max length)
        """
        batch_size = encoder_output.shape[0]
        max_word_length = word_tensor.shape[-1]
        # create the first decoder input, which is blank inputs
        decoder_input = torch.tensor([0] * batch_size, dtype=torch.long, device=next(self.parameters()).device)
        # initial hidden and cell state comes from final layer of encoder, shape is (N, encoder hidden)
        decoder_hidden = encoder_hidden
        decoder_cell = encoder_cell

        decoder_outputs = torch.Tensor().to(device)
        # for each letter in word
        for i in range(max_word_length):
            decoder_output, decoder_hidden, decoder_cell = self.step(decoder_input,
                                                                     decoder_hidden,
                                                                     decoder_cell,
                                                                     encoder_output)
            decoder_outputs = torch.cat((decoder_outputs,decoder_output), dim=1)  # add to list of outputs
            print(decoder_output.shape)

            teacher_force = torch.rand(1).item() < self.force_ratio
            
            if teacher_force:
                decoder_input = word_tensor[:,i]    # next letter
            else:
                # for using when not teacher forcing, use model prediction
                decoder_input = torch.argmax(decoder_output.squeeze(1), dim=-1) # 1D with length N
        
        probs = F.softmax(decoder_outputs, dim=-1)

        return probs, decoder_hidden, decoder_cell

    def step(self, decoder_input, hidden, cell, encoder_output):
        """
        Args:
            decoder_input: shape (N), should be indicies for the letters, dtype either int or long
            hidden: hidden state of lstm (D*Layers, N, decoder hidden)
            cell: cell state of lstm (D*Layers, N, deocder hidden)
            encoder_output: output from encoder (N, L, decoder hidden)
        """
        embedded = self.embedding(decoder_input.unsqueeze(1))  # (N, 1 decoder hidden)
        query = hidden[-1].unsqueeze(0).permute(1, 0, 2) # (N, D*Layers, decoder hidden)
        # print(query.shape)
        context = F.scaled_dot_product_attention(query, encoder_output, encoder_output) # (N, 1, decoder hidden)
        # print(context.shape)
        input_lstm = torch.cat((embedded, context), dim=-1) # (N, 1, 2*decoder hidden)
        output_lstm, (hidden, cell) = self.lstm(input_lstm, (hidden, cell)) # output lstm (N, 1, D*hidden)
        output_fc = self.fc(output_lstm)

        return output_fc, hidden, cell

In [3]:
test_decoder = Decoder(hidden_size=32,
                       output_size=27,
                       force_ratio=0.7,
                       num_layers=2,
                       bidirectional=False).to(device)

test_en_output = torch.rand((8, 60, 32)).to(device)
test_en_hid = torch.rand((2, 8, 32)).to(device)
test_en_cell = torch.rand((2, 8, 32)).to(device)
test_word = torch.randint(0, 28, (8, 10)).to(device)

de_out, de_hid, de_cell = test_decoder(test_en_output, test_en_hid, test_en_cell, test_word)

torch.Size([8, 1, 27])
torch.Size([8, 1, 27])
torch.Size([8, 1, 27])
torch.Size([8, 1, 27])
torch.Size([8, 1, 27])
torch.Size([8, 1, 27])
torch.Size([8, 1, 27])


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
