# Challenge C: torch.compile without Graph Breaks

This notebook verifies the graph breaks and recompilations for QLoRA with `torch.compile`.

In [None]:
!pip install torch>=2.4 transformers peft bitsandbytes -U --quiet

In [None]:
import torch
import torch._dynamo
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

torch._dynamo.config.suppress_errors = True
# torch._dynamo.config.verbose = True

model_name = "unsloth/llama-3-8b-bnb-4bit"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
)

lora_config = LoraConfig(
    r=16,
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)

# Challenge: Eliminate graph breaks
# We wrap the forward pass
@torch.compile(fullgraph=False, dynamic=True)
def compiled_forward(model, input_ids, labels):
    return model(input_ids=input_ids, labels=labels).loss

# Diagnostic check
input_ids = torch.randint(0, 32000, (1, 128)).cuda()
labels = input_ids.clone()

print("Running diagnostic...")
explanation = torch._dynamo.explain(compiled_forward, model, input_ids, labels)
print(f"Graph breaks: {explanation.graph_break_count}")
for break_reason in explanation.graph_breaks:
    print(f"Reason: {break_reason.reason}")

# Run multiple iterations to check for recompilations
print("Checking recompilations...")
for i in range(5):
    loss = compiled_forward(model, input_ids, labels)
    print(f"Iteration {i} loss: {loss.item()}")