In [None]:
# Installing necessary libraries

!pip install torch torchvision diffusers  # Installing PyTorch, torchvision (for datasets), and diffusers (for Stable Diffusion)

In [None]:
# Importing necessary modules

from diffusers import StableDiffusionPipeline  # Importing StableDiffusionPipeline for generating images

from torch import autocast  # Importing autocast for automatic mixed precision (helps speed up inference)

import torch  # Importing PyTorch for tensor manipulation and GPU management

from PIL import Image  # Importing PIL for image processing (e.g., opening, saving images)

import os  # Importing os for file and directory handling

from torchvision import datasets, transforms  # Importing torchvision for datasets and image transformations

In [None]:
# Setting up device for training (using GPU if available, otherwise fallback to CPU)

device = "cuda" if torch.cuda.is_available() else "cpu"  # Checking if CUDA is available for GPU, otherwise using CPU

In [None]:
# Loading a lightweight Stable Diffusion model

model_name = "stabilityai/stable-diffusion-2-1-base"  # Specifying the model to use (Stable Diffusion v2.1 base model)

pipe = StableDiffusionPipeline.from_pretrained(model_name).to(device)  # Loading the pre-trained model to the selected device (GPU or CPU)

In [None]:
# Loading CIFAR-10 dataset

transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])  # Resizing images to 512x512 and converting them to tensors

cifar10 = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)  # Downloading CIFAR-10 dataset and applying transformations

In [None]:
# Setting the prompt and parameters for image generation

prompt = "A realistic photo"  # Defining the text prompt used for generating images with Stable Diffusion

strength = 0.3  # Defining strength to control how much of the original image should be preserved (lower value = more original image preserved)

guidance_scale = 7.5  # Defining the guidance scale to control the strength of the guidance towards the prompt

In [None]:
# Defining the output directory for generated images

output_dir = '/kaggle/working/generated_images'  # Path on Google Drive to save generated images

os.makedirs(output_dir, exist_ok=True)  # Creating the output directory if it doesn't exist

In [None]:
import matplotlib.pyplot as plt

for idx in range(10):  # Looping to generate 10 images
    img, label = cifar10[idx]  # Getting an image and its label (class) from the CIFAR-10 dataset

    # Converting tensor to a numpy array and scaling to [0, 255] for image conversion
    img = img.permute(1, 2, 0).cpu().numpy()  # Converting the tensor from CxHxW format to HxWxC format and moving it to CPU
    img = (img * 255).astype('uint8')  # Scaling the pixel values from [0, 1] to [0, 255] and converting to uint8 type

    # Creating a PIL image from the numpy array
    img = Image.fromarray(img).convert("RGB")  # Converting the numpy array to a PIL image and ensuring it's in RGB mode

    # Generating the image with the updated parameters
    with autocast("cuda"):  # Using autocast for mixed precision (if CUDA is available, improves performance)
        generated_image = pipe(prompt=prompt, init_image=img, strength=strength, guidance_scale=guidance_scale).images[0]  # Generating the image

    # Saving the generated image with a unique name
    generated_image_path = os.path.join(output_dir, f'generated_image_{idx}.jpg')  # Constructing the path for saving the generated image
    generated_image.save(generated_image_path)  # Saving the generated image to the specified path

    # Displaying the generated image using matplotlib
    plt.imshow(generated_image)  # Displaying the generated image
    plt.axis('off')  # Optional: Hides axes for cleaner display
    plt.show()  # Show the image

    # Optionally printing the label of the image (CIFAR-10 labels are integers, but can be mapped to human-readable class names)
    print(f"Generated Image {idx} with label: {label}")  # Printing the index of the generated image and its corresponding CIFAR-10 label
