# Zellige Generation using Stable Diffusion Model

In [None]:
# ---------------------------
#   SECTION 1 — INSTALL LIBS
# ---------------------------
!pip install -q diffusers==0.30.0 transformers accelerate safetensors datasets torchvision

import os
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import math

from diffusers import StableDiffusionPipeline, DDPMScheduler
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_cosine_schedule_with_warmup
from transformers import AutoTokenizer
from safetensors.torch import save_file

device = "cuda"
print("Using:", device)


In [None]:
# ---------------------------
#  SECTION 2 — MOUNT DRIVE
# ---------------------------
from google.colab import drive
drive.mount('/content/drive')

# CHANGE DATASET PATH HERE
DATASET_DIR = "/content/drive/MyDrive/ZelligeDataset/train"
OUTPUT_DIR = "/content/drive/MyDrive/Zellige_LoRA"
os.makedirs(OUTPUT_DIR, exist_ok=True)


In [None]:
# ---------------------------
#  SECTION 3 — TRAIN CONFIG
# ---------------------------
class TrainConfig:
    model_id = "runwayml/stable-diffusion-v1-5"
    train_data_dir = DATASET_DIR
    output_dir = OUTPUT_DIR
    image_size = 512
    batch_size = 2
    num_epochs = 10
    lr = 1e-4
    lr_warmup_steps = 200
    gradient_accumulation_steps = 1
    max_train_steps = None
    lora_rank = 4

config = TrainConfig()


In [None]:
# ---------------------------
#  SECTION 4 — DATASET CLASS
# ---------------------------
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((config.image_size, config.image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

class ZelligeDataset(Dataset):
    def __init__(self, folder, tokenizer, caption="Moroccan Zellige pattern, geometric mosaic"):
        self.paths = [
            os.path.join(folder, f) for f in os.listdir(folder)
            if f.lower().endswith((".jpg", ".png", ".jpeg"))
        ]
        self.tokenizer = tokenizer
        self.caption = caption

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = transform(img)

        text_inputs = self.tokenizer(
            self.caption,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt"
        )

        return {
            "pixel_values": img,
            "input_ids": text_inputs.input_ids[0],
        }

# Load tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained(config.model_id, subfolder="tokenizer")
dataset = ZelligeDataset(config.train_data_dir, tokenizer)
train_dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)

print("Number of training images:", len(dataset))


In [None]:
# ---------------------------
#  SECTION 5 — LOAD SD MODEL
# ---------------------------
pipe = StableDiffusionPipeline.from_pretrained(
    config.model_id,
    torch_dtype=torch.float16,
    safety_checker=None
).to(device)

vae = pipe.vae
unet = pipe.unet
text_encoder = pipe.text_encoder
noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config)

# Freeze everything except LoRA
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)

In [None]:
# ---------------------------
#  SECTION 6 — ADD LORA LAYERS
# ---------------------------
def add_lora_attention(unet, rank=4):
    loras = {}
    for name, module in unet.named_modules():
        if hasattr(module, "set_processor"):
            hidden = module.to_q.in_features
            loras[name] = LoRAAttnProcessor(hidden, rank=rank)
    unet.set_attn_processor(loras)

    # Collect parameters
    params = []
    for proc in unet.attn_processors.values():
        params.extend(list(proc.parameters()))
    return params

lora_params = add_lora_attention(unet, rank=config.lora_rank)
print("Trainable LoRA params:", sum(p.numel() for p in lora_params))

optimizer = torch.optim.AdamW(lora_params, lr=config.lr)

In [None]:
# ---------------------------
#  SECTION 7 — TRAINING SETUP
# ---------------------------
num_steps_per_epoch = len(train_dataloader)
total_steps = config.num_epochs * num_steps_per_epoch

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=total_steps,
)


In [None]:
# ---------------------------
#  SECTION 8 — TRAIN LOOP
# ---------------------------
global_step = 0
unet.train()

for epoch in range(config.num_epochs):
    print(f"\n--- Epoch {epoch+1}/{config.num_epochs} ---")

    for batch in train_dataloader:
        images = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)

        # Encode images into latents
        with torch.no_grad():
            latents = vae.encode(images).latent_dist.sample() * 0.18215

        # Add noise
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.size(0),), device=device)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Text embeddings
        with torch.no_grad():
            enc = text_encoder(input_ids)[0]

        # Predict noise
        noise_pred = unet(noisy_latents, timesteps, enc).sample

        # Loss
        loss = torch.nn.functional.mse_loss(noise_pred, noise)
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        global_step += 1

        if global_step % 50 == 0:
            print(f"Step {global_step}/{total_steps} | Loss: {loss.item():.4f}")

print("\nTraining completed!")

In [None]:
# ---------------------------
#  SECTION 9 — SAVE LORA
# ---------------------------
print("Saving LoRA...")

lora_state = {}
for name, module in unet.attn_processors.items():
    if isinstance(module, LoRAAttnProcessor):
        for k, v in module.state_dict().items():
            lora_state[f"{name}.{k}"] = v.cpu()

save_path = os.path.join(OUTPUT_DIR, "zellige_lora.safetensors")
save_file(lora_state, save_path)

print("LoRA saved to:", save_path)