In [1]:
!pip install diffusers accelerate torchvision transformers datasets ftfy tensorboard Jinja2 peft wandb bitsandbytes clip pillow==10.3.0



In [1]:
import logging
import math
import os
import time
import random
import shutil
import clip
from pathlib import Path
from types import SimpleNamespace

import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, DistributedType, set_seed
from datasets import load_dataset
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from PIL import Image

import diffusers
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DiffusionPipeline,
    StableDiffusionPipeline,
    UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import (
    check_min_version,
    convert_state_dict_to_diffusers,
)
from diffusers.utils.import_utils import is_wandb_available
from diffusers.utils.torch_utils import is_compiled_module
from transformers import CLIPTextModel, CLIPProcessor, CLIPModel, CLIPTextModelWithProjection # Added CLIPTextModelWithProjection
from torch.utils.tensorboard import SummaryWriter # Added for SummaryWriter

if is_wandb_available():
    import wandb

check_min_version("0.30.0.dev0")

logger = get_logger(__name__, log_level="INFO")


def save_lora_state(save_dir: str, unet, unwrap_model_fn):
    os.makedirs(save_dir, exist_ok=True)
    unwrapped = unwrap_model_fn(unet)
    lora_state = get_peft_model_state_dict(unwrapped)
    save_path = os.path.join(save_dir, "lora_state.pt")
    torch.save(lora_state, save_path)
    return save_path

In [2]:
gradient_accumulation_steps = 4
train_batch_size = 1
num_train_epochs = 2
max_grad_norm = 4
checkpointing_steps = 2
checkpoints_total_limit = 5
logging_steps = 50

# Model parameters
pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
revision = None
variant = None
LOCAL_MODEL_PATH = "" # Base path for local models, e.g., "/content/models/"

# LoRA loading/saving parameters
# If loading an existing LoRA model, specify the path to its directory.
# If starting training from scratch, leave empty.
source_model_path = "" # e.g., "my_finetuned_lora_weights_dir" or full path "/content/my_finetuned_lora_weights_dir"

# Dataset parameters
dataset_type = "huggingface" # Options: "huggingface" or "imagefolder"
dataset_name = "lambdalabs/naruto-blip-captions" # Required for dataset_type="huggingface"
dataset_config_name = None # Optional for dataset_type="huggingface"
dataset_path = "" # Required for dataset_type="imagefolder" and if local data is used
image_column = "image"
caption_column = "text"

# Training parameters
resolution = 512 # Image resolution for training
center_crop = False # Whether to center crop the images
random_flip = False # Whether to randomly flip images horizontally
max_train_samples = None # Number of training samples to use (None for all)
dataloader_num_workers = 0 # Number of workers for data loading
lr_warmup_steps = 5 # Number of warmup steps for learning rate scheduler
max_train_steps = None # Maximum number of training steps (None for derived from num_train_epochs)
lr_scheduler = "constant" # Learning rate scheduler type
use_8bit_adam = True # Whether to use 8-bit Adam optimizer
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_weight_decay = 1e-2
adam_epsilon = 1e-04
gradient_checkpointing = False # Whether to use gradient checkpointing

# Resume and output parameters
resume_from_checkpoint = None # "latest" or a path to a specific checkpoint, or None
output_dir = "sd-finetune-model" # Directory to save checkpoints and final model
logging_dir = "logs" # Directory for TensorBoard logs

# Validation and logging
wandb_project = "sd-finetune" # Set to None to disable Weights & Biases logging
num_validation_images = 4 # Number of images to generate for validation

