---
---
---
<a name="COMPILE"></a>
## C) Make `torch.compile` work without graph breaks for QLoRA [Difficulty: Easy to Medium] [Max points: 9]

1. Goal: Write a single Python script like task B), except the goal is to `torch.compile` all modules if possible.

2. There must NOT be graph breaks, and excessive re-compilations should not be seen.

3. You should have say max 30 compilations. Over 60 is definitely wrong.

4. The loss must match with the non compiled module.

5. Utilize patching as much as possible.

6. Think about which areas might need disabling for compilation. Think about regional compilation. How do we compile sections efficiently?

7. Log memory / VRAM usage, and monitor speedups as well.

8. Must work for QLoRA.

We provided a script below, and showcased how to detect if graph breaks are seen. We also torch compiled the MLP for Llama:

In [3]:
!pip install flash-attn
!pip install gputil

Collecting flash-attn
  Downloading flash_attn-2.7.4.post1.tar.gz (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m55.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: flash-attn
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone
  Created wheel for flash-attn: filename=flash_attn-2.7.4.post1-cp310-cp310-linux_x86_64.whl size=187797312 sha256=b267f80a08e516292cdd748056a2178a45b8abedf7fca123292eb17c21c8c87c
  Stored in directory: /root/.cache/pip/wheels/59/ce/d5/08ea07bfc16ba218dc65a3a7ef9b6a270530bcbd2cea2ee1ca
Successfully built flash-attn
Installing collected packages: flash-attn
Successfully installed flash-attn-2.7.4.post1
Collecting gputil
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: gputil
  Building wheel for gputil (setup.py) ... [?25

In [4]:
import torch
import os
import logging
import gc
import psutil
import time
import math
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from transformers.trainer_callback import TrainerCallback

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Configure PyTorch and CUDA settings
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \
    "expandable_segments:True,"\
    "roundup_power2_divisions:[32:256,64:128,256:64,>:32]"

# Enable verbose logging for torch.compile
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
os.environ["TORCH_LOGS"] = "dynamic"
os.environ["TORCHINDUCTOR_VERBOSE"] = "1"

# Configure torch._dynamo settings for better compilation
torch._dynamo.config.cache_size_limit = 512
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.dynamic_shapes = True
torch._dynamo.config.automatic_dynamic_shapes = True
torch._dynamo.config.optimize_ddp = False
torch._dynamo.config.replay_record_enabled = True

# Compilation counter for monitoring
compilation_counter = 0

def log_memory_usage():
    process = psutil.Process(os.getpid())
    logger.info(f"CPU Memory: {process.memory_info().rss / 1024 / 1024:.2f} MB")
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            logger.info(f"GPU {i} Memory: {torch.cuda.memory_allocated(i) / 1024 / 1024:.2f} MB / {torch.cuda.memory_reserved(i) / 1024 / 1024:.2f} MB")

# Configure torch.compile options with better defaults for QLoRA
torch_compile_options = {
    "mode": "reduce-overhead",
    "dynamic": True,
    "fullgraph": False,
    "options": {
        "epilogue_fusion": True,
        "max_autotune": True,
        "shape_padding": True,
        "triton.cudagraphs": False,
        "triton.unique_kernel_names": True,
        "trace.enabled": True,
        "trace.graph_diagram": True,
        "trace.action_set_logging": True
    }
}

def count_compilation():
    global compilation_counter
    compilation_counter += 1
    if compilation_counter > 30:
        logger.warning(f"Excessive compilations detected: {compilation_counter}")
    logger.info(f"Current compilation count: {compilation_counter}")

# Patch the forward methods
def patch_llama_modules():
    import transformers.models.llama.modeling_llama as llama_module
    from functools import partial
    
    def create_compiled_function(fn, name):
        def wrapped(*args, **kwargs):
            count_compilation()
            return fn(*args, **kwargs)
        return wrapped
    
    def mlp_forward(self, x):
        # Split operations to reduce graph complexity
        gate_out = self.gate_proj(x)
        up_out = self.up_proj(x)
        act_gate = self.act_fn(gate_out)
        prod = act_gate * up_out
        return self.down_proj(prod)
    
    def attention_forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False):
        from torch.nn.attention import flex_attention
        
        bsz, q_len, _ = hidden_states.shape
        
        # Projections with 4-bit quantization
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
    
        # Reshape for flex_attention (bsz, num_heads, q_len, head_dim)
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    
        if past_key_value is not None:
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
    
        past_key_value = (key_states, value_states) if use_cache else None
        
        # Use flex_attention with dynamic sequence length support
        attn_output = flex_attention(
            query=query_states,
            key=key_states,
            value=value_states,
            attn_mask=attention_mask,
            is_causal=True,
        )
        
        # Restore original shape
        attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        
        return attn_output, None, past_key_value
    
    def rmsnorm_forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return (self.weight * hidden_states).to(input_dtype)
    
    # Create compiled versions with proper error handling
    def safe_compile(fn, name):
        try:
            compiled_fn = torch.compile(
                fn,
                **torch_compile_options
            )
            return create_compiled_function(compiled_fn, name)
        except Exception as e:
            logger.warning(f"Could not compile {name}: {str(e)}")
            return fn

    # Apply patches with safe compilation
    logger.info("Applying module patches...")
    llama_module.LlamaMLP.forward = safe_compile(mlp_forward, "mlp_forward")
    llama_module.LlamaAttention.forward = safe_compile(attention_forward, "attention_forward")
    llama_module.LlamaRMSNorm.forward = safe_compile(rmsnorm_forward, "rmsnorm_forward")
    logger.info("Module patches applied successfully")

# Configure model parameters
max_seq_length = 1024
torch.set_default_dtype(torch.float16)
model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"
dtype = torch.float16

# Log initial memory state
logger.info("Initial memory state:")
log_memory_usage()

# Configure BitsAndBytes
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=dtype,
)

