In [1]:
from diffusers import StableDiffusionImageVariationPipeline
import torch
from torch.optim import AdamW
import glob
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision import transforms
import os
from PIL import Image
import numpy as np
from datasets import load_dataset
import torchvision

from load_data import CustomImageDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the pipeline and components
pipeline = StableDiffusionImageVariationPipeline.from_pretrained("/home/rmuproject/rmuproject/users/sandesh/models/80_epochs/",
                                                                 requires_safety_checker=False,
                                                                 device='cuda')
pipeline.enable_model_cpu_offload()
unet = pipeline.unet
vae = pipeline.vae
clip_encoder = pipeline.image_encoder.to('cpu')
# feature_extractor = pipeline.feature_extractor.to('cpu')
# Freeze the VAE and CLIP encoder
for param in vae.parameters():
    param.requires_grad = False
for param in clip_encoder.parameters():
    param.requires_grad = False



Keyword arguments {'device': 'cuda'} are not expected by StableDiffusionImageVariationPipeline and will be ignored.
Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00, 15.51it/s]


In [3]:

dataset_name = "/home/rmuproject/rmuproject/data"  # @param
dataset = load_dataset(dataset_name, split="train")

class CustomImageDataset(Dataset):
    def __init__(self, dataset, clip_encoder, feature_extractor, size=512):
        """
        Args:
            dataset: A dataset object from the `datasets` library.
            clip_encoder: The CLIP image encoder model (e.g., CLIPVisionModelWithProjection).
            feature_extractor: The feature extractor from the StableDiffusionImageVariationPipeline.
            size: The size to which images should be resized (default: 512x512).
        """
        self.dataset = dataset
        self.clip_encoder = clip_encoder
        self.feature_extractor = feature_extractor
        self.size = size

        # Transformations for the input images (resize, normalize, etc.)
        self.transforms = transforms.Compose(
            [
                transforms.Resize((size, size)),  # Resize to the required size
                transforms.CenterCrop(size),      # Center crop to ensure square images
                transforms.ToTensor(),            # Convert to tensor
                transforms.Normalize([0.5], [0.5]),  # Normalize to [-1, 1]
            ]
        )

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        example = {}
        try:
            # Load the image
            image = self.dataset[index]["image"]
            if not isinstance(image, Image.Image):
                image = Image.open(image).convert("RGB")  # Ensure it's a PIL image

            if image.mode != "RGB":
                image = image.convert("RGB")

            # Transform the image for the UNet
            example["instance_images"] = self.transforms(image)

            # Generate the CLIP embedding
            with torch.no_grad():
                # Preprocess the image for CLIP using the feature_extractor
                clip_input = self.feature_extractor(image, return_tensors="pt").pixel_values.squeeze(0)  # Shape: [3, 224, 224]
                clip_input = clip_input.to(self.clip_encoder.device)  # Move to the correct device
                clip_embedding = self.clip_encoder(clip_input.unsqueeze(0)).image_embeds  # Shape: [1, embedding_dim]

            # Add a sequence length dimension to the CLIP embedding
            clip_embedding = clip_embedding.unsqueeze(1).to('cpu')  # Shape: [1, 1, embedding_dim]
            example["clip_embeddings"] = clip_embedding.squeeze(0)  # Remove batch dimension for collation

        except Exception as e:
            # Skip corrupted or invalid images
            print(f"Error processing image at index {index}: {e}")
            return self.__getitem__((index + 1) % len(self))  # Skip to the next image

        return example  

In [4]:
feature_extractor = pipeline.feature_extractor
# Create the dataset
image_variation_dataset = CustomImageDataset(dataset, clip_encoder, feature_extractor)

# Example: Access a single item
example = image_variation_dataset[0]
print(example["instance_images"].shape)  # Should be [3, 512, 512]
print(example["clip_embeddings"].shape)  # Should be [embedding_dim]
example = image_variation_dataset[0]
print(example.keys())

torch.Size([3, 512, 512])
torch.Size([1, 768])
dict_keys(['instance_images', 'clip_embeddings'])


In [5]:
import torch

def collate_fn(examples):
    """
    Collate function for the ImageVariationDataset.
    Args:
        examples: A list of dictionaries, where each dictionary contains:
            - "instance_images": A tensor of shape [C, H, W].
            - "clip_embeddings": A tensor of shape [embedding_dim].
    Returns:
        A dictionary containing:
            - "instance_images": A stacked tensor of shape [batch_size, C, H, W].
            - "clip_embeddings": A stacked tensor of shape [batch_size, embedding_dim].
    """
    # Extract instance_images and clip_embeddings from the examples
    instance_images = [example["instance_images"] for example in examples]
    clip_embeddings = [example["clip_embeddings"] for example in examples]

    # Stack the tensors along the batch dimension
    instance_images = torch.stack(instance_images)
    clip_embeddings = torch.stack(clip_embeddings)

    # Ensure the instance_images tensor is in contiguous memory format and cast to float
    instance_images = instance_images.to(memory_format=torch.contiguous_format).float()

    # Return the batch as a dictionary
    batch = {
        "instance_images": instance_images,  # Shape: [batch_size, C, H, W]
        "clip_embeddings": clip_embeddings,  # Shape: [batch_size, embedding_dim]
    }
    return batch

