### Import Libraries

In [1]:
from collections import OrderedDict
import torch
import numpy as np
import dill as pickle
from collections import Counter

In [2]:
import sys
sys.path.append("sopa_master")

In [3]:
from data import read_embeddings, read_docs, read_labels
from soft_patterns import ProbSemiring, MaxPlusSemiring, LogSpaceMaxTimesSemiring, SoftPatternClassifier, train, Batch, evaluate_accuracy
from util import to_cuda
from interpret_classification_results import interpret_documents
from visualize import visualize_patterns as vp

### Load Files

In [4]:
test_file = "data/time_data_clean/test.data"
test_label_file = "data/time_data_clean/test.labels"

In [5]:
patterns = "7-10_6-10_5-10_4-10_3-10_2-10"
pattern_specs = OrderedDict(sorted(([int(y) for y in x.split("-")] for x in patterns.split("_")),
                                key=lambda t: t[0]))

In [6]:
vocab = pickle.load(open("data/embeddings/vocab.p","rb"))
embeddings = pickle.load(open("data/embeddings/embeddings.p","rb"))
word_dim = pickle.load(open("data/embeddings/word_dim.p","rb"))

### Create Model

In [7]:
model = SoftPatternClassifier(
    pattern_specs=pattern_specs,
    mlp_hidden_dim=25,
    num_mlp_layers=5,
    num_classes=2,
    embeddings=embeddings,
    vocab=vocab,
    semiring=LogSpaceMaxTimesSemiring,
    bias_scale_param=0.1,
    shared_sl=False,
    no_sl=False
)

60 OrderedDict([(2, 10), (3, 10), (4, 10), (5, 10), (6, 10), (7, 10)])
# params: 256727


### Load weights

In [8]:
model_dict = torch.load("data/models/best_sopa.pth")
model.load_state_dict(model_dict)

### Load test data

In [9]:
test_input, test_text = read_docs(test_file, vocab, num_padding_tokens=0)

In [11]:
test_labels = read_labels(test_label_file)
test_data = list(zip(test_input, test_labels))

In [12]:
evaluate_accuracy(model, test_data, batch_size=150, gpu=False)

num predicted 1s: 245
num gold 1s:      288


0.8384279475982532

In [13]:
vp(model, test_data[:100], test_text[:100], k_best=5, max_doc_len=-1, num_padding_tokens=0)

.

RuntimeError: $ Torch: not enough memory: you tried to allocate 6GB. Buy new RAM! at /pytorch/torch/lib/TH/THGeneral.c:270

In [13]:
top_pos, top_neg = interpret_documents(model, 1, test_data, test_text,"data/prob_semiring_interpret.txt", 20)

MM: 0.006, other: 0.013
ss torch.Size([1, 60])
MM: 0.001, other: 0.005
ss torch.Size([1, 60])
MM: 0.0, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.005
ss torch.Size([1, 60])
MM: 0.0, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.008
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.006
ss torch.Size([1, 60])
MM: 0.0, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.005
ss torch.Size([1, 60])
MM: 0.0, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.0, other: 0.004
ss torch.Size([1, 60])
MM: 0.0, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.007
ss torch.Size([1, 60])
MM: 0.001, other: 0.081
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.0, other: 0.003
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss torch.Size([1, 60])
MM: 0.001, other: 0.004
ss 

In [18]:
[(x[0], x[1]/(len(top_neg)/10)) for x in Counter(top_neg).most_common(len(set(top_neg)))]

[(43, 0.5107084019769358),
 (22, 0.5107084019769358),
 (44, 0.5107084019769358),
 (57, 0.5107084019769358),
 (34, 0.5107084019769358),
 (41, 0.48929159802306427),
 (45, 0.48819330038440417),
 (53, 0.48819330038440417),
 (6, 0.4876441515650741),
 (16, 0.48215266337177376),
 (27, 0.47885777045579353),
 (50, 0.47885777045579353),
 (37, 0.4332784184514003),
 (58, 0.37891268533772654),
 (7, 0.35914332784184516),
 (20, 0.3470620538165843),
 (52, 0.29873695771554093),
 (51, 0.2833607907742998),
 (47, 0.27347611202635913),
 (28, 0.22899505766062603),
 (19, 0.1971444261394838),
 (4, 0.18341570565623283),
 (32, 0.17627677100494235),
 (33, 0.1696869851729819),
 (39, 0.13289401427786932),
 (42, 0.0785282811641955),
 (56, 0.0785282811641955),
 (31, 0.06644700713893466),
 (30, 0.06534870950027458),
 (40, 0.06040637012630423),
 (5, 0.054365733113673806),
 (35, 0.04887424492037342),
 (38, 0.033498077979132346),
 (0, 0.03185063152114223),
 (24, 0.025260845689181768),
 (21, 0.009884678747940691),
 (2, 0

In [20]:
pickle.dump(top_pos, open("data/top_patterns/top_pos_model1.p","wb"))
pickle.dump(top_neg, open("data/top_patterns/top_neg_model1.p","wb"))

In [17]:
max(top_pos)

58