In [1]:
import torch
import pickle
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
N = 7

In [3]:
x_data = np.load(f'./bindata/{N}-frame_x_data.npy', allow_pickle=True)
y_data = np.load(f'./bindata/{N}-frame_y_data.npy', allow_pickle=True)

In [4]:
dataset = TensorDataset(torch.tensor(x_data), torch.tensor(y_data, dtype=torch.int64))

In [5]:
trainLoader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

In [6]:
class LSTM_based(torch.nn.Module):

    def __init__(self, N = N) -> None:
        super(LSTM_based, self).__init__()
        self.lstm = torch.nn.LSTM(input_size = 113, hidden_size = 113, batch_first = True)
        self.h_linear = torch.nn.Linear(113 * N, 64)
        self.o_linear = torch.nn.Linear(64, 2)
        

    def forward(self, x):
        h_s, (_, _) = self.lstm(x)
        h_f = torch.flatten(h_s, start_dim=1)
        h_f = torch.nn.functional.leaky_relu(h_f)
        h_f = self.h_linear(h_f)
        h_f = torch.nn.functional.leaky_relu(h_f)
        h_o = self.o_linear(h_f)
        return h_o

In [7]:
model = LSTM_based().cuda()
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [8]:
for epoch in range(10):
    trainProcess = tqdm(trainLoader)
    total_loss = 0.
    for batch, (x, y) in enumerate(trainProcess, start = 1):
        x, y = x.cuda(), y.cuda()
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        trainProcess.set_postfix({"EPOCH" : epoch + 1, "AVG_LOSS" : total_loss / batch})


100%|██████████| 2199/2199 [00:05<00:00, 393.72it/s, EPOCH=1, AVG_LOSS=0.441]
100%|██████████| 2199/2199 [00:05<00:00, 397.37it/s, EPOCH=2, AVG_LOSS=0.408]
100%|██████████| 2199/2199 [00:05<00:00, 412.36it/s, EPOCH=3, AVG_LOSS=0.405]
100%|██████████| 2199/2199 [00:05<00:00, 406.53it/s, EPOCH=4, AVG_LOSS=0.403]
100%|██████████| 2199/2199 [00:05<00:00, 408.62it/s, EPOCH=5, AVG_LOSS=0.401]
100%|██████████| 2199/2199 [00:05<00:00, 406.00it/s, EPOCH=6, AVG_LOSS=0.4]  
100%|██████████| 2199/2199 [00:05<00:00, 407.52it/s, EPOCH=7, AVG_LOSS=0.398]
100%|██████████| 2199/2199 [00:05<00:00, 407.16it/s, EPOCH=8, AVG_LOSS=0.397]
100%|██████████| 2199/2199 [00:05<00:00, 405.47it/s, EPOCH=9, AVG_LOSS=0.396]
100%|██████████| 2199/2199 [00:05<00:00, 406.30it/s, EPOCH=10, AVG_LOSS=0.396]
