<a href="https://colab.research.google.com/github/Rishardmunene/Stable-Diffusion-test/blob/train/SDXL_trial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import gc
import psutil
from dataclasses import dataclass
from typing import Optional, Union, List, Tuple

import torch
from torch import amp
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from diffusers import StableDiffusionXLPipeline, ControlNetModel
from accelerate import Accelerator
from PIL import Image
from google.colab import files

In [2]:
def monitor_memory():
    """Utility function to monitor memory usage."""
    gpu_memory = torch.cuda.memory_allocated() / 1024**2
    ram_memory = psutil.Process().memory_info().rss / 1024**2
    return {
        "gpu_memory_mb": gpu_memory,
        "ram_memory_mb": ram_memory
    }

In [3]:
@dataclass
class OptimizationConfig:
    image_size: Tuple[int, int] = (256, 256)
    precision: str = "fp16"  # or "bf16"
    enable_checkpointing: bool = True
    enable_attention_slicing: bool = True
    enable_sequential_cpu_offload: bool = False
    vae_slicing: bool = True

In [4]:
class OptimizedSDXL:
    def __init__(
        self,
        model_id: str = "stabilityai/stable-diffusion-xl-base-1.0",
        config: Optional[OptimizationConfig] = None
    ):
        self.config = config or OptimizationConfig()
        self.accelerator = Accelerator()
        self.setup_pipeline(model_id)

    def setup_pipeline(self, model_id: str):
        """Initialize and optimize SDXL pipeline with memory saving techniques."""
        # Enable TF32 for better performance on Ampere GPUs
        torch.backends.cuda.matmul.allow_tf32 = True

        # Initialize pipeline with memory optimizations
        self.pipeline = StableDiffusionXLPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16 if self.config.precision == "fp16" else torch.bfloat16,
            use_safetensors=True,
            variant="fp16",
        )

        # Apply memory optimization techniques
        if self.config.enable_attention_slicing:
            self.pipeline.enable_attention_slicing()

        if self.config.enable_sequential_cpu_offload:
            self.pipeline.enable_sequential_cpu_offload()

        if self.config.vae_slicing:
            self.pipeline.enable_vae_slicing()

        # Move pipeline to accelerator device
        self.pipeline = self.pipeline.to(self.accelerator.device)

    def apply_gradient_checkpointing(self):
        """Enable gradient checkpointing for training."""
        if self.config.enable_checkpointing:
            self.pipeline.unet.enable_gradient_checkpointing()
            if hasattr(self.pipeline, "text_encoder"):
                self.pipeline.text_encoder.gradient_checkpointing_enable()
            if hasattr(self.pipeline, "text_encoder_2"):
                self.pipeline.text_encoder_2.gradient_checkpointing_enable()

    def train_single_stage(self, dataset, num_epochs):
        """Training loop for a single resolution stage."""
        self.apply_gradient_checkpointing()

        # GradScaler initialization
        scaler = amp.GradScaler('cuda')
        optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)

        for epoch in range(num_epochs):
            for batch in dataset:
                # Convert tensor prompts to strings
                if isinstance(batch, torch.Tensor):
                    batch = self._process_batch(batch)

                # autocast context
                with amp.autocast('cuda'):
                    loss = self._training_step(batch)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

    def _process_batch(self, batch):
        """Process input batch to ensure correct format."""
        if isinstance(batch, torch.Tensor):
            # Convert tensor to appropriate string format
            # This is a placeholder - modify according to your data format
            return "default prompt"  # Replace with actual tensor -> string conversion
        return batch

    def _training_step(self, batch):
        """Perform a single training step."""
        # Implement your training logic here
        # This is a placeholder - replace with your actual training step
        return self.pipeline(batch).loss

    def generate_image(
        self,
        prompt: Union[str, List[str]],
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_inference_steps: int = 30
    ):
        """Generate image with memory-optimized inference."""
        # Convert tensor to string if necessary
        if isinstance(prompt, torch.Tensor):
            prompt = self._process_batch(prompt)

        # Validate prompt type
        if not isinstance(prompt, (str, list)):
            raise ValueError(
                f"`prompt` must be a string or a list of strings, but got {type(prompt)}"
            )

        # If prompt is a list, ensure all elements are strings
        if isinstance(prompt, list) and not all(isinstance(p, str) for p in prompt):
            raise ValueError("All elements in the `prompt` list must be strings.")

        # Clear CUDA cache and garbage collect
        torch.cuda.empty_cache()
        gc.collect()

        # Generate and return the image
        return self.pipeline(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            height=self.config.image_size[0],
            width=self.config.image_size[1],
        ).images[0]

