In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
!pip install torch transformers peft trl bitsandbytes

Collecting trl
  Downloading trl-0.15.1-py3-none-any.whl.metadata (11 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.2-py3-none-manylinux_2_24_x86_64.whl.metadata (5.8 kB)
Downloading trl-0.15.1-py3-none-any.whl (318 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m318.9/318.9 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading bitsandbytes-0.45.2-py3-none-manylinux_2_24_x86_64.whl (69.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 MB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trl, bitsandbytes
Successfully installed bitsandbytes-0.45.2 trl-0.15.1


In [3]:
!pip install --upgrade torch torchvision

Collecting torch
  Downloading torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.21.0-cp310-cp310-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft

In [4]:
# Environment Setup
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Clean module cache
import sys
def clear_cache():
    packages = ["trl", "transformers", "peft", "bitsandbytes"]
    for pkg in packages:
        for name in list(sys.modules):
            if name.startswith(pkg):
                del sys.modules[name]
clear_cache()

from datasets import load_dataset
import torch
from accelerate import notebook_launcher

# Import FSDP2 primitives.
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

###############################################################################
# HELPER FUNCTIONS
###############################################################################

# post_order_apply: Recursively traverse module tree in post‑order and apply fn(module, **kwargs)
def post_order_apply(fn, module, policy, ignored_modules=(), **kwargs):
    for child in module.children():
        if child not in ignored_modules:
            post_order_apply(fn, child, policy, ignored_modules, **kwargs)
    if policy(module):
        fn(module, **kwargs)

# convert_frozen_int_params_to_buffers: Convert frozen, non-floating-point parameters to buffers.
def convert_frozen_int_params_to_buffers(module):
    for name, param in list(module.named_parameters(recurse=False)):
        if not param.requires_grad and (not param.dtype.is_floating_point):
            if name in module._parameters:
                del module._parameters[name]
            if hasattr(module, name):
                delattr(module, name)
            module.register_buffer(name, param)
    for child in module.children():
        convert_frozen_int_params_to_buffers(child)

# mark_self_attn_ignore: Mark any submodule whose name contains "self_attn" with fsdp_ignore=True.
def mark_self_attn_ignore(module):
    for name, child in module.named_children():
        if "self_attn" in name:
            child.fsdp_ignore = True
        mark_self_attn_ignore(child)

# compile_lora_modules: Recursively traverse the model and compile any submodule that contains LoRA adapters.
def compile_lora_modules(module):
    for name, child in module.named_children():
        compile_lora_modules(child)
        # Check if the module appears to have LoRA adapter parameters
        if hasattr(child, "lora_A") or hasattr(child, "lora_B"):
            try:
                compiled_child = torch.compile(child)
                setattr(module, name, compiled_child)
                print(f"Compiled LoRA module: {name}")
            except Exception as e:
                print(f"Compilation failed for module {name}: {e}")

###############################################################################
# MAIN FUNCTION
###############################################################################

def main():
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        TrainingArguments,
        BitsAndBytesConfig
    )
    from peft import LoraConfig, get_peft_model
    from trl import SFTTrainer

    # Model config
    model_name = "unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit"
    
    # 4-bit Quantization config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True
    )

    # For distributed training, load on the local GPU.
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    device_map = {"": local_rank}  # Each process loads on its own GPU.

    # 1. Load model on the local GPU.
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        torch_dtype=torch.float16,
        attn_implementation="sdpa",
        device_map=device_map
    )

    # 2. Freeze the base model so that its 4-bit quantized weights are not updated.
    model.requires_grad_(False)

    # 3. Apply LoRA adapters; these add new trainable (floating point) parameters.
    lora_config = LoraConfig(
        r=64,
        lora_alpha=128,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )
    model = get_peft_model(model, lora_config)
    # Now, only the LoRA parameters are trainable.
    for p in model.parameters():
        if not p.dtype.is_floating_point:
            p.requires_grad = False

    # 3.5. Convert frozen, non-floating-point (quantized) parameters to buffers.
    convert_frozen_int_params_to_buffers(model)

    # 3.6. Move the model (and its buffers) to GPU.
    device = torch.device(f"cuda:{local_rank}")
    model = model.to(device)

    # 3.7. Mark self-attention submodules (e.g. those with "self_attn" in their name) to be ignored by FSDP.
    mark_self_attn_ignore(model)

    # 4. Define FSDP2 policies.
    mp_policy = MixedPrecisionPolicy(
        param_dtype=torch.float16,
        reduce_dtype=torch.float16,
        output_dtype=torch.float16
    )
    offload_policy = CPUOffloadPolicy(pin_memory=True)
    fsdp_kwargs = {
        "mp_policy": mp_policy,
        "offload_policy": offload_policy,
        "reshard_after_forward": True,
        "sync_module_states": False,  # Disable syncing to avoid casting quantized buffers.
    }

    # 5. Define an auto-wrap policy:
    # We'll wrap a module if it is a LlamaDecoderLayer, is not marked with fsdp_ignore,
    # and has at least one trainable floating-point parameter.
    def should_fully_shard(module):
        if isinstance(module, LlamaDecoderLayer) and not getattr(module, "fsdp_ignore", False):
            return any(p.requires_grad and p.dtype.is_floating_point for p in module.parameters(recurse=False))
        return False

    # 6. Manually apply FSDP wrapping via post‑order traversal.
    post_order_apply(fully_shard, model, should_fully_shard, **fsdp_kwargs)
    # (Only submodules meeting the auto-wrap policy get wrapped; the frozen base weights remain untouched.)

    # 6.5. Compile only the trainable LoRA adapter modules.
    compile_lora_modules(model)

    # 7. Load the dataset.
    dataset = load_dataset(
        "json",
        data_files={"train": "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"},
        split="train[:10%]"
    ).map(lambda x: {"text": x["text"]})

    # 8. Training arguments.
    # IMPORTANT: Remove fsdp and fsdp_config from TrainingArguments to avoid re-wrapping by the accelerator.
    training_args = TrainingArguments(
        output_dir="./output",
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        learning_rate=2e-4,
        max_steps=60,
        logging_steps=10,
        optim="paged_adamw_8bit",
        report_to="none",
    )

    # 9. Setup Trainer.
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        args=training_args,
    )

    # 10. Start training.
    trainer.train()

