In [1]:
# !pip install torch
# !pip install transformers
# !pip install SentencePiece
# !pip install accelerate

In [1]:
from string import Template
from sentencepiece import SentencePieceProcessor
from logging import getLogger
from typing import List
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


logger = getLogger()


class Tokenizer:
    def __init__(self, model_path: str):
        # reload tokenizer
        assert os.path.isfile(model_path), model_path
        self.sp_model = SentencePieceProcessor(model_file=model_path)
        logger.info(f"Reloaded SentencePiece model from {model_path}")

        # BOS / EOS token IDs
        self.n_words: int = self.sp_model.vocab_size()
        self.bos_id: int = self.sp_model.bos_id()
        self.eos_id: int = self.sp_model.eos_id()
        self.pad_id: int = self.sp_model.pad_id()
        logger.info(
            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
        )
        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
        assert type(s) is str
        t = self.sp_model.encode(s)
        if bos:
            t = [self.bos_id] + t
        if eos:
            t = t + [self.eos_id]
        return t

    def decode(self, t: List[int]) -> str:
        return self.sp_model.decode(t)

In [2]:
# from transformers import AutoTokenizer, AutoModelForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("eachadea/vicuna-7b-1.1")

# model = AutoModelForCausalLM.from_pretrained("eachadea/vicuna-7b-1.1")

In [2]:
tokenizer = Tokenizer('./LLaMA_tokenizer.model')

model = AutoModelForCausalLM.from_pretrained("eachadea/vicuna-7b-1.1", device_map='auto')

Downloading (…)lve/main/config.json:   0%|          | 0.00/554 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00002.bin:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [35]:
def gen_answer(model, tokenizer, prompt, subs):
    input_text = prompt.substitute(**subs)
    
    conv = Conversation(
                system="A chat between a curious user and an artificial intelligence assistant. "
                "The assistant gives helpful, detailed, and polite answers to the user's questions.",
                roles=("USER", "ASSISTANT"),
                messages=[],
                offset=0,
                sep_style=SeparatorStyle.TWO,
                sep=" ",
                sep2="</s>",
            )
    
    conv.append_message("USER", input_text)
    conv.append_message("ASSISTANT", None)
    
    input_text = conv.get_prompt()
    
    tokenized_input_text = torch.tensor([tokenizer.encode(input_text, False, False)]).to(0)
    
    generated_indices = model.generate(input_ids=tokenized_input_text, max_new_tokens=512, 
                                       temperature=.7, do_sample=True).detach().cpu()
    facts = tokenizer.decode(generated_indices[0].tolist())
    
    prompt_result = facts[len(input_text):]
    
    return prompt_result

In [36]:
def gen_speakers_facts(**kwargs):
    speakers = kwargs['speakers']
    dialog = kwargs['dialog']
    prompt = kwargs['prompt']

    speakers_facts = dict()
    for speaker in speakers:
        speakers_facts[speaker] = gen_answer(model, tokenizer, prompt, 
                                             {'speaker': speaker, 'dialog': dialog})
    return speakers_facts

In [37]:
def gen_summarization(**kwargs):
    dialog = kwargs['dialog']
    prompt = kwargs['prompt']

    summarization = {'Summarization': gen_answer(model, tokenizer, prompt, {'dialog': dialog})}
    return summarization

In [38]:
prompts = {
    'facts': {
        'prompt':
        Template("""$dialog\nPlease, write only relevant facts about $speaker in number list."""),
        'method': gen_speakers_facts
    },
    'summarization': {
        'prompt':
        Template('$dialog\nProvide a summary of the exchange between bot_0 and bot_1.'),
        'method': gen_summarization
    }
}

In [39]:
import json

with open('./sample_data_summ_sess.json', 'r') as f:
    dialogs = json.loads(f.read())

speakers = ['bot_0', 'bot_1']

In [40]:
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any


class SeparatorStyle(Enum):
    """Different separator style."""

    SINGLE = auto()
    TWO = auto()
    DOLLY = auto()
    OASST_PYTHIA = auto()


