In [1]:
%load_ext autoreload
%autoreload 2

from models.ower import Ower
from pathlib import Path
from random import shuffle
from typing import List, Tuple

import torch
from sklearn.metrics import precision_recall_fscore_support
from torch import Tensor, tensor

from data.ower.ower_dir import OwerDir, Sample
from models.base import Base

In [2]:
ower_dir_path = '../data/ower/ower-v4-fb-irt-100-5/'
class_count = 100
sent_count = 5

batch_size = 1024
device = 'cuda'
emb_size = None
epoch_count = 20
log_dir = None
log_steps = False
lr = 0.01
mode = 'mean'
model_name = 'base'
save_dir = None
sent_len = 64
update_vectors = False
vectors = 'glove.6B.300d'

In [3]:
ower_dir = OwerDir(Path(ower_dir_path))
ower_dir.check()

train_set, valid_set, test_set, vocab = ower_dir.read_datasets(class_count, sent_count)



In [4]:
def generate_batch(batch: List[Sample]) -> Tuple[Tensor, Tensor, Tensor]:
    """
    :param batch: [Sample(ent, [class], [sent])]

    :return: ent_batch      IntTensor(batch_size),
             sents_batch    IntTensor(batch_size, sent_count, sent_len),
             classes_batch  IntTensor(batch_size, class_count)
    """

    ent_batch, classes_batch, sents_batch = zip(*batch)

    for sents in sents_batch:
        shuffle(sents)

    cropped_sents_batch = [[sent[:sent_len] for sent in sents] for sents in sents_batch]
    padded_sents_batch = [[sent + [0] * (sent_len - len(sent)) for sent in sents] for sents in cropped_sents_batch]

    return tensor(ent_batch), tensor(padded_sents_batch), tensor(classes_batch)

In [5]:
def ids_to_sent(ids: List[int]) -> str:
    return ' '.join([vocab.itos[id] for id in ids])

ent_to_lbl = ower_dir.ent_labels_txt.load()
rel_tail_freq_lbl_tuples = ower_dir.classes_tsv.load()



In [6]:
ower = Base.from_random(154289, 300, 100, 'mean')
ower.load_state_dict(torch.load('../models/model_base_0/model.pt'))
ower.eval()

Base(
  (embedding_bag): EmbeddingBag(154289, 300, mode=mean)
  (linear): Linear(in_features=300, out_features=100, bias=True)
)

In [10]:
samples = valid_set[:20]

ent_batch, sents_batch, gt_batch, = generate_batch(samples)

logits = ower(sents_batch)
pred_batch = (logits > 0).int()

prec, rec, f1, supp = precision_recall_fscore_support(gt_batch, pred_batch, average='macro')

foo_logits = ower.foo(sents_batch)

print(f'Precision = {prec:.2f}, Recall = {rec:.2f}, F1 = {f1:.2f}')
print()

for i, sample in enumerate(samples):
    print(ent_to_lbl[sample.ent])
    print()
    for sent in sample.sents:
        print('-', ids_to_sent(sent))
    print()
    print('Ground Truth:')
    print(gt_batch[i])
    print()
    print('Predicted:')
    print(pred_batch[i])

    for c in range(1, 5):
        print()
        print(f'Class {c}:', rel_tail_freq_lbl_tuples[c][3])
        for s, sent in enumerate(sample.sents):
            print('{:5.2f} '.format(foo_logits[i][s][c].item()), ids_to_sent(sent))

    print()
    print()

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Precision = 0.12, Recall = 0.34, F1 = 0.16

A&M Records

- wright recalls in his autobiography that, despite a&m and the music press being enthusiastic about the potential of "stand for our rights" and harrison's involvement, neither the single nor the album met with any commercial success – a situation that <unk> him after the failure of <unk> the previous year.
- in 2002, gwen stefani was invited to sing and perform with the dolls and brought along interscope geffen a&m chairman jimmy iovine and then-president of a&m records; both took interest into turning it into a singing group.
- although this song entered the country charts, the album itself was not released due to the closure of <unk> nashville unit.
- at one new york show they were discovered by an a&m records talent scout, patrick clifford, and the band signed their first recording contract.
- the song gained success shortly after styx left wooden nickel records to move to a&m records in 1974 as it began picking up airplay <u

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0])

