In [7]:
# WITHOUT COMPILATION

import os
import logging
from typing import Optional, Tuple
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
import torch
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import Cache, Unpack, FlashAttentionKwargs, Callable, eager_attention_forward, apply_rotary_pos_emb, ALL_ATTENTION_FUNCTIONS, logger, BaseModelOutputWithPast, Union, DynamicCache
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.modeling_utils import PreTrainedModel
from peft import get_peft_model, LoraConfig, TaskType

def custom_compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):

    kwargs = {
        "input_ids": inputs.data["input_ids"],
        "attention_mask": inputs.data["attention_mask"],
        "labels": inputs.data["labels"],
    }

    outputs = model(**kwargs)
    logits = outputs.logits  # shape: (B, S, V)
    
    # For causal language modeling, shift logits and labels so that
    # prediction at time t is compared with label at time t+1.
    shift_logits = logits[:, :-1, :]      # shape: (B, S-1, V)
    shift_labels = inputs["labels"][:, 1:]  # shape: (B, S-1)
    
    # Flatten the tensors for cross entropy: (B*(S-1), V) and (B*(S-1))
    loss = torch.nn.functional.cross_entropy(
        shift_logits.reshape(-1, shift_logits.size(-1)),
        shift_labels.reshape(-1),
        ignore_index=-100
    )
    return loss

SFTTrainer.compute_loss = custom_compute_loss

max_seq_length = 1024
torch.set_default_dtype(torch.float16)
model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"
dtype = torch.float16

bnb_config = BitsAndBytesConfig(
    load_in_4bit              = True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type       = "nf4",
    bnb_4bit_compute_dtype    = dtype,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map = "auto",
    attn_implementation = "sdpa",
    # quantization_config = bnb_config,  ---> No need to pass as the model already has its own quantization config
)

model = model.to("cuda:0")

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

lora_config = LoraConfig(
    r = 32,
    lora_alpha = 64,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    lora_dropout = 0,
    bias = "none",
    task_type = TaskType.CAUSAL_LM,
)

# Get LoRA and setup model
model = get_peft_model(model, lora_config)

with torch.no_grad():
    for name, param in model.named_parameters():
        if ".lora_A." in name or ".lora_B." in name: param.requires_grad_(True)
        else: param.requires_grad_(False)

model.enable_input_require_grads()

# Get dataset
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files = {"train" : url}, split = "train[:20%]")

import time
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl

class TimingCallback(TrainerCallback):
    def __init__(self, warmup_steps: int = 3):
        self.warmup_steps = warmup_steps
        self.warmup_time = 0.0
        self.main_time = 0.0

    def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        # Record the start time of the step.
        self.step_start = time.time()

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        step_time = time.time() - self.step_start
        # If the current step is within the warmup steps, add time to warmup_time.
        if state.global_step <= self.warmup_steps:
            self.warmup_time += step_time
        else:
            self.main_time += step_time

    def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        total_steps = state.global_step
        main_steps = total_steps - self.warmup_steps
        print(f"\nTotal warmup time for {self.warmup_steps} steps: {self.warmup_time:.4f} seconds")
        print(f"Total main training time for {main_steps} steps: {self.main_time:.4f} seconds")


trainer = SFTTrainer(
    model = model,
    train_dataset = dataset,
    processing_class = tokenizer,
    args = SFTConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 2,
        warmup_steps = 5,
        max_steps = 150,
        logging_steps = 1,
        output_dir = "outputs",
        seed = 3407,
        max_seq_length = max_seq_length,
        fp16 = model.get_input_embeddings().weight.dtype == torch.float16,
        bf16 = model.get_input_embeddings().weight.dtype == torch.bfloat16,
        report_to = "none", # For W&B
        dataset_num_proc = 4,
        label_names = ["input_ids", "labels", "attention_mask"]
    ),
    callbacks = [TimingCallback(warmup_steps=5)]
)

trainer.train()

Step,Training Loss
1,3.6572
2,5.9999
3,4.3757
4,5.2595
5,4.4185
6,3.7118
7,3.6374
8,3.3688
9,4.254
10,3.8025



Total warmup time for 5 steps: 2.5751 seconds
Total main training time for 145 steps: 78.9863 seconds


TrainOutput(global_step=150, training_loss=3.8708507696787517, metrics={'train_runtime': 86.8897, 'train_samples_per_second': 3.453, 'train_steps_per_second': 1.726, 'total_flos': 159975600611328.0, 'train_loss': 3.8708507696787517})