In [None]:
# --- Setup ---
!pip install diffusers transformers accelerate safetensors torch torchvision pillow


In [None]:

import os, time, math, json
import torch, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers import StableDiffusionXLPipeline, AutoencoderKL, DDIMScheduler
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from safetensors.torch import save_file
from transformers import CLIPTokenizer, CLIPTextModel

# --- Dataset ---
class TextImageDataset(Dataset):
    def __init__(self, data_dir, resolution=512):
        self.images_dir = os.path.join(data_dir, "images")
        self.prompts_path = os.path.join(data_dir, "prompts.json")
        self.items = [json.loads(line) for line in open(self.prompts_path)]
        self.transform = transforms.Compose([
            transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
    def __len__(self): return len(self.items)
    def __getitem__(self, idx):
        item = self.items[idx]
        img = Image.open(os.path.join(self.images_dir, item["image"])).convert("RGB")
        return {"pixel_values": self.transform(img), "prompt": item["prompt"]}

# --- LoRA helpers ---
def inject_lora_unet(unet, r=4):
    lora_cls = LoRAAttnProcessor2_0 if hasattr(torch, "compile") else LoRAAttnProcessor
    for _, module in unet.named_modules():
        if hasattr(module, "set_processor"):
            module.set_processor(lora_cls(r=r))
    return unet

def save_unet_lora(unet, save_path):
    state = {}
    for name, module in unet.named_modules():
        proc = getattr(module, "processor", None)
        if proc is None: continue
        for pname, param in module.named_parameters():
            if param.requires_grad:
                state[f"{name}.{pname}"] = param.detach().cpu()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    save_file(state, save_path)
    print(f"Saved LoRA weights to {save_path}")
    
    



def encode_long_prompt(prompt, device):
    tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"].to(device)
    # Split into chunks of 77
    chunks = tokens[0].split(77)
    embeddings = []
    for chunk in chunks:
        emb = text_encoder(chunk.unsqueeze(0))[0]
        embeddings.append(emb)
    # Concatenate along sequence dimension
    final_emb = torch.cat(embeddings, dim=1)
    return final_emb

# --- Training loop ---
data_dir = "./data_out"   # your preprocessed dataset
out_dir  = "./sdxl_lora_out"
os.makedirs(out_dir, exist_ok=True)

accelerator = Accelerator(mixed_precision="fp16")
set_seed(42)

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    use_safetensors=True,
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
tokenizer = CLIPTokenizer.from_pretrained(pipe.text_encoder.config._name_or_path)
text_encoder = CLIPTextModel.from_pretrained(pipe.text_encoder.config._name_or_path)

vae, unet = pipe.vae, pipe.unet
vae.requires_grad_(False)
inject_lora_unet(unet, r=4)
lora_params = [p for p in unet.parameters() if p.requires_grad]

dataset = TextImageDataset(data_dir, resolution=512)
dl = DataLoader(dataset, batch_size=4, shuffle=True)

optimizer = torch.optim.AdamW(lora_params, lr=5e-5)
unet, optimizer, dl = accelerator.prepare(unet, optimizer, dl)
vae.to(accelerator.device)

vae_scale_factor = 0.18215
global_step, last_save = 0, time.time()

for epoch in range(3):   # adjust epochs
    for batch in dl:
        with accelerator.accumulate(unet):
            latents = vae.encode(batch["pixel_values"].to(accelerator.device)).latent_dist.sample()
            latents = latents * vae_scale_factor
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],), device=accelerator.device).long()
            noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

            prompt_embeds = encode_long_prompt(batch["prompt"], accelerator.device)
            model_pred = unet(noisy_latents, timesteps,
                                 prompt_embeds=prompt_embeds).sample
            loss = F.mse_loss(model_pred, noise)
            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()

        global_step += 1
        # Save every 45 minutes
        if time.time() - last_save > 45*60:
            save_unet_lora(unet, os.path.join(out_dir, f"lora_step{global_step}.safetensors"))
            last_save = time.time()

