In [1]:
import os
import sys
import subprocess
from google.colab import drive

# 1. Mount Google Drive
print("🔌 Mounting Google Drive...")
drive.mount('/content/drive')

# 2. Install Training Libraries
# We need 'peft' for LoRA and 'datasets' for MagicBrush
print("⏳ Installing libraries (this takes ~45s)...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
                       "git+https://github.com/huggingface/diffusers.git",
                       "accelerate", "transformers", "datasets", "peft", "bitsandbytes"])
print("✅ Environment Ready.")

🔌 Mounting Google Drive...
Mounted at /content/drive
⏳ Installing libraries (this takes ~45s)...
✅ Environment Ready.


In [2]:
import torch
import os
import shutil
import gc
import datasets
from accelerate import Accelerator
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
from transformers import CLIPTokenizer, CLIPTextModel
from datasets import load_dataset
from torchvision import transforms
from tqdm.auto import tqdm
from safetensors.torch import load_file

# ================= CONFIGURATION =================
# We save directly to Drive so we never lose progress
CHECKPOINT_DIR = "/content/drive/MyDrive/Projects/Image-Editing-by-Natural-Language-Constraints/checkpoints"
FINAL_OUTPUT_DIR = "/content/drive/MyDrive/Projects/Image-Editing-by-Natural-Language-Constraints/lora_instruction_tuned"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

MAX_STEPS = 2000
SAVE_INTERVAL = 10
CLEANUP_INTERVAL = 30
# =================================================

def cleanup_disk():
    """Frees up disk space by clearing caches."""
    print("🧹 Cleaning up disk space...")
    try:
        hf_cache = "/root/.cache/huggingface/datasets"
        if os.path.exists(hf_cache):
            shutil.rmtree(hf_cache)
            os.makedirs(hf_cache, exist_ok=True)
        gc.collect()
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"   ⚠️ Cleanup warning: {e}")

print("🚀 Starting Robust LoRA Training (500 Steps)...")

# 1. Setup Accelerator
accelerator = Accelerator(gradient_accumulation_steps=4, mixed_precision="fp16")
device = accelerator.device
MODEL_ID = "runwayml/stable-diffusion-v1-5"

# 2. Load MagicBrush
print("📥 Loading MagicBrush...")
dataset = load_dataset("osunlp/MagicBrush", split="train[:800]")

# 3. Load Models
print("📦 Loading Stable Diffusion...")
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ID, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(MODEL_ID, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(MODEL_ID, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(MODEL_ID, subfolder="unet")
scheduler = DDPMScheduler.from_pretrained(MODEL_ID, subfolder="scheduler")

vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)

# 4. Inject LoRA
print("💉 Injecting LoRA Adapters...")
lora_config = LoraConfig(
    r=16, # Higher rank for better capacity
    lora_alpha=32,
    target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    init_lora_weights="gaussian"
)
unet = get_peft_model(unet, lora_config)

# 5. Resume Logic
start_step = 0
checkpoints = sorted([int(x.split("_")[1]) for x in os.listdir(CHECKPOINT_DIR) if x.startswith("step_")])
if checkpoints:
    latest_step = checkpoints[-1]
    ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{latest_step}")
    print(f"🔄 Resuming from checkpoint: {ckpt_path}")
    state_dict = load_file(os.path.join(ckpt_path, "adapter_model.safetensors"))
    set_peft_model_state_dict(unet, state_dict)
    start_step = latest_step
    print(f"   ✅ Resumed at Step {start_step}")
else:
    print("✨ Starting fresh training.")

unet.print_trainable_parameters()

# 6. Optimizer
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)
unet, optimizer, text_encoder, vae = accelerator.prepare(
    unet, optimizer, text_encoder, vae
)

# 7. Training Loop
unet.train()
transform = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5])
])

print(f"🔥 Training from Step {start_step} to {MAX_STEPS}...")
progress_bar = tqdm(range(start_step, MAX_STEPS))
data_iter = iter(dataset)

