In [1]:
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.optimization import get_scheduler
from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor, CLIPModel
from accelerate import Accelerator
from torchmetrics.image.inception import InceptionScore

  from .autonotebook import tqdm as notebook_tqdm
A matching Triton is not available, some optimizations will not be enabled
Traceback (most recent call last):
  File "c:\Users\Dinesh\AppData\Local\Programs\Python\Python311\Lib\site-packages\xformers\__init__.py", line 57, in _is_triton_available
    import triton  # noqa
    ^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'triton'


# Part 1: Fine-Tuning Stable Diffusion with LoRA

In [2]:
class TextImageDataset(Dataset):
    def __init__(self, image_folder, captions_file, tokenizer, resolution=512):
        self.image_folder = image_folder
        self.resolution = resolution
        self.tokenizer = tokenizer

        # Load captions
        with open(captions_file, "r") as f:
            self.captions = [line.strip() for line in f.readlines()]

        # Load image filenames
        self.image_files = sorted(
            [f for f in os.listdir(image_folder) if f.endswith((".png", ".jpg", ".jpeg"))]
        )

        # Ensure the number of captions matches the number of images
        assert len(self.captions) == len(self.image_files), "Number of captions and images must match!"

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.image_folder, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")
        image = image.resize((self.resolution, self.resolution), Image.LANCZOS)
        # Convert to tensor and normalize to [-1, 1], ensuring float16 type
        image = torch.tensor(np.array(image) / 255.0, dtype=torch.float16).permute(2, 0, 1)  # Shape: (3, 512, 512)
        image = image * 2.0 - 1.0  # Normalize to [-1, 1]

        # Load caption
        caption = self.captions[idx]
        text_inputs = self.tokenizer(
            caption,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )

        return {
            "pixel_values": image,
            "input_ids": text_inputs["input_ids"].squeeze(),
        }

In [3]:
# Paths to your dataset
captions_file = "dataset/captions2k.txt"
image_folder = "dataset/images_2k"

In [4]:
# Initialize tokenizer (from CLIP)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

# Create dataset and dataloader
dataset = TextImageDataset(
    image_folder=image_folder,
    captions_file=captions_file,
    tokenizer=tokenizer,
    resolution=512,
)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)

# Initialize Accelerator for distributed training
accelerator = Accelerator()

In [5]:
# Load Stable Diffusion model
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    use_auth_token=False,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

Keyword arguments {'use_auth_token': False} are not expected by StableDiffusionPipeline and will be ignored.
Loading pipeline components...: 100%|██████████| 7/7 [00:04<00:00,  1.57it/s]


In [6]:
# Enable LoRA
pipe.unet.enable_lora()

In [7]:
# Prepare model, optimizer, and dataloader for training
unet = pipe.unet
vae = pipe.vae
text_encoder = pipe.text_encoder
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)

unet, vae, text_encoder, optimizer, dataloader = accelerator.prepare(
    unet, vae, text_encoder, optimizer, dataloader
)

In [8]:
# Learning rate scheduler
num_training_steps = len(dataloader) * 10  # 10 epochs
lr_scheduler = get_scheduler(
    "cosine",
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=num_training_steps,
)

In [9]:
print(f"Training on device: {accelerator.device}", flush=True)

Training on device: cuda


In [10]:
# Training loop with gradient accumulation
num_epochs = 3
global_step = 0
accumulation_steps = 2  # Accumulate gradients over 2 steps

print("Starting fine-tuning...", flush=True)
print(f"Training on device: {accelerator.device}", flush=True)
print(f"Model loaded on device: {pipe.device}", flush=True)

unet = unet.to(accelerator.device)
vae = vae.to(accelerator.device)
text_encoder = text_encoder.to(accelerator.device)

import time
start_time = time.time()

