In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from diffusers.models.attention_processor import LoRAAttnProcessor

from transformers import CLIPTokenizer, CLIPTextModel

from PIL import Image
from pathlib import Path
import os
from tqdm import tqdm

import numpy as np


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [3]:
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    revision="fp16",
).to("cuda")

pipe = pipe.to(device)
pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)

 The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title 'CompVis/stable-diffusion-v1-4 is missing fp16 files' so that the correct variant file can be added.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

An error occurred while trying to fetch C:\Users\adity\.cache\huggingface\hub\models--CompVis--stable-diffusion-v1-4\snapshots\2880f2ca379f41b0226444936bb7a6766a227587\unet: Error no file named diffusion_pytorch_model.safetensors found in directory C:\Users\adity\.cache\huggingface\hub\models--CompVis--stable-diffusion-v1-4\snapshots\2880f2ca379f41b0226444936bb7a6766a227587\unet.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
An error occurred while trying to fetch C:\Users\adity\.cache\huggingface\hub\models--CompVis--stable-diffusion-v1-4\snapshots\2880f2ca379f41b0226444936bb7a6766a227587\vae: Error no file named diffusion_pytorch_model.safetensors found in directory C:\Users\adity\.cache\huggingface\hub\models--CompVis--stable-diffusion-v1-4\snapshots\2880f2ca379f41b0226444936bb7a6766a227587\vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


In [4]:
unet = pipe.unet

In [5]:
# Set LoRA Adapters for UNet
rank = 4  # LoRA rank

# Iterate through all attention processors
for name, module in unet.attn_processors.items():
    if isinstance(module, LoRAAttnProcessor):
        continue  # Already LoRA
    cross_attention_dim = module.cross_attention_dim if hasattr(module, "cross_attention_dim") else None
    hidden_size = module.hidden_size if hasattr(module, "hidden_size") else None

    if cross_attention_dim is None or hidden_size is None:
        continue

    # Create LoRA processor
    lora_attn_processor = LoRAAttnProcessor(
        hidden_size=hidden_size,
        cross_attention_dim=cross_attention_dim,
        rank=rank
    )

    # Set it
    unet.set_attn_processor(name, lora_attn_processor)

print("LoRA injected successfully into UNet!")


LoRA injected successfully into UNet!


In [6]:
class DefectImageCaptionDataset(Dataset):
    def __init__(self, data_folder, image_size=512):
        self.data_folder = data_folder
        self.image_paths = []
        self.caption_paths = []
        for file in os.listdir(data_folder):
            if file.endswith(".png") or file.endswith(".jpg"):
                img_path = os.path.join(data_folder, file)
                txt_path = img_path.replace(".png", ".txt").replace(".jpg", ".txt")
                if os.path.exists(txt_path):
                    self.image_paths.append(img_path)
                    self.caption_paths.append(txt_path)
        
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img = self.transform(img)
        
        with open(self.caption_paths[idx], "r") as f:
            caption = f.read().strip()
        
        return {"image": img, "caption": caption}


In [7]:
dataset = DefectImageCaptionDataset(data_folder="../dataset/bottle/image/")
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [8]:
optimizer = torch.optim.AdamW(unet.parameters(), lr=0.0001)
num_epochs = 5
processor = pipe
# processor = {
#     "image_processor": pipe.feature_extractor,  # For images
#     "tokenizer": pipe.tokenizer,                # For captions
# }

In [10]:
dir(text_inputs)

['__annotations__',
 '__class__',
 '__class_getitem__',
 '__contains__',
 '__dataclass_fields__',
 '__dataclass_params__',
 '__delattr__',
 '__delitem__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__ior__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__match_args__',
 '__module__',
 '__ne__',
 '__new__',
 '__or__',
 '__post_init__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__reversed__',
 '__ror__',
 '__setattr__',
 '__setitem__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'clear',
 'copy',
 'fromkeys',
 'get',
 'images',
 'items',
 'keys',
 'move_to_end',
 'nsfw_content_detected',
 'pop',
 'popitem',
 'setdefault',
 'to_tuple',
 'update',
 'values']

In [11]:
unet.train()
from tqdm import tqdm

# Training settings
num_epochs = 5
learning_rate = 1e-4
gradient_accumulation_steps = 1  # can be >1 if you want
max_grad_norm = 1.0

# Optimizer and Scaler
optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
scaler = torch.cuda.amp.GradScaler()

# Training Loop
global_step = 0
for epoch in range(num_epochs):
    unet.train()
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}")
    
    for step, batch in enumerate(progress_bar):
        with torch.no_grad():
            pixel_values = (batch["image"].to(device) / 127.5) - 1.0  # scale images to [-1, 1]

            captions = batch["caption"]  # list of caption strings

            # Tokenize captions on the fly
            text_inputs = processor(
                captions,
                padding="max_length",
                max_length=512,
                truncation=True,
                return_tensors="pt"
            )
            print(text_inputs)
            input_ids = text_inputs.input_ids.to(device)
            
            # Encode text
            encoder_hidden_states = pipe.text_encoder(input_ids)[0]

            # Sample random noise
            noise = torch.randn_like(pixel_values)
            timesteps = torch.randint(
                0, pipe.scheduler.config.num_train_timesteps,
                (pixel_values.shape[0],),
                device=device
            ).long()

            # Add noise to images
            noisy_images = pipe.scheduler.add_noise(pixel_values, noise, timesteps)

        with torch.cuda.amp.autocast():
            # Predict the noise residual
            model_pred = unet(noisy_images, timesteps, encoder_hidden_states=encoder_hidden_states).sample
            # Loss
            loss = torch.nn.functional.mse_loss(model_pred, noise)

        # Backward
        scaler.scale(loss / gradient_accumulation_steps).backward()

        if (step + 1) % gradient_accumulation_steps == 0:
            # Gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            global_step += 1

        # Print loss
        if step % 10 == 0:
            progress_bar.set_postfix({"loss": loss.item()})

    print(f"Epoch {epoch} completed.")

print("Training Completed")

  scaler = torch.cuda.amp.GradScaler()
Epoch 0:   0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 0:   0%|          | 0/63 [00:08<?, ?it/s]

StableDiffusionPipelineOutput(images=[<PIL.Image.Image image mode=RGB size=512x512 at 0x167036B1360>], nsfw_content_detected=[False])





AttributeError: 'StableDiffusionPipelineOutput' object has no attribute 'input_ids'