# Edge of Hardware Limits: Scaling Inputs with Flash Attention 2 🚀

## Learning Objectives 🎯
- Understand the hardware requirements for scaling large models and using advanced techniques like Flash Attention.
- Explore how Flash Attention 2 can optimize large input processing, provided the correct GPU architecture (Ampere or newer) is available.
- Configure and launch a specialized training session while staying within hardware limits.

## Library Installation 🛠️
Install the necessary libraries, including Flash Attention 2, which will enable more efficient handling of large inputs. Ensure you are using an Ampere GPU for this step, as Flash Attention only works with these newer architectures.

In [2]:
# !pip install --no-build-isolation axolotl[flash-attn,deepspeed]

## Importing Libraries

In [3]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


## Configuration Setup for Large-Scale Training 📝
Set up the YAML configuration for training large models with extended input lengths. In this case, we are using the "unsloth/gemma-2-27b-it" model, which requires more advanced GPU capabilities for processing.

In [7]:
import yaml

train_config = {
    # "base_model": "unsloth/gemma-2-27b-it-bnb-4bit", # the 27B doesn't fit on the free tier, use it if you have access to a 24GB GPU
    "base_model": "unsloth/gemma-2-9b-it-bnb-4bit",

     # dataset params
    "datasets": [{"path": "Yukang/LongAlpaca-12k", "type": "alpaca"}],
    "output_dir": "./models/LongAlpaca",

    # model params
    "sequence_length": 1024,
    "bf16": "auto",
    "tf32": False,

    # training params
    "micro_batch_size": 1,
    "num_epochs": 1,
    "optimizer": "adamw_bnb_8bit",
    "learning_rate": 0.0002,

    "logging_steps": 1,

    # LoRA / qLoRA
    "adapter": "qlora",
    "lora_r": 32,
    "lora_alpha": 16,
    "lora_dropout": 0.05,
    "lora_target_linear": True,

    # Gradient Accumulation
    "gradient_accumulation_steps": 1,

    # Gradient Checkpointing
    "gradient_checkpointing": True,

    # Low Precision
    "load_in_8bit": False,
    "load_in_4bit": True,

    # Flash Attention
    "flash_attention": False,
}





# Write the YAML file
with open("specialised_train_flash.yml", 'w') as file:
    yaml.dump(train_config, file)


## Training Launch 🚀
Start the training process using the `accelerate launch` command. This will initiate the training with large-scale inputs and specialized configurations like Flash Attention 2, taking full advantage of an Ampere GPU.

In [5]:
# !accelerate launch -m axolotl.cli.train specialised_train.yml
# # Optional: Merge the trained adapter
# !accelerate launch -m axolotl.cli.merge_lora specialised_train.yml