In [139]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList

In [140]:
from models.knowledge_grounded_generator.kg_model import KnowledgeGroundedDecoder, KG_loss

In [141]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side='left')
tokenizer.pad_token = tokenizer.eos_token

opt = {
    "num_hops": 2,
    "aggregate_method": "max",
    "embedding_size": 768,
    "alpha": 0.7,
    "beta": 0.2,
    "gamma": 0.33,
    'fixed_lm': False,
    'block_src': False,
    'gate': None
}

model = KnowledgeGroundedDecoder(opt, tokenizer)

2023-03-15 21:10:22,359 INFO     | Initialized TripleEncoder
2023-03-15 21:10:22,368 INFO     | Initialized KnowledgeGroundedDecoder


In [142]:
model

KnowledgeGroundedDecoder(
  (gpt2model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  )
  (triple_encoder):

In [143]:
model.gpt2model.generation_config
model.gpt2model.generation_config.pad_token_id = model.gpt2model.config.eos_token_id

In [144]:
s = "what should i do"
s2 = "the wheather"
enc = tokenizer([s, s2], padding=True, return_tensors='pt')

In [145]:
enc

{'input_ids': tensor([[10919,   815,  1312,   466],
        [50256,  1169,   483,  1032]]), 'attention_mask': tensor([[1, 1, 1, 1],
        [0, 1, 1, 1]])}

In [149]:
gen = model.gpt2model.generate(**enc)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [137]:
class MyCustomLogitsProcessor(LogitsProcessor):
    def __init__(self):
        self.call_counter=0
        pass

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        self.call_counter += 1
        print(self.call_counter, input_ids)
#         scores = torch.zeros_like(scores)
        scores[:, 1:] = scores[:, :-1]
#         print(scores.shape)
#         scores[0, 50] = 10
        return scores # Minimally working


In [138]:
logits_processor_list = LogitsProcessorList([
    MyCustomLogitsProcessor(),
])
gen = model.gpt2model.generate(
   **enc,
   num_beams=1, do_sample=False,
   return_dict_in_generate=False,
   output_scores=False,
   logits_processor=logits_processor_list, 
)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


1 tensor([[10919,   815,  1312,   466],
        [50256,  1169,   483,  1032]])
2 tensor([[10919,   815,  1312,   466,   352],
        [50256,  1169,   483,  1032,    83]])
3 tensor([[10919,   815,  1312,   466,   352,    14],
        [50256,  1169,   483,  1032,    83,   287]])
4 tensor([[10919,   815,  1312,   466,   352,    14,    18],
        [50256,  1169,   483,  1032,    83,   287,   263]])
5 tensor([[10919,   815,  1312,   466,   352,    14,    18,   287],
        [50256,  1169,   483,  1032,    83,   287,   263,   325]])
6 tensor([[10919,   815,  1312,   466,   352,    14,    18,   287,   263],
        [50256,  1169,   483,  1032,    83,   287,   263,   325,   322]])
7 tensor([[10919,   815,  1312,   466,   352,    14,    18,   287,   263,   325],
        [50256,  1169,   483,  1032,    83,   287,   263,   325,   322,    12]])
8 tensor([[10919,   815,  1312,   466,   352,    14,    18,   287,   263,   325,
            31],
        [50256,  1169,   483,  1032,    83,   287,   26

In [147]:
gen

tensor([[10919,   815,  1312,   466,   351,   428,    30,   198,   198,    40,
          1101,   407,  1654,   644,   284,   466,   351,   428,    13,   198],
        [50256,  1169,   483,  1032,    82,    11,   290,   262,   584,   734,
           389,   262,  3392,   326,   389,   407,    13,   198,   198,   464]])

In [136]:
tokenizer.batch_decode(gen)

['what should i do 1/3 inerse@t/u000000',
 '<|endoftext|>the wheathert inerseow-u.\x0b/j00000']

In [124]:
gen

tensor([[10919,   815,  1312,   466,   351,   428,    30,   198,   198,    40,
          1101,   407,  1654,   644,   284,   466,   351,   428,    13,   198],
        [50256,  1169,   483,  1032,    82,    11,   290,   262,   584,   734,
           389,   262,  3392,   326,   389,   407,    13,   198,   198,   464]])