for global_step in range(start_step, MAX_STEPS):
    try:
        batch = next(data_iter)
    except StopIteration:
        data_iter = iter(dataset)
        batch = next(data_iter)

    with accelerator.accumulate(unet):
        target_pixels = transform(batch['target_img'].convert("RGB")).unsqueeze(0).to(device)
        latents = vae.encode(target_pixels).latent_dist.sample() * 0.18215

        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (1,), device=device).long()
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        tokens = tokenizer(batch['instruction'], padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
        encoder_hidden_states = text_encoder(tokens)[0]

        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
        loss = torch.nn.functional.mse_loss(model_pred, noise, reduction="mean")
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

    progress_bar.update(1)
    progress_bar.set_description(f"Loss: {loss.item():.4f}")

    if (global_step + 1) % SAVE_INTERVAL == 0:
        save_path = os.path.join(CHECKPOINT_DIR, f"step_{global_step + 1}")
        print(f"\n💾 Saving checkpoint to: {save_path}")
        unet.save_pretrained(save_path)

    if (global_step + 1) % CLEANUP_INTERVAL == 0:
        cleanup_disk()

print(f"\n🎉 Training Complete! Saving final model to: {FINAL_OUTPUT_DIR}")
unet.save_pretrained(FINAL_OUTPUT_DIR)
print("✅ Done.")

Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.


🚀 Starting Robust LoRA Training (500 Steps)...
📥 Loading MagicBrush...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/51 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/51 [00:00<?, ?files/s]

data/train-00000-of-00051-9fd9f23e2b1cb3(…):   0%|          | 0.00/506M [00:00<?, ?B/s]

data/train-00001-of-00051-7fc041c6c75c7a(…):   0%|          | 0.00/493M [00:00<?, ?B/s]

data/train-00002-of-00051-4683be833ed06f(…):   0%|          | 0.00/480M [00:00<?, ?B/s]

data/train-00003-of-00051-dbbf01b8d47907(…):   0%|          | 0.00/533M [00:00<?, ?B/s]

data/train-00004-of-00051-d373bd832c3222(…):   0%|          | 0.00/491M [00:00<?, ?B/s]

data/train-00005-of-00051-261c379dee1873(…):   0%|          | 0.00/498M [00:00<?, ?B/s]

data/train-00006-of-00051-1601e4998b8705(…):   0%|          | 0.00/494M [00:00<?, ?B/s]

data/train-00007-of-00051-09790524710489(…):   0%|          | 0.00/505M [00:00<?, ?B/s]

data/train-00008-of-00051-45475d9537033a(…):   0%|          | 0.00/528M [00:00<?, ?B/s]

data/train-00009-of-00051-cf6ff6a53b3552(…):   0%|          | 0.00/512M [00:00<?, ?B/s]

data/train-00010-of-00051-d73c61985d7081(…):   0%|          | 0.00/522M [00:00<?, ?B/s]

data/train-00011-of-00051-8b4c59c53a36a8(…):   0%|          | 0.00/481M [00:00<?, ?B/s]

data/train-00012-of-00051-7a326334a53d22(…):   0%|          | 0.00/500M [00:00<?, ?B/s]

data/train-00013-of-00051-7a32a843bae55f(…):   0%|          | 0.00/494M [00:00<?, ?B/s]

data/train-00014-of-00051-34fbf4e78f18f5(…):   0%|          | 0.00/499M [00:00<?, ?B/s]

data/train-00015-of-00051-b0616f1aa69cc3(…):   0%|          | 0.00/503M [00:00<?, ?B/s]

data/train-00016-of-00051-2ce592a2aa4c0b(…):   0%|          | 0.00/519M [00:00<?, ?B/s]

data/train-00017-of-00051-29b16824f54f17(…):   0%|          | 0.00/504M [00:00<?, ?B/s]

data/train-00018-of-00051-b8a3cc2bebe485(…):   0%|          | 0.00/491M [00:00<?, ?B/s]

data/train-00019-of-00051-7496aa939d5aaf(…):   0%|          | 0.00/530M [00:00<?, ?B/s]

data/train-00020-of-00051-0e4f76bd7de39d(…):   0%|          | 0.00/520M [00:00<?, ?B/s]

data/train-00021-of-00051-aea46908b3d257(…):   0%|          | 0.00/492M [00:00<?, ?B/s]

data/train-00022-of-00051-3a7846cdc795c3(…):   0%|          | 0.00/511M [00:00<?, ?B/s]

data/train-00023-of-00051-581ce36816474c(…):   0%|          | 0.00/490M [00:00<?, ?B/s]

data/train-00024-of-00051-63f85f8694db99(…):   0%|          | 0.00/509M [00:00<?, ?B/s]

data/train-00025-of-00051-eb6b4b387abb3b(…):   0%|          | 0.00/500M [00:00<?, ?B/s]

data/train-00026-of-00051-a37258445ad77c(…):   0%|          | 0.00/503M [00:00<?, ?B/s]

data/train-00027-of-00051-0c3caef58833e3(…):   0%|          | 0.00/487M [00:00<?, ?B/s]

data/train-00028-of-00051-a864102cdd7e0d(…):   0%|          | 0.00/519M [00:00<?, ?B/s]

data/train-00029-of-00051-d4b5816a0785e2(…):   0%|          | 0.00/493M [00:00<?, ?B/s]

data/train-00030-of-00051-d57316916f4cce(…):   0%|          | 0.00/511M [00:00<?, ?B/s]

data/train-00031-of-00051-baa9fa3e29dfb8(…):   0%|          | 0.00/489M [00:00<?, ?B/s]

data/train-00032-of-00051-cfc5f479ca5625(…):   0%|          | 0.00/515M [00:00<?, ?B/s]

data/train-00033-of-00051-80e659de48bfed(…):   0%|          | 0.00/515M [00:00<?, ?B/s]

data/train-00034-of-00051-d5a8a32783d7b7(…):   0%|          | 0.00/500M [00:00<?, ?B/s]

data/train-00035-of-00051-6df799054d7c06(…):   0%|          | 0.00/492M [00:00<?, ?B/s]

data/train-00036-of-00051-3c85ce9e4996c7(…):   0%|          | 0.00/489M [00:00<?, ?B/s]

data/train-00037-of-00051-24003f9a1733a2(…):   0%|          | 0.00/501M [00:00<?, ?B/s]

data/train-00038-of-00051-a035b244c8b6e2(…):   0%|          | 0.00/494M [00:00<?, ?B/s]

data/train-00039-of-00051-0a94fb0d0e1e35(…):   0%|          | 0.00/508M [00:00<?, ?B/s]

data/train-00040-of-00051-3770c96bedf6c3(…):   0%|          | 0.00/499M [00:00<?, ?B/s]

data/train-00041-of-00051-45bf2e58112437(…):   0%|          | 0.00/489M [00:00<?, ?B/s]

data/train-00042-of-00051-494734a9b0704e(…):   0%|          | 0.00/506M [00:00<?, ?B/s]

data/train-00043-of-00051-340358803f3655(…):   0%|          | 0.00/485M [00:00<?, ?B/s]

data/train-00044-of-00051-12cc4bb9cdfcfd(…):   0%|          | 0.00/498M [00:00<?, ?B/s]

data/train-00045-of-00051-ac7a4fc63484f5(…):   0%|          | 0.00/495M [00:00<?, ?B/s]

data/train-00046-of-00051-caf11c5fbe9bb1(…):   0%|          | 0.00/523M [00:00<?, ?B/s]

data/train-00047-of-00051-a99476a212efcb(…):   0%|          | 0.00/494M [00:00<?, ?B/s]

data/train-00048-of-00051-9aa83b8e4abb3e(…):   0%|          | 0.00/481M [00:00<?, ?B/s]

data/train-00049-of-00051-e24aa47c28573c(…):   0%|          | 0.00/514M [00:00<?, ?B/s]

data/train-00050-of-00051-7a5506bb37822b(…):   0%|          | 0.00/500M [00:00<?, ?B/s]

data/dev-00000-of-00004-f147d414270a90e1(…):   0%|          | 0.00/387M [00:00<?, ?B/s]

data/dev-00001-of-00004-8ef3de1dc8cb8a6a(…):   0%|          | 0.00/380M [00:00<?, ?B/s]

data/dev-00002-of-00004-54c4d7b0a9e49db5(…):   0%|          | 0.00/374M [00:00<?, ?B/s]

data/dev-00003-of-00004-384b81a61c93b7e3(…):   0%|          | 0.00/380M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8807 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/528 [00:00<?, ? examples/s]

📦 Loading Stable Diffusion...


tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

text_encoder/model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

vae/diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

unet/diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/308 [00:00<?, ?B/s]

💉 Injecting LoRA Adapters...
🔄 Resuming from checkpoint: /content/drive/MyDrive/Projects/Image-Editing-by-Natural-Language-Constraints/checkpoints/step_540
   ✅ Resumed at Step 540
trainable params: 3,188,736 || all params: 862,709,700 || trainable%: 0.3696
🔥 Training from Step 540 to 2000...


  0%|          | 0/1460 [00:00<?, ?it/s]


💾 Saving checkpoint to: /content/drive/MyDrive/Projects/Image-Editing-by-Natural-Language-Constraints/checkpoints/step_550

💾 Saving checkpoint to: /content/drive/MyDrive/Projects/Image-Editing-by-Natural-Language-Constraints/checkpoints/step_560

💾 Saving checkpoint to: /content/drive/MyDrive/Projects/Image-Editing-by-Natural-Language-Constraints/checkpoints/step_570
🧹 Cleaning up disk space...

💾 Saving checkpoint to: /content/drive/MyDrive/Projects/Image-Editing-by-Natural-Language-Constraints/checkpoints/step_580

💾 Saving checkpoint to: /content/drive/MyDrive/Projects/Image-Editing-by-Natural-Language-Constraints/checkpoints/step_590

💾 Saving checkpoint to: /content/drive/MyDrive/Projects/Image-Editing-by-Natural-Language-Constraints/checkpoints/step_600
🧹 Cleaning up disk space...

💾 Saving checkpoint to: /content/drive/MyDrive/Projects/Image-Editing-by-Natural-Language-Constraints/checkpoints/step_610

💾 Saving checkpoint to: /content/drive/MyDrive/Projects/Image-Editing-by-Na