In [1]:
import numpy as np
import torch

from data_loaders.pulja_data_loader import PuljaDataLoader
from models._20220520_00 import UserModel
from models.utils import collate_fn

if torch.cuda.is_available():
    from torch.cuda import FloatTensor, LongTensor
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
    from torch import FloatTensor, LongTensor

In [2]:
seq_len = 100

In [3]:
dataset = PuljaDataLoader(seq_len)

In [4]:
dataset.num_c, dataset.num_d

(37, 7)

In [5]:
dataset.c_seqs[0].shape

(101,)

In [6]:
model = UserModel(dataset.num_c, dataset.num_d, 50)

In [7]:
model

UserModel(
  (D1): Embedding(7, 1)
  (D2): Embedding(7, 50)
  (gru): GRU(100, 50, batch_first=True)
  (linear_1): Sequential(
    (0): Linear(in_features=50, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=50, bias=True)
    (3): ReLU()
    (4): Linear(in_features=50, out_features=1, bias=True)
  )
  (linear_2): Sequential(
    (0): Linear(in_features=150, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=50, bias=True)
    (3): ReLU()
    (4): Linear(in_features=50, out_features=1, bias=True)
  )
)

In [8]:
model(LongTensor(dataset.c_seqs[:5]), LongTensor(dataset.d_seqs[:5]), FloatTensor(dataset.r_seqs[:5]))

(tensor([[ 3.1342e-02, -4.4428e-03, -5.2711e-03, -2.9942e-02, -1.9846e-02,
          -3.8592e-02, -2.5297e-03,  2.0252e-02,  3.2853e-02,  3.3471e-02,
           3.2829e-02,  3.2256e-02,  3.2022e-02,  3.2006e-02,  3.2128e-02,
           4.4421e-04, -1.8993e-02, -3.3210e-02, -4.1893e-02, -4.8580e-02,
          -1.2496e-02,  1.1197e-02,  2.7971e-02,  3.2050e-02, -2.1221e-03,
           1.4818e-02,  2.9051e-02,  3.1664e-02,  3.2173e-02,  3.2176e-02,
           3.2305e-02, -1.7708e-04,  3.2186e-04, -2.5884e-02, -4.0259e-02,
          -6.8394e-03,  1.6284e-02,  3.1460e-02,  3.2939e-02,  3.2705e-02,
           1.6322e-02,  1.9358e-03, -4.9612e-03, -2.9183e-02, -2.0786e-02,
          -1.8688e-02,  2.7420e-02,  6.4171e-03, -2.6028e-02,  6.5535e-03,
           2.7785e-02,  3.4508e-02,  3.4021e-02, -3.7536e-04,  1.7728e-02,
           3.0645e-02,  3.1913e-02,  3.2159e-02,  3.2403e-02,  3.2495e-02,
           1.9487e-02, -1.7048e-02, -3.1693e-02, -1.9857e-02, -1.7140e-02,
           2.2774e-02,  3

In [9]:
alpha_seq, beta_seq, gamma_seq, h_seq, C_seq = model(LongTensor(dataset.c_seqs[:5]), LongTensor(dataset.d_seqs[:5]), FloatTensor(dataset.r_seqs[:5]))

In [11]:
alpha_seq.shape

torch.Size([5, 101])