In [6]:
from torch.utils.data import DataLoader
batch_size = 1
# Assuming `dataset` is your ImageVariationDataset
dataloader = DataLoader(image_variation_dataset, batch_size= batch_size, shuffle=True, collate_fn=collate_fn)

for batch in dataloader:
    instance_images = batch["instance_images"]  # Shape: [batch_size, C, H, W]
    clip_embeddings = batch["clip_embeddings"]  # Shape: [batch_size, embedding_dim]
    # Pass these to your model or pipeline
    print(instance_images.shape)
    print(clip_embeddings.shape)
    break

torch.Size([1, 3, 512, 512])
torch.Size([1, 1, 768])


In [7]:
learning_rate = 2e-06
max_train_steps = 400
from argparse import Namespace

args = Namespace(
    # pretrained_model_name_or_path=model_id,
    resolution=512,  # Reduce this if you want to save some memory
    train_dataset=image_variation_dataset,
    resume_from_checkpoint = None,
     checkpointing_steps = 200,
    # instance_prompt=instance_prompt,
    learning_rate=learning_rate,
    max_train_steps=max_train_steps,
    train_batch_size=1,
    gradient_accumulation_steps=1,  # Increase this if you want to lower memory usage
    max_grad_norm=1.0,
    gradient_checkpointing=True,  # Set this to True to lower the memory usage
    use_8bit_adam=True,  # Use 8bit optimizer from bitsandbytes
    seed=3434554,
    sample_batch_size=2,
    output_dir="/home/rmuproject/rmuproject/users/sandesh/models/new",  # Where to save the pipeline
)

In [8]:
import math
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, PNDMScheduler, StableDiffusionImageVariationPipeline
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
from lpips import LPIPS

def training_function(vae, unet, clip_encoder):
    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision="fp16",  # Enable mixed precision
    )

    set_seed(args.seed)

    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()

    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
    if args.use_8bit_adam:
        import bitsandbytes as bnb

        optimizer_class = bnb.optim.AdamW8bit
    else:
        optimizer_class = torch.optim.AdamW

    optimizer = optimizer_class(
        unet.parameters(),  # Only optimize unet
        lr=args.learning_rate,
    )

    lr_scheduler = CosineAnnealingLR(optimizer, T_max=args.max_train_steps, eta_min=1e-7)
    
    noise_scheduler = DDPMScheduler(
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        num_train_timesteps=1000,
    )

    train_dataloader = DataLoader(
        args.train_dataset,
        batch_size=args.train_batch_size,
        shuffle=True,
        collate_fn=collate_fn,
    )

    unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)

    # Move vae and clip_encoder to the accelerator device
    vae.to(accelerator.device)
    clip_encoder.to(accelerator.device)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # Train!
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Steps")
    global_step = 0

    for epoch in range(num_train_epochs):
        unet.train()
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(unet):
                # Convert images to latent space
                with torch.no_grad():
                    latents = vae.encode(batch["instance_images"].to(accelerator.device)).latent_dist.sample()
                    latents = latents * 0.18215  # Scale the latents (VAE scaling factor)

                # Sample noise that we'll add to the latents
                noise = torch.randn(latents.shape).to(latents.device)
                bsz = latents.shape[0]

                # Sample a random timestep for each image
                timesteps = torch.randint(
                    0,
                    noise_scheduler.config.num_train_timesteps,
                    (bsz,),
                    device=latents.device,
                ).long()

                # Add noise to the latents according to the noise magnitude at each timestep
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                # Get the CLIP embeddings for conditioning
                clip_embeddings = batch["clip_embeddings"].to(accelerator.device)
                # Predict the noise residual
                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=clip_embeddings).sample

                # Calculate the loss
                loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()

                # Backpropagate
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
                optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step()

            # Update progress bar
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1

            logs = {"loss": loss.detach().item() * args.gradient_accumulation_steps}  # Rescale loss for logging
            progress_bar.set_postfix(**logs)

            if global_step >= args.max_train_steps:
                break

        accelerator.wait_for_everyone()

    # Create the pipeline using the trained modules and save it
    if accelerator.is_main_process:
        print(f"Loading pipeline and saving to {args.output_dir}...")
        scheduler = PNDMScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            skip_prk_steps=True,
            steps_offset=1,
        )
        pipeline = StableDiffusionImageVariationPipeline(
            vae=vae,
            unet=accelerator.unwrap_model(unet),
            scheduler=scheduler,
            image_encoder=clip_encoder,
            safety_checker= None,
            feature_extractor=feature_extractor
        )
        pipeline.save_pretrained(args.output_dir)

In [None]:
training_function(vae, unet, clip_encoder)