In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, PretrainedConfig, GenerationConfig

In [None]:
from models.knowledge_grounded_generator.kg_model import KnowledgeGroundedDecoder, KG_loss
from models.knowledge_grounded_generator.kg_agent import KG_enriched_MSC_Session

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side='left')
tokenizer.pad_token = tokenizer.eos_token

opt = {
    "num_hops": 2,
    "aggregate_method": "max",
    "alpha": 0.7,
    "beta": 0.2,
    "gamma": 0.33,
    'fixed_lm': False,
    'block_src': False,
    'gate': 0.0 # Gate=0.0 means output should be equal to regular GPT2 output
}

model = KnowledgeGroundedDecoder(opt, tokenizer, config=PretrainedConfig())


In [None]:
model.gpt2model.config

In [None]:
opt_dataset = {
    'kg_datadir': '/users/FrankVerhoef/Programming/Project_AI/ParlAI/data/kg_data/', 
    'dataset_concepts': 'total_concepts.txt', 
    'kg': 'kg.graph-sm', 
    "speaker_prefixes": None,
    "include_persona": False,
    "max_concepts": 256,
    "max_triples": 768,
    "max_branch": 64,
    "overlapping_concepts": "excl-src-in-tgt",
    "num_hops": 2,
}

datapath = '/Users/FrankVerhoef/Programming/PEX/data/msc/msc_dialogue/session_2/train.txt'
dataset = KG_enriched_MSC_Session(
    opt_dataset, 
    datapath, 
    tokenizer, 
    max_samples=None, 
    batch_format="huggingface", 
    batch_pad_id=tokenizer.pad_token_id
)

In [None]:
class Mini_dataset:

    def __init__(self):
        self.data = [
            {
                "text": "I like my mother and sister. It is good to be with them.", 
                "labels": ["Your family is important since birth"],
            }, {
                "text": "Shall we play soccer?", 
                "labels": ["It is fun and a great sport to play as a team"],
            }, {
                "text": "The dinner was great, but now I want to go home.", 
                "labels": ["Yes, the food was delicious"],
            }
        ]
    def __getitem__(self, i):
        return self.data[i]['text'], self.data[i]['labels']
    def __len__(self):
        return len(self.data)

testdata = Mini_dataset()
enriched = [(*testdata[i], dataset._get_kg_info(*testdata[i])) for i in range(len(testdata))]
enriched

In [None]:
tokenizer(text=[testdata[i][0] for i in range(len(testdata))], padding=True, return_tensors='pt')

In [None]:
batch = dataset.batchify(enriched)
inputs, labels, kg_input = batch
L = inputs.input_ids.shape[1]
input_ids = inputs.input_ids

In [None]:
output = model.generate(
    inputs=input_ids,
    kg_input=kg_input,
    generation_config=GenerationConfig(
        pad_token_id=model.gpt2model.config.eos_token_id,
        output_hidden_states=True,
        use_cache=True,
        num_beams=1,
        do_sample=False,
        max_new_tokens=10
    )
)
for context, out in zip(enriched, output):
    print("Context:  ", context[0])
#     print("Label:    ", context[1])
    print("Tensor:   ", out)
    print("Response: ", dataset.tokenizer.batch_decode(out))
    print("-" * 20)

In [None]:
output = model.gpt2model.generate(
    inputs=input_ids,
    generation_config=GenerationConfig(
        pad_token_id=model.gpt2model.config.eos_token_id,
        output_hidden_states=True,
        use_cache=True,
        num_beams=1,
        do_sample=False,
        max_new_tokens=10
    )
)
for context, out in zip(enriched, output):
    print("Context:  ", context[0])
#     print("Label:    ", context[1])
    print("Tensor:   ", out)
    print("Response: ", dataset.tokenizer.batch_decode(out))
    print("-" * 20)

In [None]:
output = model.forward(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    kg_input=kg_input
)
print(inputs.input_ids)
print(inputs.attention_mask)
print(output.last_hidden_state.shape)
print(output.logits.argmax(dim=-1))

In [None]:
attention_mask = inputs.attention_mask
position_ids = (torch.cumsum(attention_mask, dim=1) - 1).clip(0)
position_ids = position_ids[:, -input_ids.shape[1]:]
output = model.gpt2model.forward(
    input_ids=input_ids,
    attention_mask=attention_mask,
    position_ids=position_ids
)
print(output.logits.argmax(dim=-1))