In [1]:
from src.model import Model
from src.trainer import Trainer
from src.data_reader import amazon_dataset_iters
from src.beam_search import Beam

import pickle  # DEBUG
import torch
import numpy as np

import warnings
warnings.filterwarnings('ignore')

import os

In [2]:
os.environ.setdefault('CUDA_VISIBLE_DEVICES', '2')

'2'

Let's load a dataset.

In [3]:
text_vocab, tips_vocab, train_iter, val_iter, test_iter = (
    amazon_dataset_iters('./data/average_dataset/', device=None)
)

Loading datasets...
datasets loaded
item vocab built
user vocab built
text vocab built
tips vocab built


In [4]:
items_count = int(max([i.item.max().cpu().data.numpy() for i in train_iter] +
                      [i.item.max().cpu().data.numpy() for i in test_iter])[0])
users_count = int(max([i.user.max().cpu().data.numpy() for i in train_iter] +
                              [i.user.max().cpu().data.numpy() for i in test_iter])[0])
items_count, users_count

(11, 936)

Creating the model.

In [91]:
model = Model(vocabulary_size=len(text_vocab.itos),
              items_count=items_count+10,
              users_count=users_count+10,
              context_size=50,
              hidden_size=50,
              user_latent_factors_count=50,
              item_latent_factors_count=50).cuda()

In [92]:
trainer = Trainer(model)

Start training.

In [134]:
history = trainer.train(train_iter, n_epochs=1)

Epochs: 0 / 1, Loss: inf: 100%|██████████| 32/32 [00:03<00:00,  8.84it/s]


Let's decode the outputs.

In [135]:
batch_sample = next(iter(train_iter))

In [155]:
batch_predict_sample = model.forward(batch_sample.user, batch_sample.item)

In [156]:
beam_size = 22
beam = Beam(beam_size, text_vocab.stoi, cuda=True)

In [157]:
for i in range(5):
    beam.advance(torch.exp(batch_predict_sample[2][2, :, :]).data)

In [158]:
results = np.array([beam.get_hyp(i) for i in range(beam_size)])

In [159]:
n_best = 60
scores, ks = beam.sort_best()
hyps = list(zip(*[beam.get_hyp(k) for k in ks[:n_best]]))

In [160]:
print('\n'.join('\t'.join(text_vocab.itos[i] if i < len(text_vocab.itos) else '<!>' 
                         for i in results[k])
                for k in range(22)
               ))

$start	$start	$start	$start	$start
$start	$start	$start	to	<pad>
$start	$start	$start	$start	a
$start	$start	$start	$start	<unk>
$start	$start	$start	$start	the
$start	$start	$start	$start	$end
$start	$start	$start	$start	<pad>
$start	$start	$start	$start	great
$start	$start	$start	$start	of
$start	$start	$start	$start	this
$start	$start	$start	best	<pad>
$start	$start	$start	$start	,
$start	$start	$start	$start	classic
$start	$start	$start	$start	it
$start	$start	$start	$start	an
$start	$start	$start	$start	not
$start	$start	$start	$start	christmas
$start	$start	$start	$start	movie
$start	$start	$start	$start	!
$start	$start	$start	$start	-
$start	$start	$start	$start	best
$start	$start	$start	$start	to
