In [5]:
import torch
import torch.nn as nn

In [13]:
input = torch.randn(2, 3, 4, 4)

print(input.shape)

torch.Size([2, 3, 4, 4])


In [16]:
print(input.size()[1:]) 
m1 = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
m2 = nn.LayerNorm(input.size()[1:], elementwise_affine=True)

boutput1 = m1(input)
boutput2 = m2(input)

print(boutput1 - boutput2)

torch.Size([3, 4, 4])
tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]],


        [[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], grad_fn=<SubBackward0>)


In [28]:
import random
import time

import pandas as pd
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score


class RNNBaseModel(nn.Module):
    def __init__(self, rnn_type, input_size, hidden_size, output_size):
        super(RNNBaseModel, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.layerNorm = nn.LayerNorm(input_size, dtype=torch.float64)

        self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, num_layers=1,
                                         batch_first=True, dtype=torch.float64, bidirectional=True)

        self.out = nn.Linear(hidden_size*2, output_size, dtype=torch.float64)

    def forward(self, inputs, hidden=None):
        inputs = self.layerNorm(inputs)
        output, hidden = self.rnn(inputs, hidden)
        output = self.out(output)
        return output, hidden


def fit(rnn_type):
    start = time.time()
    
    inputs = torch.randn(100, 67, 11, dtype=torch.float64)
    
    input_size = inputs.shape[2]-1
    

    # model = RNNBaseModel(train_ds[0][0].shape[1], 16, 1)
    model = RNNBaseModel(rnn_type, input_size=input_size, hidden_size=input_size * 2, output_size=1)
    loss_func = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    epochs = 1
    bs = 5

    for epoch in range(epochs):
        model.train()
        
        for bs_start in range(int(inputs.shape[0]/bs)):
            X = inputs[bs_start*bs:bs_start*bs+bs, :, :-1]
            y = inputs[bs_start*bs:bs_start*bs+bs, :, -1]
                        
            outputs, hidden = model(X, None)
            outputs = outputs.squeeze()

            print(X.shape)
            print(y.shape)
            print(outputs.shape)
            
            optimizer.zero_grad()
            loss = loss_func(outputs, y)
            loss.backward()
            optimizer.step()

#         if (epoch + 1) % 10 == 0:
#             model.eval()
#             with torch.no_grad():
#                 acc = torch.zeros(1)
#                 for test_X, test_y in test_dl:
#                     # test_X = test_X.unsqueeze(0)
#                     # print(test_X.shape)
#                     outputs, hidden = model(test_X, None)
#                     outputs = outputs.squeeze(2)

#                     # print(outputs.shape)
#                     # print(test_y.shape)

#                     loss = loss_func(outputs, test_y)

#                     acc += (outputs * test_y > 0).sum() / test_y.shape[1]

#                     test_size = len(test_ds)
#                     acc = acc / test_size

#                     print(f'{epoch + 1}/{epochs}: loss: {loss}, acc: {acc.item()}, test size: {test_size}')

    print(f'Run time: {time.time() - start}')


if __name__ == "__main__":
    rnn_type_list = ["RNN", "LSTM", "GRU"]

    for rnn_type in rnn_type_list:
        fit(rnn_type)


torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 67])
torch.Size([5, 67])
torch.Size([5, 67, 10])
torch.Size([5, 6

In [None]:
class LayerNormLSTMCell(nn.LSTMCell):
    def __init__(self, input_size, hidden_size, bias=True):
        super().__init__(input_size, hidden_size, bias)

        self.ln_ih = nn.LayerNorm(4 * hidden_size)
        self.ln_hh = nn.LayerNorm(4 * hidden_size)
        self.ln_ho = nn.LayerNorm(hidden_size)

    def forward(self, input, hidden=None):
        self.check_forward_input(input)
        if hidden is None:
            hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
            cx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
        else:
            hx, cx = hidden
        self.check_forward_hidden(input, hx, '[0]')
        self.check_forward_hidden(input, cx, '[1]')

        gates = self.ln_ih(F.linear(input, self.weight_ih, self.bias_ih)) \
                 + self.ln_hh(F.linear(hx, self.weight_hh, self.bias_hh))
        i, f, o = gates[:, :(3 * self.hidden_size)].sigmoid().chunk(3, 1)
        g = gates[:, (3 * self.hidden_size):].tanh()

        cy = (f * cx) + (i * g)
        hy = o * self.ln_ho(cy).tanh()
        return hy, cy
    

class LayerNormLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, bias=True, bidirectional=False):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional

        num_directions = 2 if bidirectional else 1
        self.hidden0 = nn.ModuleList([
            LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions),
                              hidden_size=hidden_size, bias=bias)
            for layer in range(num_layers)
        ])

        if self.bidirectional:
            self.hidden1 = nn.ModuleList([
                LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions),
                                  hidden_size=hidden_size, bias=bias)
                for layer in range(num_layers)
            ])

    def forward(self, input, hidden=None):
        seq_len, batch_size, hidden_size = input.size()  # supports TxNxH only
        num_directions = 2 if self.bidirectional else 1
        if hidden is None:
            hx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False)
            cx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False)
        else:
            hx, cx = hidden

        ht = [[None, ] * (self.num_layers * num_directions)] * seq_len
        ct = [[None, ] * (self.num_layers * num_directions)] * seq_len

        if self.bidirectional:
            xs = input
            for l, (layer0, layer1) in enumerate(zip(self.hidden0, self.hidden1)):
                l0, l1 = 2 * l, 2 * l + 1
                h0, c0, h1, c1 = hx[l0], cx[l0], hx[l1], cx[l1]
                for t, (x0, x1) in enumerate(zip(xs, reversed(xs))):
                    ht[t][l0], ct[t][l0] = layer0(x0, (h0, c0))
                    h0, c0 = ht[t][l0], ct[t][l0]
                    t = seq_len - 1 - t
                    ht[t][l1], ct[t][l1] = layer1(x1, (h1, c1))
                    h1, c1 = ht[t][l1], ct[t][l1]
                xs = [torch.cat((h[l0], h[l1]), dim=1) for h in ht]
            y  = torch.stack(xs)
            hy = torch.stack(ht[-1])
            cy = torch.stack(ct[-1])
        else:
            h, c = hx, cx
            for t, x in enumerate(input):
                for l, layer in enumerate(self.hidden0):
                    ht[t][l], ct[t][l] = layer(x, (h[l], c[l]))
                    x = ht[t][l]
                h, c = ht[t], ct[t]
            y  = torch.stack([h[-1] for h in ht])
            hy = torch.stack(ht[-1])
            cy = torch.stack(ct[-1])

        return y, (hy, cy)