# Load model with optimized settings
logger.info("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    attn_implementation="sdpa",
    quantization_config=bnb_config,
    torch_dtype=dtype,
)

# Load and configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

# Configure LoRA
lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

# Apply patches before LoRA
logger.info("Applying patches...")
patch_llama_modules()

# Setup LoRA
logger.info("Setting up LoRA...")
model = get_peft_model(model, lora_config)
with torch.no_grad():
    for name, param in model.named_parameters():
        if ".lora_A." in name or ".lora_B." in name:
            param.requires_grad_(True)
        else:
            param.requires_grad_(False)

model.enable_input_require_grads()

# Log memory after model setup
logger.info("Memory state after model setup:")
log_memory_usage()

# Load dataset
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files={"train": url}, split="train[:10%]")

class MemoryTrackingCallback(TrainerCallback):
    def __init__(self):
        self.start_time = None
    
    def on_init_end(self, args, state, control, **kwargs):
        logger.info("Training initialization completed")
        log_memory_usage()
    
    def on_train_begin(self, args, state, control, **kwargs):
        logger.info("Training started")
        self.start_time = time.time()
        log_memory_usage()
    
    def on_step_begin(self, args, state, control, **kwargs):
        log_memory_usage()
        logger.info(f"Compilation count: {compilation_counter}")
    
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % args.logging_steps == 0:
            logger.info(f"Step {state.global_step} completed")
            log_memory_usage()
    
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            # Add memory info to logs
            if torch.cuda.is_available():
                logs["gpu_memory"] = torch.cuda.memory_allocated() / 1024 / 1024
            logs["compilation_count"] = compilation_counter
            logger.info(f"Training stats: {logs}")
    
    def on_train_end(self, args, state, control, **kwargs):
        logger.info("Training completed")
        if self.start_time:
            total_time = time.time() - self.start_time
            logger.info(f"Total training time: {total_time:.2f} seconds")
        log_memory_usage()

# Configure training arguments
training_args = SFTConfig(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    warmup_steps=1,
    max_steps=10,
    logging_steps=1,
    output_dir="outputs",
    max_seq_length=max_seq_length,
    fp16=model.get_input_embeddings().weight.dtype == torch.float16,
    bf16=model.get_input_embeddings().weight.dtype == torch.bfloat16,
    optim="adamw_torch_fused",
    learning_rate=2e-4,
    seed=3407,
    report_to="none",
    dataloader_pin_memory=True,
    dataloader_num_workers=1,
)

# Initialize trainer with memory tracking
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    processing_class=tokenizer,
    args=training_args,
    callbacks=[MemoryTrackingCallback()],
)

# Start training
logger.info("Starting training...")
trainer.train()

config.json:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


model.safetensors:   0%|          | 0.00/1.03G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.7k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

unified_chip2.jsonl:   0%|          | 0.00/95.6M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Converting train dataset to ChatML:   0%|          | 0/21029 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/21029 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/21029 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/21029 [00:00<?, ? examples/s]

Step,Training Loss
1,3.2979
2,5.2637
3,6.0405
4,6.154
5,4.1729
6,5.219
7,3.4887
8,2.4648
9,3.921
10,4.1579


TrainOutput(global_step=10, training_loss=4.4180361270904545, metrics={'train_runtime': 6.161, 'train_samples_per_second': 3.246, 'train_steps_per_second': 1.623, 'total_flos': 10592155496448.0, 'train_loss': 4.4180361270904545})

Log all your steps for debugging in a Colab (maybe this one). Edward's blog http://blog.ezyang.com/, Horace's blogs https://www.thonking.ai/, Slaying OOMs by Jane & Mark: ttps://www.youtube.com/watch?v=UvRl4ansfCg could be useful.

## Marking Criteria for C) Max points = 9
```python
if attemped_C:
    C_score = 0
    if uses_flex_attention:
        if dynamic_sequence_length_works: C_score += 3
        else: C_score += 1
    if no_torch_compile_BnB: C_score -= 2
    elif use_part_A: C_score += 1
    elif torch_compile_BnB: C_score += 1

    if attention_compiled:
        if excessive_recompilation: C_score -= 3
        else: C_score += 2
    if mlp_compiled:
        if excessive_recompilation: C_score -= 3
        C_score += 1

    if not loss_compiled: C_score -= 1
    if not layernorms_compiled: C_score -= 3

    if max_autotune_triton_matmul:
        if excessive_recompilation: C_score -= 2
        else: C_score += 2
    
    final_score += C_score
else:
    final_score -= 1
```