In [5]:
class ProgressiveTrainer:
    def __init__(self, sdxl_model: OptimizedSDXL):
        self.model = sdxl_model
        self.progressive_sizes = [(256, 256), (512, 512), (768, 768), (1024, 1024)]
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def train_progressive(self, dataset, num_epochs_per_size=5):
        """Implement progressive training with increasing image sizes."""
        for size in self.progressive_sizes:
            try:
                print(f"Starting training at resolution {size}")
                # Update model config and resize dataset images if needed
                self.model.config.image_size = size
                resized_dataset = self.resize_dataset(dataset, size)
                self.train_single_stage(resized_dataset, num_epochs_per_size)
                print(f"Completed training at resolution {size}")
            except Exception as e:
                print(f"Error during training at resolution {size}: {e}")
                continue

    def resize_dataset(self, dataset, size):
        """Resize dataset images to target size."""
        # Implement dataset resizing logic here
        # Placeholder implementation; modify as per dataset structure
        return dataset

    def train_single_stage(self, dataset, num_epochs):
        """Training loop for a single resolution stage."""
        try:
            # Ensure the model and subcomponents are in training mode
            if hasattr(self.model, 'train'):
                self.model.train()
            else:
                self.model.pipeline.train()

            # Move the model to the appropriate device
            self.model.pipeline.to(self.device)

            # Enable gradient checkpointing if supported
            if hasattr(self.model, 'apply_gradient_checkpointing'):
                self.model.apply_gradient_checkpointing()

            scaler = torch.cuda.amp.GradScaler()
            optimizer = torch.optim.AdamW(self.model.pipeline.unet.parameters(), lr=1e-5)

            for epoch in range(num_epochs):
                total_loss = 0
                num_batches = 0

                for batch in dataset:
                    try:
                        # Move batch to the correct device
                        batch = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in batch.items()}

                        # Clear gradients
                        optimizer.zero_grad()

                        # Forward pass with automatic mixed precision
                        with torch.cuda.amp.autocast():
                            loss = self.model.pipeline(batch)  # Forward pass
                            total_loss += loss.item()

                        # Backward pass with gradient scaling
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                        num_batches += 1
                    except Exception as batch_error:
                        print(f"Error during batch processing: {batch_error}")
                        continue

                # Print epoch statistics
                avg_loss = total_loss / max(num_batches, 1)  # Avoid division by zero
                print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

                # Clear CUDA cache and garbage collect after each epoch
                torch.cuda.empty_cache()
                gc.collect()

        except Exception as e:
            print(f"Error during training stage: {e}")


In [6]:
def upload_images():
    """Allow users to upload images for training."""
    os.makedirs("uploaded_images", exist_ok=True)
    uploaded_files = files.upload()
    for filename in uploaded_files.keys():
        img = Image.open(filename)
        img.save(f"uploaded_images/{filename}")
        print(f"Saved {filename} to uploaded_images/")

In [7]:
class ImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

def prepare_dataset():
    """Prepare dataset from uploaded images."""
    image_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    return ImageDataset("uploaded_images", transform=image_transforms)

