In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

#Load the model and Tokenizer
model_id = "google/gemma-2b-it"
#
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
def generate_prompt(data_point):
    """Gen. input text based on a prompt, task instruction, (context info.), and answer

    :param data_point: dict: Data point
    :return: dict: tokenzed prompt
    """

    # Generate prompt
    prefix_text = 'Below is an instruction that describes a task. Write a response that ' \
               'appropriately completes the request.\n\n'
    # Samples with additional context into.
    if data_point['input']:
        text = f"""<start_of_turn>user {prefix_text} {data_point["instruction"]} here are the inputs {data_point["input"]} <end_of_turn>\n<start_of_turn>model{data_point["output"]} <end_of_turn>"""
    # Without
    else:
        text = f"""<start_of_turn>user {prefix_text} {data_point["instruction"]} <end_of_turn>\n<start_of_turn>model{data_point["output"]} <end_of_turn>"""
    return text


Dataset({
    features: ['instruction', 'output', 'input', 'prompt'],
    num_rows: 3702
})

In [8]:
from peft import LoraConfig, PeftModel, get_peft_model
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="mps")
model.gradient_checkpointing_enable()
#
print(model)

def find_all_linear_names(model):
  cls = torch.nn.Linear
  lora_module_names = set()
  for name, module in model.named_modules():
    if isinstance(module, cls):
      names = name.split('.')
      lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names: # needed for 16-bit
      lora_module_names.remove('lm_head')
  return list(lora_module_names)
#
modules = find_all_linear_names(model)

limit=4

lora_config = LoraConfig(
    r=4,
    lora_alpha=2,
    target_modules=modules if len(modules) < limit else modules[:limit],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.95s/it]


GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): GELUActivation()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaRM

  warn("The installed version of bitsandbytes was compiled without GPU support. "


In [10]:
import transformers

from trl import SFTTrainer

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side='right'


trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=test_data,
    dataset_text_field="prompt",
    peft_config=lora_config,
    max_seq_length=2500,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=0.03,
        max_steps=100,
        learning_rate=2e-4,
        logging_steps=1,
        output_dir="outputs",
        optim="adamw_torch",
        save_strategy="epoch",
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
#
trainer.train()

Map: 100%|██████████| 3331/3331 [00:00<00:00, 6973.39 examples/s]
Map: 100%|██████████| 371/371 [00:00<00:00, 7403.70 examples/s]
  0%|          | 0/100 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  1%|          | 1/100 [00:12<20:03, 12.15s/it]

{'loss': 3.3987, 'grad_norm': 0.43529394268989563, 'learning_rate': 0.00019805941782534764, 'epoch': 0.0}


  2%|▏         | 2/100 [00:22<17:53, 10.96s/it]

{'loss': 3.4251, 'grad_norm': 0.4118957221508026, 'learning_rate': 0.00019605881764529358, 'epoch': 0.0}


  3%|▎         | 3/100 [00:32<17:10, 10.62s/it]

{'loss': 3.2128, 'grad_norm': 0.41236865520477295, 'learning_rate': 0.00019405821746523957, 'epoch': 0.0}


  4%|▍         | 4/100 [00:43<17:25, 10.89s/it]

{'loss': 3.2085, 'grad_norm': 0.523223340511322, 'learning_rate': 0.00019205761728518557, 'epoch': 0.0}


  5%|▌         | 5/100 [00:52<16:15, 10.27s/it]

{'loss': 3.3047, 'grad_norm': 0.7127281427383423, 'learning_rate': 0.00019005701710513156, 'epoch': 0.01}


  6%|▌         | 6/100 [01:04<16:31, 10.55s/it]

{'loss': 3.0771, 'grad_norm': 0.7587288618087769, 'learning_rate': 0.00018805641692507753, 'epoch': 0.01}


  7%|▋         | 7/100 [01:12<15:10,  9.79s/it]

{'loss': 3.3018, 'grad_norm': 1.0235854387283325, 'learning_rate': 0.00018605581674502352, 'epoch': 0.01}


  8%|▊         | 8/100 [01:21<14:58,  9.76s/it]

{'loss': 2.9145, 'grad_norm': 0.9993389844894409, 'learning_rate': 0.00018405521656496952, 'epoch': 0.01}


  9%|▉         | 9/100 [01:34<15:58, 10.53s/it]

{'loss': 2.8404, 'grad_norm': 1.0065335035324097, 'learning_rate': 0.00018205461638491548, 'epoch': 0.01}


 10%|█         | 10/100 [01:47<16:57, 11.30s/it]

{'loss': 2.9337, 'grad_norm': 1.116021990776062, 'learning_rate': 0.00018005401620486148, 'epoch': 0.01}


 11%|█         | 11/100 [01:55<15:24, 10.38s/it]

{'loss': 2.8629, 'grad_norm': 1.2370742559432983, 'learning_rate': 0.00017805341602480744, 'epoch': 0.01}


 12%|█▏        | 12/100 [02:07<16:03, 10.95s/it]

{'loss': 2.4597, 'grad_norm': 0.9787382483482361, 'learning_rate': 0.00017605281584475344, 'epoch': 0.01}


 13%|█▎        | 13/100 [02:18<15:47, 10.89s/it]

{'loss': 2.4824, 'grad_norm': 1.0393227338790894, 'learning_rate': 0.00017405221566469943, 'epoch': 0.02}


 14%|█▍        | 14/100 [02:35<18:16, 12.74s/it]

{'loss': 2.4507, 'grad_norm': 1.1128212213516235, 'learning_rate': 0.00017205161548464542, 'epoch': 0.02}


 15%|█▌        | 15/100 [02:45<16:39, 11.76s/it]

{'loss': 2.5095, 'grad_norm': 1.1867575645446777, 'learning_rate': 0.00017005101530459136, 'epoch': 0.02}


 16%|█▌        | 16/100 [02:56<16:09, 11.54s/it]

{'loss': 2.2462, 'grad_norm': 0.7535495758056641, 'learning_rate': 0.00016805041512453736, 'epoch': 0.02}


 17%|█▋        | 17/100 [03:07<15:53, 11.49s/it]

{'loss': 2.2696, 'grad_norm': 1.0727155208587646, 'learning_rate': 0.00016604981494448335, 'epoch': 0.02}


 18%|█▊        | 18/100 [03:17<15:09, 11.09s/it]

{'loss': 2.4703, 'grad_norm': 1.1087111234664917, 'learning_rate': 0.00016404921476442935, 'epoch': 0.02}


 19%|█▉        | 19/100 [03:37<18:35, 13.77s/it]

{'loss': 2.2819, 'grad_norm': 0.8488591313362122, 'learning_rate': 0.0001620486145843753, 'epoch': 0.02}


 20%|██        | 20/100 [03:54<19:39, 14.75s/it]

{'loss': 2.1958, 'grad_norm': 1.0698051452636719, 'learning_rate': 0.0001600480144043213, 'epoch': 0.02}


 21%|██        | 21/100 [04:07<18:33, 14.10s/it]

{'loss': 2.2579, 'grad_norm': 0.9298002123832703, 'learning_rate': 0.0001580474142242673, 'epoch': 0.03}


 22%|██▏       | 22/100 [04:32<22:40, 17.44s/it]

{'loss': 2.0994, 'grad_norm': 0.7968690395355225, 'learning_rate': 0.00015604681404421327, 'epoch': 0.03}


 23%|██▎       | 23/100 [05:01<26:52, 20.94s/it]

{'loss': 2.2638, 'grad_norm': 0.9316089153289795, 'learning_rate': 0.00015404621386415926, 'epoch': 0.03}


 24%|██▍       | 24/100 [05:12<22:50, 18.03s/it]

{'loss': 2.1221, 'grad_norm': 0.8660854697227478, 'learning_rate': 0.00015204561368410523, 'epoch': 0.03}


 25%|██▌       | 25/100 [05:48<28:59, 23.19s/it]

{'loss': 2.187, 'grad_norm': 0.7943320274353027, 'learning_rate': 0.00015004501350405122, 'epoch': 0.03}


 26%|██▌       | 26/100 [06:06<27:02, 21.92s/it]

{'loss': 1.9732, 'grad_norm': 0.8320187330245972, 'learning_rate': 0.00014804441332399721, 'epoch': 0.03}


KeyboardInterrupt: 