In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt

from data import Dictionary, Corpus, PAD_INDEX
from mst import mst
import os

In [2]:
def plot(S_arc, heads):
    fig, ax = plt.subplots()
    # Make a 0/1 gold adjacency matrix.
    n = heads.size(1)
    G = np.zeros((n, n))
    heads = heads.squeeze().data.numpy()
    G[heads, np.arange(n)] = 1.
    im = ax.imshow(G, vmin=0, vmax=1)
    fig.colorbar(im)
    plt.savefig('img/gold.pdf')
    plt.cla()
    # Plot the predicted adjacency matrix
    A = F.softmax(S_arc.squeeze(0), dim=0)
    fig, ax = plt.subplots()
    im = ax.imshow(A.data.numpy(), vmin=0, vmax=1)
    fig.colorbar(im)
    plt.savefig('img/a.pdf')
    plt.cla()
    plt.clf()


def predict(model, words, tags):
    assert type(words) == type(tags)
    if type(words) == type(tags) == list:
        # Convert the lists into input for the PyTorch model.
        words = Variable(torch.LongTensor([words]))
        tags = Variable(torch.LongTensor([tags]))
    # Dissable dropout.
    model.eval()
    # Predict arc and label score matrices.
    S_arc, S_lab = model(words, tags)

    # Predict heads
    S = S_arc[0].data.numpy()
    heads = mst(S)

    # Predict labels
    S_lab = S_lab[0]
    select = torch.LongTensor(heads).unsqueeze(0).expand(S_lab.size(0), -1)
    select = Variable(select)
    selected = torch.gather(S_lab, 1, select.unsqueeze(1)).squeeze(1)
    _, labels = selected.max(dim=0)
    labels = labels.data.numpy()
    return heads, labels


def predict_batch(S_arc, S_lab, tags):
    # Predict heads
    S = S_arc.data.numpy()
    heads = mst(S)

    # Predict labels
    select = torch.LongTensor(heads).unsqueeze(0).expand(S_lab.size(0), -1)
    select = Variable(select)
    selected = torch.gather(S_lab, 1, select.unsqueeze(1)).squeeze(1)
    _, labels = selected.max(dim=0)
    labels = labels.data.numpy()
    return heads, labels

In [6]:
data_path = 'data/ud/UD_English-EWT'
vocab_path = 'vocab/train'
model_path = 'checkpoints/enmodel.pt'

dictionary = Dictionary(vocab_path)
corpus = Corpus(data_path=data_path, vocab_path=vocab_path)
model = torch.load(model_path)
batches = corpus.train.batches(1)

print(model)
words, tags, heads, labels = next(batches)
S_arc, S_lab = model(words, tags)

plot(S_arc, heads)
words = tags = [1, 2, 3, 4]
heads_pred, labels_pred = predict(model, words, tags)
print(heads_pred, '\n', heads[0].data.numpy())
print(labels_pred, '\n', labels[0].data.numpy())

BiAffineParser(
  (embedding): WordEmbedding(
    (embedding): Embedding(19343, 300, padding_idx=0)
    (dropout): Dropout(p=0.3)
  )
  (encoder): RecurrentEncoder(
    (rnn): LSTM(300, 400, num_layers=3, batch_first=True, dropout=0.3, bidirectional=True)
  )
  (arc_mlp_h): MLP(
    (layers): Sequential(
      (fc_0): Linear(in_features=800, out_features=500, bias=True)
      (ReLU_0): ReLU()
      (dropout_0): Dropout(p=0.3)
      (fc_1): Linear(in_features=500, out_features=500, bias=True)
      (ReLU_1): ReLU()
      (dropout_1): Dropout(p=0.3)
    )
  )
  (arc_mlp_d): MLP(
    (layers): Sequential(
      (fc_0): Linear(in_features=800, out_features=500, bias=True)
      (ReLU_0): ReLU()
      (dropout_0): Dropout(p=0.3)
      (fc_1): Linear(in_features=500, out_features=500, bias=True)
      (ReLU_1): ReLU()
      (dropout_1): Dropout(p=0.3)
    )
  )
  (lab_mlp_h): MLP(
    (layers): Sequential(
      (fc_0): Linear(in_features=800, out_features=100, bias=True)
      (ReLU_0): ReL

KeyError: 'words'