In [None]:
# Cell 1: Import required libraries
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
import matplotlib.pyplot as plt

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Cell 2: Define dataset preprocessing and load images
# Assuming your training dataset is in the `./dataset` directory in the root folder
root_folder = "./dataset"

# Define transformations for images (resize, crop, normalize)
transform = transforms.Compose([
    transforms.Resize((512, 512)),  # Resize images to 512x512
    transforms.ToTensor(),  # Convert images to tensor
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize RGB channels
])

# Load the dataset
dataset = datasets.ImageFolder(root=root_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Check the dataset
print(f"Number of images in the dataset: {len(dataset)}")

# Display a batch of images
def show_batch(loader):
    images, _ = next(iter(loader))
    images = images.permute(0, 2, 3, 1)  # Move channels to last dimension for visualization
    images = (images * 0.5 + 0.5).numpy()  # Denormalize for display
    plt.figure(figsize=(10, 10))
    for i in range(len(images)):
        plt.subplot(2, 4, i + 1)
        plt.imshow(images[i])
        plt.axis("off")
    plt.show()

show_batch(dataloader)

In [None]:
# Cell 3: Define or load the model
from diffusers import StableDiffusionXLInpaintPipeline

# Load the SDXL model pipeline for inpainting
model = StableDiffusionXLInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base").to(device)

print("Model loaded successfully.")

In [None]:
# Cell 4: Fine-tune SDXL on your custom dataset
# Note: For this demonstration, we'll just process one batch to check model functionality

# Example loop to fine-tune or apply inpainting on dataset images
for i, (images, _) in enumerate(dataloader):
    if i >= 1:  # Process only the first batch for demo
        break
    
    # Send images to the device
    images = images.to(device)
    
    # Generate inpainting results
    prompt = "A portrait of a person with clear skin and vibrant features"
    results = model(prompt=prompt, image=images)

    # Visualize results
    for j, img in enumerate(results.images):
        plt.imshow(img)
        plt.title(f"Result {j+1}")
        plt.axis("off")
        plt.show()

In [None]:
# Cell 5: Save the output images
output_folder = "./output"
os.makedirs(output_folder, exist_ok=True)

# Save the results from the last batch
for idx, img in enumerate(results.images):
    img.save(os.path.join(output_folder, f"result_{idx}.png"))

print(f"Processed images saved to {output_folder}.")

In [None]:
# Cell 6: Measure processing time and efficiency
import time

start_time = time.time()

# Process an entire dataset (for evaluation)
for i, (images, _) in enumerate(dataloader):
    images = images.to(device)
    model(prompt="A sample test prompt", image=images)

end_time = time.time()
print(f"Total time taken to process dataset: {end_time - start_time:.2f} seconds.")

In [None]:
# Cell 7: Analyze and visualize outputs
# Assuming outputs are stored in the ./output folder
output_images = [Image.open(os.path.join(output_folder, img)) for img in os.listdir(output_folder)]

# Display output images
plt.figure(figsize=(12, 12))
for i, img in enumerate(output_images):
    plt.subplot(3, 3, i + 1)
    plt.imshow(img)
    plt.axis("off")
plt.show()