In [1]:
import torch
import torch.nn as nn
import random
import numpy as np
import pickle
from gru_model import Model

In [2]:
def load_dataset():
    d = pickle.load( open( "./prepared_dataset.p", "rb" ) )
    #d = pickle.load( open( "./data/short_sessions.p", "rb" ) )
    return d['x_test'], d['vocab']

x_test, vocab = load_dataset()

In [3]:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")

In [4]:
model = Model(vocab_size=214, embedding_dim=20, hidden_dim=100, gru_layers=1, dropout=0.0).to(device)
model.load_state_dict(torch.load('./state_dict.pth'))
loss_func = nn.CrossEntropyLoss()

In [5]:
def print_session(s):
        print()
        for i in s:
            print(i,vocab[i])
        print()

In [6]:
def batches(data, batch_size):
    """ Yields batches of sentences from 'data', ordered on length. """
    random.shuffle(data)
    for i in range(0, len(data), batch_size):
        sentences = data[i:i + batch_size]
        sentences.sort(key=lambda l: len(l), reverse=True)
        yield [torch.LongTensor(s) for s in sentences]

def step(model, sents, loss_func, device):
    """ Performs a model inference for the given model and sentence batch.
    Returns the model otput, total loss and target outputs. """
    x = nn.utils.rnn.pack_sequence([s[:-1] for s in sents])
    y = nn.utils.rnn.pack_sequence([s[1:] for s in sents])
    
    if device.type == 'cuda':
        x, y = x.cuda(), y.cuda()
    out = model(x)
    return out, y

def calc_accuracy(output_distribution, targets):
    prediction = torch.argmax(output_distribution, dim=1)
    print('prediction')
    print_session(prediction.tolist())
    print('targets')
    print_session(targets.tolist())
    num_correct_prediction = (prediction == targets).float().sum()
    return num_correct_prediction.item()/targets.shape[0]

def test_accuracy(test_data, model, loss_func, device):
    model.eval()
    test_accuracies = []
    with torch.no_grad():
        for sents in batches(test_data, 200):
            out, y = step(model, sents, loss_func, device)
            
            test_accuracies.append(calc_accuracy(out,y.data))
    print('test accuracy:',np.mean(test_accuracies))

In [11]:
test_accuracy(x_test[48:49], model, loss_func, device)

prediction

122 click_on_subscription
16 click_on_confirm
96 click_on_number_details
170 click_log_out

targets

3 submit_order
122 click_on_subscription
166 load_homepage
16 click_on_confirm

0.0


In [11]:
for x in x_test[16:17]:
        print()
        for i in x:
            print(i,vocab[i])


tensor(68) load_other_page
tensor(16) click_on_confirm
tensor(79) click_on_accept_continue
tensor(41) scroll_on_homepage
tensor(86) adding_additional_services
