# Fine-tuning Stable Diffusion with LoRA on TPUs using Unsloth

This notebook demonstrates how to fine-tune a pre-trained Stable Diffusion model on a specific task using LoRA (Low-Rank Adaptation) with Unsloth's `FastDiffusionModel` and `DiffusionTrainer` on a TPU runtime.

**Key Steps:**
1. Setup: Install libraries and configure TPU environment.
2. Load Model: Use `FastDiffusionModel` to load a Stable Diffusion pipeline.
3. Prepare Dataset: Create or load an image-caption dataset and preprocess it.
4. Configure LoRA: Apply LoRA adapters to the UNet component of the diffusion model.
5. Define Training Function: Create a function for the distributed training loop.
6. Launch Distributed Training: Use `DiffusionTrainer.launch_distributed` to train on multiple TPU cores.
7. Inference: Load the fine-tuned LoRA adapters and generate images.

## 1. Setup

**Ensure TPU Runtime:**
If you are using Google Colab, make sure to select a TPU runtime:
1. Go to `Runtime` -> `Change runtime type`.
2. Select `TPU` from the `Hardware accelerator` dropdown.

**Install Libraries:**

In [None]:
# Install necessary libraries
!pip install "unsloth[tpu] @ git+https://github.com/unslothai/unsloth.git" # Install unsloth with TPU extras
!pip install diffusers transformers accelerate datasets Pillow

# It's crucial that torch_xla is installed. Unsloth's TPU extras should handle this.
# If you encounter issues, you might need to manually install a compatible torch_xla version:
# import os
# if os.environ.get('COLAB_TPU_ADDR'):
#   !pip install torch_xla cloud-tpu-client https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.1-cp310-cp310-linux_x86_64.whl
# else:
#   print('Not running on a Colab TPU')

# Verify torch_xla installation (optional)
try:
    import torch_xla.core.xla_model as xm
    print(f"torch_xla version: {xm.__version__}")
    print(f"Default XLA device: {xm.xla_device()}")
    print(f"XLA world size: {xm.xrt_world_size()}")
except ImportError:
    print("torch_xla not found. Please ensure it's installed correctly for TPU usage.")

**Import Libraries:**

In [None]:
import os
import torch
from datasets import Dataset
from PIL import Image
from io import BytesIO
import requests
import random

from unsloth import FastDiffusionModel, DiffusionTrainer, DiffusionTrainingArguments

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

from torchvision import transforms
from transformers import CLIPTokenizer

## 2. Initialize TPU Distributed Training

TPUs excel at distributed training, where multiple cores work together. Unsloth's `DiffusionTrainer` provides a `launch_distributed` method that simplifies this process. It uses `torch_xla.distributed.xla_multiprocessing.spawn` (xmp.spawn) to run the training function on all available TPU cores.

The `train_fn` we define later will encapsulate the training logic for a single process. `launch_distributed` will handle spawning this function across all TPU cores.

In [None]:
def check_tpu_availability():
    if 'COLAB_TPU_ADDR' in os.environ:
        print(f"TPU available. Address: {os.environ['COLAB_TPU_ADDR']}")
        print(f"Number of XLA devices: {xm.xrt_world_size()}")
        return True
    else:
        print("TPU not detected. This notebook is designed for TPU runtimes.")
        print("If on Colab, ensure Runtime > Change runtime type > TPU is selected.")
        # For local testing without a real TPU, one might mock xm, but launch_distributed won't work.
        return False

IS_TPU_AVAILABLE = check_tpu_availability()

## 3. Load Model and Tokenizer

We'll use `FastDiffusionModel.from_pretrained` to load a small, pre-trained Stable Diffusion model. This method returns the UNet (as the main `model` object), the tokenizer, and the full diffusers pipeline. The UNet will have references to other components like VAE, text encoder, and scheduler.

In [None]:
# Using a tiny model for quick demonstration
model_name = "hf-internal-testing/tiny-stable-diffusion-pipe"
# For a more standard model, you might use:
# model_name = "runwayml/stable-diffusion-v1-5" 
# model_name = "stabilityai/stable-diffusion-2-1-base"

