In [1]:
import torch
from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel, AutoencoderKL
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoTokenizer
from PIL import Image
import numpy as np
from accelerate import Accelerator
import os
from torch.utils.data import Dataset, DataLoader

In [4]:
# model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
# pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.float16)
# pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

Loading pipeline components...:   0%|          | 0/9 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
# lora_config = LoraConfig(
#     r=16,
#     lora_alpha=32, 
#     lora_dropout=0.1, 
#     bias="none"
# )

In [8]:
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")

preprocessor_config.json:   0%|          | 0.00/445 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/527 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.60k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.88G [00:00<?, ?B/s]

In [9]:
def generate_caption(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")
    caption = model.generate(**inputs)
    return processor.decode(caption[0], skip_special_tokens=True)

In [15]:
image_folder = "datasets/designs/"
caption_file = "datasets/caption.txt"
for img in os.listdir(image_folder):
    caption = generate_caption(os.path.join(image_folder, img))
    print(f"{img}: {caption}")
    with open(caption_file, "a") as f:
        f.write(f'{img}\t{caption}\n')

p03.jpg: gnome with a flower and butterfly on a white background
p06.jpg: a black and white drawing of a hand with a triangle and all seeing symbols
ghouldfriend_p01.png: a close up of a black bat with big eyes and a big nose
p02.jpg: cartoon illustration of a bee with a pot of honey
p12.png: two wooden nutcrackers with pine cones and hollyconnets on them
p10.jpg: there is a snowman with a hat and scarf holding a present
p01.jpg: gnome with a butterfly on his head holding a flower pot
p11.jpg: a close up of a toy figure of a cat with a cup
p09.jpg: there is a snowman with a scarf and hat holding a red ribbon
p08.jpg: a close up of a penguin with a christmas present on a calendar
p04.jpg: a close up of a halloween candle holder with a witch hat on top
p05.jpg: there are three pumpkins stacked on top of each other
p07.jpg: a close up of a skeleton hand holding a skull on a stand


In [2]:
class CustomDataset(Dataset):
    def __init__(self, image_dir, caption_file, tokenizer_1, tokenizer_2, target_size=1024):
        self.image_dir = image_dir
        self.target_size = target_size
        self.tokenizer_1 = tokenizer_1
        self.tokenizer_2 = tokenizer_2
        
        # Load captions
        self.image_caption_pairs = []
        with open(caption_file, 'r') as f:
            for line in f:
                image_name, caption = line.strip().split('\t')
                self.image_caption_pairs.append((image_name, caption))

    def resize_and_pad(self, image):
        """Resize image maintaining aspect ratio and pad if necessary."""
        # Get original dimensions
        original_width, original_height = image.size
        
        # Calculate aspect ratio
        aspect_ratio = original_width / original_height
        
        if aspect_ratio > 1:  # Width > Height
            new_width = self.target_size
            new_height = int(self.target_size / aspect_ratio)
        else:  # Height >= Width
            new_height = self.target_size
            new_width = int(self.target_size * aspect_ratio)
            
        # Resize image
        image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
        
        # Create new image with padding
        new_image = Image.new('RGB', (self.target_size, self.target_size), (0, 0, 0))
        
        # Calculate padding
        left_padding = (self.target_size - new_width) // 2
        top_padding = (self.target_size - new_height) // 2
        
        # Paste resized image onto padded background
        new_image.paste(image, (left_padding, top_padding))
        
        return new_image

    def __len__(self):
        return len(self.image_caption_pairs)

    def __getitem__(self, idx):
        image_name, caption = self.image_caption_pairs[idx]
        
        # Load and preprocess image
        image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB')
        
        # Resize and pad image
        image = self.resize_and_pad(image)
        
        # Convert to tensor and normalize
        image = torch.from_numpy(np.array(image)).float() / 127.5 - 1
        image = image.permute(2, 0, 1)
        
        # Tokenize caption with both tokenizers
        tokens_1 = self.tokenizer_1(
            caption,
            padding="max_length",
            truncation=True,
            max_length=77,
            return_tensors="pt"
        )
        
        tokens_2 = self.tokenizer_2(
            caption,
            padding="max_length",
            truncation=True,
            max_length=77,
            return_tensors="pt"
        )
        
        return {
            "pixel_values": image,
            "input_ids_1": tokens_1.input_ids[0],
            "input_ids_2": tokens_2.input_ids[0],
            "attention_mask_1": tokens_1.attention_mask[0],
            "attention_mask_2": tokens_2.attention_mask[0]
        }

In [3]:
def train_stable_diffusion(
    model_id="stabilityai/stable-diffusion-xl-base-1.0",
    image_dir="datasets/designs",
    caption_file="datasets/caption.txt",
    output_dir="fine_tuned_model",
    num_epochs=10,
    batch_size=1,
    learning_rate=1e-5,
    gradient_accumulation_steps=4,
    project_name="sdxl-finetuning",
    run_name=None
):
    # Initialize wandb
    wandb.init(
        project=project_name,
        name=run_name,
        config={
            "learning_rate": learning_rate,
            "batch_size": batch_size,
            "num_epochs": num_epochs,
            "gradient_accumulation_steps": gradient_accumulation_steps,
            "model": "stabilityai/stable-diffusion-xl-base-1.0",
            "image_dir": image_dir,
            "caption_file": caption_file
        }
    )

    # Initialize accelerator without mixed precision
    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        mixed_precision=None  # Changed from "fp16" to None
    )

    # Set device
    device = accelerator.device
    print(f"Using device: {device}")

    # Load base model
    print("Loading SDXL model...")
    pipeline = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float32,  # Changed from float16 to float32
        use_safetensors=True
    ).to(device)
    
    # Get components
    tokenizer_1 = pipeline.tokenizer
    tokenizer_2 = pipeline.tokenizer_2
    vae = pipeline.vae.to(device)
    unet = pipeline.unet.to(device)
    text_encoder_1 = pipeline.text_encoder.to(device)
    text_encoder_2 = pipeline.text_encoder_2.to(device)
    
    # Freeze VAE and text encoders
    vae.requires_grad_(False)
    text_encoder_1.requires_grad_(False)
    text_encoder_2.requires_grad_(False)
    
    # Create dataset and dataloader
    dataset = CustomDataset(image_dir, caption_file, tokenizer_1, tokenizer_2)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Setup optimizer
    optimizer = torch.optim.AdamW(
        unet.parameters(),
        lr=learning_rate,
    )
    
    # Prepare for training
    unet, optimizer, dataloader = accelerator.prepare(
        unet, optimizer, dataloader
    )
    
    # Training loop
    flobal_setp = 0
    for epoch in range(num_epochs):
        unet.train()
        total_loss = 0
        
        for step, batch in enumerate(dataloader):
            with accelerator.accumulate(unet):
                # Move batch to device and convert to float32
                pixel_values = batch["pixel_values"].to(device, dtype=torch.float32)
                input_ids_1 = batch["input_ids_1"].to(device)
                input_ids_2 = batch["input_ids_2"].to(device)
                
                # Get latent representation
                latents = vae.encode(pixel_values).latent_dist.sample()
                latents = latents * vae.config.scaling_factor
                
                # Add noise
                noise = torch.randn_like(latents)
                timesteps = torch.randint(
                    0, pipeline.scheduler.config.num_train_timesteps,
                    (latents.shape[0],), device=device
                )
                noisy_latents = pipeline.scheduler.add_noise(
                    latents, noise, timesteps
                )
                
                # Get text embeddings
                with torch.no_grad():
                    # First encoder
                    text_outputs_1 = text_encoder_1(
                        input_ids_1,
                        output_hidden_states=True,
                        return_dict=True
                    )
                    text_hidden_states_1 = text_outputs_1.hidden_states[-2]
                    
                    # Second encoder
                    text_outputs_2 = text_encoder_2(
                        input_ids_2,
                        output_hidden_states=True,
                        return_dict=True
                    )
                    text_hidden_states_2 = text_outputs_2.hidden_states[-2]
                    pooled_text_embeds_2 = text_outputs_2.text_embeds
                
                # Ensure proper dtype for embeddings
                text_hidden_states_1 = text_hidden_states_1.to(dtype=torch.float32)
                text_hidden_states_2 = text_hidden_states_2.to(dtype=torch.float32)
                pooled_text_embeds_2 = pooled_text_embeds_2.to(dtype=torch.float32)
                
                # Concatenate embeddings
                prompt_embeds = torch.cat([text_hidden_states_1, text_hidden_states_2], dim=-1)
                
                # Create time embeddings
                add_time_ids = torch.tensor([
                    1024, 1024,  # Original Size
                    0, 0,        # Crops top-left
                    1024, 1024,  # Target Size
                ], device=device, dtype=torch.float32)
                add_time_ids = add_time_ids.unsqueeze(0).repeat(batch_size, 1)
                
                # Add conditioning
                added_cond_kwargs = {
                    "text_embeds": pooled_text_embeds_2,
                    "time_ids": add_time_ids
                }
                
                # Predict noise
                noise_pred = unet(
                    noisy_latents,
                    timesteps,
                    prompt_embeds,
                    added_cond_kwargs=added_cond_kwargs
                ).sample
                
                # Calculate loss
                loss = torch.nn.functional.mse_loss(
                    noise_pred,
                    noise,
                    reduction="mean"
                )
                
                # Backward pass
                accelerator.backward(loss)
                
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(unet.parameters(), 1.0)
                
                optimizer.step()
                optimizer.zero_grad()
                
                total_loss += loss.detach().item()
                
                if step % 10 == 0:
                    print(f"Epoch {epoch+1}/{num_epochs}, Step {step}, Loss: {loss.item():.4f}")


                    # Log to wandb
                    wandb.log({
                        "loss": current_loss,
                        "learning_rate": learning_rate,
                        "epoch": epoch,
                        "global_step": global_step,
                    })

                global_step += 1
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs} completed. Average Loss: {avg_loss:.4f}")
        
        if (epoch + 1) % 5 == 0:
            pipeline.save_pretrained(
                os.path.join(output_dir, f"checkpoint-{epoch+1}")
            )
            wandb.save(os.path.join(checkpoint_dir, "*"))
    
    pipeline.save_pretrained(output_dir)
    wandb.save(os.path.join(checkpoint_dir, "*"))
    return pipeline

In [4]:
trained_pipeline = train_stable_diffusion(
    model_id="stabilityai/stable-diffusion-xl-base-1.0",
    image_dir="datasets/designs",
    caption_file="datasets/caption.txt",
    output_dir="fine_tuned_sd3",
    num_epochs=2
)

Using device: cuda
Loading SDXL model...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Epoch 1/2, Step 0, Loss: 0.5764
Epoch 1/2, Step 10, Loss: 0.1260
Epoch 1/2 completed. Average Loss: 0.5182
Epoch 2/2, Step 0, Loss: 0.0961
Epoch 2/2, Step 10, Loss: 0.3688
Epoch 2/2 completed. Average Loss: 0.2821
