In [1]:
import os
import gc
import psutil
from typing import Optional, Union, List, Tuple, Dict, Any
from dataclasses import dataclass
from pathlib import Path
from datetime import datetime

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from google.colab import drive
from tqdm import tqdm

from diffusers import StableDiffusionXLPipeline
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [2]:
# Mount Drive and setup paths
drive.mount('/content/drive')
ROOT_DIR = '/content/drive/MyDrive/SDXL_images'
IMAGE_DIR = os.path.join(ROOT_DIR, 'landscape_images')
OUTPUT_DIR = os.path.join(ROOT_DIR, 'generated_images')
CHECKPOINT_DIR = os.path.join(ROOT_DIR, 'checkpoints')

# Create directories
for dir_path in [ROOT_DIR, IMAGE_DIR, OUTPUT_DIR, CHECKPOINT_DIR]:
    os.makedirs(dir_path, exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
# Dataset Class
class DriveImageDataset(Dataset):
    @staticmethod
    def collate_fn(batch):
        # Filter out None values
        batch = [item for item in batch if item is not None]
        if not batch:
            return None

        return {
            'image': torch.stack([item['image'] for item in batch]),
            'prompt': [item['prompt'] for item in batch],
            'path': [item['path'] for item in batch]
        }
    def __init__(self, image_size: Tuple[int, int] = (512, 512), transform=None):
        self.folder_path = Path(IMAGE_DIR)
        self._validate_drive_folder()
        self.image_files = self._get_drive_images()
        self.image_size = image_size
        self.transform = transform or self._default_transform()
        print(f"Loaded dataset with {len(self.image_files)} images from {IMAGE_DIR}")

    def _validate_drive_folder(self):
        if not self.folder_path.exists():
            raise ValueError(f"Drive image directory not found: {IMAGE_DIR}")

    def _get_drive_images(self) -> List[Path]:
        files = list(self.folder_path.glob("*.jpg")) + list(self.folder_path.glob("*.png"))
        if not files:
            raise ValueError(f"No images found in Drive folder {IMAGE_DIR}")
        return files

    def _default_transform(self):
        return transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        try:
            image = Image.open(image_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return {"image": image, "prompt": f"A landscape photo of {image_path.stem}", "path": str(image_path)}
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return None

@dataclass
class OptimizationConfig:
    image_size: Tuple[int, int] = (512, 512)
    precision: str = "fp16"
    enable_checkpointing: bool = True
    enable_attention_slicing: bool = True
    enable_sequential_cpu_offload: bool = False
    vae_slicing: bool = True

class OptimizedSDXL:
    def __init__(self, model_id: str = "stabilityai/stable-diffusion-xl-base-1.0", config: Optional[OptimizationConfig] = None, checkpoint_path: Optional[str] = None):
        self.config = config or OptimizationConfig()
        self.accelerator = Accelerator()
        self.setup_pipeline(model_id)
        if checkpoint_path and os.path.exists(checkpoint_path):
            print(f"Loading checkpoint from {checkpoint_path}")
            self.pipeline.unet.load_state_dict(torch.load(checkpoint_path))


    def setup_pipeline(self, model_id: str):
        try:
            self.pipeline = StableDiffusionXLPipeline.from_pretrained(
                model_id,
                torch_dtype=torch.float16 if self.config.precision == "fp16" else torch.bfloat16,
                use_safetensors=True
            )
            if self.config.enable_attention_slicing:
                self.pipeline.enable_attention_slicing()
            if self.config.enable_sequential_cpu_offload:
                self.pipeline.enable_sequential_cpu_offload()
            if self.config.vae_slicing:
                self.pipeline.enable_vae_slicing()
            self.pipeline = self.pipeline.to(self.accelerator.device)
            if self.config.enable_checkpointing:
                self.pipeline.unet.enable_gradient_checkpointing()
            print(f"Pipeline setup complete on device: {self.accelerator.device}")
        except Exception as e:
            print(f"Error setting up pipeline: {str(e)}")
            raise

    def generate_image(self, prompt: str, negative_prompt: Optional[str] = None, num_inference_steps: int = 50, **kwargs):
        try:
            with torch.amp.autocast('cuda'):
                result = self.pipeline(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, **kwargs).images[0]
            return result
        except Exception as e:
            print(f"Error generating image: {str(e)}")
            return None

In [4]:
def train_model(model: OptimizedSDXL, dataset: DriveImageDataset, num_epochs: int = 5, batch_size: int = 1, save_interval: int = 100):
    try:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=DriveImageDataset.collate_fn)
        optimizer = torch.optim.AdamW(model.pipeline.unet.parameters(), lr=1e-5)
        scaler = torch.cuda.amp.GradScaler()
        noise_scheduler = model.pipeline.scheduler

        report = []
        for epoch in range(num_epochs):
            print(f"Starting epoch {epoch + 1}/{num_epochs}")
            total_loss = 0
            num_batches = 0

            for batch_idx, batch in enumerate(tqdm(dataloader)):
                if batch is None:
                    continue
                try:
                    optimizer.zero_grad()
                    images = batch["image"].to(model.accelerator.device, dtype=torch.float16)
                    latents = model.pipeline.vae.encode(images).latent_dist.sample()
                    latents = latents * model.pipeline.vae.config.scaling_factor

                    # Add noise
                    noise = torch.randn_like(latents, dtype=torch.float16)
                    batch_size = latents.shape[0]
                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=latents.device).long()
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                    # Get prompt embeddings
                    prompt_ids = model.pipeline.tokenizer(batch["prompt"], padding="max_length", max_length=model.pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt").input_ids.to(model.accelerator.device)

                    encoder_hidden_states = model.pipeline.text_encoder(prompt_ids)[0]

                    with torch.amp.autocast('cuda'):
                        model_pred = model.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states).sample
                        loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float())

                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                    total_loss += loss.item()
                    num_batches += 1

                    if batch_idx % save_interval == 0:
                        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_e{epoch}_b{batch_idx}.pt")
                        torch.save(model.pipeline.unet.state_dict(), checkpoint_path)

                except Exception as e:
                    print(f"Error during batch {batch_idx}: {e}")
                    continue

            avg_loss = total_loss / num_batches if num_batches > 0 else 0
            memory_usage = {
                "gpu_memory_mb": torch.cuda.memory_allocated() / 1024**2,
                "ram_memory_mb": psutil.Process().memory_info().rss / 1024**2,
            }
            report.append({
                "epoch": epoch + 1,
                "average_loss": avg_loss,
                "gpu_memory_mb": memory_usage["gpu_memory_mb"],
                "ram_memory_mb": memory_usage["ram_memory_mb"]
            })

            print(f"Epoch {epoch + 1} completed with average loss: {avg_loss:.4f}, Memory: {memory_usage}")

        # Save report
        report_path = os.path.join(ROOT_DIR, "training_report.json")
        with open(report_path, "w") as f:
            json.dump(report, f, indent=4)
        print(f"Training report saved at {report_path}")

    except Exception as e:
        print(f"Training error: {str(e)}")
    finally:
        torch.cuda.empty_cache()
        gc.collect()