unet, tokenizer, pipeline = FastDiffusionModel.from_pretrained(
    model_name_or_path=model_name,
    torch_dtype=torch.bfloat16, # bfloat16 is recommended for TPUs
)

# Access components (they are also attributes of the unet object itself)
vae = unet.vae
text_encoder = unet.text_encoder
scheduler = unet.scheduler

print(f"UNet type: {type(unet)}")
print(f"Tokenizer type: {type(tokenizer)}")
print(f"Pipeline type: {type(pipeline)}")
print(f"VAE type: {type(vae)}")
print(f"Text Encoder type: {type(text_encoder)}")
print(f"Scheduler type: {type(scheduler)}")

## 4. Prepare Dataset

For fine-tuning, we need a dataset of images and corresponding text captions. The `DiffusionTrainer` expects inputs in a dictionary format, typically including `pixel_values` (for images) and `input_ids` (for tokenized captions).

Here, we'll create a small dummy dataset for demonstration purposes. In a real scenario, you would load your own dataset (e.g., from Hugging Face Hub or local files).

In [None]:
# Dummy dataset generation
num_samples = 64 # Should be divisible by number of TPU cores * batch_size for simplicity
image_size = 64  # Using very small images for this tiny model example (tiny-sd outputs 128x128 by default)
                   # For runwayml/stable-diffusion-v1-5 use 512
                   # For stabilityai/stable-diffusion-2-1-base use 768

dummy_captions = [
    "A photo of a red square",
    "A drawing of a blue circle",
    "An image of a green triangle",
    "A painting of a yellow star"
]

def generate_dummy_image_data(width, height, color):
    img = Image.new('RGB', (width, height), color=color)
    return img

dataset_dict = {"image": [], "text": []}
colors = ["red", "blue", "green", "yellow"]
shapes = ["square", "circle", "triangle", "star"]

for i in range(num_samples):
    color = random.choice(colors)
    shape = random.choice(shapes)
    # img_data = generate_dummy_image_data(image_size, image_size, color)
    # For the tiny model, it's better to use images it might have seen (like cats, dogs)
    # Since we can't download easily, we'll stick to simple color images for pure technical demo.
    # If you use a real model, use real images.
    img = Image.new('RGB', (image_size, image_size), color=random.choice(colors))
    dataset_dict["image"].append(img)
    dataset_dict["text"].append(f"A photo of a {color} {shape}")

hf_dataset = Dataset.from_dict(dataset_dict)

# Preprocessing functions
image_transforms = transforms.Compose([
    transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]), # Normalize to [-1, 1]
])

def preprocess_dataset(examples):
    images = [image_transforms(image.convert("RGB")) for image in examples["image"]]
    captions = examples["text"]
    
    # Tokenize captions
    # Max length should be based on tokenizer's model_max_length
    # For tiny-sd, it might be small. For SD 1.5/2.1, it's usually 77.
    max_len = tokenizer.model_max_length if hasattr(tokenizer, 'model_max_length') else 77
    inputs = tokenizer(
        captions, max_length=max_len, padding="max_length", 
        truncation=True, return_tensors="pt"
    )
    
    return {"pixel_values": images, "input_ids": inputs.input_ids}


# Apply preprocessing
# This is done on the main process before distributing data
processed_dataset = hf_dataset.map(
    function=preprocess_dataset, 
    batched=True, 
    remove_columns=["image", "text"]
)

# Define the data collator (stacks tensors)
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    input_ids = torch.stack([example["input_ids"] for example in examples])
    return {"pixel_values": pixel_values, "input_ids": input_ids}

print(f"Processed dataset features: {processed_dataset.features}")
print(f"Example pixel_values shape: {processed_dataset[0]['pixel_values'].shape}")
print(f"Example input_ids shape: {processed_dataset[0]['input_ids'].shape}")

## 5. Configure LoRA and Training

Now, we'll apply LoRA to the UNet. `FastDiffusionModel.get_peft_model` (which is available on the UNet object itself after loading with `FastDiffusionModel`) helps with this. Then, we define `DiffusionTrainingArguments` and instantiate the `DiffusionTrainer`.