for epoch in range(num_epochs):
    print(f"Starting Epoch {epoch+1}/{num_epochs}", flush=True)
    unet.train()
    for batch in dataloader:
        batch_start = time.time()
        print("Processing batch...", flush=True)
        try:
            pixel_values = batch["pixel_values"].to(accelerator.device)
            input_ids = batch["input_ids"].to(accelerator.device)

            with torch.no_grad():
                text_embeddings = text_encoder(input_ids)[0].to(torch.float16)  # Ensure float16

            with torch.no_grad():
                latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
                latents = latents.to(torch.float16)  # Ensure float16

            noise = torch.randn_like(latents, dtype=torch.float16)  # Create noise in float16
            timesteps = torch.randint(0, pipe.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
            noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
            loss = torch.nn.functional.mse_loss(noise_pred, noise)

            # Scale loss for gradient accumulation
            loss = loss / accumulation_steps
            accelerator.backward(loss)

            global_step += 1

            if global_step % accumulation_steps == 0:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            if global_step % 100 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Step {global_step}, Loss: {(loss.item() * accumulation_steps):.4f}", flush=True)
        except Exception as e:
            print(f"Error in training step {global_step}: {e}", flush=True)
            break

        print(f"Batch time: {time.time() - batch_start:.2f} seconds")

Starting fine-tuning...
Training on device: cuda
Model loaded on device: cuda:0
Starting Epoch 1/3
Processing batch...


  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


Batch time: 1.72 seconds
Processing batch...
Batch time: 4.87 seconds
Processing batch...
Batch time: 21.95 seconds
Processing batch...
Batch time: 18.40 seconds
Processing batch...
Batch time: 19.33 seconds
Processing batch...
Batch time: 18.40 seconds
Processing batch...
Batch time: 18.80 seconds
Processing batch...
Batch time: 20.91 seconds
Processing batch...
Batch time: 19.50 seconds
Processing batch...
Batch time: 18.96 seconds
Processing batch...
Batch time: 20.75 seconds
Processing batch...
Batch time: 19.41 seconds
Processing batch...
Batch time: 19.73 seconds
Processing batch...
Batch time: 19.82 seconds
Processing batch...
Batch time: 19.19 seconds
Processing batch...
Batch time: 17.00 seconds
Processing batch...
Batch time: 19.69 seconds
Processing batch...
Batch time: 18.76 seconds
Processing batch...
Batch time: 19.00 seconds
Processing batch...
Batch time: 19.27 seconds
Processing batch...
Batch time: 20.85 seconds
Processing batch...
Batch time: 20.82 seconds
Processing

In [12]:
# Save the fine-tuned LoRA weights
accelerator.wait_for_everyone()
if accelerator.is_main_process:
    # Extract LoRA layers from the UNet
    unet_lora_layers = pipe.unet.state_dict()  # Get the fine-tuned UNet state dict
    # Optionally, you can extract LoRA layers from the text encoder if fine-tuned
    text_encoder_lora_layers = None  # Set to pipe.text_encoder.state_dict() if text encoder was fine-tuned

    # Save the LoRA weights
    pipe.save_lora_weights(
        "fine_tuned_lora_weights",
        unet_lora_layers=unet_lora_layers,
        text_encoder_lora_layers=text_encoder_lora_layers
    )
    print("Fine-tuned LoRA weights saved to 'fine_tuned_lora_weights'")

# Optionally, save the entire pipeline
pipe.save_pretrained("fine_tuned_stable_diffusion")
print("Fine-tuned model saved to 'fine_tuned_stable_diffusion'")

Fine-tuned LoRA weights saved to 'fine_tuned_lora_weights'
Fine-tuned model saved to 'fine_tuned_stable_diffusion'


# Part 2: Generate Images Based on Text Prompts

In [30]:
from diffusers import StableDiffusionPipeline
from safetensors.torch import load_file
import torch
import os

# Disable gradient computation to save memory (since we're only doing inference)
torch.no_grad()

# Load the base model on CPU
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float32,  # Use float32 for CPU (float16 is not supported on CPU)
    use_auth_token=False,
)

# Load the full fine-tuned UNet weights
weights_path = "fine_tuned_lora_weights/pytorch_lora_weights.safetensors"
state_dict = load_file(weights_path)

# Load the weights into the UNet
pipe.unet.load_state_dict(state_dict, strict=False)
print("Loaded fine-tuned UNet weights from", weights_path)

# Disable the safety checker
pipe.safety_checker = None

