In [None]:
%load_ext autoreload
%autoreload 2

from models.ower import Ower
from pathlib import Path
from random import shuffle
from typing import List, Tuple

import torch
from sklearn.metrics import precision_recall_fscore_support
from torch import Tensor, tensor

from data.ower.ower_dir import OwerDir, Sample
from models.base import Base

In [None]:
ower_dir_path = '../data/ower/ower-v4-fb-irt-100-5/'
class_count = 100
sent_count = 5

batch_size = 1024
device = 'cuda'
emb_size = None
epoch_count = 20
log_dir = None
log_steps = False
lr = 0.01
mode = 'mean'
model_name = 'base'
save_dir = None
sent_len = 64
update_vectors = False
vectors = 'glove.6B.300d'

In [None]:
ower_dir = OwerDir(Path(ower_dir_path))
ower_dir.check()

train_set, valid_set, test_set, vocab = ower_dir.read_datasets(class_count, sent_count)

In [None]:
def generate_batch(batch: List[Sample]) -> Tuple[Tensor, Tensor, Tensor]:
    """
    :param batch: [Sample(ent, [class], [sent])]

    :return: ent_batch      IntTensor(batch_size),
             sents_batch    IntTensor(batch_size, sent_count, sent_len),
             classes_batch  IntTensor(batch_size, class_count)
    """

    ent_batch, classes_batch, sents_batch = zip(*batch)

    for sents in sents_batch:
        shuffle(sents)

    cropped_sents_batch = [[sent[:sent_len] for sent in sents] for sents in sents_batch]
    padded_sents_batch = [[sent + [0] * (sent_len - len(sent)) for sent in sents] for sents in cropped_sents_batch]

    return tensor(ent_batch), tensor(padded_sents_batch), tensor(classes_batch)

In [None]:
def ids_to_sent(ids: List[int]) -> str:
    return ' '.join([vocab.itos[id] for id in ids])

ent_to_lbl = ower_dir.ent_labels_txt.load()
rel_tail_freq_lbl_tuples = ower_dir.classes_tsv.load()

In [None]:
ower = Base.from_random(154289, 300, 100, 'mean')
ower.load_state_dict(torch.load('../models/model_base_0/model.pt'))
ower.eval()

In [None]:
samples = valid_set[:20]

ent_batch, sents_batch, gt_batch, = generate_batch(samples)

logits = ower(sents_batch)
pred_batch = (logits > 0).int()

prec, rec, f1, supp = precision_recall_fscore_support(gt_batch, pred_batch, average='macro')

foo_logits = ower.foo(sents_batch)

print(f'Precision = {prec:.2f}, Recall = {rec:.2f}, F1 = {f1:.2f}')
print()

for i, sample in enumerate(samples):
    print(ent_to_lbl[sample.ent])
    print()
    for sent in sample.sents:
        print('-', ids_to_sent(sent))
    print()
    print('Ground Truth:')
    print(gt_batch[i])
    print()
    print('Predicted:')
    print(pred_batch[i])

    for c in range(1, 5):
        print()
        print(f'Class {c}:', rel_tail_freq_lbl_tuples[c][3])
        for s, sent in enumerate(sample.sents):
            print('{:5.2f} '.format(foo_logits[i][s][c].item()), ids_to_sent(sent))

    print()
    print()

In [None]:
ower = Ower.from_random(154289, 300, 100, 'mean')
ower.load_state_dict(torch.load('../models/model_ower_0/model.pt'))
ower.eval()

In [None]:
samples = valid_set[:20]

ent_batch, sents_batch, gt_batch, = generate_batch(samples)

logits = ower(sents_batch)
pred_batch = (logits > 0).int()

prec, rec, f1, supp = precision_recall_fscore_support(gt_batch, pred_batch, average='macro')

foo_logits = ower.foo(sents_batch)
bar_logits = ower.bar(sents_batch)

print(f'Precision = {prec:.2f}, Recall = {rec:.2f}, F1 = {f1:.2f}')
print()

for i, sample in enumerate(samples):
    print(ent_to_lbl[sample.ent])
    print()
    for sent in sample.sents:
        print(ids_to_sent(sent))
    print()
    print('Ground Truth:')
    print(gt_batch[i])
    print()
    print('Predicted:')
    print(pred_batch[i])

    for c in range(1, 5):
        print()
        print(f'Class {c}:', rel_tail_freq_lbl_tuples[c][3])
        for s, sent in enumerate(sample.sents):
            print('{:5.2f} * {:4.2f}  {}'.format(
                foo_logits[i][s][c].item(),
                bar_logits[i][c][s].item(),
                ids_to_sent(sent)))

    print()
    print()