In [None]:
# Apply LoRA to the UNet
# The 'unet' object obtained from FastDiffusionModel.from_pretrained is the UNet itself.
unet_lora = FastDiffusionModel.get_peft_model(
    unet, # Pass the UNet model
    r=16, # LoRA rank
    lora_alpha=32, # LoRA alpha
    target_modules=None, # Let Unsloth determine default UNet targets or specify manually
    lora_dropout=0.05,
    bias="none",
    use_gradient_checkpointing=True, # Recommended for large models
    random_state=42
)

print(f"UNet with LoRA type: {type(unet_lora)}")
unet_lora.print_trainable_parameters() # PEFT model utility

# Define Training Arguments
output_dir = "./tpu_diffusion_finetuned_lora"
training_args = DiffusionTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=2, # Small number of epochs for demo
    per_device_train_batch_size=2, # Adjust based on TPU memory and image size
    gradient_accumulation_steps=1,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_steps=10,
    logging_dir=f"{output_dir}/logs",
    logging_steps=10,
    save_steps=50, # How often to save checkpoints (if desired, not critical for LoRA only)
    save_total_limit=2,
    dataloader_num_workers=2, # If using torch Dataloader
    tpu_num_cores=xm.xrt_world_size() if IS_TPU_AVAILABLE else None, # Specify number of TPU cores
    # bf16=True, # Handled by dtype in from_pretrained for model, trainer will use what model uses
    report_to="none", # Disable wandb/tensorboard for this demo
    remove_unused_columns=False, # Ensure all columns needed by compute_loss are kept
)

# Note: DiffusionTrainer is instantiated inside the train_fn for distributed training

## 6. Fine-tuning Function

This function will be executed on each TPU core. It sets up the `DiffusionTrainer` and starts the training process.

In [None]:
def training_function(index, unet_lora_model, tokenizer_obj, vae_obj, scheduler_obj, 
                        train_ds, collate_fn_obj, training_args_obj):
    """The main training function to be executed on each TPU core."""
    
    # Set the XLA device for this process
    device = xm.xla_device()
    print(f"Process {index} using device: {device}")

    # Move models to the current TPU core's device
    # Although _wrap_model in trainer does this, it's good practice if using components separately
    unet_lora_model.to(device)
    vae_obj.to(device)
    text_encoder_obj = unet_lora_model.text_encoder # text_encoder is part of unet here
    if text_encoder_obj: text_encoder_obj.to(device)

    # Instantiate the trainer
    trainer = DiffusionTrainer(
        model=unet_lora_model, # This is the UNet with LoRA adapters
        args=training_args_obj,
        train_dataset=train_ds,
        tokenizer=tokenizer_obj, # CLIP Tokenizer
        data_collator=collate_fn_obj,
        # Pass other pipeline components needed for compute_loss
        text_encoder=text_encoder_obj,
        vae=vae_obj,
        scheduler=scheduler_obj,
    )
    
    print(f"Process {index}: Starting training...")
    trainer.train()
    
    # Wait for all processes to finish training before saving
    xm.rendezvous(f"process_train_done_{index}") # Barrier
    
    # Save the LoRA adapters (only on the master process)
    if xm.is_master_ordinal():
        print(f"Process {index} (master): Saving LoRA model adapters...")
        # The model passed to trainer is unet_lora. save_pretrained will save LoRA adapters.
        # If full model was trained, trainer.save_model() would save the whole UNet.
        # For PEFT models, model.save_pretrained() saves adapters correctly.
        save_lora_dir = os.path.join(training_args_obj.output_dir, "lora_adapters")
        unet_lora_model.save_pretrained(save_lora_dir)
        print(f"Process {index} (master): LoRA adapters saved to {save_lora_dir}")
    
    xm.rendezvous(f"process_save_done_{index}") # Barrier
    print(f"Process {index}: Training and saving finished.")

## 7. Launch Distributed Training

We use `DiffusionTrainer.launch_distributed` (which is a static method calling `xmp.spawn`) to start the training on all TPU cores.

