In [None]:
from transformers import AutoTokenizer
from datasets import load_dataset
from models.slm import ScalableLM
from models.config import LlamaCLConfig
import torch

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
dataset = load_dataset("AdiOO7/llama-2-finance", split='train')
model = ScalableLM(LlamaCLConfig())
model.load_state_dict(torch.load("/home/user/bhpeng/SLM-llama/outputs/finance/finance/checkpoint-222/pytorch_model_only_retriever.bin"), strict=False)

In [None]:
import re
PROMPT_TEMPLATE = {
'with_sys': """[INST] <<SYS>>
{instruction}
<</SYS>>

{input} [/INST] """,

'without_sys': """[INST] {input} [/INST] """
}
def extract_from_finance(example):
    text = example['text']
    m = re.search('### Instruction:(.+?)### Human:', text)
    instruction = m.group(1).strip() if m else None
    m = re.search('### Human:(.+?)### Assistant:', text)
    input = m.group(1).strip() if m else None
    output = text.split("### Assistant:")[1].strip()
    return dict(
        instruction=instruction,
        input=input,
        output=output
    )
def prepare_data(example):
    example = extract_from_finance(example)
    raw_text = example['instruction'] + " ### " + example['input'] if 'instruction' in example.keys() else \
                   example['input']
    source = PROMPT_TEMPLATE['with_sys'].format_map(example) if 'instruction' in example.keys() else\
                 PROMPT_TEMPLATE['without_sys'].format_map(example)
    source = tokenizer(
            source,
            max_length=512,
            truncation=True,
            return_tensors='pt'
        )
    task = 'finance'
    return dict(
        raw_text=raw_text,
        input_ids=source['input_ids'],
        attention_mask=source['attention_mask'],
        task=task
    ), example['output'] + "</s>"

In [None]:
import random
inps, outs = prepare_data(dataset[random.randint(0, 1000)])
out = model.generate(**inps)
print(outs)
print(tokenizer.decode(out[0]))