### Import Libraries

In [1]:
from collections import OrderedDict
import torch
import numpy as np
import dill as pickle
from collections import Counter

In [2]:
import sys
sys.path.append("sopa_master")

In [3]:
from data import read_embeddings, read_docs, read_labels
from soft_patterns import ProbSemiring, MaxPlusSemiring, LogSpaceMaxTimesSemiring, SoftPatternClassifier, train, Batch, evaluate_accuracy
from util import to_cuda
from interpret_classification_results import interpret_documents
from visualize import visualize_patterns as vp

### Load Files

In [4]:
test_file = "data/time_data_clean/test.data"
test_label_file = "data/time_data_clean/test.labels"

In [5]:
patterns = "7-10_6-10_5-10_4-10_3-10_2-10"
pattern_specs = OrderedDict(sorted(([int(y) for y in x.split("-")] for x in patterns.split("_")),
                                key=lambda t: t[0]))

In [6]:
vocab = pickle.load(open("data/embeddings/vocab.p","rb"))
embeddings = pickle.load(open("data/embeddings/embeddings.p","rb"))
word_dim = pickle.load(open("data/embeddings/word_dim.p","rb"))

### Create Model

In [7]:
model = SoftPatternClassifier(
    pattern_specs=pattern_specs,
    mlp_hidden_dim=25,
    num_mlp_layers=5,
    num_classes=2,
    embeddings=embeddings,
    vocab=vocab,
    semiring=LogSpaceMaxTimesSemiring,
    bias_scale_param=0.1,
    shared_sl=False,
    no_sl=False
)

60 OrderedDict([(2, 10), (3, 10), (4, 10), (5, 10), (6, 10), (7, 10)])
# params: 256727


### Load weights

In [8]:
model_dict = torch.load("data/models/best_sopa.pth")
model.load_state_dict(model_dict)

### Load test data

In [10]:
test_input, test_text = read_docs(test_file, vocab, num_padding_tokens=0)

In [11]:
test_labels = read_labels(test_label_file)
test_data = list(zip(test_input, test_labels))

In [12]:
evaluate_accuracy(model, test_data, batch_size=150, gpu=False)

num predicted 1s: 212
num gold 1s:      232


0.9394773039889959

In [13]:
vp(model, test_data[:100], test_text[:100], k_best=5, max_doc_len=-1, num_padding_tokens=0)

.Pattern: 0 of length 2
Highest scoring spans:
0 -0.290  b'HP now            #label=1'
1 -0.290  b'HP now            #label=1'
2 -0.337  b'HP today          #label=1'
3 -0.349  b'HP 1988           #label=1'
4 -0.357  b'HP todays         #label=1'
self-loops:   1.00 * carefree        + -0.29,  1.00 * bOy             +  0.38
fwd 1s:       5.49 * 18:01:12        +  1.20
epsilons:                              -1.59

Pattern: 1 of length 2
Highest scoring spans:
0 -0.211  b'HP Friday,        #label=1'
1 -0.211  b'HP Friday,        #label=1'
2 -0.224  b'HP today          #label=1'
3 -0.233  b'HP Friday         #label=1'
4 -0.233  b'HP Friday         #label=1'
self-loops:   1.00 * THEA            +  0.60,  1.00 * Strella         +  0.11
fwd 1s:       5.59 * May-25-2013     + -0.47
epsilons:                              -2.35

Pattern: 2 of length 2
Highest scoring spans:
0 -0.309  b'HP agree          #label=0'
1 -0.327  b'HP abortion       #label=0'
2 -0.327  b'HP abortion       #label=1'
3 -

In [None]:
interpret_documents(model, 1, test_data, test_text, "data/interpret_best.txt", 20)

MM: 0.002, other: 0.003
ss torch.Size([1, 60])


  output = softmax(res).data
  forwarded = softmax(model.mlp.forward(Variable(torch.FloatTensor(scores_data)))).data.numpy()


MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.002, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.002, other: 0.006
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, ot

MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.002, other: 0.005
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.002, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.002, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.002, other: 0.006
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, ot

MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.002, other: 0.006
ss torch.Size([1, 60])
MM: 0.001, ot

MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.002
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.002, other: 0.005
ss torch.Size([1, 60])
MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.001, other: 0.001
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, ot