In [1]:
import random, gc, os, pickle, csv, time, re
import torch
from torch.utils import data
import numpy as np
import datasets.utils
from datasets.lifelong_fewrel_dataset import LifelongFewRelDataset
from models.rel_baseline import Baseline
import models.utils
from tqdm import tqdm

# Constants

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
BASE_MODEL_PATH = "/data/model_runs/original_oml/"
model_name = "Baseline-order5-id4-2023-05-28_10-52-41.172695.pt"
model_path = os.path.join(BASE_MODEL_PATH, model_name)

data_dir = '/data/omler_data/LifelongFewRel'
relation_file = os.path.join(data_dir, 'relation_name.txt')
validation_file = os.path.join(data_dir, 'val_data.txt')
relation_names = datasets.utils.read_relations(relation_file)
val_data = datasets.utils.read_rel_data(validation_file)
val_dataset = LifelongFewRelDataset(val_data, relation_names)

In [4]:
relation_names

['fill',
 ['place', 'served', 'by', 'transport', 'hub'],
 ['mountain', 'range'],
 ['religion'],
 ['participating', 'team'],
 ['contains', 'administrative', 'territorial', 'entity'],
 ['head', 'of', 'government'],
 ['country', 'of', 'citizenship'],
 ['original', 'network'],
 ['heritage', 'designation'],
 ['performer'],
 ['participant', 'of'],
 ['position', 'held'],
 ['has', 'part'],
 ['location', 'of', 'formation'],
 ['located', 'on', 'terrain', 'feature'],
 ['architect'],
 ['country', 'of', 'origin'],
 ['publisher'],
 ['director'],
 ['father'],
 ['developer'],
 ['military', 'branch'],
 ['mouth', 'of', 'the', 'watercourse'],
 ['nominated', 'for'],
 ['movement'],
 ['successful', 'candidate'],
 ['followed', 'by'],
 ['manufacturer'],
 ['instance', 'of'],
 ['after', 'a', 'work', 'by'],
 ['member', 'of', 'political', 'party'],
 ['licensed', 'to', 'broadcast', 'to'],
 ['headquarters', 'location'],
 ['sibling'],
 ['instrument'],
 ['country'],
 ['occupation'],
 ['residence'],
 ['work', 'locatio

# Model

In [5]:
args = {
    'n_epochs': 1, 
    'lr': 0.0004, 
    'inner_lr': 0.001, 
    'meta_lr': 3e-05, 
    'model': 'roberta', 
    'learner': 'sequential', 
    'mini_batch_size': 4, 
    'updates': 5, 
    'write_prob': 1.0, 
    'max_length': 64, 
    'seed': 42, 
    'replay_rate': 0.01, 
    'order': 5, 
    'num_clusters': 10, 
    'replay_every': 1600, 
    'model_dir': '/data/model_runs/original_oml'
}
mini_batch_size = args["mini_batch_size"]
model_name = args['model']

In [6]:
torch.manual_seed(args["seed"])
random.seed(args["seed"])
np.random.seed(args["seed"])

In [7]:
learner = Baseline(device=device, training_mode='sequential', **args)
print('Using {} as learner'.format(learner.__class__.__name__))
learner.load_model(model_path)

2023-05-29 08:42:58,036 - transformers.tokenization_utils_base - INFO - loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json from cache at /root/.cache/torch/transformers/d0c5776499adc1ded22493fae699da0971c1ee4c2587111707a4d177d20257a2.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b
2023-05-29 08:42:58,039 - transformers.tokenization_utils_base - INFO - loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt from cache at /root/.cache/torch/transformers/b35e7cd126cd4229a746b5d5c29a749e8e84438b14bcdb575950584fe33207e8.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda
2023-05-29 08:42:59,206 - transformers.configuration_utils - INFO - loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json from cache at /root/.cache/torch/transformers/e1a2a406b5a05063c31f4dfdee7608986ba7c6393f7f79db5e69dcd197208534.117c81977c5979de8c088352e74ec6e70f5c66096c2

Using Baseline as learner


# Testing

In [8]:
dataloader = data.DataLoader(val_dataset, batch_size=mini_batch_size, shuffle=False,
                                  collate_fn=datasets.utils.rel_encode)

In [9]:
all_losses, all_predictions, all_labels = [], [], []
learner.model.eval()
_iter = 0
for text, label, candidates in tqdm(dataloader):
    replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(text, label, candidates)

    add_prefix_space = False
    if model_name == "roberta":
        replicated_text = [ " ".join(_t) for _t in replicated_text ]
        replicated_relations = [ " ".join(_t) for _t in replicated_relations ]
        add_prefix_space = True

    with torch.no_grad():
        input_dict = learner.model.encode_text(list(zip(replicated_text, replicated_relations)),add_prefix_space)
        output = learner.model(input_dict)

    pred, true_labels = models.utils.make_rel_prediction(output, ranking_label)
    all_predictions.extend(pred.tolist())
    all_labels.extend(true_labels.tolist())


    _iter += 1
    if _iter % 1000 == 0:
        acc = models.utils.calculate_accuracy(all_predictions, all_labels)
        print('Epoch {} metrics: Loss = {:.4f}, accuracy = {:.4f}'.format(1, np.mean(all_losses), acc))
        print("replicated_text")
        for text, relation in zip(replicated_text[:5], replicated_relations[:5]):
            print(text + relation)
        print("tokenized text")
        for ids in input_dict['input_ids'][:5]:
            print(" ".join(learner.model.tokenizer.convert_ids_to_tokens([_id for _id in ids if _id != 1])))
        print("RankingLabel: ", ranking_label)
        print("PRED: ", pred.tolist())
        print("ANS: ", true_labels.tolist())

acc = models.utils.calculate_accuracy(all_predictions, all_labels)
acc

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
 36%|███████████████████████▎                                         | 1002/2800 [01:10<02:13, 13.50it/s]

Epoch 1 metrics: Loss = nan, accuracy = 1.0000
replicated_text
after union with greece , thessaly became divided into four prefectures : larissa prefecture , magnesia prefecture , karditsa prefecture , and trikala prefecture .country
after union with greece , thessaly became divided into four prefectures : larissa prefecture , magnesia prefecture , karditsa prefecture , and trikala prefecture .performer
after union with greece , thessaly became divided into four prefectures : larissa prefecture , magnesia prefecture , karditsa prefecture , and trikala prefecture .developer
after union with greece , thessaly became divided into four prefectures : larissa prefecture , magnesia prefecture , karditsa prefecture , and trikala prefecture .field of work
after union with greece , thessaly became divided into four prefectures : larissa prefecture , magnesia prefecture , karditsa prefecture , and trikala prefecture .place served by transport hub
tokenized text
<s> Ġafter Ġunion Ġwith Ġg ree ce Ġ

 72%|██████████████████████████████████████████████▍                  | 2002/2800 [02:19<00:57, 13.88it/s]

Epoch 1 metrics: Loss = nan, accuracy = 1.0000
replicated_text
it was soon expanded with the addition of the lord privy seal , arthur greenwood , and the chancellor of the exchequer , hugh dalton .position held
it was soon expanded with the addition of the lord privy seal , arthur greenwood , and the chancellor of the exchequer , hugh dalton .said to be the same as
it was soon expanded with the addition of the lord privy seal , arthur greenwood , and the chancellor of the exchequer , hugh dalton .developer
it was soon expanded with the addition of the lord privy seal , arthur greenwood , and the chancellor of the exchequer , hugh dalton .sports season of league or competition
it was soon expanded with the addition of the lord privy seal , arthur greenwood , and the chancellor of the exchequer , hugh dalton .member of
tokenized text
<s> Ġit Ġwas Ġsoon Ġexpanded Ġwith Ġthe Ġaddition Ġof Ġthe Ġlord Ġpriv y Ġseal Ġ, Ġar thur Ġgreen wood Ġ, Ġand Ġthe Ġchancellor Ġof Ġthe Ġex che quer Ġ, Ġh 

100%|█████████████████████████████████████████████████████████████████| 2800/2800 [03:15<00:00, 14.35it/s]


1.0

In [23]:
text, label, candidates = next(dataloader.__iter__())

In [32]:
replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(text, label, candidates)
replicated_text = replicated_text[:11]
replicated_relations = replicated_relations[:11]
ranking_label = ranking_label[:11]
replicated_text

[['in',
  'august',
  '2016',
  ',',
  'duplass',
  'brothers',
  'announced',
  'another',
  'television',
  'project',
  '"',
  'room',
  '104',
  '"',
  'to',
  'air',
  'on',
  'hbo',
  'in',
  '2017',
  '.'],
 ['in',
  'august',
  '2016',
  ',',
  'duplass',
  'brothers',
  'announced',
  'another',
  'television',
  'project',
  '"',
  'room',
  '104',
  '"',
  'to',
  'air',
  'on',
  'hbo',
  'in',
  '2017',
  '.'],
 ['in',
  'august',
  '2016',
  ',',
  'duplass',
  'brothers',
  'announced',
  'another',
  'television',
  'project',
  '"',
  'room',
  '104',
  '"',
  'to',
  'air',
  'on',
  'hbo',
  'in',
  '2017',
  '.'],
 ['in',
  'august',
  '2016',
  ',',
  'duplass',
  'brothers',
  'announced',
  'another',
  'television',
  'project',
  '"',
  'room',
  '104',
  '"',
  'to',
  'air',
  'on',
  'hbo',
  'in',
  '2017',
  '.'],
 ['in',
  'august',
  '2016',
  ',',
  'duplass',
  'brothers',
  'announced',
  'another',
  'television',
  'project',
  '"',
  'room',
  '104

In [33]:
with torch.no_grad():
    input_dict = learner.model.encode_text(list(zip(replicated_text, replicated_relations)),add_prefix_space)
    output = learner.model(input_dict)

In [37]:
input_dict['input_ids'].size()

torch.Size([11, 64])

In [38]:
output.size()

torch.Size([11, 1])

In [39]:
output

tensor([[-2.3018],
        [-2.3018],
        [-2.3018],
        [-2.3018],
        [-2.3018],
        [-2.3018],
        [-2.3018],
        [-2.3018],
        [-2.3018],
        [-2.3018],
        [-2.3018]], device='cuda:0')

In [40]:
pred, true_labels = models.utils.make_rel_prediction(output, ranking_label)
print(ranking_label)
print(pred)
print(true_labels)

[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
tensor([0])
tensor([0])