In [None]:
if IS_TPU_AVAILABLE:
    # Prepare arguments for the training function
    # Note: Models are passed directly. xmp.spawn handles pickling/unpickling if possible,
    # but for complex objects like models, it's often better to load them inside the spawned function
    # or ensure they are correctly handled by XLA's multiprocessing context.
    # For Unsloth, passing the PEFT model should be fine as it's designed with this in mind.
    
    # Make sure components not part of unet_lora (like original VAE, scheduler) are passed
    # The text_encoder is part of the unet_lora object as unet_lora.text_encoder
    args_for_spawn = (unet_lora, tokenizer, vae, scheduler, 
                        processed_dataset, collate_fn, training_args)
    
    print("Launching distributed training on TPUs...")
    # DiffusionTrainer.launch_distributed calls xmp.spawn
    DiffusionTrainer.launch_distributed(training_function, args=args_for_spawn)
    print("Distributed training finished.")
else:
    print("Skipping distributed training as TPU is not available or configured.")

## 8. Inference with Fine-tuned Model

After training, the LoRA adapters are saved. Now, we'll load the original base model and apply these fine-tuned adapters to perform inference.

In [None]:
from peft import PeftModel
from diffusers import StableDiffusionPipeline # For loading base pipeline for inference
import matplotlib.pyplot as plt

def display_images(images, prompts, cols=2):
    rows = (len(images) + cols - 1) // cols
    plt.figure(figsize=(15, 5 * rows))
    for i, (image, prompt) in enumerate(zip(images, prompts)):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(image)
        plt.title(f"Prompt: {prompt}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

if IS_TPU_AVAILABLE: # Only run inference if training was attempted
    # Path where LoRA adapters were saved
    lora_adapter_path = os.path.join(output_dir, "lora_adapters")

    if not os.path.exists(lora_adapter_path):
        print(f"LoRA adapters not found at {lora_adapter_path}. Skipping inference.")
    else:
        print("Loading base model for inference...")
        # Load the original (non-LoRA) UNet and pipeline components again
        # Ensure this is done on CPU or a single GPU/TPU core for inference if memory is a concern
        # For TPUs, xm.xla_device() would be the target for inference too.
        device_for_inference = xm.xla_device() if IS_TPU_AVAILABLE else torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Load the base pipeline using diffusers' StableDiffusionPipeline for convenience
        base_pipeline = StableDiffusionPipeline.from_pretrained(
            model_name, 
            torch_dtype=torch.bfloat16, # Use bfloat16 for TPUs
            # Add other components if they were modified or need specific loading
        )
        base_pipeline.to(device_for_inference)
        unet_for_inference = base_pipeline.unet

        print(f"Loading LoRA adapters from {lora_adapter_path} into UNet...")
        # Load LoRA weights into the base UNet
        unet_with_lora = PeftModel.from_pretrained(unet_for_inference, lora_adapter_path)
        unet_with_lora = unet_with_lora.to(device_for_inference) # Ensure it's on device
        unet_with_lora.eval() # Set to evaluation mode

        # Replace the UNet in the pipeline with the LoRA-fused UNet
        base_pipeline.unet = unet_with_lora

        print("Generating images with fine-tuned LoRA model...")
        prompts = [
            "A photo of a red square", 
            "A drawing of a blue circle",
            "A painting of a green star", # Test generalization slightly
            "A photo of a yellow triangle"
        ]
        generated_images = []

        with torch.no_grad():
            for prompt in prompts:
                # For tiny-sd, output size is small. Adjust height/width if using a different base model.
                image = base_pipeline(prompt, num_inference_steps=20, height=image_size, width=image_size).images[0]
                generated_images.append(image)
        
        # Display images (requires matplotlib)
        print("Displaying generated images...")
        display_images(generated_images, prompts)
else:
    print("Skipping inference as training was not run on TPU.")

This concludes the example of fine-tuning a Stable Diffusion model with LoRA on TPUs using Unsloth. Remember that for real tasks, you'll need a larger, more diverse dataset and potentially more epochs and hyperparameter tuning.