In [1]:
import os
import pandas as pd
from llama_4bit_wrapper import import_llama, lora_model_zeros_and_scales_to_half
from peft import LoraConfig, get_peft_model
from llama_memorizing_transformers.memory_collection import CosineKnnMemoryCollection
from llama_memorizing_transformers.context_choice import ContextChoiceLinear
from llama_memorizing_transformers.model_wrapper import replace_llama_layer_with_memory
from llama_memorizing_transformers.document_trainer import MemorizingLlamaDocumentTrainer
from torch.optim import Adam
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
RANDOM_STATE = 42

# Dataset
DATASET_PATH = "long-vicuna-set-lessgpt4all-vicuna13b-processed"

# Training procedure
CONTEXT_LENGTH = 512
CONTEXT_STEP = 256
PRETRAIN_LENGTH = 2048
PRETRAIN_DOCUMENTS = 2048

# Model
COSINE_KNN_MAX_TEMPORARY_BUFFER_SIZE = 1024
REPLACE_LAYER = 21
BASE_MODEL = "../vicuna-13b-GPTQ-4bit-128g"
BASE_MODEL_WEIGHTS = "../vicuna-13b-GPTQ-4bit-128g/vicuna-13b-4bit-128g.safetensors"

USE_FP16 = True
LR_PRETRAIN = 3e-4

In [3]:
_, _, load_llama_model_4bit_low_ram, _, _, _, apply_gradient_checkpointing, _, _ = import_llama(
    use_flash_attention=False,
    use_xformers=False,
    autograd_4bit_cuda=False,
    autograd_4bit_triton=True,
)

Using Triton implementation.


## Data reading

In [4]:
df_texts = pd.read_pickle(os.path.join(DATASET_PATH, "texts.pkl"))
df_texts.head()

Unnamed: 0,processed_text,input_ids,length
0,<msg_prompter> Can you write a short introduct...,"[529, 7645, 29918, 14032, 29886, 357, 29958, 1...",50
1,"<msg_assistant> ""Monopsony"" refers to a market...","[529, 7645, 29918, 465, 22137, 29958, 376, 718...",351
2,<msg_prompter> Now explain it to a dog,"[529, 7645, 29918, 14032, 29886, 357, 29958, 2...",13
3,<msg_assistant> Monopsony is a market structur...,"[529, 7645, 29918, 465, 22137, 29958, 2598, 45...",238
4,<msg_prompter> How can one fight back when a m...,"[529, 7645, 29918, 14032, 29886, 357, 29958, 1...",22


In [5]:
df_indices_train = pd.read_pickle(os.path.join(DATASET_PATH, "indices-train.pkl"))
df_indices_train.head()

Unnamed: 0,indices,source,session_length
0,"[0, 1, 2]",openassistant,414
1,"[0, 3, 4]",openassistant,310
2,"[0, 5, 6, 7]",openassistant,426
3,"[0, 5, 6, 8]",openassistant,595
4,"[0, 5, 6, 9]",openassistant,334


In [6]:
df_indices_validation = pd.read_pickle(os.path.join(DATASET_PATH, "indices-validation.pkl"))
df_indices_validation.head()

Unnamed: 0,indices,source,session_length
0,"[82483, 82484]",openassistant,302
1,"[82483, 82485]",openassistant,218
2,"[82483, 82486]",openassistant,79
3,"[82487, 82488]",openassistant,561
4,"[82487, 82489, 82490, 82491]",openassistant,546


## Model preparation

In [7]:
model, tokenizer = load_llama_model_4bit_low_ram(
    config_path="../vicuna-13b-GPTQ-4bit-128g/",
    model_path="../vicuna-13b-GPTQ-4bit-128g/vicuna-13b-4bit-128g.safetensors",
    groupsize=128,
    is_v1_model=False,
)
tokenizer.pad_token_id = 0

Loading Model ...


The safetensors archive passed at ../vicuna-13b-GPTQ-4bit-128g/vicuna-13b-4bit-128g.safetensors does not contain metadata. Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata.


Loaded the model in 3.45 seconds.


In [8]:
context_choice = ContextChoiceLinear(model.config.num_attention_heads,
                                     model.config.hidden_size)

In [9]:
memory = CosineKnnMemoryCollection(COSINE_KNN_MAX_TEMPORARY_BUFFER_SIZE,
                                   remember_until_position=0)

In [10]:
model.model = replace_llama_layer_with_memory(
    model.model,
    REPLACE_LAYER,
    context_choice,
    memory,
)

In [11]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM",
)
lora_model = get_peft_model(model, lora_config)
lora_model = lora_model_zeros_and_scales_to_half(lora_model)

In [12]:
apply_gradient_checkpointing(lora_model)

Forward Patch Applied For Block 0
Forward Patch Applied For Block 1
Forward Patch Applied For Block 2
Forward Patch Applied For Block 3
Forward Patch Applied For Block 4
Forward Patch Applied For Block 5
Forward Patch Applied For Block 6
Forward Patch Applied For Block 7
Forward Patch Applied For Block 8
Forward Patch Applied For Block 9
Forward Patch Applied For Block 10
Forward Patch Applied For Block 11
Forward Patch Applied For Block 12
Forward Patch Applied For Block 13
Forward Patch Applied For Block 14
Forward Patch Applied For Block 15
Forward Patch Applied For Block 16
Forward Patch Applied For Block 17
Forward Patch Applied For Block 18
Forward Patch Applied For Block 19
Forward Patch Applied For Block 20
Forward Patch Applied For Block 21
Forward Patch Applied For Block 22
Forward Patch Applied For Block 23
Forward Patch Applied For Block 24
Forward Patch Applied For Block 25
Forward Patch Applied For Block 26
Forward Patch Applied For Block 27
Forward Patch Applied For Bloc