# Define the prompts
prompts = [
    "Anime girl with long silver hair and green eyes, standing in a school classroom, wearing a school uniform, holding a notebook, with a chalkboard in the background",
    "Anime boy with short black hair and blue eyes, standing in a school hallway, wearing a school uniform, with lockers in the background",
    "Anime girl with long brown hair, wearing a traditional school uniform, walking on a school campus, with cherry blossom trees in the background",
    "Anime boy with short blonde hair, wearing a school uniform, sitting at a school desk, reading a textbook, with a window showing a sunny day",
    "Anime girl with short pink hair, wearing a school uniform, standing in a school library, holding a book, with bookshelves and a window in the background"
]

# Create a directory to save generated images
output_dir = "generated_images"
os.makedirs(output_dir, exist_ok=True)

# Generate images one at a time with a fixed seed
print("\nStarting image generation (on CPU, this may take several minutes per image)...")
for i, prompt in enumerate(prompts):
    print(f"Generating image {i+1}/{len(prompts)}: {prompt[:50]}...")
    
    image = pipe(
        prompt,
        num_inference_steps=50,
        guidance_scale=7.5,
        num_images_per_prompt=1,
        generator=torch.Generator(device="cpu").manual_seed(42 + i),  # Use CPU generator
    ).images[0]
    
    # Create a short, descriptive filename based on the prompt
    prompt_snippet = prompt.lower().replace(" ", "_").replace(",", "").replace("anime_", "")[:50]
    filename = f"image_{i+1}_{prompt_snippet}.png"
    image_path = os.path.join(output_dir, filename)
    
    # Save the image
    image.save(image_path)
    print(f"Saved {image_path}")

print("Image generation complete!")

Keyword arguments {'use_auth_token': False} are not expected by StableDiffusionPipeline and will be ignored.
Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00,  9.58it/s]


Loaded fine-tuned UNet weights from fine_tuned_lora_weights/pytorch_lora_weights.safetensors

Starting image generation (on CPU, this may take several minutes per image)...
Generating image 1/5: Anime girl with long silver hair and green eyes, s...


100%|██████████| 50/50 [02:42<00:00,  3.25s/it]


Saved generated_images\image_1_girl_with_long_silver_hair_and_green_eyes_standing.png
Generating image 2/5: Anime boy with short black hair and blue eyes, sta...


100%|██████████| 50/50 [02:49<00:00,  3.38s/it]


Saved generated_images\image_2_boy_with_short_black_hair_and_blue_eyes_standing_i.png
Generating image 3/5: Anime girl with long brown hair, wearing a traditi...


100%|██████████| 50/50 [02:47<00:00,  3.35s/it]


Saved generated_images\image_3_girl_with_long_brown_hair_wearing_a_traditional_sc.png
Generating image 4/5: Anime boy with short blonde hair, wearing a school...


100%|██████████| 50/50 [02:49<00:00,  3.40s/it]


Saved generated_images\image_4_boy_with_short_blonde_hair_wearing_a_school_unifor.png
Generating image 5/5: Anime girl with short pink hair, wearing a school ...


100%|██████████| 50/50 [02:49<00:00,  3.39s/it]


Saved generated_images\image_5_girl_with_short_pink_hair_wearing_a_school_uniform.png
Image generation complete!


Here the images generated were completely black not because of NSFW restrications but in our debugging we found that its because the weights were originally being infered wrong.

# Part 3: Evaluate Generated Images (Inception Score and CLIP Similarity Score)

In [58]:
from transformers import CLIPProcessor, CLIPModel

# Load CLIP model
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Define prompts
prompts = [
    "Anime girl with long silver hair and green eyes, standing in a school classroom, wearing a school uniform, holding a notebook, with a chalkboard in the background",
    "Anime boy with short black hair and blue eyes, standing in a school hallway, wearing a school uniform, with lockers in the background",
    "Anime girl with long brown hair, wearing a traditional school uniform, walking on a school campus, with cherry blossom trees in the background",
    "Anime boy with short blonde hair, wearing a school uniform, sitting at a school desk, reading a textbook, with a window showing a sunny day",
    "Anime girl with short pink hair, wearing a school uniform, standing in a school library, holding a book, with bookshelves and a window in the background"
]

