In [1]:
import os
import gc
import psutil
from typing import Optional, Union, List, Tuple, Dict, Any
from dataclasses import dataclass
from pathlib import Path
from datetime import datetime

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from google.colab import drive
from tqdm import tqdm

from diffusers import StableDiffusionXLPipeline
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForSequenceClassification

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [2]:
# Mount Drive and setup paths
drive.mount('/content/drive')
ROOT_DIR = '/content/drive/MyDrive/SDXL_images'
IMAGE_DIR = os.path.join(ROOT_DIR, 'landscape_images')
OUTPUT_DIR = os.path.join(ROOT_DIR, 'generated_images')
CHECKPOINT_DIR = os.path.join(ROOT_DIR, 'checkpoints')

# Create directories
for dir_path in [ROOT_DIR, IMAGE_DIR, OUTPUT_DIR, CHECKPOINT_DIR]:
    os.makedirs(dir_path, exist_ok=True)

Mounted at /content/drive


In [3]:
# Dataset Class
class DriveImageDataset(Dataset):
    def __init__(self, image_size: Tuple[int, int] = (512, 512), transform=None):
        self.folder_path = Path(IMAGE_DIR)
        self._validate_drive_folder()
        self.image_files = self._get_drive_images()
        self.image_size = image_size
        self.transform = transform or self._default_transform()
        print(f"Loaded dataset with {len(self.image_files)} images from {IMAGE_DIR}")

    def _validate_drive_folder(self):
        if not self.folder_path.exists():
            raise ValueError(f"Drive image directory not found: {IMAGE_DIR}")

    def _get_drive_images(self) -> List[Path]:
        files = list(self.folder_path.glob("*.jpg")) + list(self.folder_path.glob("*.png"))
        if not files:
            raise ValueError(f"No images found in Drive folder {IMAGE_DIR}")
        return files

    def _default_transform(self):
        return transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        try:
            image = Image.open(image_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return {"image": image, "prompt": f"A landscape photo of {image_path.stem}", "path": str(image_path)}
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return None

@dataclass
class OptimizationConfig:
    image_size: Tuple[int, int] = (512, 512)
    precision: str = "fp16"
    enable_checkpointing: bool = True
    enable_attention_slicing: bool = True
    enable_sequential_cpu_offload: bool = False
    vae_slicing: bool = True

class OptimizedSDXL:
    def __init__(self, model_id: str = "stabilityai/stable-diffusion-xl-base-1.0", config: Optional[OptimizationConfig] = None):
        self.config = config or OptimizationConfig()
        self.accelerator = Accelerator()
        self.setup_pipeline(model_id)

    def setup_pipeline(self, model_id: str):
        try:
            self.pipeline = StableDiffusionXLPipeline.from_pretrained(
                model_id,
                torch_dtype=torch.float16 if self.config.precision == "fp16" else torch.bfloat16,
                use_safetensors=True
            )
            if self.config.enable_attention_slicing:
                self.pipeline.enable_attention_slicing()
            if self.config.enable_sequential_cpu_offload:
                self.pipeline.enable_sequential_cpu_offload()
            if self.config.vae_slicing:
                self.pipeline.enable_vae_slicing()
            self.pipeline = self.pipeline.to(self.accelerator.device)
            if self.config.enable_checkpointing:
                self.pipeline.unet.enable_gradient_checkpointing()
            print(f"Pipeline setup complete on device: {self.accelerator.device}")
        except Exception as e:
            print(f"Error setting up pipeline: {str(e)}")
            raise

    def generate_image(self, prompt: str, negative_prompt: Optional[str] = None, num_inference_steps: int = 50, **kwargs):
        try:
            with torch.cuda.amp.autocast():
                result = self.pipeline(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, **kwargs).images[0]
            return result
        except Exception as e:
            print(f"Error generating image: {str(e)}")
            return None

In [4]:
def train_model(model: OptimizedSDXL, dataset: DriveImageDataset, num_epochs: int = 5, batch_size: int = 1, save_interval: int = 100):
    try:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        optimizer = torch.optim.AdamW(model.pipeline.unet.parameters(), lr=1e-5)
        scaler = torch.cuda.amp.GradScaler()

        for epoch in range(num_epochs):
            print(f"Starting epoch {epoch + 1}/{num_epochs}")
            total_loss = 0

            for batch_idx, batch in enumerate(tqdm(dataloader)):
                try:
                    optimizer.zero_grad()
                    with torch.cuda.amp.autocast():
                        loss = model.pipeline.unet(batch["image"]).mean()
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                    total_loss += loss.item()
                    if batch_idx % save_interval == 0:
                        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_e{epoch}_b{batch_idx}.pt")
                        torch.save(model.pipeline.unet.state_dict(), checkpoint_path)

                except Exception as e:
                    print(f"Error during batch {batch_idx}: {e}")
            print(f"Epoch {epoch + 1} completed with total loss: {total_loss:.4f}")

    except Exception as e:
        print(f"Training error: {str(e)}")
    finally:
        torch.cuda.empty_cache()
        gc.collect()

In [6]:
def run_training():
    try:
        print("Initializing training components...")
        config = OptimizationConfig(image_size=(512, 512))
        model = OptimizedSDXL(config=config)
        dataset = DriveImageDataset(image_size=(512, 512))
        print(f"Dataset loaded with {len(dataset)} images")

        train_model(model=model, dataset=dataset, num_epochs=5, batch_size=1, save_interval=100)
        print("Training completed successfully")
    except Exception as e:
        print(f"Training failed: {str(e)}")
    finally:
        torch.cuda.empty_cache()
        gc.collect()

if __name__ == "__main__":
    run_training()

Initializing training components...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model_index.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

Fetching 19 files:   0%|          | 0/19 [00:00<?, ?it/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/479 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.78G [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

text_encoder_2/config.json:   0%|          | 0.00/575 [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/737 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

tokenizer_2/tokenizer_config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

tokenizer_2/special_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/10.3G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

vae/config.json:   0%|          | 0.00/642 [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  scaler = torch.cuda.amp.GradScaler()


Pipeline setup complete on device: cuda
Loaded dataset with 7 images from /content/drive/MyDrive/SDXL_images/landscape_images
Dataset loaded with 7 images
Starting epoch 1/5


  with torch.cuda.amp.autocast():
 14%|█▍        | 1/7 [00:00<00:05,  1.05it/s]

Error during batch 0: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 29%|██▊       | 2/7 [00:01<00:04,  1.13it/s]

Error during batch 1: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 43%|████▎     | 3/7 [00:02<00:03,  1.02it/s]

Error during batch 2: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 57%|█████▋    | 4/7 [00:03<00:02,  1.08it/s]

Error during batch 3: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 71%|███████▏  | 5/7 [00:04<00:01,  1.06it/s]

Error during batch 4: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 86%|████████▌ | 6/7 [00:05<00:00,  1.07it/s]

Error during batch 5: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


100%|██████████| 7/7 [00:06<00:00,  1.06it/s]


Error during batch 6: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'
Epoch 1 completed with total loss: 0.0000
Starting epoch 2/5


 14%|█▍        | 1/7 [00:00<00:02,  2.21it/s]

Error during batch 0: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 29%|██▊       | 2/7 [00:00<00:02,  2.15it/s]

Error during batch 1: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 43%|████▎     | 3/7 [00:01<00:02,  1.96it/s]

Error during batch 2: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 57%|█████▋    | 4/7 [00:02<00:01,  1.87it/s]

Error during batch 3: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 71%|███████▏  | 5/7 [00:02<00:01,  1.71it/s]

Error during batch 4: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 86%|████████▌ | 6/7 [00:03<00:00,  1.79it/s]

Error during batch 5: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


100%|██████████| 7/7 [00:03<00:00,  1.86it/s]


Error during batch 6: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'
Epoch 2 completed with total loss: 0.0000
Starting epoch 3/5


 14%|█▍        | 1/7 [00:00<00:03,  1.79it/s]

Error during batch 0: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 29%|██▊       | 2/7 [00:01<00:02,  1.79it/s]

Error during batch 1: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 43%|████▎     | 3/7 [00:01<00:02,  1.92it/s]

Error during batch 2: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 57%|█████▋    | 4/7 [00:02<00:01,  2.00it/s]

Error during batch 3: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 71%|███████▏  | 5/7 [00:02<00:00,  2.16it/s]

Error during batch 4: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 86%|████████▌ | 6/7 [00:02<00:00,  2.33it/s]

Error during batch 5: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


100%|██████████| 7/7 [00:03<00:00,  2.20it/s]


Error during batch 6: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'
Epoch 3 completed with total loss: 0.0000
Starting epoch 4/5


 14%|█▍        | 1/7 [00:00<00:02,  2.78it/s]

Error during batch 0: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 29%|██▊       | 2/7 [00:00<00:01,  2.61it/s]

Error during batch 1: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 43%|████▎     | 3/7 [00:01<00:01,  2.56it/s]

Error during batch 2: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 57%|█████▋    | 4/7 [00:01<00:01,  2.39it/s]

Error during batch 3: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 71%|███████▏  | 5/7 [00:02<00:00,  2.27it/s]

Error during batch 4: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 86%|████████▌ | 6/7 [00:02<00:00,  2.29it/s]

Error during batch 5: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


100%|██████████| 7/7 [00:02<00:00,  2.43it/s]


Error during batch 6: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'
Epoch 4 completed with total loss: 0.0000
Starting epoch 5/5


 14%|█▍        | 1/7 [00:00<00:02,  2.50it/s]

Error during batch 0: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 29%|██▊       | 2/7 [00:00<00:02,  2.29it/s]

Error during batch 1: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 43%|████▎     | 3/7 [00:01<00:01,  2.45it/s]

Error during batch 2: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 57%|█████▋    | 4/7 [00:01<00:01,  2.24it/s]

Error during batch 3: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 71%|███████▏  | 5/7 [00:02<00:00,  2.34it/s]

Error during batch 4: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


 86%|████████▌ | 6/7 [00:02<00:00,  2.33it/s]

Error during batch 5: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'


100%|██████████| 7/7 [00:02<00:00,  2.40it/s]

Error during batch 6: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'
Epoch 5 completed with total loss: 0.0000





Training completed successfully


The error occurs because the UNet forward pass is missing required arguments `timestep` and `encoder_hidden_states`. Let me help you fix the training code.

Here's the corrected `train_model` function:



In [None]:
def train_model(model: OptimizedSDXL, dataset: DriveImageDataset, num_epochs: int = 5, batch_size: int = 1, save_interval: int = 100):
    try:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        optimizer = torch.optim.AdamW(model.pipeline.unet.parameters(), lr=1e-5)
        scaler = torch.cuda.amp.GradScaler()
        noise_scheduler = model.pipeline.scheduler

        for epoch in range(num_epochs):
            print(f"Starting epoch {epoch + 1}/{num_epochs}")
            total_loss = 0

            for batch_idx, batch in enumerate(tqdm(dataloader)):
                try:
                    optimizer.zero_grad()
                    
                    # Get image and convert to latents
                    images = batch["image"].to(model.accelerator.device)
                    latents = model.pipeline.vae.encode(images).latent_dist.sample()
                    latents = latents * model.pipeline.vae.config.scaling_factor
                    
                    # Add noise
                    noise = torch.randn_like(latents)
                    batch_size = latents.shape[0]
                    timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (batch_size,), device=latents.device)
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                    
                    # Get prompt embeddings
                    prompt_ids = model.pipeline.tokenizer(
                        batch["prompt"], 
                        padding="max_length",
                        max_length=model.pipeline.tokenizer.model_max_length,
                        truncation=True,
                        return_tensors="pt"
                    ).input_ids.to(model.accelerator.device)
                    
                    encoder_hidden_states = model.pipeline.text_encoder(prompt_ids)[0]

                    with torch.cuda.amp.autocast():
                        model_pred = model.pipeline.unet(
                            noisy_latents,
                            timesteps,
                            encoder_hidden_states
                        ).sample
                        
                        loss = torch.nn.functional.mse_loss(model_pred, noise)

                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                    total_loss += loss.item()
                    
                    if batch_idx % save_interval == 0:
                        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_e{epoch}_b{batch_idx}.pt")
                        torch.save(model.pipeline.unet.state_dict(), checkpoint_path)

                except Exception as e:
                    print(f"Error during batch {batch_idx}: {e}")
                    
            avg_loss = total_loss / len(dataloader)
            print(f"Epoch {epoch + 1} completed with average loss: {avg_loss:.4f}")

    except Exception as e:
        print(f"Training error: {str(e)}")
    finally:
        torch.cuda.empty_cache()
        gc.collect()



Key changes made:

1. Added noise scheduler from the pipeline
2. Convert images to latents using VAE
3. Add noise to latents using scheduler
4. Get text embeddings using text encoder
5. Pass all required arguments to UNet:
   - noisy_latents
   - timesteps 
   - encoder_hidden_states
6. Calculate loss between predicted and original noise
7. Added average loss calculation per epoch

The UNet now receives all required inputs for training:
- The noisy latents to denoise
- The timesteps indicating noise level
- The text embeddings providing conditioning

This should resolve the missing arguments error and allow proper training of the diffusion model.