In [1]:
import os
import gc
import psutil
from typing import Optional, Union, List, Tuple, Dict, Any
from dataclasses import dataclass
from pathlib import Path
from datetime import datetime

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from google.colab import drive
from tqdm import tqdm

from diffusers import StableDiffusionXLPipeline
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [2]:
# Mount Drive and setup paths
drive.mount('/content/drive')
ROOT_DIR = '/content/drive/MyDrive/SDXL_images'
IMAGE_DIR = os.path.join(ROOT_DIR, 'landscape_images')
OUTPUT_DIR = os.path.join(ROOT_DIR, 'generated_images')
CHECKPOINT_DIR = os.path.join(ROOT_DIR, 'checkpoints')

# Create directories
for dir_path in [ROOT_DIR, IMAGE_DIR, OUTPUT_DIR, CHECKPOINT_DIR]:
    os.makedirs(dir_path, exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
# Dataset Class
class DriveImageDataset(Dataset):
    def __init__(self, image_size: Tuple[int, int] = (512, 512), transform=None):
        self.folder_path = Path(IMAGE_DIR)
        self._validate_drive_folder()
        self.image_files = self._get_drive_images()
        self.image_size = image_size
        self.transform = transform or self._default_transform()
        print(f"Loaded dataset with {len(self.image_files)} images from {IMAGE_DIR}")

    def _validate_drive_folder(self):
        if not self.folder_path.exists():
            raise ValueError(f"Drive image directory not found: {IMAGE_DIR}")

    def _get_drive_images(self) -> List[Path]:
        files = list(self.folder_path.glob("*.jpg")) + list(self.folder_path.glob("*.png"))
        if not files:
            raise ValueError(f"No images found in Drive folder {IMAGE_DIR}")
        return files

    def _default_transform(self):
        return transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        try:
            image = Image.open(image_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return {"image": image, "prompt": f"A landscape photo of {image_path.stem}", "path": str(image_path)}
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return None

@dataclass
class OptimizationConfig:
    image_size: Tuple[int, int] = (512, 512)
    precision: str = "fp16"
    enable_checkpointing: bool = True
    enable_attention_slicing: bool = True
    enable_sequential_cpu_offload: bool = False
    vae_slicing: bool = True

class OptimizedSDXL:
    def __init__(self, model_id: str = "stabilityai/stable-diffusion-xl-base-1.0", config: Optional[OptimizationConfig] = None):
        self.config = config or OptimizationConfig()
        self.accelerator = Accelerator()
        self.setup_pipeline(model_id)

    def setup_pipeline(self, model_id: str):
        try:
            self.pipeline = StableDiffusionXLPipeline.from_pretrained(
                model_id,
                torch_dtype=torch.float16 if self.config.precision == "fp16" else torch.bfloat16,
                use_safetensors=True
            )
            if self.config.enable_attention_slicing:
                self.pipeline.enable_attention_slicing()
            if self.config.enable_sequential_cpu_offload:
                self.pipeline.enable_sequential_cpu_offload()
            if self.config.vae_slicing:
                self.pipeline.enable_vae_slicing()
            self.pipeline = self.pipeline.to(self.accelerator.device)
            if self.config.enable_checkpointing:
                self.pipeline.unet.enable_gradient_checkpointing()
            print(f"Pipeline setup complete on device: {self.accelerator.device}")
        except Exception as e:
            print(f"Error setting up pipeline: {str(e)}")
            raise

    def generate_image(self, prompt: str, negative_prompt: Optional[str] = None, num_inference_steps: int = 50, **kwargs):
        try:
            with torch.cuda.amp.autocast():
                result = self.pipeline(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, **kwargs).images[0]
            return result
        except Exception as e:
            print(f"Error generating image: {str(e)}")
            return None

In [4]:
def train_model(model: OptimizedSDXL, dataset: DriveImageDataset, num_epochs: int = 5, batch_size: int = 1, save_interval: int = 100):
    try:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        optimizer = torch.optim.AdamW(model.pipeline.unet.parameters(), lr=1e-5)
        scaler = torch.cuda.amp.GradScaler()
        noise_scheduler = model.pipeline.scheduler

        for epoch in range(num_epochs):
            print(f"Starting epoch {epoch + 1}/{num_epochs}")
            total_loss = 0

            for batch_idx, batch in enumerate(tqdm(dataloader)):
                try:
                    optimizer.zero_grad()

                    # Get image and convert to latents
                    images = batch["image"].to(model.accelerator.device)
                    latents = model.pipeline.vae.encode(images).latent_dist.sample()
                    latents = latents * model.pipeline.vae.config.scaling_factor

                    # Add noise
                    noise = torch.randn_like(latents)
                    batch_size = latents.shape[0]
                    timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (batch_size,), device=latents.device).long()
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                    # Get prompt embeddings
                    prompt_ids = model.pipeline.tokenizer(
                        batch["prompt"],
                        padding="max_length",
                        max_length=model.pipeline.tokenizer.model_max_length,
                        truncation=True,
                        return_tensors="pt"
                    ).input_ids.to(model.accelerator.device)

                    encoder_hidden_states = model.pipeline.text_encoder(prompt_ids)[0]

                    with torch.cuda.amp.autocast():
                        # Pass noisy_latents, timesteps, and encoder_hidden_states to the UNet forward method
                        model_pred = model.pipeline.unet(
                            noisy_latents,
                            timesteps,
                            encoder_hidden_states
                        ).sample

                        # Calculate loss between predicted noise and actual noise
                        loss = torch.nn.functional.mse_loss(model_pred, noise)

                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                    total_loss += loss.item()

                    if batch_idx % save_interval == 0:
                        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_e{epoch}_b{batch_idx}.pt")
                        torch.save(model.pipeline.unet.state_dict(), checkpoint_path)

                except Exception as e:
                    print(f"Error during batch {batch_idx}: {e}")

            avg_loss = total_loss / len(dataloader)
            print(f"Epoch {epoch + 1} completed with average loss: {avg_loss:.4f}")

    except Exception as e:
        print(f"Training error: {str(e)}")
    finally:
        torch.cuda.empty_cache()
        gc.collect()

In [5]:
def run_training():
    try:
        print("Initializing training components...")
        config = OptimizationConfig(image_size=(512, 512))
        model = OptimizedSDXL(config=config)
        dataset = DriveImageDataset(image_size=(512, 512))
        print(f"Dataset loaded with {len(dataset)} images")

        train_model(model=model, dataset=dataset, num_epochs=5, batch_size=1, save_interval=100)
        print("Training completed successfully")
    except Exception as e:
        print(f"Training failed: {str(e)}")
    finally:
        torch.cuda.empty_cache()
        gc.collect()

if __name__ == "__main__":
    run_training()

Initializing training components...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  scaler = torch.cuda.amp.GradScaler()


Pipeline setup complete on device: cuda
Loaded dataset with 7 images from /content/drive/MyDrive/SDXL_images/landscape_images
Dataset loaded with 7 images
Starting epoch 1/5


 14%|█▍        | 1/7 [00:00<00:03,  1.60it/s]

Error during batch 0: Input type (float) and bias type (c10::Half) should be the same


 29%|██▊       | 2/7 [00:01<00:02,  1.99it/s]

Error during batch 1: Input type (float) and bias type (c10::Half) should be the same


 43%|████▎     | 3/7 [00:01<00:01,  2.26it/s]

Error during batch 2: Input type (float) and bias type (c10::Half) should be the same


 57%|█████▋    | 4/7 [00:01<00:01,  2.17it/s]

Error during batch 3: Input type (float) and bias type (c10::Half) should be the same


 71%|███████▏  | 5/7 [00:02<00:00,  2.28it/s]

Error during batch 4: Input type (float) and bias type (c10::Half) should be the same


 86%|████████▌ | 6/7 [00:02<00:00,  2.28it/s]

Error during batch 5: Input type (float) and bias type (c10::Half) should be the same


100%|██████████| 7/7 [00:03<00:00,  2.26it/s]


Error during batch 6: Input type (float) and bias type (c10::Half) should be the same
Epoch 1 completed with average loss: 0.0000
Starting epoch 2/5


 14%|█▍        | 1/7 [00:00<00:02,  2.07it/s]

Error during batch 0: Input type (float) and bias type (c10::Half) should be the same


 29%|██▊       | 2/7 [00:00<00:02,  2.30it/s]

Error during batch 1: Input type (float) and bias type (c10::Half) should be the same


 43%|████▎     | 3/7 [00:01<00:01,  2.20it/s]

Error during batch 2: Input type (float) and bias type (c10::Half) should be the same


 57%|█████▋    | 4/7 [00:01<00:01,  2.29it/s]

Error during batch 3: Input type (float) and bias type (c10::Half) should be the same


 71%|███████▏  | 5/7 [00:02<00:00,  2.48it/s]

Error during batch 4: Input type (float) and bias type (c10::Half) should be the same


 86%|████████▌ | 6/7 [00:02<00:00,  2.55it/s]

Error during batch 5: Input type (float) and bias type (c10::Half) should be the same


100%|██████████| 7/7 [00:02<00:00,  2.40it/s]


Error during batch 6: Input type (float) and bias type (c10::Half) should be the same
Epoch 2 completed with average loss: 0.0000
Starting epoch 3/5


 14%|█▍        | 1/7 [00:00<00:02,  2.00it/s]

Error during batch 0: Input type (float) and bias type (c10::Half) should be the same


 29%|██▊       | 2/7 [00:00<00:02,  2.03it/s]

Error during batch 1: Input type (float) and bias type (c10::Half) should be the same


 43%|████▎     | 3/7 [00:01<00:02,  1.71it/s]

Error during batch 2: Input type (float) and bias type (c10::Half) should be the same


 57%|█████▋    | 4/7 [00:02<00:01,  1.53it/s]

Error during batch 3: Input type (float) and bias type (c10::Half) should be the same


 71%|███████▏  | 5/7 [00:03<00:01,  1.51it/s]

Error during batch 4: Input type (float) and bias type (c10::Half) should be the same


 86%|████████▌ | 6/7 [00:04<00:00,  1.24it/s]

Error during batch 5: Input type (float) and bias type (c10::Half) should be the same


100%|██████████| 7/7 [00:04<00:00,  1.48it/s]


Error during batch 6: Input type (float) and bias type (c10::Half) should be the same
Epoch 3 completed with average loss: 0.0000
Starting epoch 4/5


 14%|█▍        | 1/7 [00:00<00:04,  1.44it/s]

Error during batch 0: Input type (float) and bias type (c10::Half) should be the same


 29%|██▊       | 2/7 [00:01<00:02,  1.69it/s]

Error during batch 1: Input type (float) and bias type (c10::Half) should be the same


 43%|████▎     | 3/7 [00:01<00:02,  1.73it/s]

Error during batch 2: Input type (float) and bias type (c10::Half) should be the same


 57%|█████▋    | 4/7 [00:02<00:01,  1.74it/s]

Error during batch 3: Input type (float) and bias type (c10::Half) should be the same


 71%|███████▏  | 5/7 [00:02<00:01,  1.83it/s]

Error during batch 4: Input type (float) and bias type (c10::Half) should be the same


 86%|████████▌ | 6/7 [00:03<00:00,  1.91it/s]

Error during batch 5: Input type (float) and bias type (c10::Half) should be the same


100%|██████████| 7/7 [00:03<00:00,  1.89it/s]


Error during batch 6: Input type (float) and bias type (c10::Half) should be the same
Epoch 4 completed with average loss: 0.0000
Starting epoch 5/5


 14%|█▍        | 1/7 [00:00<00:02,  2.71it/s]

Error during batch 0: Input type (float) and bias type (c10::Half) should be the same


 29%|██▊       | 2/7 [00:00<00:01,  2.76it/s]

Error during batch 1: Input type (float) and bias type (c10::Half) should be the same


 43%|████▎     | 3/7 [00:01<00:01,  2.45it/s]

Error during batch 2: Input type (float) and bias type (c10::Half) should be the same


 57%|█████▋    | 4/7 [00:01<00:01,  2.25it/s]

Error during batch 3: Input type (float) and bias type (c10::Half) should be the same


 71%|███████▏  | 5/7 [00:02<00:00,  2.33it/s]

Error during batch 4: Input type (float) and bias type (c10::Half) should be the same


 86%|████████▌ | 6/7 [00:02<00:00,  2.38it/s]

Error during batch 5: Input type (float) and bias type (c10::Half) should be the same


100%|██████████| 7/7 [00:02<00:00,  2.39it/s]

Error during batch 6: Input type (float) and bias type (c10::Half) should be the same
Epoch 5 completed with average loss: 0.0000





Training completed successfully


The error "Input type (float) and bias type (c10::Half) should be the same" indicates a mismatch between the data types of the input and the model parameters. This typically happens when the model is using mixed precision (e.g., `float16` for the model parameters) but the input data is in `float32`.

To fix this, you need to ensure that the input data is also in `float16` when using mixed precision. Here are the steps to fix the issue:

1. Convert the input data to `float16` before passing it to the model.
2. Ensure that all operations involving the model and input data are consistent in terms of data types.

Here's the corrected `train_model` function:



In [None]:
def train_model(model: OptimizedSDXL, dataset: DriveImageDataset, num_epochs: int = 5, batch_size: int = 1, save_interval: int = 100):
    try:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        optimizer = torch.optim.AdamW(model.pipeline.unet.parameters(), lr=1e-5)
        scaler = torch.cuda.amp.GradScaler()
        noise_scheduler = model.pipeline.scheduler

        for epoch in range(num_epochs):
            print(f"Starting epoch {epoch + 1}/{num_epochs}")
            total_loss = 0

            for batch_idx, batch in enumerate(tqdm(dataloader)):
                try:
                    optimizer.zero_grad()
                    
                    # Get image and convert to latents
                    images = batch["image"].to(model.accelerator.device, dtype=torch.float16)
                    latents = model.pipeline.vae.encode(images).latent_dist.sample()
                    latents = latents * model.pipeline.vae.config.scaling_factor
                    
                    # Add noise
                    noise = torch.randn_like(latents, dtype=torch.float16)
                    batch_size = latents.shape[0]
                    timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (batch_size,), device=latents.device).long()
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                    
                    # Get prompt embeddings
                    prompt_ids = model.pipeline.tokenizer(
                        batch["prompt"], 
                        padding="max_length",
                        max_length=model.pipeline.tokenizer.model_max_length,
                        truncation=True,
                        return_tensors="pt"
                    ).input_ids.to(model.accelerator.device)
                    
                    encoder_hidden_states = model.pipeline.text_encoder(prompt_ids)[0]

                    with torch.cuda.amp.autocast():
                        # Pass noisy_latents, timesteps, and encoder_hidden_states to the UNet forward method
                        model_pred = model.pipeline.unet(
                            noisy_latents,
                            timesteps,
                            encoder_hidden_states
                        ).sample
                        
                        # Calculate loss between predicted noise and actual noise
                        loss = torch.nn.functional.mse_loss(model_pred, noise)

                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                    total_loss += loss.item()
                    
                    if batch_idx % save_interval == 0:
                        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_e{epoch}_b{batch_idx}.pt")
                        torch.save(model.pipeline.unet.state_dict(), checkpoint_path)

                except Exception as e:
                    print(f"Error during batch {batch_idx}: {e}")
                    
            avg_loss = total_loss / len(dataloader)
            print(f"Epoch {epoch + 1} completed with average loss: {avg_loss:.4f}")

    except Exception as e:
        print(f"Training error: {str(e)}")
    finally:
        torch.cuda.empty_cache()
        gc.collect()



Key changes made:
1. Convert `images` to `float16` before passing them to the model.
2. Ensure `noise` is also in `float16`.

This should resolve the data type mismatch error and allow the training to run seamlessly.