In [3]:
class StableDiffusionTrainer:
    def __init__(self, output_dir, num_train_epochs, max_train_steps, logging_steps, train_dataloader, vae, unet, text_encoder, text_encoder_2, noise_scheduler, optimizer, lr_scheduler, accelerator, weight_dtype, tensorboard_writer, logger, progress_bar, first_epoch, initial_global_step, unwrap_model):
        self.output_dir = output_dir
        self.num_train_epochs = num_train_epochs
        self.max_train_steps = max_train_steps
        self.logging_steps = logging_steps
        self.train_dataloader = train_dataloader
        self.vae = vae
        self.unet = unet
        self.text_encoder = text_encoder
        self.text_encoder_2 = text_encoder_2 # Correctly initialize text_encoder_2
        self.noise_scheduler = noise_scheduler
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.accelerator = accelerator
        self.weight_dtype = weight_dtype
        self.tensorboard_writer = tensorboard_writer
        self.logger = logger
        self.global_step = initial_global_step
        self.progress_bar = progress_bar
        self.first_epoch = first_epoch
        self.metrics = {}
        self.unwrap_model = unwrap_model
        self.train_batch_size = getattr(self.accelerator.state, "train_batch_size", None)
        if self.train_batch_size is None:
            self.train_batch_size = train_dataloader.batch_size

        # If you pass these in args, replace with args.xxx
        self.gradient_accumulation_steps = accelerator.gradient_accumulation_steps
        self.max_grad_norm = getattr(self, "max_grad_norm", 1.0)

        # Optional configs — fill with safe defaults if not provided
        self.checkpointing_steps = 500
        self.checkpoints_total_limit = 3
        self.wandb_project = None

    # Added for SDXL
    def compute_time_ids(self, original_size, crops_coords_top_left, resolution):
        bs = original_size.shape[0]
        res_tensor = torch.tensor([resolution, resolution], device=self.accelerator.device, dtype=original_size.dtype)
        orig_and_crop = torch.cat([original_size.to(self.accelerator.device), crops_coords_top_left.to(self.accelerator.device)], dim=1)
        add_time_ids = torch.cat([orig_and_crop, res_tensor.unsqueeze(0).repeat(bs, 1)], dim=1)
        return add_time_ids.to(self.accelerator.device, dtype=self.weight_dtype)

    def train(self):
        train_batch_size = self.train_batch_size
        grad_acc_steps = self.gradient_accumulation_steps
        for epoch in range(self.first_epoch, self.num_train_epochs):
            self.unet.train()
            train_loss = 0.0
            current_loss = 0.0
            for step, batch in enumerate(self.train_dataloader):
                with self.accelerator.accumulate(self.unet):
                    device = self.accelerator.device
                    self.vae.to(device, dtype=torch.float32)
                    pixel_values = batch["pixel_values"].to(device=device, dtype=torch.float32, non_blocking=True)

                    with torch.no_grad():
                        # use autocast for the frozen modules so they use fp16/bf16 if accelerator configured
                        with self.accelerator.autocast():
                            # pixel_values = batch["pixel_values"].to(device=self.accelerator.device, dtype=self.weight_dtype)
                            vae_out = self.vae.encode(pixel_values)
                            # latents = vae_out.latent_dist.sample() * self.vae.config.scaling_factor
                            latents = vae_out.latent_dist.sample()
                            # convert latents to training weight dtype (fp16/bf16) for UNet forward
                            model_input = (latents * self.vae.config.scaling_factor).to(self.weight_dtype).detach()
                            self.vae.to("cpu")
                            torch.cuda.empty_cache()

                            # text encoders (frozen): move input_ids to device
                            input_ids = batch["input_ids"].to(self.accelerator.device)
                            input_ids_2 = batch["input_ids_2"].to(self.accelerator.device)

                            encoder_hidden_states = self.text_encoder(input_ids)[0]
                            te2_out = self.text_encoder_2(input_ids_2, return_dict=True)
                            pooled_text_embeds = te2_out.text_embeds
                            encoder_hidden_states_2 = te2_out.last_hidden_state
                            encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_2], dim=-1)

                    # Sample noise & timesteps
                    noise = torch.randn_like(model_input)
                    bsz = model_input.shape[0]
                    timesteps = torch.randint(
                        0,
                        self.noise_scheduler.config.num_train_timesteps,
                        (bsz,),
                        device=model_input.device,
                    ).long()
                    noisy_model_input = self.noise_scheduler.add_noise(model_input, noise, timesteps)

                    # compute SDXL time ids
                    original_sizes = batch["original_sizes"].squeeze(1)
                    crops_coords_top_left = batch["crops_coords_top_left"].squeeze(1)
                    add_time_ids = self.compute_time_ids(original_sizes, crops_coords_top_left, resolution=resolution)

                    added_cond_kwargs = {
                        "text_embeds": pooled_text_embeds,
                        "time_ids": add_time_ids,
                    }

                    # UNet forward in autocast (mixed precision)
                    with self.accelerator.autocast():
                        model_pred = self.unet(
                            noisy_model_input,
                            timesteps,
                            encoder_hidden_states,
                            added_cond_kwargs=added_cond_kwargs,
                            return_dict=False,
                        )[0]

                    # free heavy intermediates immediately
                    del vae_out, latents, encoder_hidden_states_2, te2_out, pooled_text_embeds
                    torch.cuda.empty_cache()

                    # Select target and compute MSE loss
                    if self.noise_scheduler.config.prediction_type == "epsilon":
                        target = noise
                    else:
                        target = self.noise_scheduler.get_velocity(model_input, noise, timesteps)

                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
                    # gather for logging across processes
                    avg_loss = self.accelerator.gather(loss.detach().repeat(train_batch_size)).mean()
                    train_loss += avg_loss.item() / grad_acc_steps
                    current_loss = train_loss

                    # Backpropagate
                    self.accelerator.backward(loss)

                    if self.accelerator.sync_gradients:
                        # clip grads for only trainable parameters (LoRA)
                        trainable_params = []
                        for group in self.optimizer.param_groups:
                            trainable_params.extend(group['params'])
                        total_norm = self.accelerator.clip_grad_norm_(trainable_params, max_grad_norm)
                        # optimizer step, scheduler step, zero grad
                        self.optimizer.step()
                        self.lr_scheduler.step()
                        self.optimizer.zero_grad()

                # Step-end bookkeeping and checkpointing
                if self.accelerator.sync_gradients:
                    self.at_step_end(loss, total_norm, epoch, current_loss)
                else:
                    # still update progress and step count on non-syncing steps
                    self.progress_bar.update(1)
                    self.global_step += 1
                    logs = {"step_loss": loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[0] if self.lr_scheduler else 0}
                    self.progress_bar.set_postfix(**logs)

                if self.global_step >= self.max_train_steps:
                    break

            self.at_epoch_end(epoch, current_loss)
            if self.global_step >= self.max_train_steps:
                break

    def at_step_end(self, loss, total_norm, epoch, current_loss):
        self.progress_bar.update(1)
        self.global_step += 1

        # ensure grad norm is float32 for logging
        if isinstance(total_norm, torch.Tensor) and (total_norm.dtype == torch.bfloat16 or total_norm.dtype == torch.float16):
            total_norm = total_norm.to(torch.float32)

        if self.global_step % self.logging_steps == 0:
            self.tensorboard_writer.add_scalar('train/loss', current_loss, self.global_step)
            if self.lr_scheduler is not None:
                self.tensorboard_writer.add_scalar('train/learning_rate', self.lr_scheduler.get_last_lr()[0], self.global_step)
            self.tensorboard_writer.add_scalar('train/epoch', epoch+1, self.global_step)
            self.tensorboard_writer.add_scalar('train/grad_norm', float(total_norm), self.global_step)

        if os.environ.get('WANDB_API_KEY') and self.wandb_project:
            try:
                import wandb
                wandb.log({
                    "train_loss": current_loss,
                    "learning_rate": self.lr_scheduler.get_last_lr()[0] if self.lr_scheduler else 0,
                    "epoch": epoch+1,
                    "grad_norm": float(total_norm)
                }, step=self.global_step)
            except Exception:
                pass

        # Checkpointing (only main process or DeepSpeed)
        if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process:
            if self.global_step % self.checkpointing_steps == 0:
                if self.checkpoints_total_limit is not None:
                    checkpoints = [d for d in os.listdir(self.output_dir) if d.startswith("checkpoint")]
                    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
                    if len(checkpoints) >= self.checkpoints_total_limit:
                        num_to_remove = len(checkpoints) - self.checkpoints_total_limit + 1
                        for removing_checkpoint in checkpoints[:num_to_remove]:
                            shutil.rmtree(os.path.join(self.output_dir, removing_checkpoint))

                save_dir = os.path.join(self.output_dir, f"checkpoint-{self.global_step}")
                self.accelerator.save_state(save_dir)
                # Save LoRA/PEFT state dict
                save_lora_state(save_dir, self.unet, self.unwrap_model)
                self.logger.info(f"Saved checkpoint to {save_dir}")

        logs = {"step_loss": loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[0] if self.lr_scheduler else 0}
        self.progress_bar.set_postfix(**logs)

    def at_epoch_end(self, epoch, current_loss):
        self.logger.info(f"EPOCH: {epoch+1}, TRAIN_LOSS: {current_loss}")
        self.metrics = {"epochs": epoch+1, "train_loss": current_loss}

    def get_final_metrics(self):
        return self.metrics


