In [1]:
from huggingface_hub import notebook_login
from datasets import load_dataset, Dataset

In [2]:
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer
from trl.trainer import ConstantLengthDataset

In [3]:
#notebook_login()

In [4]:
eli5: Dataset = load_dataset("eli5_category", split="train[:5000]", trust_remote_code=True)

In [5]:
eli5_train_test = eli5.train_test_split(test_size=0.2)

In [6]:
eli5_train_test["train"][0]

{'q_id': '7ggicc',
 'title': 'How can a coat check say they are not responsible for damage or loss to my coat?',
 'selftext': '',
 'category': 'Other',
 'subreddit': 'explainlikeimfive',
 'answers': {'a_id': ['dqiwen2'],
  'text': ["Because it doesn't really mean anything legally. Statements like this, signs in parking lots, and written liability waivers are largely for show. Were you to ignore the waiver, and sue for the cost of your lost jacket, the judge would not stop you because of their claim. They could still absolutely be liable for replacing the lost jacket. So why even bother having waivers in the first place? Because it makes customers think that they have no legal recourse, so they don't push the issue after an item is stollen or someone gets hurt. It also shows the potential judge that the customer was warned of the potential danger before participating, which can be a helpful defense in some personal injury cases."],
  'score': [12],
  'text_urls': [[]]},
 'title_urls': [

In [7]:
model_id = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,#todo explore whether it actually increases overall memory usage to change this to float32
    bnb_4bit_use_double_quant=True# reminder that this makes computing the total memory used by the frozen weights even more complicated, something about reducing the size of the quantization constants that are used for remembering how to dequantize a given block of quantized weights? by the equivalent of 0.4 bits per parameter, per docs
    
)

In [8]:
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map="auto")


Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [9]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [10]:
# tokenizer.padding_side = "right" #todo based on runtime warning while building SFTTrainer, can try this if problems occur, but by default I think I should respect the default for the Gemma Tokenizer

In [11]:
lora_config = LoraConfig(
    r=4,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", # just lora'ing attention heads for now, to mimic original LoRA paper
                    #"gate_proj", "up_proj", "down_proj" todo what is gate_proj? I think up_proj is first weights matrix of MLP block (fan out) and down_proj is second weights matrix of MLP block (fan in), but no idea what gate_proj is
                    ],# reminder, can use "all-linear" (not inside list) for the expansive case https://huggingface.co/docs/peft/developer_guides/lora#qlora-style-training
    task_type="CAUSAL_LM",
    use_rslora=True
    #todo investigate whether it's worth trying Dora, iirc that was said to be especially helpful when lora rank is low
)

In [12]:
seq_len = 128

In [13]:
# tokenizer

In [14]:
eli5_train_test = eli5_train_test.flatten()

In [15]:
# eli5_train_test["train"][0]

In [16]:
def fix_data(record):
    '''
    make dataset usable by TRL (i.e. its classes have a dataset_text_field param, and that column must be string-type, 
    not list<string> type)
    :param record: record where the text column is actually a length-1 list column
    :return: record where the text column is straightforwardly a text-type column
    '''
    record["answers.text"] = record["answers.text"][0]
    return record

In [17]:
eli5_train_test = eli5_train_test.map(fix_data)

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [18]:
fixed_len_train_dset =  ConstantLengthDataset(tokenizer, eli5_train_test["train"], "answers.text", seq_length=seq_len)
fixed_len_eval_dset =  ConstantLengthDataset(tokenizer, eli5_train_test["test"], "answers.text", seq_length=seq_len)

In [19]:
# fixed_len_train_dset.__iter__().__next__()

In [20]:
trainer = SFTTrainer(
    model=model,
    train_dataset=fixed_len_train_dset,
    eval_dataset=fixed_len_eval_dset,
    args=TrainingArguments(
        per_device_train_batch_size=1,
        # gradient_accumulation_steps=4,#todo don't want to touch this until I understand it
        warmup_steps=2,
        max_steps=10,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_32bit"#can try paged_adamw_8bit in absolute worst case
    ),
    packing=True,
    # dataset_text_field="answers.text",
    peft_config=lora_config
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [21]:
#todo set up pytorch gpu mem tracking
torch.cuda.memory._record_memory_history()


#investigate whether WSL 2 could be used to get around the "no windows support for gpu mem visualization" issue



In [22]:
torch.cuda.empty_cache()

In [23]:
trainer.train()

Step,Training Loss
1,6.4495
2,5.7394
3,5.2198
4,5.1831
5,4.4235
6,5.8971
7,5.0102
8,4.2755
9,6.3148
10,5.8561


TrainOutput(global_step=10, training_loss=5.436914730072021, metrics={'train_runtime': 45.1367, 'train_samples_per_second': 0.222, 'train_steps_per_second': 0.222, 'total_flos': 15227950202880.0, 'train_loss': 5.436914730072021, 'epoch': 0.0})

In [24]:
import os

In [25]:
os.getcwd()

'/mnt/c/Users/ssili/PycharmProjects/local-finetuning-calculator'

In [26]:
torch.cuda.memory._dump_snapshot("gemma2b_2nd_mem_snapshot.pickle")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av