save_unet_lora(unet, os.path.join(out_dir, "lora_final.safetensors"))


In [1]:
import os, time, math, json
import torch, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers import StableDiffusionXLPipeline, AutoencoderKL, DDIMScheduler
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from safetensors.torch import save_file

# --- Dataset ---
class TextImageDataset(Dataset):
    def __init__(self, data_dir, resolution=512):
        self.images_dir = os.path.join(data_dir, "images")
        self.prompts_path = os.path.join(data_dir, "prompts.jsonl")
        self.items = [json.loads(line) for line in open(self.prompts_path)]
        self.transform = transforms.Compose([
            transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
    def __len__(self): return len(self.items)
    def __getitem__(self, idx):
        item = self.items[idx]
        img = Image.open(os.path.join(self.images_dir, item["image"])).convert("RGB")
        return {"pixel_values": self.transform(img), "prompt": item["prompt"]}

# --- LoRA helpers ---
def inject_lora_unet(unet, r=4):
    lora_cls = LoRAAttnProcessor2_0 if hasattr(torch, "compile") else LoRAAttnProcessor
    for _, module in unet.named_modules():
        if hasattr(module, "set_processor"):
            module.set_processor(lora_cls(r=r))
    return unet

def save_unet_lora(unet, save_path):
    state = {}
    for name, module in unet.named_modules():
        proc = getattr(module, "processor", None)
        if proc is None: continue
        for pname, param in module.named_parameters():
            if param.requires_grad:
                state[f"{name}.{pname}"] = param.detach().cpu()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    save_file(state, save_path)
    print(f"Saved LoRA weights to {save_path}")


2025-12-14 08:56:13.462422: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-14 08:56:13.475105: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-12-14 08:56:13.492108: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-12-14 08:56:13.498020: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-12-14 08:56:13.511674: I tensorflow/core/platform/cpu_feature_guar

In [None]:
# --- Training loop ---
data_dir = "./data_out"   # your preprocessed dataset
out_dir  = "./sdxl_lora_out"
os.makedirs(out_dir, exist_ok=True)

accelerator = Accelerator(mixed_precision="fp16")
set_seed(42)

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    use_safetensors=True,
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

vae, unet = pipe.vae, pipe.unet
vae.requires_grad_(False)
inject_lora_unet(unet, r=4)
lora_params = [p for p in unet.parameters() if p.requires_grad]

dataset = TextImageDataset(data_dir, resolution=512)
dl = DataLoader(dataset, batch_size=4, shuffle=True)

optimizer = torch.optim.AdamW(lora_params, lr=5e-5)
unet, optimizer, dl = accelerator.prepare(unet, optimizer, dl)
vae.to(accelerator.device)

vae_scale_factor = 0.18215
global_step, last_save = 0, time.time()

for epoch in range(3):   # adjust epochs
    for batch in dl:
        with accelerator.accumulate(unet):
            latents = vae.encode(batch["pixel_values"].to(accelerator.device)).latent_dist.sample()
            latents = latents * vae_scale_factor
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],), device=accelerator.device).long()
            noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

            enc = pipe.encode_prompt(batch["prompt"], device=accelerator.device)
            model_pred = unet(noisy_latents, timesteps,
                              prompt_embeds=enc["prompt_embeds"],
                              pooled_prompt_embeds=enc["pooled_prompt_embeds"]).sample
            loss = F.mse_loss(model_pred, noise)
            accelerator.backward(loss)
            optimizer.step(); optimizer.zero_grad()

        global_step += 1
        # Save every 45 minutes
        if time.time() - last_save > 45*60:
            save_unet_lora(unet, os.path.join(out_dir, f"lora_step{global_step}.safetensors"))
            last_save = time.time()

save_unet_lora(unet, os.path.join(out_dir, "lora_final.safetensors"))

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


TypeError: LoRAAttnProcessor2_0.__init__() got an unexpected keyword argument 'r'

: 