<a href="https://colab.research.google.com/github/Rishardmunene/Stable-Diffusion-test/blob/main/Untitled1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
!pip install torch torchvision accelerate diffusers transformers PIL tqdm

[31mERROR: Could not find a version that satisfies the requirement PIL (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for PIL[0m[31m
[0m

In [16]:
import os
import logging
from typing import Optional, List, Tuple
from dataclasses import dataclass
from pathlib import Path
from datetime import datetime

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm.auto import tqdm
from accelerate import Accelerator
from diffusers import StableDiffusionXLPipeline

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [17]:
@dataclass
class SDXLConfig:
    model_id: str = "stabilityai/stable-diffusion-xl-base-1.0"
    vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
    image_size: Tuple[int, int] = (1024, 1024)
    train_batch_size: int = 1
    num_train_epochs: int = 5
    gradient_accumulation_steps: int = 1
    learning_rate: float = 1e-5
    max_grad_norm: float = 1.0
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_weight_decay: float = 1e-2
    mixed_precision: str = "fp16"
    save_interval: int = 500
    root_dir: str = "trained_models"
    seed: int = 42

In [18]:
class ImageDataset(Dataset):
    def __init__(self, image_dir: str, config: SDXLConfig):
        self.image_dir = Path(image_dir)
        self.config = config
        self.image_paths = list(self.image_dir.glob("*.jpg")) + list(self.image_dir.glob("*.png"))
        self.transform = transforms.Compose([
            transforms.Resize(config.image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return {
            "pixel_values": image,
            "prompt": f"A high quality photo of {image_path.stem}"
        }

In [19]:
class SDXLTrainer:
    def __init__(self, config: SDXLConfig):
        self.config = config
        self.accelerator = Accelerator(
            gradient_accumulation_steps=config.gradient_accumulation_steps,
            mixed_precision=config.mixed_precision
        )
        self._setup_pipeline()
        torch.manual_seed(config.seed)

    def _setup_pipeline(self):
        self.pipeline = StableDiffusionXLPipeline.from_pretrained(
            self.config.model_id,
            torch_dtype=torch.float16,
            use_safetensors=True,
            variant="fp16"
        ).to(self.accelerator.device)
        self.pipeline.enable_vae_slicing()
        self.pipeline.enable_attention_slicing()

    def train_step(self, batch):
        latents = self.pipeline.vae.encode(
            batch["pixel_values"].to(dtype=torch.float16)
        ).latent_dist.sample()
        latents = latents * self.pipeline.vae.config.scaling_factor

        noise = torch.randn_like(latents)
        timesteps = torch.randint(
            0, self.pipeline.scheduler.config.num_train_timesteps,
            (latents.shape[0],), device=latents.device
        )
        noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)

        prompt_embeds = self.pipeline.tokenizer(
            batch["prompt"],
            padding="max_length",
            max_length=self.pipeline.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt"
        ).input_ids.to(self.accelerator.device)

        encoder_hidden_states = self.pipeline.text_encoder(prompt_embeds)[0]

        with torch.amp.autocast('cuda'):
            model_pred = self.pipeline.unet(
                noisy_latents,
                timesteps,
                encoder_hidden_states
            ).sample

        loss = F.mse_loss(model_pred.float(), noise.float())
        return loss

    def train(self, train_dataset):
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.config.train_batch_size,
            shuffle=True,
            pin_memory=True
        )

        optimizer = torch.optim.AdamW(
            self.pipeline.unet.parameters(),
            lr=self.config.learning_rate,
            betas=(self.config.adam_beta1, self.config.adam_beta2),
            weight_decay=self.config.adam_weight_decay
        )

        self.pipeline.unet, optimizer, train_dataloader = self.accelerator.prepare(
            self.pipeline.unet, optimizer, train_dataloader
        )

        global_step = 0
        for epoch in range(self.config.num_train_epochs):
            self.pipeline.unet.train()
            for step, batch in enumerate(tqdm(train_dataloader)):
                with self.accelerator.accumulate(self.pipeline.unet):
                    loss = self.train_step(batch)
                    self.accelerator.backward(loss)

                    if self.accelerator.sync_gradients:
                        self.accelerator.clip_grad_norm_(
                            self.pipeline.unet.parameters(),
                            self.config.max_grad_norm
                        )

                    optimizer.step()
                    optimizer.zero_grad()

                if (step + 1) % 10 == 0:
                    logger.info(f"Epoch {epoch}, Step {step}: Loss = {loss.item():.4f}")

                if (step + 1) % self.config.save_interval == 0:
                    self._save_checkpoint(global_step)
                global_step += 1

    def _save_checkpoint(self, step: int):
        save_path = os.path.join(self.config.root_dir, f"checkpoint-{step}")
        self.accelerator.save(
            self.pipeline.unet.state_dict(),
            os.path.join(save_path, "unet.pt")
        )

In [20]:
def setup_training_environment():
    os.chdir('/content')

    root_dir = '/content'
    image_dir = os.path.join(root_dir, 'training_images')
    output_dir = os.path.join(root_dir, 'generated_images')

    for dir_path in [root_dir, image_dir, output_dir]:
        os.makedirs(dir_path, exist_ok=True)

    image_files = list(Path(image_dir).glob('*.jpg')) + list(Path(image_dir).glob('*.png'))
    if not image_files:
        logger.error(f"No images found in {image_dir}")
        logger.info("Please add .jpg or .png images to the training_images folder")
        logger.info(f"Current directory contents: {os.listdir(image_dir)}")
        raise ValueError(f"No training images found in {image_dir}")

    return root_dir, image_dir, output_dir, len(image_files)

In [21]:
def main():
    root_dir, image_dir, output_dir, num_images = setup_training_environment()
    logger.info(f"Found {num_images} training images")

    config = SDXLConfig(root_dir=root_dir)
    trainer = SDXLTrainer(config)
    train_dataset = ImageDataset(image_dir, config)

    trainer.train(train_dataset)
    logger.info("Training completed.")