In [1]:
from datasets import load_dataset

# Load your dataset
dataset = load_dataset("imagefolder", data_dir="dataset/train")

ModuleNotFoundError: No module named 'datasets'

In [None]:
from diffusers import StableDiffusionInpaintPipeline
from accelerate import Accelerator
from torch.utils.data import DataLoader
from transformers import AdamW

# Load the inpainting pipeline
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")

# Freeze VAE weights
for param in pipe.vae.parameters():
    param.requires_grad = False

# Freeze Text Encoder weights (if using text conditioning)
for param in pipe.text_encoder.parameters():
    param.requires_grad = False

# Fine-tune only the U-Net
for param in pipe.unet.parameters():
    param.requires_grad = True

pipe.to("cuda")

# Define data loader
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Optimizer
optimizer = AdamW(pipe.unet.parameters(), lr=5e-5)

# Accelerator for distributed training
accelerator = Accelerator()
pipe, optimizer, train_dataloader = accelerator.prepare(pipe, optimizer, train_dataloader)

# Training loop
for epoch in range(10):  # Adjust based on dataset size
    pipe.train()
    for batch in train_dataloader:
        masked_images = batch["masked_images"].to("cuda")
        full_images = batch["full_images"].to("cuda")

        # Forward pass
        loss = pipe(masked_images, full_images)["loss"]

        # Backward pass and optimization
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch+1} completed")
