In [1]:
!pip install -q transformers peft datasets accelerate bitsandbytes trl>=0.8.0 sentencepiece

In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import torch

In [3]:
ds = load_dataset("databricks/databricks-dolly-15k")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [4]:
def format_sample(example):
    return {
        "text": f"Instruction:\n{example['instruction']}\n\nResponse:\n{example['response']}"
    }

In [5]:
ds=ds.map(format_sample,remove_columns=ds['train'].column_names)

In [6]:
train_data=ds['train']

In [7]:
train_data['text']

Column(['Instruction:\nWhen did Virgin Australia start operating?\n\nResponse:\nVirgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.', 'Instruction:\nWhich is a species of fish? Tope or Rope\n\nResponse:\nTope', 'Instruction:\nWhy can camels survive for long without water?\n\nResponse:\nCamels use the fat in their humps to keep them filled with energy and hydration for long periods of time.', "Instruction:\nAlice's parents have three daughters: Amy, Jessy, and what’s the name of the third daughter?\n\nResponse:\nThe name of the third daughter is Alice", 'Instruction:\nWhen was Tomoaki Komorida born?\n\nResponse:\nTomoaki Komorida was born on July 10,1981.'])

In [8]:
train_data = train_data.shuffle(seed=42).select(range(2000))

In [9]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [11]:
model_id='google/gemma-2b'

In [12]:
tokenizer=AutoTokenizer.from_pretrained(model_id,use_fast=True)

In [13]:
model=AutoModelForCausalLM.from_pretrained(model_id,device_map="auto",load_in_4bit=True,dtype=torch.float16)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [14]:
model

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear4bit(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear4bit(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear4bit(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear4bit(in_features=16384, out_features=2048, bias=False)
          (act_fn): GELUActivation()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
      )
    )
    (n

In [15]:
lora_config=LoraConfig(r=8,
lora_alpha=16,
lora_dropout=0.05,task_type='CAUSAL_LM',target_modules=['q_proj','k_proj','v_proj','o_proj'])

In [16]:
args = TrainingArguments(
    output_dir="gemma-dolly-lora",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    num_train_epochs=1,
    learning_rate=2e-4,
    optim="paged_adamw_8bit",
    logging_steps=25,
    save_steps=500,
    report_to="none",
)


In [17]:
def formatting_func(example):
    return example["text"][:512]

trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    peft_config=lora_config,
    args=args,
    formatting_func=formatting_func,
)


In [18]:
trainer.train()

Step,Training Loss
25,2.599
50,2.2657
75,2.2412
100,2.1352
125,2.2357
150,2.2644
175,2.1232
200,2.295
225,2.1003
250,2.2137


TrainOutput(global_step=1000, training_loss=2.1642517013549805, metrics={'train_runtime': 1082.2904, 'train_samples_per_second': 1.848, 'train_steps_per_second': 0.924, 'total_flos': 1754146851545088.0, 'train_loss': 2.1642517013549805, 'epoch': 1.0})

In [19]:
trainer.model.save_pretrained("gemma-dolly-lora")
tokenizer.save_pretrained("gemma-dolly-lora")

('gemma-dolly-lora/tokenizer_config.json',
 'gemma-dolly-lora/special_tokens_map.json',
 'gemma-dolly-lora/tokenizer.model',
 'gemma-dolly-lora/added_tokens.json',
 'gemma-dolly-lora/tokenizer.json')

In [20]:
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM

base = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
model = PeftModel.from_pretrained(base, "gemma-dolly-lora").eval()
tok = AutoTokenizer.from_pretrained(model_id)

prompt = "Summarize the importance of data preprocessing in machine learning."
inputs = tok(f"Instruction:\n{prompt}\n\nResponse:\n", return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=150)
print(tok.decode(outputs[0], skip_special_tokens=True))


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Instruction:
Summarize the importance of data preprocessing in machine learning.

Response:
Data preprocessing is an essential step in building a machine learning model. It involves cleaning, standardizing, and standardizing the data. It helps to remove outliers, missing values, and other errors in the data. It also helps to identify patterns and trends in the data.
