### Import Libraries

In [1]:
from collections import OrderedDict
import torch
import tensorboardX
import dill as pickle

In [2]:
import sys
sys.path.append("sopa_master/")

In [3]:
from data import read_embeddings, read_docs, read_labels
from soft_patterns import ProbSemiring, MaxPlusSemiring, LogSpaceMaxTimesSemiring, SoftPatternClassifier, train, Batch, evaluate_accuracy
from util import to_cuda
from interpret_classification_results import interpret_documents
from visualize import visualize_patterns

### Files

In [13]:
train_data_file = "data/time_data_clean/train.data"
train_label_file ="data/time_data_clean/train.labels"
dev_data_file = "data/time_data_clean/dev.data"
dev_label_file = "data/time_data_clean/dev.labels"
test_file = "data/time_data_clean/test.data"
test_label="data/time_data_clean/test.labels"

### Model's parameters

In [5]:
patterns = "7-10_6-10_5-10_4-10_3-10_2-10"
pattern_specs = OrderedDict(sorted(([int(y) for y in x.split("-")] for x in patterns.split("_")),
                                key=lambda t: t[0]))

In [6]:
pattern_specs

OrderedDict([(2, 10), (3, 10), (4, 10), (5, 10), (6, 10), (7, 10)])

In [7]:
num_padding_tokens=max(list(pattern_specs.keys())) - 1
num_padding_tokens

6

### Loading Embeddings

In [8]:
vocab = pickle.load(open("data/embeddings/vocab.p","rb"))
embeddings = pickle.load(open("data/embeddings/embeddings.p","rb"))
word_dim = pickle.load(open("data/embeddings/word_dim.p","rb"))

### Create Model

In [17]:
model = SoftPatternClassifier(
    pattern_specs=pattern_specs,
    mlp_hidden_dim=25,
    num_mlp_layers=5,
    num_classes=2,
    embeddings=embeddings,
    vocab=vocab,
    semiring=LogSpaceMaxTimesSemiring,
    bias_scale_param=1,
    shared_sl=False,
    no_sl=False,
    no_eps=False
)

60 OrderedDict([(2, 10), (3, 10), (4, 10), (5, 10), (6, 10), (7, 10)])
# params: 256727


### Training

In [15]:
train_input, train_text = read_docs(train_data_file, vocab, num_padding_tokens=num_padding_tokens)
train_labels = read_labels(train_label_file)
dev_input, dev_text = read_docs(dev_data_file, vocab, num_padding_tokens=num_padding_tokens)
dev_labels = read_labels(dev_label_file)

In [16]:
train_data = list(zip(train_input, train_labels))
dev_data = list(zip(dev_input, dev_labels))

In [18]:
train(
    train_data=train_data,
    dev_data=dev_data,
    model=model,
    model_save_dir="data/models/modelstime/",
    num_iterations=250,
    model_file_prefix="best_sopa",
    learning_rate=0.005,
    batch_size=150,
    num_classes=2,
    patience=30
)

...................

num predicted 1s: 0
num gold 1s:      426
num predicted 1s: 0
num gold 1s:      217
iteration:       0 train time:     0.133m, eval time:     0.014m train loss:        0.685 train_acc:   57.400% dev loss:        0.675 dev_acc:   60.545%
New best acc!
New best dev!
saving model to data/models/modelstime/best_sopa_0.pth
...................

num predicted 1s: 0
num gold 1s:      420
num predicted 1s: 0
num gold 1s:      217
iteration:       1 train time:     0.276m, eval time:     0.015m train loss:        0.682 train_acc:   58.000% dev loss:        0.671 dev_acc:   60.545%
New best dev!
saving model to data/models/modelstime/best_sopa_1.pth
...................

num predicted 1s: 0
num gold 1s:      403
num predicted 1s: 0
num gold 1s:      217
iteration:       2 train time:     0.418m, eval time:     0.015m train loss:        0.681 train_acc:   59.700% dev loss:        0.671 dev_acc:   60.545%
New best dev!
saving model to data/models/modelstime/best_sopa_2.pth
.....

SoftPatternClassifier (
  (mlp): MLP (
    (layers): ModuleList (
      (0): Linear (60 -> 25)
      (1): Linear (25 -> 25)
      (2): Linear (25 -> 25)
      (3): Linear (25 -> 25)
      (4): Linear (25 -> 2)
    )
  )
)

In [19]:
torch.save(model.state_dict(), "data/models/best_sopa.pth")

### Forward

In [20]:
test_input, test_text = read_docs(test_file, vocab, num_padding_tokens=0)
labels=read_labels(test_label)

In [21]:
test_data = list(zip(test_input, labels))

In [22]:
evaluate_accuracy(model, test_data, batch_size=150, gpu=False)

num predicted 1s: 295
num gold 1s:      288


0.8500727802037845