# prepare

In [1]:
# imports
import os
import argparse
import json

import torch
import pytorch_lightning as pl
import torchmetrics
import transformers

from utils import (
    PersonaDataset,
    GenerativeCollator,
    RetrievalCollator,
    aggregate_encoder_output,
    sim_func,
)
from models import RetrievalModel, GenerativeModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# proxy
os.environ["http_proxy"] = "http://proxy.ad.speechpro.com:3128"
os.environ["https_proxy"] = "http://proxy.ad.speechpro.com:3128"
os.environ["ftp_proxy"] = "http://proxy.ad.speechpro.com:3128"

In [3]:
# config bert
parser = argparse.ArgumentParser()
bert_args = parser.parse_args("")
with open("configs/bert_config.json", "r") as config:
    opt = json.load(config)
vars(bert_args).update(opt)

# config gpt
parser = argparse.ArgumentParser()
gpt_args = parser.parse_args("")
with open("configs/gpt_config.json", "r") as config:
    opt = json.load(config)
vars(gpt_args).update(opt)

# pretrained model

In [4]:
with open(bert_args.special_tokens_dict, "r") as config:
    special_tokens_dict = json.load(config)

# bert tokenizer
bert_tokenizer = transformers.AutoTokenizer.from_pretrained(
    bert_args.pretrained_bert,
    truncation_side=bert_args.truncation_side,
    padding_side=bert_args.padding_side,
)
bert_tokenizer.add_special_tokens(special_tokens_dict)

# gpt tokenizer
gpt_tokenizer = transformers.AutoTokenizer.from_pretrained(
    gpt_args.pretrained_gpt,
    truncation_side=gpt_args.truncation_side,
    padding_side=gpt_args.padding_side,
)
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
gpt_tokenizer.add_special_tokens(special_tokens_dict)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


7

In [None]:
# retrieva model
bert_ckpt = "/home/stc/persona/logs/bi_encoder/ba9a2503126b46e2b2ec8049c669b0f1/checkpoints/epoch=29-step=22770.ckpt"
retrieval_model = RetrievalModel.load_from_checkpoint(bert_ckpt)
retrieval_model.eval()

# generative model
gpt_ckpt = "logs/gpt_answer/gpt-epoch=02-val_loss=0.61.ckpt"
generative_model = GenerativeModel.load_from_checkpoint(gpt_ckpt)
generative_model.eval()

# data

In [9]:
# bert callator
bert_callator = RetrievalCollator(
    bert_tokenizer, padding=bert_args.padding, max_length=bert_args.context_len
)

# gpt callator
gpt_callator = GenerativeCollator(
    gpt_tokenizer, padding=gpt_args.padding, max_length=gpt_args.max_len
)

Using eos_token, but it is not set yet.


# inference

In [10]:
# encode functions
def encode_persona(text_batch, encoder):
    inp_persona_tokens = bert_callator.CandidateCollator(text_batch)
    vec_batch = aggregate_encoder_output(
        encoder.candidat_BERT(**inp_persona_tokens), mod=bert_args.aggregation_mod
    )
    return vec_batch


def encode_context(text_batch, encoder):
    inp_context_tokens = bert_callator.ContextCollator([text_batch])
    print(bert_tokenizer.batch_decode(inp_context_tokens['input_ids']))
    vec_batch = aggregate_encoder_output(
        encoder.context_BERT(**inp_context_tokens), mod=bert_args.aggregation_mod
    )
    return vec_batch

In [11]:
persona = [
    "У меня любимая работа",
    "Я уважаю людей",
    "У меня есть попугай",
    "Я был в Париже",
    "Я люблю кофе",
    "У меня есть собака",
    "У меня есть кошка",
    "Я играю на гитаре",
    "Я кассир",
    "Я работаю в магазине",
]
context = []

vec_persona = encode_persona(persona, retrieval_model)
# vec_context = encode_context(context, retrieval_model)
# ranks = sim_func(vec_context, vec_persona, mod=bert_args.sim_mod)[0].tolist()
# gks = sorted(list(zip(ranks, persona)), key=lambda x: x[0], reverse=True)
# print(gks)

In [12]:
user_msg = '?'
context.append(user_msg)
for c in context:
    print('-',c)
vec_context = encode_context(context[-2:], retrieval_model)
ranks = sim_func(vec_context, vec_persona, mod=bert_args.sim_mod)[0].tolist()
gks = sorted(list(zip(ranks, persona)), key=lambda x: x[0], reverse=True)
print("знания о персоне:", gks)
gks = [gk[1] for gk in gks[:1]]

# generate
dict_inp = [{"context": context, "gk": gks, "candidate": ""}]
gpt_inp = gpt_callator.test(dict_inp)[0]["input_ids"][:, :-2]
len_gpt_inp = gpt_inp.size()[-1]
print()
print(("         (" + gpt_tokenizer.batch_decode(gpt_inp)[0] + ")"))
print()
gpt_out = generative_model.GPT.generate(
    gpt_inp,
    max_new_tokens=32,
)
gpt_answer = gpt_out
answer_raw = gpt_tokenizer.decode(gpt_answer[0], skip_special_tokens=False)
print()
print(("         (" + answer_raw + ")"))
print()

# proc answer
answer = gpt_tokenizer.decode(
    gpt_answer[0, len_gpt_inp + 1 :], skip_special_tokens=True
)
context.append(answer)
print("model:", answer)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


- а как ты относишься к собакам собакам собакам собакам собакам?
['[CLS] [P1u] а как ты относишься к собакам собакам собакам собакам собакам? [P2u] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']
знания о персоне: [(0.6285446882247925, 'У меня есть собака.'), (-2.215956926345825, 'У меня есть кошка.'), (-5.754154682159424, 'У меня любимая работа.'), (-6.454430