([<alpaca_lora_4bit.gradient_checkpointing.NewForward at 0x7f888c384310>,
  <alpaca_lora_4bit.gradient_checkpointing.NewForward at 0x7f888f496e10>,
  <alpaca_lora_4bit.gradient_checkpointing.NewForward at 0x7f888c27f690>,
  <alpaca_lora_4bit.gradient_checkpointing.NewForward at 0x7f8890625c50>,
  <alpaca_lora_4bit.gradient_checkpointing.NewForward at 0x7f888f469810>,
  <alpaca_lora_4bit.gradient_checkpointing.NewForward at 0x7f888c285a50>,
  <alpaca_lora_4bit.gradient_checkpointing.NewForward at 0x7f888c2652d0>,
  <alpaca_lora_4bit.gradient_checkpointing.NewForward at 0x7f888f497550>,
  <alpaca_lora_4bit.gradient_checkpointing.NewForward at 0x7f888c2fd050>,
  <alpaca_lora_4bit.gradient_checkpointing.NewForward at 0x7f888c287710>,
  <alpaca_lora_4bit.gradient_checkpointing.NewForward at 0x7f888c2d6950>,
  <alpaca_lora_4bit.gradient_checkpointing.NewForward at 0x7f888c2d6cd0>,
  <alpaca_lora_4bit.gradient_checkpointing.NewForward at 0x7f888c249450>,
  <alpaca_lora_4bit.gradient_checkpoin

In [13]:
lora_model.config.use_cache = False

## Tokenized texts checking

In [14]:
df_texts.sample(5, random_state=RANDOM_STATE)

Unnamed: 0,processed_text,input_ids,length
110934,<msg_prompter> Below is an instruction that de...,"[529, 7645, 29918, 14032, 29886, 357, 29958, 1...",59
74656,"<msg_prompter> Куда падает ударение в слове ""т...","[529, 7645, 29918, 14032, 29886, 357, 29958, 7...",23
39950,<msg_prompter> А можешь побольше рассказать о ...,"[529, 7645, 29918, 14032, 29886, 357, 29958, 1...",48
156432,<msg_prompter> Below is an instruction that de...,"[529, 7645, 29918, 14032, 29886, 357, 29958, 1...",72
174184,<msg_prompter> Below is an instruction that de...,"[529, 7645, 29918, 14032, 29886, 357, 29958, 1...",57


In [15]:
df_texts.sample(5, random_state=RANDOM_STATE)["input_ids"].apply(
    lambda item: tokenizer.decode([tokenizer.bos_token_id] + list(item) + [tokenizer.eos_token_id])
)

110934    <s> <msg_prompter> Below is an instruction tha...
74656     <s> <msg_prompter> Куда падает ударение в слов...
39950     <s> <msg_prompter> А можешь побольше рассказат...
156432    <s> <msg_prompter> Below is an instruction tha...
174184    <s> <msg_prompter> Below is an instruction tha...
Name: input_ids, dtype: object

In [16]:
df_indices_train.loc[df_indices_train["session_length"] <= PRETRAIN_LENGTH].groupby("source")["session_length"].quantile(0.5)

source
alpaca            125.0
booksum          1311.5
govreport        1548.0
gpt4all           300.0
openassistant     388.0
qasper           1652.0
Name: session_length, dtype: float64

### Pretraining

In [17]:
optimizer = Adam(
    lora_model.parameters(),
    lr=LR_PRETRAIN,
)

In [18]:
df_indices_pretrain = df_indices_train\
    .loc[(df_indices_train["session_length"] <= PRETRAIN_LENGTH) & (df_indices_train["session_length"] >CONTEXT_LENGTH)]\
    .sample(PRETRAIN_DOCUMENTS, random_state=RANDOM_STATE)
df_indices_pretrain.head()

Unnamed: 0,indices,source,session_length
12096,"[23183, 23189, 23190, 23195, 23196]",openassistant,1594
286331,"[795568, 795569]",gpt4all,972
150295,"[244918, 244919]",gpt4all,584
294713,"[952816, 952817]",gpt4all,1005
355534,"[301340, 301341]",gpt4all,874


In [19]:
_batch_counter = 0

def batch_counter_update() -> int:
    global _batch_counter
    _batch_counter += 1
    return _batch_counter

In [20]:
log_writer = SummaryWriter("long-vicuna--pretrain--tensorboard")
document_trainer = MemorizingLlamaDocumentTrainer(
    model=lora_model,
    tokenizer=tokenizer,
    memory=memory,
    tokens_per_chunk=CONTEXT_LENGTH,
    tokens_step=CONTEXT_STEP,
    optimizer=optimizer,
    scheduler=None,
    accumulate_gradients=1,
    float16=USE_FP16,
    train_callback=lambda document_index, document_batch, loss: log_writer.add_scalar("Loss/pretrain", loss, batch_counter_update()),
    eval_callback=None,
)
for i, indices in enumerate(tqdm(df_indices_pretrain["indices"])):
    main_document_index = indices[0]
    rest_session_index = indices[1:]
    document_tokens = torch.LongTensor(df_texts.loc[main_document_index, "input_ids"].astype(np.int32))
    rest_session_tokens = torch.cat([
        torch.LongTensor(array.astype(np.int32))
        for array in df_texts.loc[rest_session_index, "input_ids"]
    ]).view((1, -1))
    document_trainer.train_document(
        document_tokens=document_tokens,
        prompt_tokens=rest_session_tokens,
        sample_weight=1.0,
        callback_kwargs={
            "document_index": i,
        }
    )

100%|██████████| 2048/2048 [7:46:25<00:00, 13.66s/it]  


In [22]:
torch.save(lora_model.state_dict(), "long-vicuna--pretrain--state-dict.pth")