# Fine-Tune Stable Diffusion with Flickr8k using LoRA

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from transformers import CLIPTokenizer
from diffusers import StableDiffusionPipeline, DDPMScheduler
from peft import get_peft_model, LoraConfig
from datasets import Dataset as HFDataset
from tqdm import tqdm

In [None]:
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip

--2025-04-22 00:06:28--  https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
Resolving github.com (github.com)... 140.82.112.4
Connecting to github.com (github.com)|140.82.112.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/124585957/47f52b80-3501-11e9-8f49-4515a2a3339b?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250422%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250422T000628Z&X-Amz-Expires=300&X-Amz-Signature=b2b3ce6c8448de59a241fb6e28d92acf4d3b838c04f2dabe871b00af19494c2f&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3DFlickr8k_Dataset.zip&response-content-type=application%2Foctet-stream [following]
--2025-04-22 00:06:28--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/124585957/47f52b80-3501-11e9-8f49-4515a2a3339b?X-Amz-Algorithm=AWS4-HMAC-SHA2

In [None]:
!unzip -q Flickr8k_Dataset.zip
!unzip -q Flickr8k_text.zip

In [None]:
import torch

# Check if GPU is available
print("CUDA Available:", torch.cuda.is_available())

# Show GPU name if available
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
    print("Total GPU Memory (MB):", torch.cuda.get_device_properties(0).total_memory / (1024**2))
else:
    print("No GPU found.")


CUDA Available: True
GPU Name: NVIDIA A100-SXM4-40GB
Total GPU Memory (MB): 40506.8125


In [None]:
# === Config ===
image_dir = "/content/Flicker8k_Dataset"
captions_file = "/content/Flickr8k.token.txt"
pretrained_model = "CompVis/stable-diffusion-v1-4"
output_dir = "./sd-fine-tuned-lora"
image_size = 512
batch_size = 12
num_epochs =10
lr = 5e-6
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# === Load Captions ===
def load_captions(captions_path, image_dir):
    pairs = []
    with open(captions_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith("image,caption"):
                continue
            try:
                img_with_id, caption = line.split("\t")
                img = img_with_id.split("#")[0]
                full_img_path = os.path.join(image_dir, img)
                if os.path.exists(full_img_path):
                    pairs.append({'image': full_img_path, 'caption': caption.strip()})
            except Exception as e:
                print(f"Error processing line: {line} - {e}")
    return pairs



# Load and convert to HuggingFace Dataset
pairs = load_captions(captions_file,image_dir)[:1000]
from datasets import Dataset as HFDataset
hf_dataset = HFDataset.from_list(pairs)



In [None]:
# === Tokenizer and Transform ===
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
# === Custom Dataset ===
class FlickrDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        example = self.data[idx]
        image = Image.open(example['image']).convert('RGB')
        pixel_values = transform(image)
        text_inputs = tokenizer(example['caption'], padding='max_length', truncation=True, max_length=77, return_tensors='pt')
        return {
            'pixel_values': pixel_values,
            'input_ids': text_inputs.input_ids.squeeze(0),
            'attention_mask': text_inputs.attention_mask.squeeze(0)
        }

dataset = FlickrDataset(hf_dataset)

In [None]:
# === Load Pipeline and Freeze ===
pipe = StableDiffusionPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16 if device=="cuda" else torch.float32)
pipe.to(device)

pipe.enable_xformers_memory_efficient_attention()
pipe.vae.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), ep

In [None]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=[
        "to_q", "to_k", "to_v", "to_out.0",
        "ff.net.0.proj", "ff.net.2"
    ],
    bias="none",
)

pipe.unet = get_peft_model(pipe.unet, lora_config)


In [None]:
# === Training ===
optimizer = torch.optim.Adam(pipe.unet.parameters(), lr=lr)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)

for epoch in range(num_epochs):
    pipe.unet.train()
    for batch in tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        images = batch['pixel_values'].to(device, dtype=torch.float16)
        latents = pipe.vae.encode(images).latent_dist.sample() * 0.18215
        input_ids = batch['input_ids'].to(device)
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
        noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
        with torch.no_grad():
            encoder_hidden_states = pipe.text_encoder(input_ids)[0].to(dtype=torch.float16)
        model_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample
        loss = torch.nn.functional.mse_loss(model_pred, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}: Loss = {loss.item()}')

Epoch 1/10: 100%|██████████| 84/84 [00:46<00:00,  1.82it/s]


Epoch 1: Loss = 0.1607666015625


Epoch 2/10: 100%|██████████| 84/84 [00:46<00:00,  1.81it/s]


Epoch 2: Loss = 0.326416015625


Epoch 3/10: 100%|██████████| 84/84 [00:45<00:00,  1.83it/s]


Epoch 3: Loss = 0.1600341796875


Epoch 4/10: 100%|██████████| 84/84 [00:45<00:00,  1.83it/s]


Epoch 4: Loss = 0.08660888671875


Epoch 5/10: 100%|██████████| 84/84 [00:45<00:00,  1.83it/s]


Epoch 5: Loss = 0.218994140625


Epoch 6/10: 100%|██████████| 84/84 [00:45<00:00,  1.83it/s]


Epoch 6: Loss = 0.317626953125


Epoch 7/10: 100%|██████████| 84/84 [00:45<00:00,  1.83it/s]


Epoch 7: Loss = 0.2587890625


Epoch 8/10: 100%|██████████| 84/84 [00:45<00:00,  1.83it/s]


Epoch 8: Loss = 0.25341796875


Epoch 9/10: 100%|██████████| 84/84 [00:45<00:00,  1.83it/s]


Epoch 9: Loss = 0.1951904296875


Epoch 10/10: 100%|██████████| 84/84 [00:45<00:00,  1.83it/s]

Epoch 10: Loss = 0.16650390625





In [None]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Torch version:", torch.__version__)
if torch.cuda.is_available():
    print("Device:", torch.cuda.get_device_name(0))

CUDA available: True
Torch version: 2.6.0+cu124
Device: NVIDIA A100-SXM4-40GB


In [None]:
# === Save final LoRA fine-tuned UNet ===
pipe.unet.save_pretrained(output_dir)
print(f"Fine-tuned U-Net saved to: {output_dir}")


Fine-tuned U-Net saved to: ./sd-fine-tuned-lora


In [None]:
from diffusers import StableDiffusionPipeline
import torch
from peft import PeftModel, LoraConfig

# Load original SD pipeline
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe.to("cuda")

# Load your LoRA fine-tuned U-Net
from peft import PeftModel
pipe.unet = PeftModel.from_pretrained(pipe.unet, "./sd-fine-tuned-lora")
pipe.unet.eval()

# Enable faster generation
pipe.enable_attention_slicing()


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
# Prompt to test
prompt = "A dog playing basketball"

# Generate
with torch.autocast("cuda"):
    image = pipe(prompt=prompt, guidance_scale=7.5).images[0]

# Show image (Jupyter)
image.show()

# Optionally save
image.save("generated_output.png")

  0%|          | 0/50 [00:00<?, ?it/s]