In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import logging
from transformers import pipeline
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from Timer import Timer

In [2]:
BATCH_SIZE = 5
EPOCHS = 1

In [3]:
logging.set_verbosity_error()

In [4]:
#model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased');

In [5]:
from LamaTRExData import LamaTRExData
from LamaPrompts import LamaPrompts
from PromptModel import PromptModel, BERT_VOCAB_SIZE

In [6]:
train_data = LamaTRExData()
train_data.load()
train_dataloader = DataLoader(train_data.data_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [7]:
test_data = LamaTRExData(train=False)
test_data.load()

In [8]:
lama_prompts = LamaPrompts()
lama_prompts.load()

In [9]:
prompts = [prmpt for prmpt, _ in lama_prompts.generate_for_rel("mine","P1001", "Eibenstock")]

In [10]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased")

In [11]:
unmasker = pipeline('fill-mask', model='bert-base-uncased', top_k=2, tokenizer=('bert-base-cased', {"use_fast": True}))

In [12]:
mined_prompt_model = PromptModel(unmasker, "mine", lama_prompts)
paraphrased_prompt_model = PromptModel(unmasker, "paraphrase", lama_prompts)

In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(mined_prompt_model.parameters(), lr=0.01, momentum=0.9)

In [None]:
for epoch in range(EPOCHS):
    for batch in train_dataloader: #tqdm(train_dataloader):
        (subjs, rels), labels = batch
        label_ids = torch.tensor(tokenizer.convert_tokens_to_ids(labels))
        output = mined_prompt_model(subjs, rels)
        loss = criterion(output, label_ids)
        loss.backward()
        optimizer.step()
    print(loss)

start with 2.725754976272583 seconds
prompts with 0.0008020401000976562 seconds
pipline with 4.3178322315216064 seconds
preprocessing with 0.005151033401489258 seconds
start with 0.0033767223358154297 seconds
prompts with 0.0002791881561279297 seconds
pipline with 4.267877817153931 seconds
preprocessing with 0.004385948181152344 seconds
start with 0.0029850006103515625 seconds
prompts with 0.00017309188842773438 seconds
pipline with 4.505741119384766 seconds
preprocessing with 0.003969907760620117 seconds
start with 0.0031290054321289062 seconds
prompts with 0.0002391338348388672 seconds
pipline with 4.7076380252838135 seconds
preprocessing with 0.0038487911224365234 seconds
start with 0.003874063491821289 seconds
prompts with 0.0004329681396484375 seconds
pipline with 3.7717628479003906 seconds
preprocessing with 0.0031092166900634766 seconds
start with 0.0025358200073242188 seconds
prompts with 0.0002512931823730469 seconds
pipline with 4.595362663269043 seconds
preprocessing with 0.

text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)

In [None]:
%%time
output = unmasker(prompts, batch_size=8)

In [None]:
output

In [None]:
tokenizer.convert_tokens_to_ids(['england', "germany"])

In [None]:
tokenizer.decode(2762)