In [1]:
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL
from diffusers.training_utils import EMAModel
from torch.optim import AdamW
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
import os
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CustomImageDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.transform = transform
        self.image_paths = [os.path.join(image_folder, img) for img in os.listdir(image_folder) if img.endswith(('.jpg', '.png'))]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGBA")  # Ensure image is in RGB format
        if self.transform:
            image = self.transform(image)
        return image

In [3]:
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]),  # Normalize for 3 channels (RGB)
])

In [5]:
dataset = CustomImageDataset("/Users/adhithyaasabareeswaran/Desktop/Face_SD_Finetune/Humans", transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipeline.to(device)

Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 10.81it/s]


StableDiffusionPipeline {
  "_class_name": "StableDiffusionPipeline",
  "_diffusers_version": "0.30.3",
  "_name_or_path": "CompVis/stable-diffusion-v1-4",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "image_encoder": [
    null,
    null
  ],
  "requires_safety_checker": true,
  "safety_checker": [
    "stable_diffusion",
    "StableDiffusionSafetyChecker"
  ],
  "scheduler": [
    "diffusers",
    "PNDMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [7]:
# Prepare optimizer
optimizer = AdamW(pipeline.unet.parameters(), lr=5e-6)
pipeline.unet.train()  # Set the model to training mode

# Fine-tuning loop
default_caption = "A photo of a human face"  # Default caption for all images


In [None]:
for epoch in range(5):  # Number of epochs
    for images in tqdm(dataloader, desc=f"Epoch {epoch + 1}/5"):
        images = images.to(device)

        # Generate random noise
        noise = torch.randn_like(images).to(device)
        timesteps = torch.randint(0, 1000, (images.size(0),), device=device).long()

        # Add noise to the images
        noisy_images = pipeline.scheduler.add_noise(images, noise, timesteps)

        # Tokenize the default caption
        text_inputs = pipeline.tokenizer(default_caption, return_tensors="pt").to(device)

        # Encode text to get encoder_hidden_states
        with torch.no_grad():
            encoder_hidden_states = pipeline.text_encoder(text_inputs.input_ids)[0]

        # Forward pass through UNet
        noise_pred = pipeline.unet(noisy_images, timesteps, encoder_hidden_states=encoder_hidden_states).sample

        # Compute loss
        loss = F.mse_loss(noise_pred, noise)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

In [None]:
pipeline.unet = unet
pipeline.save_pretrained("finetuned_model")