In [1]:
import numpy as np
import torch

from torch.utils.data import DataLoader, random_split

from data_loaders.pulja_data_loader import PuljaDataLoader
from models._20220520_00 import UserModel
from models.utils import collate_fn

if torch.cuda.is_available():
    from torch.cuda import FloatTensor, LongTensor
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
    from torch import FloatTensor, LongTensor

In [2]:
batch_size = 256
seq_len = 100
train_ratio = 0.9

In [3]:
dataset = PuljaDataLoader(seq_len)

In [4]:
dataset.num_c, dataset.num_d

(37, 7)

In [5]:
dataset.c_seqs[0].shape

(101,)

In [6]:
train_size = int(len(dataset) * train_ratio)
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(
    dataset, [train_size, test_size]
)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_fn
)
test_loader = DataLoader(
    test_dataset, batch_size=test_size, shuffle=True,
    collate_fn=collate_fn
)

In [7]:
model = UserModel(dataset.num_c, dataset.num_d, 50)

In [8]:
model

UserModel(
  (D1): Embedding(7, 1)
  (D2): Embedding(7, 50)
  (gru): GRU(100, 50, batch_first=True)
  (linear_1): Sequential(
    (0): Linear(in_features=50, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=50, bias=True)
    (3): ReLU()
    (4): Linear(in_features=50, out_features=1, bias=True)
  )
  (linear_2): Sequential(
    (0): Linear(in_features=150, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=50, bias=True)
    (3): ReLU()
    (4): Linear(in_features=50, out_features=1, bias=True)
  )
)

In [9]:
for data in train_loader:
    print(data)
    
    c_seq, d_seq, r_seq, cshft_seq, dshft_seq, rshft_seq, m_seq = data
    
    print(model(c_seq, d_seq, r_seq))
    
    break

(tensor([[ 0,  0,  0,  ...,  9,  0,  6],
        [ 2,  2,  2,  ...,  5,  5,  5],
        [12, 12, 12,  ..., 33, 33, 33],
        ...,
        [27, 27, 27,  ...,  0,  0,  0],
        [ 5,  5,  5,  ...,  7,  7,  7],
        [ 6,  6,  6,  ...,  0,  0,  0]]), tensor([[0, 2, 2,  ..., 3, 4, 4],
        [1, 1, 1,  ..., 2, 1, 2],
        [2, 2, 2,  ..., 1, 1, 2],
        ...,
        [2, 1, 2,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0]]), tensor([[0.4314, 1.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.2484, 1.0000, 0.6405,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.5882, 0.0000],
        ...,
        [0.0000, 1.0000, 0.1176,  ..., -0.0000, -0.0000, -0.0000],
        [1.0000, 0.0000, 0.9020,  ..., 1.0000, 1.0000, 1.0000],
        [0.0000, 0.5098, 0.9412,  ..., -0.0000, -0.0000, -0.0000]]), tensor([[ 0,  0,  0,  ...,  0,  6,  6],
        [ 2,  2,  2,  ...,  5,  5,  5],
        [12, 12, 12,  ..., 33, 33, 33],


In [10]:
model(LongTensor(dataset.c_seqs[:5]), LongTensor(dataset.d_seqs[:5]), FloatTensor(dataset.r_seqs[:5]))

(tensor([[-0.1114, -0.0992, -0.0781, -0.1010, -0.0684, -0.1023, -0.1032, -0.0754,
          -0.0550, -0.0443, -0.0372, -0.0315, -0.0279, -0.0255, -0.0240, -0.0664,
          -0.0935, -0.1113, -0.1176, -0.1237, -0.1232, -0.0932, -0.0640, -0.0497,
          -0.0822, -0.0868, -0.0598, -0.0471, -0.0393, -0.0337, -0.0294, -0.0694,
          -0.0695, -0.0948, -0.1110, -0.1095, -0.0840, -0.0576, -0.0461, -0.0386,
          -0.0693, -0.0676, -0.0630, -0.0952, -0.0649, -0.0569, -0.0699, -0.0698,
          -0.0862, -0.0932, -0.0650, -0.0495, -0.0402, -0.0760, -0.0826, -0.0562,
          -0.0450, -0.0375, -0.0324, -0.0285, -0.0668, -0.0840, -0.1042, -0.0690,
          -0.0577, -0.0738, -0.0600, -0.0452, -0.0361, -0.0731, -0.0686, -0.0748,
          -0.0565, -0.0439, -0.0355, -0.0303, -0.0268, -0.0648, -0.0948, -0.1001,
          -0.1116, -0.0768, -0.1176, -0.1137, -0.1199, -0.1184, -0.1264, -0.1312,
          -0.1331, -0.1336, -0.0790, -0.0587, -0.0577, -0.0934, -0.1099, -0.1223,
          -0.115

In [11]:
alpha_seq, beta_seq, gamma_seq, h_seq, C_seq = model(LongTensor(dataset.c_seqs[:5]), LongTensor(dataset.d_seqs[:5]), FloatTensor(dataset.r_seqs[:5]))

In [12]:
alpha_seq.shape

torch.Size([5, 101])