In [5]:
def run_training():
    try:
        print("Initializing training components...")
        config = OptimizationConfig(image_size=(512, 512))
        model = OptimizedSDXL(config=config)
        dataset = DriveImageDataset(image_size=(512, 512))
        print(f"Dataset loaded with {len(dataset)} images")

        train_model(
            model=model,
            dataset=dataset,
            num_epochs=5,
            batch_size=1,
            save_interval=100
        )
        print("Training completed successfully")
    except Exception as e:
        print(f"Training failed: {str(e)}")
    finally:
        torch.cuda.empty_cache()
        gc.collect()

In [6]:
def generate_image(self, prompt: str, negative_prompt: Optional[str] = None, num_inference_steps: int = 50, **kwargs):
    try:
        print(f"Generating image for prompt: {prompt}")

        # Generate latents
        with torch.cuda.amp.autocast():
            output = self.pipeline(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, **kwargs)
            image = output.images[0]

        # Debug latents and image
        latents = output.latents if hasattr(output, 'latents') else None
        if latents is not None:
            print("Latents min/max:", latents.min(), latents.max())
            if torch.isnan(latents).any():
                print("Latents contain NaN values!")
                latents = torch.nan_to_num(latents, nan=0.0, posinf=1.0, neginf=-1.0)

        if image is not None:
            print("Image tensor min/max:", torch.min(image), torch.max(image))
            if torch.isnan(image).any():
                print("Image contains NaN values!")
                image = torch.nan_to_num(image, nan=0.0, posinf=1.0, neginf=-1.0)

        return image
    except Exception as e:
        print(f"Error generating image: {e}")
        return None

In [7]:
prompt = "A beautiful mountain landscape with a sunset and snow-capped peaks"
negative_prompt = "blurry, low quality, distorted"

generated_image = test_generation(
    prompt=prompt,
    negative_prompt=negative_prompt
)

NameError: name 'test_generation' is not defined