In [1]:
import os
import argparse
import json

import torch
import pytorch_lightning as pl
import torchmetrics
import transformers
import sys

sys.path.append("..")

from utils import (
    GenerativeCollator,
    RetrievalCollator,
)
from models import BERT_RetrievalModel, GPT_GenerativeModel


  from .autonotebook import tqdm as notebook_tqdm


In [72]:
class BiEncoder_GPT:
    def __init__(
        self,
        retrieval_model,
        generative_model,
    ):
        self.retrieval_model = retrieval_model
        self.generative_model = generative_model

    def calculate_candidats(self, candidats_texts: list[str]) -> torch.Tensor:
        """расчитывает векторную базу кандидатов

        Args:
            candidats_texts (list[str]): список фактов о персоне
        Returns:
            torch.Tensor: вектора фактов о персоне
        """
        candidats_tokens = self.retrieval_model.collator.CandidateCollator(
            candidats_texts
        )
        candidats_vec = self.retrieval_model.encode_candidats(candidats_tokens)
        return candidats_vec

    def retrieve_gk(
        self,
        context_texts: list[str],
        candidats_texts,
        candidats_vecs: list[torch.Tensor],
    ) -> list[str]:
        """находит релевантные контексту кандидатов

        Args:
            context_texts (str): _description_
            candidats_vec (torch.Tensor): _description_

        Returns:
            list[str]: список кандидатов
        """
        context_texts = context_texts[-1][1]
        context_tokens = self.retrieval_model.collator.ContextCollator(
            [[context_texts]]
        )
        context_vec = self.retrieval_model.encode_context(context_tokens)
        candidats_vecs = torch.tensor(candidats_vecs)
        context_vec = context_vec.repeat(candidats_vecs.size()[0], 1)
        distances = self.retrieval_model.compute_sim(context_vec, candidats_vecs)[
            0
        ].tolist()
        all_candidats = sorted(
            list(zip(distances, candidats_texts)), key=lambda x: x[0], reverse=True
        )
        candidats = [(d, c) for d, c in all_candidats[:3] if d > 1]
        return candidats, all_candidats

    def generate_reply(self, context_texts, gks):
        # TODO: расширить регулярки
        context_texts = [i[1] for i in context_texts]
        gks = [i[1] for i in gks]
        dict_inp = [{"context": context_texts, "gk": gks, "candidate": ""}]
        gpt_inp = self.generative_model.collator.test(dict_inp)[0]["input_ids"][:, :-2]
        gpt_out = self.generative_model.GPT.generate(
            gpt_inp,
            max_new_tokens=32,
        )
        gpt_out = self.generative_model.tokenizer.decode(
            gpt_out[0][-32:], skip_special_tokens=False
        )
        gpt_out = gpt_out.split("[Gk]")
        msg = (
            gpt_out[-1].split("[P2u]")[1].split("[P1u]")[0].replace("<|endoftext|>", "")
        )
        new_gks = gpt_out[:-1]
        return msg, new_gks



In [None]:
bi_encoder = BERT_RetrievalModel.load_from_checkpoint('/home/stc/persona/logs/bi_encoder/36037371cee4404b80aa618268a2e24c/checkpoints/epoch=29-step=22080.ckpt')
bi_encoder.eval()

In [None]:
generative = GPT_GenerativeModel.load_from_checkpoint('/home/stc/persona/logs/gpt-epoch=00-val_loss=3.62.ckpt')
generative.eval()

In [73]:
full_model = BiEncoder_GPT(
    retrieval_model=bi_encoder,
    generative_model=generative,
)

In [74]:
candidats_texts=['я маша', 'я работаю садовником', 'я люблю кофе', 'я люблю собак', 'я люблю кошек', 'я бегаю по утрам']
context_texts = [('user', 'привет'), ('model', 'привет'), ('user', 'ты занимаешься спортом?')]
candidats_vecs = full_model.calculate_candidats(candidats_texts)
gks = full_model.retrieve_gk(context_texts, candidats_texts, candidats_vecs)[0]
gks

  candidats_vecs = torch.tensor(candidats_vecs)


[]

In [75]:
full_model.generate_reply(context_texts, gks)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[P2u]не[P1u]я хожу в фитнес-клуб[P2u]это хорошо[P1u]а ты?[P2u]да, люблю лыжи[P1u]круто[P2u]ты любишь готовить?


('не', [])