Predicted:
tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 0, 0], dtype=torch.int32)

Class 1: /people/person/gender male organism
-3.19  the won their 1st mvc tournament title to earn an automatic bid to the 1992 ncaa tournament.
-2.37  usf has reached the ncaa division i men's basketball tournament 3 times in their history (1990, 1992, and 2012).
 0.24  he was also an

 2.70  brother, the great spirit made us all is a 1974 studio album by dave brubeck accompanied by his sons <unk> chris and dan.

Class 4: /people/person/profession actor
 0.89  instead, over dave <unk> <unk> dance", statements of the kind, <unk> mallard would like to thank chris <unk> appear.
 0.20  the album was reviewed by scott <unk> at allmusic who wrote, "one of the more obscure dave brubeck albums is really a showcase for the young singer carmen mcrae who performs nine numbers ...
-0.99  in your own sweet way at <unk> retrieved on august 7, 2020 composed by dave brubeck with lyrics by <unk> brubeck, popularized by brubeck and by miles davis.
-0.14  the group was often billed as opening act for the dave brubeck quartet, and at these concerts <unk> was performing with such jazz greats as gerry mulligan and paul desmond, as well as dave <unk>
 0.90  brother, the great spirit made us all is a 1974 studio album by dave brubeck accompanied by his sons <unk> chris and dan.


party game

In [8]:
ower = Ower.from_random(154289, 300, 100, 'mean')
ower.load_state_dict(torch.load('../models/model_ower_0/model.pt'))
ower.eval()

Ower(
  (embedding_bag): EmbeddingBag(154289, 300, mode=mean)
)

In [9]:
samples = valid_set[:20]

ent_batch, sents_batch, gt_batch, = generate_batch(samples)

logits = ower(sents_batch)
pred_batch = (logits > 0).int()

prec, rec, f1, supp = precision_recall_fscore_support(gt_batch, pred_batch, average='macro')

foo_logits = ower.foo(sents_batch)
bar_logits = ower.bar(sents_batch)

print(f'Precision = {prec:.2f}, Recall = {rec:.2f}, F1 = {f1:.2f}')
print()

for i, sample in enumerate(samples):
    print(ent_to_lbl[sample.ent])
    print()
    for sent in sample.sents:
        print(ids_to_sent(sent))
    print()
    print('Ground Truth:')
    print(gt_batch[i])
    print()
    print('Predicted:')
    print(pred_batch[i])

    for c in range(1, 5):
        print()
        print(f'Class {c}:', rel_tail_freq_lbl_tuples[c][3])
        for s, sent in enumerate(sample.sents):
            print('{:5.2f} * {:4.2f}  {}'.format(
                foo_logits[i][s][c].item(),
                bar_logits[i][c][s].item(),
                ids_to_sent(sent)))

    print()
    print()

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Precision = 0.12, Recall = 0.34, F1 = 0.16

A&M Records

at one new york show they were discovered by an a&m records talent scout, patrick clifford, and the band signed their first recording contract.
although this song entered the country charts, the album itself was not released due to the closure of <unk> nashville unit.
wright recalls in his autobiography that, despite a&m and the music press being enthusiastic about the potential of "stand for our rights" and harrison's involvement, neither the single nor the album met with any commercial success – a situation that <unk> him after the failure of <unk> the previous year.
the song gained success shortly after styx left wooden nickel records to move to a&m records in 1974 as it began picking up airplay <unk> at the 2010 great jones county fair eventually peaking at #6 on the billboard hot 100 in march 1975.
in 2002, gwen stefani was invited to sing and perform with the dolls and brought along interscope geffen a&m chairman jimmy iovi


Ground Truth:
tensor([1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0])

Predicted:
tensor([1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0,
        0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1,
        0, 0, 1, 0], dtype=torch.int32)

Class 1: /people/person/gender male organism
 0.94 * 0.14  <unk> said he deliberately cast rajinikanth against type since he wanted to <unk> with his acting skills".
-1.44 * 0.12  she spends her days <unk> people in the village to <unk> her loans and is childhood friends