In [None]:
%matplotlib inline
import os
import pandas as pd
from llama_4bit_wrapper import import_llama, lora_model_zeros_and_scales_to_half
from peft import LoraConfig, get_peft_model
from llama_memorizing_transformers.memory_collection import CosineKnnMemoryCollection
from llama_memorizing_transformers.context_choice import ContextChoiceLinear
from llama_memorizing_transformers.model_wrapper import replace_llama_layer_with_memory
from llama_memorizing_transformers.document_trainer import MemorizingLlamaDocumentTrainer
from torch.optim import Adam
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from itertools import chain
from torch.nn.functional import softmax
import gc
import openai

In [None]:
with open("09_train_longvicuna__pretrain_api_key.txt", "r") as src:
    openai.api_key = src.read()

In [None]:
RANDOM_STATE = 42

# Dataset
DATASET_PATH = "long-vicuna-set-lessgpt4all-vicuna13b-processed"

# Training procedure
CONTEXT_LENGTH = 512
CONTEXT_STEP = 256
PRETRAIN_LENGTH = 1024
# Model
COSINE_KNN_MAX_TEMPORARY_BUFFER_SIZE = 1024
REPLACE_LAYER = 21
BASE_MODEL = "../vicuna-13b-GPTQ-4bit-128g"
BASE_MODEL_WEIGHTS = "../vicuna-13b-GPTQ-4bit-128g/vicuna-13b-4bit-128g.safetensors"

USE_FP16 = True
LR_PRETRAIN = 3e-4

In [None]:
_, _, load_llama_model_4bit_low_ram, _, model_to_half, _, _, _, AMPWrapper = import_llama(
    use_flash_attention=False,
    use_xformers=False,
    autograd_4bit_cuda=False,
    autograd_4bit_triton=True,
)

## Generate memory-testing facts

In [None]:
def generate_memoryset_item():
    initial_message = "I am generating a dataset to evaluate other AI system memory. To do it I need your help.\n" + \
        "Write some fact (one) about one fake person. You should mention this person name inside this fact. You shouldn't mention it's fake person."
    question_message = "Now write some question like 'What do we know about John Smith'.\n" + \
        "Just replace John Smith with the name from the previously generated fact. Do not give any more hints."
    response = openai.ChatCompletion.create(
        messages=[
            {"role": "user", "content": initial_message}
        ],
        model="gpt-3.5-turbo"
    )
    fact = response.choices[0].message.content
    response = openai.ChatCompletion.create(
        messages=[
            {"role": "user", "content": initial_message},
            {"role": "assistant", "content": fact},
            {"role": "user", "content": question_message}
        ],
        model="gpt-3.5-turbo",
        temperature=0.7,
    )
    question = response.choices[0].message.content

    return fact, question

In [None]:
def generate_memoryset_facts(max_count):
    result = []
    for _ in tqdm(range(max_count)):
        try:
            fact, question = generate_memoryset_item()
            result.append({"fact": fact, "question": question})
        except:
            pass
    return pd.DataFrame.from_records(result)

In [None]:
if not os.path.exists("09_train_longvicuna__pretrain_memoryset_facts.pkl"):
    df_memoryset_facts = generate_memoryset_facts(100)
    df_memoryset_facts = df_memoryset_facts.loc[df_memoryset_facts["question"].str.len() <= 60]
    df_memoryset_facts.to_pickle("09_train_longvicuna__pretrain_memoryset_facts.pkl")
else:
    df_memoryset_facts = pd.read_pickle("09_train_longvicuna__pretrain_memoryset_facts.pkl")
df_memoryset_facts.head()

## Load model

In [None]:
model, tokenizer = load_llama_model_4bit_low_ram(
    config_path=BASE_MODEL,
    model_path=BASE_MODEL_WEIGHTS,
    groupsize=128,
    is_v1_model=False,
)
tokenizer.pad_token_id = 0

context_choice = ContextChoiceLinear(model.config.num_attention_heads,
                                     model.config.hidden_size)
memory = CosineKnnMemoryCollection(COSINE_KNN_MAX_TEMPORARY_BUFFER_SIZE,
                                   remember_until_position=0)
