In [22]:
import torch
from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel, AutoencoderKL
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoTokenizer
from PIL import Image
from accelerate import Accelerator
import os
from torch.utils.data import Dataset, DataLoader

In [4]:
model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

Loading pipeline components...:   0%|          | 0/9 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32, 
    lora_dropout=0.1, 
    bias="none"
)

In [8]:
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")

preprocessor_config.json:   0%|          | 0.00/445 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/527 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.60k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.88G [00:00<?, ?B/s]

In [9]:
def generate_caption(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")
    caption = model.generate(**inputs)
    return processor.decode(caption[0], skip_special_tokens=True)


In [15]:
image_folder = "datasets/designs/"
caption_file = "datasets/caption.txt"
for img in os.listdir(image_folder):
    caption = generate_caption(os.path.join(image_folder, img))
    print(f"{img}: {caption}")
    with open(caption_file, "a") as f:
        f.write(f'{img}\t{caption}\n')

p03.jpg: gnome with a flower and butterfly on a white background
p06.jpg: a black and white drawing of a hand with a triangle and all seeing symbols
ghouldfriend_p01.png: a close up of a black bat with big eyes and a big nose
p02.jpg: cartoon illustration of a bee with a pot of honey
p12.png: two wooden nutcrackers with pine cones and hollyconnets on them
p10.jpg: there is a snowman with a hat and scarf holding a present
p01.jpg: gnome with a butterfly on his head holding a flower pot
p11.jpg: a close up of a toy figure of a cat with a cup
p09.jpg: there is a snowman with a scarf and hat holding a red ribbon
p08.jpg: a close up of a penguin with a christmas present on a calendar
p04.jpg: a close up of a halloween candle holder with a witch hat on top
p05.jpg: there are three pumpkins stacked on top of each other
p07.jpg: a close up of a skeleton hand holding a skull on a stand


In [19]:
class CustomDataset(Dataset):
    def __init__(self, image_dir, caption_file, tokenizer, image_size=1024):
        self.image_dir = image_dir
        self.image_size = image_size
        self.tokenizer = tokenizer
        
        # Load captions
        self.image_caption_pairs = []
        with open(caption_file, 'r') as f:
            for line in f:
                image_name, caption = line.strip().split('\t')
                self.image_caption_pairs.append((image_name, caption))

    def __len__(self):
        return len(self.image_caption_pairs)

    def __getitem__(self, idx):
        image_name, caption = self.image_caption_pairs[idx]
        
        # Load and preprocess image
        image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB')
        image = image.resize((self.image_size, self.image_size))
        image = torch.from_numpy(np.array(image)).float() / 127.5 - 1
        image = image.permute(2, 0, 1)
        
        # Tokenize caption
        tokens = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=77,
            return_tensors="pt"
        )
        
        return {
            "pixel_values": image,
            "input_ids": tokens.input_ids[0],
            "attention_mask": tokens.attention_mask[0]
        }


In [29]:
def train_stable_diffusion(
    model_id="stabilityai/stable-diffusion-xl-base-1.0",
    image_dir="datasets/designs",
    caption_file="datasets/caption.txt",
    output_dir="fine_tuned_model",
    num_epochs=10,
    batch_size=1,
    learning_rate=1e-5,
    gradient_accumulation_steps=4
):
    # Initialize accelerator
    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        mixed_precision="fp16"
    )

    pipeline = StableDiffusionXLPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16"
    )

    # Get tokenizer from the pipeline
    tokenizer = pipeline.tokenizer
    if tokenizer is None:
        # Fallback to default CLIP tokenizer if pipeline tokenizer is not available
        tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    vae = pipeline.vae
    unet = pipeline.unet
    text_encoder = pipeline.text_encoder
    
    # Freeze VAE and text encoder
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    
    # Create dataset and dataloader
    dataset = CustomDataset(image_dir, caption_file, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Setup optimizer
    optimizer = torch.optim.AdamW(
        unet.parameters(),
        lr=learning_rate
    )
    
    # Prepare for training
    unet, optimizer, dataloader = accelerator.prepare(
        unet, optimizer, dataloader
    )
    
    # Training loop
    for epoch in range(num_epochs):
        unet.train()
        total_loss = 0
        
        for batch in dataloader:
            with accelerator.accumulate(unet):
                # Get latent representation
                latents = vae.encode(
                    batch["pixel_values"].to(dtype=vae.dtype)
                ).latent_dist.sample()
                latents = latents * vae.config.scaling_factor
                
                # Add noise
                noise = torch.randn_like(latents)
                timesteps = torch.randint(
                    0, pipeline.scheduler.config.num_train_timesteps,
                    (latents.shape[0],), device=latents.device
                )
                noisy_latents = pipeline.scheduler.add_noise(
                    latents, noise, timesteps
                )
                
                # Get text embeddings
                encoder_hidden_states = text_encoder(batch["input_ids"])[0]
                
                # Predict noise
                noise_pred = unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states
                ).sample
                
                # Calculate loss
                loss = torch.nn.functional.mse_loss(
                    noise_pred, noise, reduction="none"
                ).mean()
                
                # Backward pass
                accelerator.backward(loss)
                
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(unet.parameters(), 1.0)
                    
                optimizer.step()
                optimizer.zero_grad()
                
                total_loss += loss.detach().item()
        
        # Print progress
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
        
        # Save checkpoint
        if (epoch + 1) % 5 == 0:
            pipeline.save_pretrained(
                os.path.join(output_dir, f"checkpoint-{epoch+1}")
            )
    
    # Save final model
    pipeline.save_pretrained(output_dir)
    return pipeline

In [None]:
trained_pipeline = train_stable_diffusion(
    model_id="stabilityai/stable-diffusion-xl-base-1.0",
    image_dir="datasets/designs",
    caption_file="datasets/caption.txt",
    output_dir="fine_tuned_sd3",
    num_epochs=10
)

Fetching 19 files:   0%|          | 0/19 [00:00<?, ?it/s]

model.fp16.safetensors:   0%|          | 0.00/246M [00:00<?, ?B/s]

diffusion_pytorch_model.fp16.safetensors:   0%|          | 0.00/167M [00:00<?, ?B/s]

model.fp16.safetensors:   0%|          | 0.00/1.39G [00:00<?, ?B/s]

diffusion_pytorch_model.fp16.safetensors:   0%|          | 0.00/167M [00:00<?, ?B/s]

diffusion_pytorch_model.fp16.safetensors:   0%|          | 0.00/5.14G [00:00<?, ?B/s]