In [4]:
import subprocess as sp

def load_tokenizers(args):
    tokenizer = AutoTokenizer.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="tokenizer",
        revision=args.revision,
        use_fast=False,
    )
    # For SDXL, load the second tokenizer as well
    tokenizer_2 = AutoTokenizer.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="tokenizer_2",
        revision=args.revision,
        use_fast=False,
    )
    return tokenizer, tokenizer_2


def load_pretrained_model(args):
    load_model_path = args.pretrained_model_name_or_path

    # Load the first text encoder
    text_encoder = CLIPTextModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
    )
    # Load the second text encoder for SDXL
    text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
    )

    noise_scheduler = DDPMScheduler.from_pretrained(load_model_path, subfolder="scheduler")
    vae = AutoencoderKL.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="vae",
        revision=args.revision,
        variant=args.variant,
        torch_dtype=torch.float16
    )
    unet = UNet2DConditionModel.from_pretrained(
        load_model_path,
        subfolder="unet",
        revision=args.revision,
        variant=args.variant,
    )
    try:
        unet.enable_xformers_memory_efficient_attention()
        logger.info("xformers enabled")
    except Exception as e:
        logger.warning(f"xformers not available: {e}")

    # Attention slicing (reduces peak mem)
    try:
        unet.enable_attention_slicing()
        logger.info("attention slicing enabled")
    except Exception as e:
        logger.warning(f"attention slicing failed: {e}")

    return noise_scheduler, text_encoder, text_encoder_2, vae, unet