model.model = replace_llama_layer_with_memory(
    model.model,
    REPLACE_LAYER,
    context_choice,
    memory,
)
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM",
)
lora_model = get_peft_model(model, lora_config)
lora_model = lora_model_zeros_and_scales_to_half(lora_model)
lora_model.config.use_cache = False
wrapper = AMPWrapper(lora_model)
wrapper.apply_forward()
wrapper.apply_generate()

In [None]:
lora_model.cpu()
gc.collect()
torch.cuda.empty_cache()

In [None]:
lora_model.load_state_dict(torch.load(
    "long-vicuna--pretrain--state-dict--checkpoint-500.pth",
    map_location=torch.device("cpu")
))
lora_model.cuda()

In [None]:
def generate_response(lora_model, memory, input_ids):
    lora_model.eval()
    memory.reset()
    memory.remember_until_position = input_ids.shape[1]
    start = 0
    with torch.no_grad():
        while True:
            block_input_ids = input_ids[:, start:start + CONTEXT_LENGTH].to(lora_model.device)
            if block_input_ids.shape[1]:
                lora_model(input_ids=block_input_ids)
            else:
                break
            start += CONTEXT_STEP
    last_block_ids = input_ids[:, -CONTEXT_STEP:].to(lora_model.device)
    generated = lora_model.generate(inputs=last_block_ids,
                               do_sample=True,
                               use_cache=False,
                               repetition_penalty=1.1,
                               max_new_tokens=100,
                               temperature=0.9,
                               top_p=0.95,
                               top_k=40,
                               return_dict_in_generate=True,
                               output_attentions=False,
                               output_hidden_states=False,
                               output_scores=False)
    return generated.sequences[0][-100:]

In [None]:
question = df_memoryset_facts["question"].values[0]
text = "<msg_prompter> " + "\n\n".join(df_memoryset_facts["fact"]) + "\n\n<msg_prompter> Now answer the following question: " + \
     question + "\n<msg_assistant> "
input_ids = torch.LongTensor([tokenizer(text)["input_ids"]])

generated = generate_response(lora_model, memory, input_ids)
print(question)
print(tokenizer.decode(generated))

In [None]:
print("\n".join(df_memoryset_facts.loc[df_memoryset_facts["fact"].str.lower().str.contains("samantha"), "fact"]))

In [None]:
df_memoryset_facts = df_memoryset_facts.sample(len(df_memoryset_facts), random_state=42)

question = df_memoryset_facts["question"].values[0]
text = "<msg_prompter> " + "\n\n".join(df_memoryset_facts["fact"]) + "\n\n<msg_prompter> Now answer the following question: " + \
     question + "\n<msg_assistant> "
input_ids = torch.LongTensor([tokenizer(text)["input_ids"]])

generated = generate_response(lora_model, memory, input_ids)
print(question)
print(tokenizer.decode(generated))

In [None]:
print("\n".join(df_memoryset_facts.loc[df_memoryset_facts["fact"].str.lower().str.contains("michael"), "fact"]))

In [None]:
df_memoryset_facts = df_memoryset_facts.sample(len(df_memoryset_facts), random_state=42)

question = df_memoryset_facts["question"].values[0]
text = "<msg_prompter> " + "\n\n".join(df_memoryset_facts["fact"]) + "\n\n<msg_prompter> Now answer the following question: " + \
     question + "\n<msg_assistant> "
input_ids = torch.LongTensor([tokenizer(text)["input_ids"]])

generated = generate_response(lora_model, memory, input_ids)
print(question)
print(tokenizer.decode(generated))

In [None]:
print("\n".join(df_memoryset_facts.loc[df_memoryset_facts["fact"].str.lower().str.contains("sophie"), "fact"]))

In [None]:
df_memoryset_facts = df_memoryset_facts.sample(len(df_memoryset_facts), random_state=42)

question = df_memoryset_facts["question"].values[0]
text = "<msg_prompter> " + "\n\n".join(df_memoryset_facts["fact"]) + "\n\n<msg_prompter> Now answer the following question: " + \
     question + "\n<msg_assistant> "
input_ids = torch.LongTensor([tokenizer(text)["input_ids"]])

generated = generate_response(lora_model, memory, input_ids)
print(question)
print(tokenizer.decode(generated))

In [None]:
print("\n".join(df_memoryset_facts.loc[df_memoryset_facts["fact"].str.lower().str.contains("john"), "fact"]))