if __name__ == "__main__":
    notebook_launcher(main, num_processes=2)


Launching training on 2 GPUs.


Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Compiled LoRA module: q_proj
Compiled LoRA module: k_proj
Compiled LoRA module: v_proj
Compiled LoRA module: o_proj
Compiled LoRA module: q_proj
Compiled LoRA module: k_proj
Compiled LoRA module: v_proj
Compiled LoRA module: o_proj
Compiled LoRA module: q_proj
Compiled LoRA module: k_projCompiled LoRA module: q_proj

Compiled LoRA module: v_projCompiled LoRA module: k_proj

Compiled LoRA module: v_proj
Compiled LoRA module: o_projCompiled LoRA module: o_proj

Compiled LoRA module: q_proj
Compiled LoRA module: q_proj
Compiled LoRA module: k_projCompiled LoRA module: k_proj

Compiled LoRA module: v_projCompiled LoRA module: v_proj

Compiled LoRA module: o_projCompiled LoRA module: o_proj

Compiled LoRA module: q_projCompiled LoRA module: q_proj

Compiled LoRA module: k_proj
Compiled LoRA module: k_proj
Compiled LoRA module: v_proj
Compiled LoRA module: v_proj
Compiled LoRA module: o_projCompiled LoRA module: o_proj

Compiled LoRA module: q_projCompiled LoRA module: q_proj

Compiled LoRA 

unified_chip2.jsonl:   0%|          | 0.00/95.6M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/21029 [00:00<?, ? examples/s]

Map:   0%|          | 0/21029 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/55.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

Converting train dataset to ChatML:   0%|          | 0/21029 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/21029 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/21029 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/21029 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/21029 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/21029 [00:00<?, ? examples/s]

  torch._dynamo.utils.warn_once(msg)
  torch._dynamo.utils.warn_once(msg)
[rank1]:W0223 19:29:50.222000 72 torch/_inductor/utils.py:1137] [15/0] Not enough SMs to use max_autotune_gemm mode
[rank0]:W0223 19:29:50.232000 71 torch/_inductor/utils.py:1137] [15/0] Not enough SMs to use max_autotune_gemm mode
[rank0]:W0223 19:30:08.944000 71 torch/_dynamo/convert_frame.py:906] [16/8] torch._dynamo hit config.cache_size_limit (8)
[rank0]:W0223 19:30:08.944000 71 torch/_dynamo/convert_frame.py:906] [16/8]    function: 'torch_dynamo_resume_in_forward_at_496' (/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/bnb.py:496)
[rank0]:W0223 19:30:08.944000 71 torch/_dynamo/convert_frame.py:906] [16/8]    last reason: 16/0: tensor 'L['x']' requires_grad mismatch. expected requires_grad=0
[rank0]:W0223 19:30:08.944000 71 torch/_dynamo/convert_frame.py:906] [16/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[rank0]:W0223 19:30:08.944000 71 torch/_dynamo/convert_frame.py:906] [1

Step,Training Loss
10,3.3962
20,3.1992
30,3.0218
40,2.9874
50,2.883
60,3.0733


Step,Training Loss
10,3.3962
20,3.1992
30,3.0218
40,2.9874
50,2.883
60,3.0733