@dataclasses.dataclass
class Conversation:
    """A class that keeps all conversation history."""

    system: str
    roles: List[str]
    messages: List[List[str]]
    offset: int
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
    sep: str = "###"
    sep2: str = None

    # Used for gradio server
    skip_next: bool = False
    conv_id: Any = None

    def get_prompt(self):
        if self.sep_style == SeparatorStyle.SINGLE:
            ret = self.system
            for role, message in self.messages:
                if message:
                    ret += self.sep + " " + role + ": " + message
                else:
                    ret += self.sep + " " + role + ":"
            return ret
        elif self.sep_style == SeparatorStyle.TWO:
            seps = [self.sep, self.sep2]
            ret = self.system + seps[0]
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + ": " + message + seps[i % 2]
                else:
                    ret += role + ":"
            return ret
        elif self.sep_style == SeparatorStyle.DOLLY:
            seps = [self.sep, self.sep2]
            ret = self.system
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + ":\n" + message + seps[i % 2]
                    if i % 2 == 1:
                        ret += "\n\n"
                else:
                    ret += role + ":\n"
            return ret
        elif self.sep_style == SeparatorStyle.OASST_PYTHIA:
            ret = self.system
            for role, message in self.messages:
                if message:
                    ret += role + message + self.sep
                else:
                    ret += role
            return ret
        else:
            raise ValueError(f"Invalid style: {self.sep_style}")

    def append_message(self, role, message):
        self.messages.append([role, message])

    def to_gradio_chatbot(self):
        ret = []
        for i, (role, msg) in enumerate(self.messages[self.offset :]):
            if i % 2 == 0:
                ret.append([msg, None])
            else:
                ret[-1][-1] = msg
        return ret

    def copy(self):
        return Conversation(
            system=self.system,
            roles=self.roles,
            messages=[[x, y] for x, y in self.messages],
            offset=self.offset,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2,
            conv_id=self.conv_id,
        )

    def dict(self):
        return {
            "system": self.system,
            "roles": self.roles,
            "messages": self.messages,
            "offset": self.offset,
            "sep": self.sep,
            "sep2": self.sep2,
            "conv_id": self.conv_id,
        }

In [41]:
torch.cuda.empty_cache()

In [43]:
from tqdm import tqdm
from os.path import exists

transform_txt_dialog = lambda x: x['id'] + ': ' + x['text']

stages = ['facts', 'summarization']

results = []

OUTPUT_PATH = f'./sess_outputs.json'
output_dialogs = list()
if exists(OUTPUT_PATH):
    with open(OUTPUT_PATH, 'r') as f:
        output_dialogs = json.loads(f.read())

start_position = len(output_dialogs)
print(start_position)

55


In [None]:
for dialog in tqdm(dialogs[start_position:start_position+55]):
    for stage in stages:
        prompt = prompts[stage]['prompt']
        method = prompts[stage]['method']
        for session, sess_dialog in dialog.items():
            prompt_result = method(model=model, tokenizer=tokenizer, prompt=prompt, 
                                dialog='\n'.join(map(transform_txt_dialog, sess_dialog['dialog'])), 
                                speakers=speakers)
            sess_dialog[stage] = prompt_result
    
    output_dialogs.append(dialog)
    with open(OUTPUT_PATH, 'w') as f:
        json.dump(output_dialogs, f)

  5%|▌         | 3/55 [09:17<2:46:39, 192.29s/it]

In [15]:
dialog['session_1']['facts']

{'bot_0': ' * bot\\_0: mentioned enjoying the Archey Center and loving to go there every hour and kind of\n* bot\\_0: mentioned having a favorite singer\n* bot\\_0: asked if bot\\_1 likes to ride bikes\n* bot\\_1: responded that they do not like to ride bikes because it is too dangerous\n* bot\\_0: mentioned getting a new job\n* bot\\_0: mentioned cleaning gutters for a living\n* bot\\_1: mentioned sewing as a hobby to keep the voices at bay',
 'bot_1': ' * bot\\_1 does not like to ride bikes\n* bot\\_1 is a male\n* bot\\_1 works as a waiter\n* bot\\_1 has a hobby of sewing which helps to keep the voices at bay.'}

In [33]:
500 // 9

55

In [31]:
55 * 9

495

In [25]:
20 / 10

2.0

In [21]:
for s, d in dialog.items():
    print(s)
    print('\n'.join(map(transform_txt_dialog, d['dialog'])))
    print('\nfacts bot_0:')
    print(d['facts']['bot_0'])
    print('\nfacts bot_1:')
    print(d['facts']['bot_1'])
    print('\n', d['summarization']['Summarization'])
    print('-'*10)

session_1
bot_0: Hey there how are you at this?
bot_1: Well, hey yourself! I can handle myself. Are you shy?
bot_0: You like the archey center? I love it
bot_1: Heck, I never even heard of such a thing. Sounds sporty. Is it fun?
bot_0: I love to go every hour an kind of, got a fav singer?
bot_1: Not unless my cat counts as a singer. Who is yours?
bot_0: Do you like to ride bikes?
bot_1: Oh no. Too dangerous. My momma was a nurse and she warned me about that.
bot_0: Really? I just got a new job
bot_1: No kidding. I wait tables. What is it that you do?
bot_0: I'll be cleaning gutters outside
bot_1: Well, that's good honest work right there. What do you do for fun?
bot_0: I like to ride my bike when I have time
bot_1: Oh, right. You mentioned bikes earlier. I just sew. It keeps the voices at bay.

facts bot_0:
 * bot\_0: mentioned enjoying the Archey Center and loving to go there every hour and kind of
* bot\_0: mentioned having a favorite singer
* bot\_0: asked if bot\_1 likes to ride bi