def decode_base64(encoded_string):
    import base64
    decoded_bytes = base64.b64decode(encoded_string)
    decoded_text = decoded_bytes.decode('utf-8')
    return decoded_text

def gpu_memory():
    command = "nvidia-smi --query-gpu=memory.free --format=csv"
    memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
    memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
    return memory_free_values

In [5]:
def main():
    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        mixed_precision="fp16", # Using fp16, changed from pf16 for consistency. Can be "bf16" as well.
        log_with="wandb" if wandb_project else None,
        kwargs_handlers=[kwargs],
    )
    validation_prompt = decode_base64(" ") # For Validation prompt
    # gpufree = gpu_memory() # Commented out, `sp` was not imported in its original location, now fixed in fpsBH2pURmK8. Re-enable if needed.

    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    seed_size = 25
    set_seed(seed_size)

    # Handle the repository creation
    if accelerator.is_main_process:
        if output_dir is not None:
            os.makedirs(output_dir, exist_ok=True)

        # if args.push_to_hub:
        #     repo_id = create_repo(
        #         repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
        #     ).repo_id

    # Load the tokenizers
    accelerator.print("Loading the tokenizers, text endcoders, schedulers and model:")
    tokenizer, tokenizer_2 = load_tokenizers(SimpleNamespace(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        revision=revision
    ))

    # import correct text encoder classes, scheduler and models
    accelerator.print("Loading the Stable diffusion Model:")
    noise_scheduler, text_encoder, text_encoder_2, vae, unet = load_pretrained_model(SimpleNamespace(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        revision=revision,
        variant=variant
    ))

    # Load the clip_model for CLIP score
    accelerator.print("Loading the CLIP Model for validation:")
    # Fix: The 'clip' module installed via pip does not have a 'load' attribute.
    # We will use CLIPModel and CLIPProcessor from the transformers library instead.
    # clip_model, preprocess = clip.load("ViT-B/32", device=accelerator.device)
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(accelerator.device)
    preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    accelerator.print(f"LOADING PRE-TRAINED MODEL FROM {LOCAL_MODEL_PATH}{source_model_path}")
    trained_model_path = f"{LOCAL_MODEL_PATH}{source_model_path}"
    if trained_model_path:
        unet.load_attn_procs(trained_model_path, weight_name="pytorch_lora_weights.safetensors")
        unet.enable_lora()
    else:
        accelerator.print("No source model path provided, LoRA will not be loaded initially.")

    # We only train the additional adapter LoRA layers
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    text_encoder_2.requires_grad_(False)
    unet.requires_grad_(False)

    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # Freeze the unet parameters before adding adapters
    for param in unet.parameters():
        param.requires_grad_(False)

    # for 10-12 VRAM
    unet_lora_config = LoraConfig(
      r=1,
      lora_alpha=8,
      init_lora_weights="gaussian",
      target_modules=["to_k", "to_q", "to_v"],
    )

    # Move vae and text_encoder to device and cast to weight_dtype
    # vae.to("cpu")
    # text_encoder.to("cpu")
    # text_encoder_2.to("cpu")
    # unet.to("cpu")
    # torch.cuda.empty_cache()
    vae.to(accelerator.device, dtype=weight_dtype)
    text_encoder.to(accelerator.device, dtype=weight_dtype)
    text_encoder_2.to(accelerator.device, dtype=weight_dtype)

    # Ensure unet is on device, but keep its parameters in float32 initially before adding adapter
    unet.to(accelerator.device) # Keep UNet in float32 for now

    # Add adapter and make sure the trainable params are in float32.
    if not trained_model_path: # Only add adapter if no pre-trained LoRA was loaded
        unet.add_adapter(unet_lora_config)
    else: # If a trained model path is provided, it means LoRA was loaded, so ensure it's enabled
        unet.enable_lora() # Explicitly enable LoRA for loaded weights

    # This ensures LoRA layers, once added, are always in float32, which is important for GradScaler.
    # The base UNet weights will be handled by accelerator during autocast.
    # We explicitly cast trainable LoRA parameters to float32.
    if accelerator.mixed_precision == "fp16": # Apply this only if we are using fp16
        for name, param in unet.named_parameters():
            if param.requires_grad:
                # keep param on same device but force float32 dtype
                param.data = param.data.to(torch.float32)
                # if a grad exists (unlikely at this point) cast it too
                if param.grad is not None:
                    param.grad.data = param.grad.data.to(torch.float32)

    def unwrap_model(model):
        model = accelerator.unwrap_model(model)
        model = model._orig_mod if is_compiled_module(model) else model
        return model

    lora_layers = filter(lambda p: p.requires_grad, unet.parameters())

    if gradient_checkpointing:
        unet.enable_gradient_checkpointing()
        logger.info("enabled unet.gradient_checkpointing()")

    learning_rate = (
          1e-4 * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
      )

    # Initialize the optimizer
    if use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    optimizer = optimizer_cls(
        lora_layers,
        lr=learning_rate,
        betas=(adam_beta1, adam_beta2),
        weight_decay=adam_weight_decay,
        eps=adam_epsilon,
    )

    # Get the datasets: you can either provide your own training and evaluation files (see below)
    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    if dataset_type == "huggingface":
        accelerator.print(f"Loading the dataset: {dataset_name}")
        dataset = load_dataset(
            dataset_name,
            dataset_config_name,
        )
        column_names = dataset["train"].column_names
        image_column_name = image_column
        caption_column_name = caption_column
    else:
        accelerator.print(f"Loading the dataset from: {dataset_path}")
        dataset = load_dataset(
            "imagefolder",
            data_dir=dataset_path,
        )
        image_column_name = "image"
        caption_column_name = "text"

    # Preprocessing the datasets.
    def tokenize_captions(examples, is_train=True):
        captions = []
        for caption in examples[caption_column_name]:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{caption_column_name}` should contain either strings or lists of strings."
                )
        # For SDXL, we need two tokenizers
        # First tokenizer for text_encoder
        input_ids = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        ).input_ids
        # Second tokenizer for text_encoder_2
        input_ids_2 = tokenizer_2(
            captions, max_length=tokenizer_2.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        ).input_ids
        return input_ids, input_ids_2

    # Preprocessing the datasets.
    train_transforms = transforms.Compose(
        [
            transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(resolution) if center_crop else transforms.RandomCrop(resolution),
            transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    def preprocess_train(examples):
        images = [image.convert("RGB") for image in examples[image_column_name]]
        examples["pixel_values"] = [train_transforms(image) for image in images]
        input_ids, input_ids_2 = tokenize_captions(examples) # Get both tokenized inputs
        examples["input_ids"] = input_ids
        examples["input_ids_2"] = input_ids_2

        # For SDXL, we need to add original_size and crop_coords_top_left as well
        # Assuming all images are resized to 'resolution' and then cropped, set original_size and crop_coords_top_left accordingly
        examples["original_size"] = [(resolution, resolution)] * len(images)
        examples["crops_coords_top_left"] = [(0, 0)] * len(images)

        return examples

    with accelerator.main_process_first():
        if max_train_samples is not None and max_train_samples > 0:
            dataset["train"] = dataset["train"].shuffle(seed=seed_size).select(range(max_train_samples))
        train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
        input_ids = torch.stack([example["input_ids"] for example in examples])
        input_ids_2 = torch.stack([example["input_ids_2"] for example in examples])
        original_sizes = torch.tensor([example["original_size"] for example in examples], dtype=torch.long)
        crops_coords_top_left = torch.tensor([example["crops_coords_top_left"] for example in examples], dtype=torch.long)

        return {"pixel_values": pixel_values, "input_ids": input_ids, "input_ids_2": input_ids_2, "original_sizes": original_sizes, "crops_coords_top_left": crops_coords_top_left}

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=train_batch_size,
        num_workers=dataloader_num_workers,
    )

    # Scheduler and math around the number of training steps.
    num_warmup_steps_for_scheduler = lr_warmup_steps * accelerator.num_processes
    if max_train_steps is None or max_train_steps <= 0:
        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / gradient_accumulation_steps)
        num_training_steps_for_scheduler = (
            num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
        )
        max_train_steps_actual = num_train_epochs * num_update_steps_per_epoch
    else:
        num_training_steps_for_scheduler = max_train_steps * accelerator.num_processes
        max_train_steps_actual = max_train_steps

    # To avoid UnboundLocalError, explicitly access the global lr_scheduler string value
    # before assigning a local lr_scheduler object.
    scheduler_type_str = globals()['lr_scheduler']

    lr_scheduler = get_scheduler(
        scheduler_type_str, # Use the global lr_scheduler string directly
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps_for_scheduler,
        num_training_steps=num_training_steps_for_scheduler,
    )

    # Prepare everything with our `accelerator`.
    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        unet, optimizer, train_dataloader, lr_scheduler
    )

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)

    # Afterwards we recalculate our number of training epochs
    num_train_epochs_actual = math.ceil(max_train_steps_actual / num_update_steps_per_epoch)

    # Train!
    total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps

    global_step = 0
    first_epoch = 0
    training_start_time = time.time()

    # Initialize a local variable for resume_from_checkpoint to avoid UnboundLocalError
    _resume_from_checkpoint = resume_from_checkpoint

    # Potentially load in the weights and states from a previous save
    if _resume_from_checkpoint and os.path.exists(output_dir):
        # Get the most recent checkpoint
        dirs = os.listdir(output_dir)
        dirs = [d for d in dirs if d.startswith("checkpoint")]
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
        path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            accelerator.print(
                "Checkpoint does not exist. Starting a new training run."
            )
            _resume_from_checkpoint = None # Update local copy
            initial_global_step = 0
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            accelerator.load_state(os.path.join(output_dir, path))
            global_step = int(path.split("-")[1])

            initial_global_step = global_step
            first_epoch = global_step // num_update_steps_per_epoch

    else:
        accelerator.print("No previous checkpoint found. Starting a new training run.")
        initial_global_step = 0
        path = None

    # Weights & Biases integration
    if wandb_project and os.environ.get('WANDB_API_KEY'):
        # initialize_wandb(args, path) # initialize_wandb is not defined, replaced with accelerator.init_trackers
        accelerator.init_trackers(project_name=wandb_project, config={
            "output_dir": output_dir,
            "resolution": resolution,
            "train_batch_size": train_batch_size,
            "gradient_accumulation_steps": gradient_accumulation_steps,
            "lr": learning_rate,
            "num_train_epochs": num_train_epochs_actual,
            "max_train_steps": max_train_steps_actual,
            "lr_scheduler": lr_scheduler,
            "seed": seed_size,
            "mixed_precision": accelerator.mixed_precision,
            "pretrained_model_name_or_path": pretrained_model_name_or_path,
            "dataset_name": dataset_name,
            "image_column": image_column,
            "caption_column": caption_column,
        }, init_kwargs={"resume": "allow" if _resume_from_checkpoint else None})
    else:
        # If wandb not enabled, still need to initialize accelerator trackers for TensorBoard
        accelerator.init_trackers(project_name="tensorboard_logs", config={
            "output_dir": output_dir,
            "resolution": resolution,
            "train_batch_size": train_batch_size,
            "gradient_accumulation_steps": gradient_accumulation_steps,
            "lr": learning_rate,
            "num_train_epochs": num_train_epochs_actual,
            "max_train_steps": max_train_steps_actual,
            "lr_scheduler": lr_scheduler,
            "seed": seed_size,
            "mixed_precision": accelerator.mixed_precision,
            "pretrained_model_name_or_path": pretrained_model_name_or_path,
            "dataset_name": dataset_name,
            "image_column": image_column,
            "caption_column": caption_column,
        }) # Passing a project_name for tensorboard only, it won't push to wandb.

    progress_bar = tqdm(
        range(0, max_train_steps_actual),
        initial=initial_global_step,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=not accelerator.is_local_main_process,
    )

    # Initialize TensorBoard writer
    tensorboard_writer = SummaryWriter(log_dir=logging_dir)

    # Initialize the Trainer
    trainer = StableDiffusionTrainer(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs_actual,
        max_train_steps=max_train_steps_actual,
        logging_steps=logging_steps,
        train_dataloader=train_dataloader, vae=vae, unet=unet, text_encoder=text_encoder, text_encoder_2=text_encoder_2, noise_scheduler=noise_scheduler, optimizer=optimizer, lr_scheduler=lr_scheduler, accelerator=accelerator, weight_dtype=weight_dtype, tensorboard_writer=tensorboard_writer, logger=logger, progress_bar=progress_bar, first_epoch=first_epoch, initial_global_step=initial_global_step, unwrap_model=unwrap_model,
    )
    # Start the training
    trainer.train()
    training_end_time = time.time()

    # Save the final model as checkpoint-final
    save_path = os.path.join(output_dir, "final")
    accelerator.save_state(save_path)
    unwrapped_unet = unwrap_model(unet)
    unet_lora_state_dict = convert_state_dict_to_diffusers(
        get_peft_model_state_dict(unwrapped_unet)
    )
    StableDiffusionPipeline.save_lora_weights(
        save_directory=save_path,
        unet_lora_layers=unet_lora_state_dict,
        safe_serialization=True,
    )
    metrics = trainer.get_final_metrics()

    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        # run inference
        validation_start_time = time.time()
        if validation_prompt and num_validation_images > 0:
            # Make sure vae.dtype is consistent with the unet.dtype
            if accelerator.mixed_precision == "fp16":
                vae.to(weight_dtype)
            # Load previous pipeline
            pipeline = DiffusionPipeline.from_pretrained(
                pretrained_model_name_or_path,
                revision=revision,
                variant=variant,
                torch_dtype=weight_dtype,
            )
            # load attention processors
            pipeline.load_lora_weights(save_path)

            # log_validation function is not defined in the provided context, commenting out.
            # images = log_validation(pipeline, SimpleNamespace(num_validation_images=num_validation_images,
            #                                                 resolution=resolution, output_dir=output_dir,
            #                                                 validation_prompt=validation_prompt),
            #                         accelerator, num_train_epochs_actual, tensorboard_writer, is_final_validation=True)
            # For demonstration, creating dummy images if log_validation is skipped.
            images = [Image.new("RGB", (resolution, resolution), color = 'red')] * num_validation_images

            # calculate_clip_score function is not defined in the provided context, commenting out.
            # clip_score = calculate_clip_score(images, validation_prompt, clip_model, preprocess, accelerator.device)
            # logger.info(f"CLIP score: {clip_score}")
            # metrics["clip_score"] = clip_score
            logger.info("Skipped CLIP score calculation because `calculate_clip_score` is not defined.")
            del pipeline
            torch.cuda.empty_cache()

        validation_end_time = time.time()

    # Close the Writers
    tensorboard_writer.close()

    accelerator.end_training()
    # push_model(output_dir, metrics) # push_model function is not defined, commenting out.

In [6]:
if __name__ == "__main__":
    main()

Loading the tokenizers, text endcoders, schedulers and model:


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading the Stable diffusion Model:


{'variance_type', 'thresholding', 'clip_sample_range', 'dynamic_thresholding_ratio', 'rescale_betas_zero_snr'} was not found in config. Values will be initialized to default values.
Instantiating AutoencoderKL model under default dtype torch.float16.
{'mid_block_add_attention', 'latents_std', 'use_post_quant_conv', 'latents_mean', 'shift_factor', 'use_quant_conv'} was not found in config. Values will be initialized to default values.
All model checkpoint weights were used when initializing AutoencoderKL.

All the weights of AutoencoderKL were initialized from the model checkpoint at stabilityai/stable-diffusion-xl-base-1.0.
If your task is similar to the task the model of the checkpoint was trained on, you can already use AutoencoderKL for predictions without further training.
{'attention_type', 'dropout', 'reverse_transformer_layers_per_block'} was not found in config. Values will be initialized to default values.
All model checkpoint weights were used when initializing UNet2DConditio

Loading the CLIP Model for validation:


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


LOADING PRE-TRAINED MODEL FROM 
No source model path provided, LoRA will not be loaded initially.
Loading the dataset: lambdalabs/naruto-blip-captions


Repo card metadata block was not found. Setting CardData to empty.


No previous checkpoint found. Starting a new training run.


[34m[1mwandb[0m: Currently logged in as: [33manonymous7770777[0m ([33manonymous7770777-e2e[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Steps:   0%|          | 0/612 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 26.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 16.12 MiB is free. Process 499691 has 14.72 GiB memory in use. Of the allocated memory 14.38 GiB is allocated by PyTorch, and 217.44 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)