In [1]:
import numpy as np
import torch

from torch.utils.data import DataLoader, random_split

from data_loaders.pulja_data_loader import PuljaDataLoader
from models._20220520_00 import UserModel
from models.utils import collate_fn

if torch.cuda.is_available():
    from torch.cuda import FloatTensor, LongTensor
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
    from torch import FloatTensor, LongTensor

In [2]:
batch_size = 256
seq_len = 100
train_ratio = 0.9

In [3]:
dataset = PuljaDataLoader(seq_len)

In [4]:
dataset.num_c, dataset.num_d

(37, 7)

In [5]:
dataset.c_seqs[0].shape

(101,)

In [6]:
train_size = int(len(dataset) * train_ratio)
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(
    dataset, [train_size, test_size]
)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_fn
)
test_loader = DataLoader(
    test_dataset, batch_size=test_size, shuffle=True,
    collate_fn=collate_fn
)

In [7]:
model = UserModel(dataset.num_c, dataset.num_d, 50)

In [8]:
model

UserModel(
  (D1): Embedding(7, 1)
  (D2): Embedding(7, 50)
  (gru): GRU(100, 50, batch_first=True)
  (linear_1): Sequential(
    (0): Linear(in_features=50, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=50, bias=True)
    (3): ReLU()
    (4): Linear(in_features=50, out_features=1, bias=True)
  )
  (linear_2): Sequential(
    (0): Linear(in_features=150, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=50, bias=True)
    (3): ReLU()
    (4): Linear(in_features=50, out_features=1, bias=True)
  )
)

In [9]:
for data in train_loader:
    print(data)
    
    c_seq, d_seq, r_seq, cshft_seq, dshft_seq, rshft_seq, m_seq = data
    
    print(model(c_seq, d_seq, r_seq))
    
    break

(tensor([[11, 11, 11,  ..., 20, 20, 20],
        [ 5,  5,  5,  ...,  7,  7,  7],
        [17, 17, 17,  ..., 17, 17, 17],
        ...,
        [ 6,  6,  6,  ..., 18, 18, 18],
        [ 3,  3,  3,  ..., 11, 11, 11],
        [14, 14, 14,  ..., 19, 19, 19]]), tensor([[1, 1, 1,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 0, 0, 1],
        [2, 2, 2,  ..., 2, 2, 2],
        ...,
        [2, 2, 2,  ..., 2, 1, 2],
        [2, 3, 2,  ..., 2, 3, 2],
        [1, 2, 1,  ..., 2, 2, 0]]), tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.6667, 0.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.0000],
        [0.5294, 0.4510, 0.4314,  ..., 0.2745, 0.4902, 1.0000],
        ...,
        [0.0000, 1.0000, 0.2549,  ..., 0.4706, 1.0000, 0.0000],
        [0.0000, 0.0000, 0.6471,  ..., 0.1373, 0.0000, 0.9412],
        [0.0000, 1.0000, 0.7974,  ..., 0.0000, 0.9412, 0.5098]]), tensor([[11, 11,  1,  ..., 20, 20, 20],
        [ 5,  5,  5,  ...,  7,  7,  7],
        [17, 17, 17,  ..., 17, 17, 17],
      

In [10]:
model(LongTensor(dataset.c_seqs[:5]), LongTensor(dataset.d_seqs[:5]), FloatTensor(dataset.r_seqs[:5]))

(tensor([[ 5.4157e-03, -7.3011e-04,  3.7128e-03,  1.0386e-02,  8.8851e-03,
           1.4107e-02, -3.7402e-04,  2.9285e-03,  7.5906e-03,  1.0560e-02,
           1.2527e-02,  1.3942e-02,  1.5032e-02,  1.5534e-02,  1.5869e-02,
           3.9514e-04,  7.5081e-03,  7.1528e-03,  5.7702e-03,  4.3111e-03,
          -5.1521e-03, -8.7487e-04,  5.0705e-03,  8.7535e-03, -1.0570e-04,
           9.2263e-04,  5.5977e-03,  9.5595e-03,  1.2108e-02,  1.3897e-02,
           1.5168e-02,  3.4366e-04,  5.6957e-03,  1.2130e-02,  1.2275e-02,
          -1.2863e-03,  2.3888e-03,  6.8883e-03,  1.0155e-02,  1.2325e-02,
          -2.2110e-03, -6.3190e-03,  7.4509e-04,  1.0327e-02,  5.6188e-03,
           5.0394e-03, -5.4558e-04, -1.0409e-03,  1.0600e-02,  2.7821e-05,
           4.3111e-03,  8.3538e-03,  1.1131e-02,  2.1979e-03,  1.9401e-03,
           6.2234e-03,  1.0037e-02,  1.2403e-02,  1.4010e-02,  1.5136e-02,
          -4.9225e-04,  6.3860e-03,  1.0166e-02,  1.2915e-02,  5.0475e-03,
           7.8291e-05,  6

In [11]:
alpha_seq, beta_seq, gamma_seq, h_seq, C_seq = model(LongTensor(dataset.c_seqs[:5]), LongTensor(dataset.d_seqs[:5]), FloatTensor(dataset.r_seqs[:5]))

In [12]:
alpha_seq.shape

torch.Size([5, 101])