In [None]:
import torch
import tqdm
import json
import spacy
import numpy as np
from fairseq.models.roberta import alignment_utils

In [None]:
FEWREL_SIZE = 5000
NUM_WAY = 5

In [None]:
roberta = torch.hub.load('pytorch/fairseq', 'roberta.base')
roberta.eval() 

In [None]:
tokens = roberta.encode('Hello world!')
assert tokens.tolist() == [0, 31414, 232, 328, 2]
assert roberta.decode(tokens) == 'Hello world!'

In [None]:
def encode(example, return_enitity_embeddings, bag_of_tokens):    
    sentence = ' '.join(example['tokens'])
    if bag_of_tokens:
        assert not return_enitity_embeddings
        return roberta.extract_features(roberta.encode(sentence)).mean(1)
    doc = roberta.extract_features_aligned_to_words(sentence)
    if not return_enitity_embeddings:
        return torch.stack([doc[x].vector for x in range(len(doc))]).mean(0, keepdims=True)
    head_tokens = [y for x in example['h'][2] for y in x]
    tail_tokens = [y for x in example['t'][2] for y in x]    
    
    head_encoded = torch.stack([doc[x + 1].vector for x in head_tokens]).mean(0, keepdims=True)
    tail_encoded = torch.stack([doc[x + 1].vector for x in tail_tokens]).mean(0, keepdims=True)

    return torch.cat([head_encoded, tail_encoded], axis=1)

In [None]:
with open('/data2/urikz/fewrel/val_wiki.json') as f:
    data = json.load(f)

FEWREL_RELATIONS = list(data.keys())
print(len(data))
print(FEWREL_RELATIONS)

In [None]:
def sample_sentences(relation_idx, size=1):
    return [x for x in np.random.choice(data[relation_idx], size=2, replace=False)]

In [None]:
def encode_fn(example):
    return encode(
        example,
        return_enitity_embeddings=False,
        bag_of_tokens=True,
    )

num_correct, num_total, num_failed = 0, 0, 0
for i in range(FEWREL_SIZE):
    try:
        relations = np.random.choice(FEWREL_RELATIONS, NUM_WAY, replace=False)
        sentences_for_the_target_relation = sample_sentences(relations[0], 2)
        target_encoded = encode_fn(sentences_for_the_target_relation[0]).squeeze(0).unsqueeze(1)
        examplars = [sentences_for_the_target_relation[1]] + [sample_sentences(x, 1)[0] for x in relations[1:]]
        examplars_encoded = torch.stack([encode_fn(x) for x in examplars], axis=1).squeeze(0)
        num_correct += (torch.mm(examplars_encoded, target_encoded).argmax().item() == 0)
        num_total += 1
    except:
        num_failed += 1

    if i % 100 == 0 and num_total > 0:
        print('-- Iteration #%d: accuract %.2f%% (total %d, failed %d)' % (i + 1, 100.0 * num_correct / num_total, num_total, num_failed))

print('FINISHED: accuracy %.2f%% (for %d)' % (100.0 * num_correct / num_total, num_total))