In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import logging
from transformers import pipeline
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from Timer import Timer

In [2]:
BATCH_SIZE = 5
EPOCHS = 1

In [3]:
logging.set_verbosity_error()

In [4]:
#model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased');

In [5]:
from LamaTRExData import LamaTRExData
from LamaPrompts import LamaPrompts
from PromptModel import PromptModel, BERT_VOCAB_SIZE

In [6]:
train_data = LamaTRExData()
train_data.load()
train_dataloader = DataLoader(train_data.data_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [7]:
test_data = LamaTRExData(train=False)
test_data.load()

In [8]:
lama_prompts = LamaPrompts()
lama_prompts.load()

In [9]:
prompts = [prmpt for prmpt, _ in lama_prompts.generate_for_rel("mine","P1001", "Eibenstock")]

In [10]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased")

In [11]:
unmasker = pipeline('fill-mask', model='bert-base-uncased', top_k=2, tokenizer=('bert-base-cased', {"use_fast": True}))

In [12]:
%%time
unmasker(prompts)

CPU times: user 1.02 s, sys: 281 ms, total: 1.3 s
Wall time: 974 ms


[[{'score': 0.017790066078305244,
   'token': 1012,
   'token_str': '夫',
   'sequence': 'Eibenstock in the australian 夫.'},
  {'score': 0.006468340754508972,
   'token': 2605,
   'token_str': '##ni',
   'sequence': 'Eibenstock in the australianni.'}],
 [{'score': 0.030957939103245735,
   'token': 2015,
   'token_str': 'Society',
   'sequence': 'Eibenstock of morbihan association insee Society.'},
  {'score': 0.019493697211146355,
   'token': 1012,
   'token_str': '夫',
   'sequence': 'Eibenstock of morbihan association insee 夫.'}],
 [{'score': 0.06836647540330887,
   'token': 100,
   'token_str': '[UNK]',
   'sequence': 'Eibenstock in the australian state of.'},
  {'score': 0.06667706370353699,
   'token': 2078,
   'token_str': 'official',
   'sequence': 'Eibenstock in the australian state of official.'}],
 [{'score': 0.25777125358581543,
   'token': 1037,
   'token_str': '月',
   'sequence': '月 politician and Eibenstock.'},
  {'score': 0.06471894681453705,
   'token': 1998,
   'token_st

In [13]:
mined_prompt_model = PromptModel(unmasker, "mine", lama_prompts)
paraphrased_prompt_model = PromptModel(unmasker, "paraphrase", lama_prompts)

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(mined_prompt_model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(EPOCHS):
    for batch in train_dataloader: #tqdm(train_dataloader):
        (subjs, rels), labels = batch
        label_ids = torch.tensor(tokenizer.convert_tokens_to_ids(labels))
        output = mined_prompt_model(subjs, rels)
        loss = criterion(output, label_ids)
        loss.backward()
        optimizer.step()
    print(loss)

text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)

In [15]:
%%time
output = unmasker(prompts, batch_size=8)

CPU times: user 452 ms, sys: 110 ms, total: 562 ms
Wall time: 250 ms


In [16]:
output

[[{'score': 0.017790047451853752,
   'token': 1012,
   'token_str': '夫',
   'sequence': 'Eibenstock in the australian 夫.'},
  {'score': 0.006468352396041155,
   'token': 2605,
   'token_str': '##ni',
   'sequence': 'Eibenstock in the australianni.'}],
 [{'score': 0.030957939103245735,
   'token': 2015,
   'token_str': 'Society',
   'sequence': 'Eibenstock of morbihan association insee Society.'},
  {'score': 0.019493697211146355,
   'token': 1012,
   'token_str': '夫',
   'sequence': 'Eibenstock of morbihan association insee 夫.'}],
 [{'score': 0.06836634874343872,
   'token': 100,
   'token_str': '[UNK]',
   'sequence': 'Eibenstock in the australian state of.'},
  {'score': 0.06667643040418625,
   'token': 2078,
   'token_str': 'official',
   'sequence': 'Eibenstock in the australian state of official.'}],
 [{'score': 0.2577711045742035,
   'token': 1037,
   'token_str': '月',
   'sequence': '月 politician and Eibenstock.'},
  {'score': 0.06471903622150421,
   'token': 1998,
   'token_str

In [17]:
tokenizer.convert_tokens_to_ids(['england', "germany"])

[2563, 2762]

In [18]:
tokenizer.decode(2762)

'g e r m a n y'

In [19]:
mined_prompt_model.load()

AttributeError: 'PromptModel' object has no attribute 'load'