In [1]:

# !pip3 install -q -U bitsandbytes==0.42.0
# !pip3 install -q -U peft==0.8.2
# !pip3 install -q -U trl==0.7.10
# !pip3 install -q -U accelerate==0.27.1
# !pip3 install -q -U datasets==2.17.0
# !pip3 install -q -U transformers==4.38.1

In [1]:
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
model_id = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, token=os.environ['HF_TOKEN'])

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
text = "Plan a trip to China"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Plan a trip to China with the help of our China travel guide. We have included the best things to do in China, the best time to visit China, the best places to visit in China, and the best places to stay in China.

<h2><strong>China Travel Guide</strong></h2>

China is a country that is hard to describe. It is a country that is hard to describe. It is a country that is hard to describe. It is a country that is hard to describe. It is a country that is hard


In [3]:
os.environ["WANDB_DISABLED"] = "true"

In [4]:
from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

In [5]:
import transformers

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        if example['context'][i] == "":
            text = f"### instruction: {example['instruction'][i]}\n ### output: {example['response'][i]} <eos>"
        else:
            text = f"### instruction: {example['instruction'][i]}\n ### input: {example['context'][i]}\n ### output: {example['response'][i]} <eos>"
        if len(tokenizer(text)["input_ids"]) <= 1024:
            output_texts.append(text)
    return output_texts

response_template = " ### output:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)


trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
    peft_config=lora_config,
    max_seq_length=1024,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=100,
        max_steps=1000, # only a demo
        num_train_epochs=3,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=100,
        output_dir="outputs",
        optim="paged_adamw_8bit",
    ),
)

trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Step,Training Loss
100,1.8074
200,1.4487
300,1.4709
400,1.4195
500,1.4119
600,1.4194
700,1.4361
800,1.445
900,1.3426
1000,1.461


Checkpoint destination directory outputs/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory outputs/checkpoint-1000 already exists and is non-empty. Saving will proceed but saved results may be invalid.


TrainOutput(global_step=1000, training_loss=1.4662649230957032, metrics={'train_runtime': 447.2686, 'train_samples_per_second': 8.943, 'train_steps_per_second': 2.236, 'total_flos': 7759896044912640.0, 'train_loss': 1.4662649230957032, 'epoch': 0.27})

In [6]:
text = "Plan a trip to China"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Plan a trip to China and you will be greeted with a plethora of options. From the bustling city of Shanghai to the ancient city of Xi’an, China has a lot to offer.

The country is home to some of the most beautiful and historic cities in the world.

The country is also home to some of the most beautiful and historic cities in the world.

The country is home to some of the most beautiful and historic cities in the world.

The country is home to some of the most beautiful and


In [7]:
peft_model_name = "gemma-2b-lora-short"
model = trainer.model
model.save_pretrained(peft_model_name)

In [8]:
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer


os.environ["CUDA_VISIBLE_DEVICES"] = "0"
model_id = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, token=os.environ['HF_TOKEN'])


from peft import get_peft_model, LoraConfig, PeftModel, PeftConfig
peft_model_name = "gemma-2b-lora-short"
model = PeftModel.from_pretrained(model, peft_model_name, quantization_config=bnb_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [9]:
text = "### instruction:: Plan a trip to China \n ### output: "
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

### instruction:: Plan a trip to China 
 ### output: 1. Plan your trip to China 
2. Book your flight 
3. Book your hotel 
4. Book your tours 
5. Book your tickets 
6. Book your transportation 
7. Book your tours 


In [11]:
# !nvidia-smi