In [1]:
import transformers

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments

In [3]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

In [4]:
from peft.optimizers import create_loraplus_optimizer

In [5]:
import bitsandbytes as bnb

In [6]:
import datasets

In [7]:
from datasets import load_from_disk

In [8]:
import torch

In [9]:
config = LoraConfig(
    task_type="CAUSAL_LM",
    r=32,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.1,
    bias="none"
)

In [10]:
quantization_config = transformers.BitsAndBytesConfig(
    #llm_int8_enable_fp32_cpu_offload=True,
    #load_in_8bit=True,
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    #llm_int8_threshold=0.0,
)

In [11]:
model = AutoModelForCausalLM.from_pretrained(
    "canopylabs/3b-fr-pretrain-research_release",
    use_cache=False,
    quantization_config=quantization_config
)
#tokenizer = AutoTokenizer.from_pretrained("canopylabs/3b-fr-pretrain-research_release")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [12]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(156940, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear4bit(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear4bit(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear4bit(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((3072,)

In [13]:
model.model.layers[-1].post_attention_layernorm.weight

Parameter containing:
tensor([0.5459, 0.5225, 0.5991,  ..., 0.5273, 0.5273, 0.4956], device='cuda:0',
       dtype=torch.float16, requires_grad=True)

In [14]:
model = prepare_model_for_kbit_training(model)
lora_model = get_peft_model(model, config)

In [15]:
lora_model.print_trainable_parameters()

trainable params: 18,350,080 || all params: 3,319,217,152 || trainable%: 0.5528


In [16]:
pad_token = 128263

In [17]:
def data_collator(features):
    input_ids = [f["input_ids"] for f in features]

    if any("attention_mask" not in f for f in features):
        attention_mask = [[1]*len(ids) for ids in input_ids]
    else:
        attention_mask = [f["attention_mask"] for f in features]

    if any("labels" not in f for f in features):
        labels = input_ids
    else:
        labels = [f["labels"] for f in features]

    input_ids = torch.nn.utils.rnn.pad_sequence(
        [torch.tensor(x, dtype=torch.long) for x in input_ids],
        batch_first=True,
        padding_value=pad_token
    )
    attention_mask = torch.nn.utils.rnn.pad_sequence(
        [torch.tensor(m, dtype=torch.long) for m in attention_mask],
        batch_first=True, padding_value=0
    )
    labels = torch.nn.utils.rnn.pad_sequence(
        [torch.tensor(l, dtype=torch.long) for l in labels],
        batch_first=True, padding_value=-100
    )

    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

In [18]:
data_train = load_from_disk("data/encoded/")

In [None]:
#lora_model = lora_model.compile()

In [19]:
batch_size = 8

args = TrainingArguments(
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    #per_device_eval_batch_size=batch_size,
    bf16=True,
    num_train_epochs=3,
    logging_steps=10,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=5,
    #eval_strategy="epoch",
    #load_best_model_at_end=True,
    label_names=["labels"],
    remove_unused_columns=True,
)

In [20]:
optimizer = create_loraplus_optimizer(
    model=lora_model,
    optimizer_cls=bnb.optim.Adam8bit,
    lr=5e-5,
    loraplus_lr_ratio=16,
)
scheduler = None

In [21]:
trainer = Trainer(
    lora_model,
    args,
    optimizers=(optimizer, scheduler),
    train_dataset=data_train,
    #eval_dataset=val_ds,
    #tokenizer=tokenizer,
    data_collator=data_collator,
)

In [22]:
trainer.train()

  return fn(*args, **kwargs)


Step,Training Loss
10,4.8419
20,4.6957
30,4.6193
40,4.5987
50,4.5336
60,4.5236
70,4.5198
80,4.5019
90,4.4948
100,4.4608


TrainOutput(global_step=111, training_loss=4.5686386726998, metrics={'train_runtime': 3093.7677, 'train_samples_per_second': 1.159, 'train_steps_per_second': 0.036, 'total_flos': 4.621749237723341e+16, 'train_loss': 4.5686386726998, 'epoch': 2.9333333333333336})

In [24]:
merged_model = lora_model.merge_and_unload()



In [25]:
merged_model.save_pretrained("ft_merged_model")