Skip to content

Commit

Permalink
Code compatible with transformers==3.1.0. Freezing some libraries in …
Browse files Browse the repository at this point in the history
…requirements.txt, ready for Release 0.2
  • Loading branch information
Ubuntu committed Sep 8, 2020
1 parent 83c66c4 commit 533bb59
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 15 deletions.
2 changes: 0 additions & 2 deletions coverage.py
Expand Up @@ -12,7 +12,6 @@
class KeywordExtractor():
def __init__(self, n_kws=15):
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.tokenizer.max_len = 10000
self.n_kws = n_kws

self.bert_w2i = {w: i for i, w in enumerate(self.tokenizer.vocab)}
Expand Down Expand Up @@ -70,7 +69,6 @@ def extract_keywords(self, unmasked):
class KeywordCoverage():
def __init__(self, device, keyword_model_file, model_file=None, n_kws=15):
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.tokenizer.max_len = 10000
self.vocab_size = self.tokenizer.vocab_size
self.n_kws = n_kws

Expand Down
18 changes: 9 additions & 9 deletions model_generator.py
Expand Up @@ -73,8 +73,8 @@ def train_batch(self, bodies, summaries, special_append=None, no_preinput=False)
inputs, summ_inp, summ_out = self.preprocess_batch(bodies, summaries, special_append)
past = None
if not no_preinput:
_, past = self.model(input_ids=inputs, past=None)
logits, _ = self.model(input_ids=summ_inp, past=past)
_, past = self.model(input_ids=inputs, past_key_values=None)
logits, _ = self.model(input_ids=summ_inp, past_key_values=past)
crit = torch.nn.CrossEntropyLoss(ignore_index=-1)
loss = crit(logits.view(-1, self.tokenizer.vocab_size), summ_out.contiguous().view(-1))
return loss
Expand All @@ -97,11 +97,11 @@ def decode_batch(self, bodies, special_append=None, max_output_length=100, sampl
# Sometimes, we process the same input, as we run it once as a sampled, and once as an argmax, in which case we should reuse the computation
if input_past is None:
inputs = self.preprocess_input(bodies, special_append)
_, input_past = self.model(input_ids=inputs, past=None)
_, input_past = self.model(input_ids=inputs, past_key_values=None)

past = input_past
while build_up is None or (build_up.shape[1] < max_output_length and not all([self.tokenizer.end_id in build for build in build_up])):
logits, past = self.model(input_ids=current, past=past)
logits, past = self.model(input_ids=current, past_key_values=past)
probs = torch.nn.functional.softmax(logits, dim=2).squeeze(1)
logprobs = torch.nn.functional.log_softmax(logits, dim=2)
if sample:
Expand Down Expand Up @@ -149,12 +149,12 @@ def decode_beam_batch(self, bodies, beam_size=3, max_output_length=100, sample=F
one_every_k = torch.FloatTensor([1] + [0] * (beam_size-1)).repeat(batch_size*beam_size).to(self.device)

# Sometimes, we process the same input, as we run it once as a sampled, and once as an argmax, in which case we should reuse the computation
_, input_past = self.model(input_ids=inputs, past=None)
_, input_past = self.model(input_ids=inputs, past_key_values=None)
input_past = [torch.repeat_interleave(p, repeats=beam_size, dim=1) for p in input_past]

past = input_past
while build_up is None or (build_up.shape[1] < max_output_length and not all([self.tokenizer.end_id in build for build in build_up])):
logits, past = self.model(input_ids=next_words, past=past)
logits, past = self.model(input_ids=next_words, past_key_values=past)
probs = torch.nn.functional.softmax(logits, dim=2).squeeze(1)
logprobs = torch.nn.functional.log_softmax(logits, dim=2)

Expand Down Expand Up @@ -254,7 +254,7 @@ def score(self, summaries, bodies, bodies_tokenized=None, lengths=None, extra=No
summ_out = summ_out.contiguous()

with torch.no_grad():
logits, _ = self.model(input_ids=summ_inp, past=None)
logits, _ = self.model(input_ids=summ_inp, past_key_values=None)

crit = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
loss = crit(logits.view(-1, self.tokenizer.vocab_size), summ_out.view(-1)).view(summ_out.shape)
Expand All @@ -272,8 +272,8 @@ def score_pairs(self, bodies, summaries):
inputs, summ_inp, summ_out = self.preprocess_batch(bodies, summaries)

with torch.no_grad():
_, past = self.model(input_ids=inputs, past=None)
logits, _ = self.model(input_ids=summ_inp, past=past)
_, past = self.model(input_ids=inputs, past_key_values=None)
logits, _ = self.model(input_ids=summ_inp, past_key_values=past)

crit = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
loss = crit(logits.view(-1, self.tokenizer.vocab_size), summ_out.view(-1)).view(summ_out.shape)
Expand Down
7 changes: 3 additions & 4 deletions requirements.txt
@@ -1,7 +1,6 @@
transformers==3.0.2
sklearn
nltk
transformers==3.1.0
sklearn==0.22.1
nltk==3.5
h5py
tqdm
matplotlib
sklearn

0 comments on commit 533bb59

Please sign in to comment.