In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# use NVIDIA Geforce GTX 1650
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
class LSTM(nn.Module):

    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        #forget gate
        self.W_f = nn.Parameter(torch.Tensor(input_size, hidden_size), requires_grad = True)
        self.U_f = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = True)
        self.b_f = nn.Parameter(torch.Tensor(hidden_size), requires_grad = True)

        #input gate 
        self.W_i = nn.Parameter(torch.Tensor(input_size, hidden_size), requires_grad = True)
        self.U_i = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = True)
        self.b_i = nn.Parameter(torch.Tensor(hidden_size), requires_grad = True)

        #c_t
        self.W_c = nn.Parameter(torch.Tensor(input_size, hidden_size), requires_grad = True)
        self.U_c = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = True)
        self.b_c = nn.Parameter(torch.Tensor(hidden_size), requires_grad = True)

        #output gate
        self.W_o = nn.Parameter(torch.Tensor(input_size, hidden_size), requires_grad = True)
        self.U_o = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = True)
        self.b_o = nn.Parameter(torch.Tensor(hidden_size), requires_grad = True)

        self.init_weights()

    def init_weights(self):
        for weight in self.parameters():
            if weight.data.ndimension() < 2:  # Bias or 1D weights
                nn.init.zeros_(weight)
            else:
                nn.init.xavier_uniform_(weight)


    def forward(self, x, hidden = None):
        bs, seq, ip = x.size()
        hidden_seq = []

        if hidden is None:
            h_prev, c_prev = (
                torch.zeros(bs, self.hidden_size).to(device),
                torch.zeros(bs, self.hidden_size).to(device),
            )
        else:
            h_prev, c_prev = hidden

        for t in range(seq):
            x_t = x[:, t, :]
            i_t = torch.sigmoid(x_t @ self.W_i + h_prev @ self.U_i + self.b_i)
            f_t = torch.sigmoid(x_t @ self.W_f + h_prev @ self.U_f + self.b_f)
            g_t = torch.tanh(x_t @ self.W_c + h_prev @ self.U_c + self.b_c)
            o_t = torch.sigmoid(x_t @ self.W_o + h_prev @ self.U_o + self.b_o)
            c_t = f_t * c_prev + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(0))

        hidden_seq = torch.cat(hidden_seq, dim = 0)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        return hidden_seq, (h_t, c_t)

In [7]:
class LSTMNet(nn.Module):

    def __init__(self, input_size: int, hidden_size: int, num_classes: int):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first = True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x_, (h_t, c_t) = self.lstm(x)
        out = F.relu(self.fc(x_[:, -1, :]))
        return out

In [8]:
class CustomUnit(nn.Module):



    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # input weight
        self.W_i = nn.Parameter(torch.Tensor(input_size, hidden_size), requires_grad = True)

        # forget gate
        self.W_f = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = True)
        self.U_f = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = True)
        self.b_f = nn.Parameter(torch.Tensor(hidden_size), requires_grad = True)

        self.init_weights()

    def init_weights(self):
        for weight in self.parameters():
            if weight.data.ndimension() < 2:  # Bias or 1D weights
                nn.init.zeros_(weight)
            else:
                nn.init.xavier_uniform_(weight)



    def forward(self, x, hidden = None):

        bs, seq, ip = x.size()
        hidden_seq = []

        if hidden is None:
            h_prev, c_prev = (
                torch.zeros(bs, self.hidden_size).to(device),
                torch.zeros(bs, self.hidden_size).to(device),
            )
        else:
            h_prev, c_prev = hidden

        for t in range(seq):
            x_t = x[:, t, :]
            x_t = torch.tanh(x_t @ self.W_i)
            f_t = torch.sigmoid(x_t @ self.W_f + h_prev @ self.U_f + self.b_f)
            h_t = (f_t * h_prev) + ((1 - f_t) * x_t)
            c_t = c_prev
            hidden_seq.append(h_t.unsqueeze(0))

        hidden_seq = torch.cat(hidden_seq, dim = 0)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        return hidden_seq, (h_t, c_t)

In [9]:
class CustomNet(nn.Module):

    def __init__(self, input_size: int, hidden_size: int, num_classes: int):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        self.custom_layer = CustomUnit(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x_, (h_t, c_t) = self.custom_layer(x)
        out = F.relu(self.fc(x_[:, -1, :]))
        return out   