# Load images
print("Loading existing generated images...")
images = []
image_paths = []
output_dir = "generated_images"
for i in range(len(prompts)):
    prompt_snippet = prompts[i].lower().replace(" ", "_").replace(",", "").replace("anime_", "")[:50]
    img_path = os.path.join(output_dir, f"image_{i+1}_{prompt_snippet}.png")
    try:
        img = Image.open(img_path).convert("RGB")
        images.append(img)
        image_paths.append(img_path)
        # print(f"Loaded {img_path}")
    except Exception as e:
        print(f"Error loading image {i+1}: {e}")

# Process images and text with CLIP
print("\nComputing CLIP similarity scores...")
inputs = clip_processor(text=prompts[:len(images)], images=images, return_tensors="pt", padding=True)
outputs = clip_model(**inputs)
raw_scores = outputs.logits_per_image.diagonal()  # Raw logits

# Normalize raw scores to [0, 1] assuming a range of 0 to 100
normalized_scores = raw_scores / 100.0

# Print raw and normalized scores
for i, (prompt, raw_score, norm_score, img_path) in enumerate(zip(prompts[:len(images)], raw_scores, normalized_scores, image_paths)):
    # print(f"Raw CLIP Score for image {i+1} ({prompt[:50]}...): {raw_score.item():.2f}")
    print(f"Normalized CLIP Score (0-1) for image {i+1} ({prompt[:50]}...): {norm_score.item():.2f}")

# Compute average scores
avg_raw_score = raw_scores.mean().item()
avg_norm_score = normalized_scores.mean().item()
# print(f"\nAverage Raw CLIP Score: {avg_raw_score:.2f}")
print(f"Average Normalized CLIP Score (0-1): {avg_norm_score:.2f}")

print("Evaluation complete!")

Loading existing generated images...

Computing CLIP similarity scores...
Normalized CLIP Score (0-1) for image 1 (Anime girl with long silver hair and green eyes, s...): 0.35
Normalized CLIP Score (0-1) for image 2 (Anime boy with short black hair and blue eyes, sta...): 0.38
Normalized CLIP Score (0-1) for image 3 (Anime girl with long brown hair, wearing a traditi...): 0.38
Normalized CLIP Score (0-1) for image 4 (Anime boy with short blonde hair, wearing a school...): 0.41
Normalized CLIP Score (0-1) for image 5 (Anime girl with short pink hair, wearing a school ...): 0.42
Average Normalized CLIP Score (0-1): 0.39
Evaluation complete!


In [None]:
import torch.nn.functional as F
from torchvision.models import inception_v3, Inception_V3_Weights
from torchvision import transforms
from scipy.stats import entropy

# Configuration
image_folder = "generated_images"
batch_size = 1
splits = 10

# Transform (no resize)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Dataset
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, folder):
        self.images = [os.path.join(folder, f)
                       for f in os.listdir(folder)
                       if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx]).convert("RGB")
        return transform(img)

# Inception Score
def inception_score(dataset, batch_size=1, splits=10):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    device = torch.device("cpu")

    weights = Inception_V3_Weights.IMAGENET1K_V1
    model = inception_v3(weights=weights, transform_input=False, aux_logits=True)
    model.to(device)
    model.eval()

    preds = []
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            out = model(batch)
            softmax = F.softmax(out, dim=1)
            preds.append(softmax.cpu().numpy())

    preds = np.concatenate(preds, axis=0)
    N = preds.shape[0]
    splits = min(splits, N)

    split_scores = []
    for k in range(splits):
        part = preds[k * (N // splits): (k + 1) * (N // splits)]
        py = np.mean(part, axis=0)
        scores = [entropy(pyx, py) for pyx in part]
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)

# Run
dataset = ImageDataset(image_folder)
print(f"Number of images: {len(dataset)}")
mean, std = inception_score(dataset, batch_size=batch_size, splits=splits)
print(f"Inception Score: {mean:.3f} ± {std:.3f}")


Number of images: 6
Inception Score: 1.000 ± 0.000