In [8]:
def train_model():
    """Train the SDXL model progressively with uploaded images."""
    # Upload images
    upload_images()

    # Prepare dataset
    dataset = prepare_dataset()
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    # Initialize OptimizedSDXL
    config = OptimizationConfig(image_size=(256, 256))  # Start at 256x256 resolution
    optimized_sdxl = OptimizedSDXL(config=config)

    # Initialize ProgressiveTrainer
    trainer = ProgressiveTrainer(sdxl_model=optimized_sdxl)

    # Train model
    trainer.train_progressive(dataset=dataloader, num_epochs_per_size=5)

In [9]:
train_model()

Saving landscape5.jpg to landscape5 (4).jpg
Saving landscape6.jpg to landscape6 (5).jpg
Saving landscape7.jpg to landscape7 (5).jpg
Saved landscape5 (4).jpg to uploaded_images/
Saved landscape6 (5).jpg to uploaded_images/
Saved landscape7 (5).jpg to uploaded_images/


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Starting training at resolution (256, 256)
Error during training stage: 'StableDiffusionXLPipeline' object has no attribute 'train'
Completed training at resolution (256, 256)
Starting training at resolution (512, 512)
Error during training stage: 'StableDiffusionXLPipeline' object has no attribute 'train'
Completed training at resolution (512, 512)
Starting training at resolution (768, 768)
Error during training stage: 'StableDiffusionXLPipeline' object has no attribute 'train'
Completed training at resolution (768, 768)
Starting training at resolution (1024, 1024)
Error during training stage: 'StableDiffusionXLPipeline' object has no attribute 'train'
Completed training at resolution (1024, 1024)


In [11]:
# Assuming the training process is done
config = OptimizationConfig(image_size=(256, 256))
optimized_sdxl = OptimizedSDXL(config=config)

# Generate an image
landscape_prompt = "A picturesque mountain range with a tranquil river flowing through it"
generate_landscape_image(optimized_sdxl, prompt=landscape_prompt)


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Generating landscape image for prompt: 'A picturesque mountain range with a tranquil river flowing through it'...


  0%|          | 0/50 [00:00<?, ?it/s]

Landscape image saved at landscape_image.png.


In [None]:
def generate_landscape_image(model: OptimizedSDXL, prompt: str, 
                           negative_prompt: str = None,
                           num_inference_steps: int = 50):
    """
    Generate a landscape image with memory optimization.
    
    Args:
        model: OptimizedSDXL instance
        prompt: Text prompt for image generation
        negative_prompt: Optional negative prompt
        num_inference_steps: Number of denoising steps
    """
    print(f"Generating landscape image for prompt: '{prompt}'...")
    
    try:
        # Clear CUDA cache before generation
        torch.cuda.empty_cache()
        gc.collect()
        
        # Generate the image
        image = model.generate_image(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps
        )
        
        # Save the generated image
        save_path = "landscape_image.png"
        image.save(save_path)
        print(f"Landscape image saved at {save_path}.")
        
        return image
        
    except Exception as e:
        print(f"Error generating image: {e}")
        return None

In [None]:
# Test the optimized image generation
if __name__ == "__main__":
    # Initialize with memory optimizations
    config = OptimizationConfig(
        image_size=(512, 512),  # Start with 512x512 resolution
        precision="fp16",       # Use half precision
        enable_checkpointing=True,
        enable_attention_slicing=True,
        vae_slicing=True
    )
    
    # Create optimized SDXL instance
    optimized_sdxl = OptimizedSDXL(config=config)
    
    # Example prompts
    prompts = [
        "A majestic mountain landscape with snow-capped peaks and a crystal clear lake reflecting the sky",
        "Dense forest with morning mist, rays of sunlight piercing through ancient trees",
        "Serene coastal scene with gentle waves lapping at a sandy beach, pastel sunset colors"
    ]
    
    # Generate multiple images
    for i, prompt in enumerate(prompts):
        image = generate_landscape_image(
            model=optimized_sdxl,
            prompt=prompt,
            negative_prompt="blurry, low quality, distorted",
            num_inference_steps=50
        )
        if image:
            image.save(f"landscape_generation_{i+1}.png")