In [5]:
# !pip install -U transformers
# !pip install -U SentencePiece
# !pip install accelerate

Collecting accelerate
  Downloading accelerate-0.18.0-py3-none-any.whl (215 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m215.3/215.3 kB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.18.0
[0m

In [1]:
from string import Template
from sentencepiece import SentencePieceProcessor
from logging import getLogger
from typing import List
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


logger = getLogger()


class Tokenizer:
    def __init__(self, model_path: str):
        # reload tokenizer
        assert os.path.isfile(model_path), model_path
        self.sp_model = SentencePieceProcessor(model_file=model_path)
        logger.info(f"Reloaded SentencePiece model from {model_path}")

        # BOS / EOS token IDs
        self.n_words: int = self.sp_model.vocab_size()
        self.bos_id: int = self.sp_model.bos_id()
        self.eos_id: int = self.sp_model.eos_id()
        self.pad_id: int = self.sp_model.pad_id()
        logger.info(
            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
        )
        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
        assert type(s) is str
        t = self.sp_model.encode(s)
        if bos:
            t = [self.bos_id] + t
        if eos:
            t = t + [self.eos_id]
        return t

    def decode(self, t: List[int]) -> str:
        return self.sp_model.decode(t)

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = Tokenizer('./LLaMA_tokenizer.model')

model = AutoModelForCausalLM.from_pretrained("eachadea/vicuna-13b-1.1", device_map='auto')

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
def gen_answer(model, tokenizer, prompt, subs):
    input_text = prompt.substitute(**subs)
    
    conv = Conversation(
                system="A chat between a curious user and an artificial intelligence assistant. "
                "The assistant gives helpful, detailed, and polite answers to the user's questions.",
                roles=("USER", "ASSISTANT"),
                messages=[],
                offset=0,
                sep_style=SeparatorStyle.TWO,
                sep=" ",
                sep2="</s>",
            )
    
    conv.append_message("USER", input_text)
    conv.append_message("ASSISTANT", None)
    
    input_text = conv.get_prompt()
    
    tokenized_input_text = torch.tensor([tokenizer.encode(input_text, False, False)]).to(model.device)
    
    generated_indices = model.generate(input_ids=tokenized_input_text, max_new_tokens=512, 
                                       temperature=.9, do_sample=True).detach().cpu()
    facts = tokenizer.decode(generated_indices[0].tolist())
    
    prompt_result = facts[len(input_text):]
    
    return prompt_result

In [4]:
def gen_speakers_facts(**kwargs):
    speakers = kwargs['speakers']
    dialog = kwargs['dialog']
    prompt = kwargs['prompt']

    speakers_facts = dict()
    for speaker in speakers:
        speakers_facts[speaker] = gen_answer(model, tokenizer, prompt, 
                                             {'speaker': speaker, 'dialog': dialog})
    return speakers_facts

In [5]:
def gen_summarization(**kwargs):
    dialog = kwargs['dialog']
    prompt = kwargs['prompt']

    summarization = {'Summarization': gen_answer(model, tokenizer, prompt, {'dialog': dialog})}
    return summarization

In [6]:
prompts = {
    'facts': {
        'prompt':
        Template("""$dialog\nPlease, write only relevant facts about $speaker in number list."""),
        'method': gen_speakers_facts
    },
    'summarization': {
        'prompt':
        Template('$dialog\nPlease, write 3 the most important and relevant bulletpoints that summarize conversation.'),
        'method': gen_summarization
    }
}

In [7]:
import json

with open('./sample_data_summ_sess.json', 'r') as f:
    dialogs = json.loads(f.read())

speakers = ['bot_0', 'bot_1']

In [8]:
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any


class SeparatorStyle(Enum):
    """Different separator style."""

    SINGLE = auto()
    TWO = auto()
    DOLLY = auto()
    OASST_PYTHIA = auto()


@dataclasses.dataclass
class Conversation:
    """A class that keeps all conversation history."""

    system: str
    roles: List[str]
    messages: List[List[str]]
    offset: int
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
    sep: str = "###"
    sep2: str = None

    # Used for gradio server
    skip_next: bool = False
    conv_id: Any = None

    def get_prompt(self):
        if self.sep_style == SeparatorStyle.SINGLE:
            ret = self.system
            for role, message in self.messages:
                if message:
                    ret += self.sep + " " + role + ": " + message
                else:
                    ret += self.sep + " " + role + ":"
            return ret
        elif self.sep_style == SeparatorStyle.TWO:
            seps = [self.sep, self.sep2]
            ret = self.system + seps[0]
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + ": " + message + seps[i % 2]
                else:
                    ret += role + ":"
            return ret
        elif self.sep_style == SeparatorStyle.DOLLY:
            seps = [self.sep, self.sep2]
            ret = self.system
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + ":\n" + message + seps[i % 2]
                    if i % 2 == 1:
                        ret += "\n\n"
                else:
                    ret += role + ":\n"
            return ret
        elif self.sep_style == SeparatorStyle.OASST_PYTHIA:
            ret = self.system
            for role, message in self.messages:
                if message:
                    ret += role + message + self.sep
                else:
                    ret += role
            return ret
        else:
            raise ValueError(f"Invalid style: {self.sep_style}")

    def append_message(self, role, message):
        self.messages.append([role, message])

    def to_gradio_chatbot(self):
        ret = []
        for i, (role, msg) in enumerate(self.messages[self.offset :]):
            if i % 2 == 0:
                ret.append([msg, None])
            else:
                ret[-1][-1] = msg
        return ret

    def copy(self):
        return Conversation(
            system=self.system,
            roles=self.roles,
            messages=[[x, y] for x, y in self.messages],
            offset=self.offset,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2,
            conv_id=self.conv_id,
        )

    def dict(self):
        return {
            "system": self.system,
            "roles": self.roles,
            "messages": self.messages,
            "offset": self.offset,
            "sep": self.sep,
            "sep2": self.sep2,
            "conv_id": self.conv_id,
        }

In [9]:
#Please, write 3 the most important and relevant bulletpoints that summarize conversation.

In [10]:
from tqdm import tqdm
from os.path import exists

transform_txt_dialog = lambda x: x['id'] + ': ' + x['text']

stages = ['summarization']

results = []

OUTPUT_PATH = f'./sess_outputs.json'
output_dialogs = list()
if exists(OUTPUT_PATH):
    with open(OUTPUT_PATH, 'r') as f:
        output_dialogs = json.loads(f.read())

start_position = len(output_dialogs)
print(start_position)

0


In [11]:
for dialog in tqdm(dialogs[start_position:start_position+57]):
    for stage in stages:
        prompt = prompts[stage]['prompt']
        method = prompts[stage]['method']
        for session, sess_dialog in dialog.items():
            prompt_result = method(model=model, tokenizer=tokenizer, prompt=prompt, 
                                dialog='\n'.join(map(transform_txt_dialog, sess_dialog['dialog'])), 
                                speakers=speakers)
            sess_dialog[stage] = prompt_result
    output_dialogs.append(dialog)
    with open(OUTPUT_PATH, 'w') as f:
        json.dump(output_dialogs, f)

100%|██████████| 57/57 [1:29:28<00:00, 94.19s/it] 
