In [3]:
import torch
from torch.autograd import Variable
from torchvision import transforms
import torch.nn.functional as F

def score_distillation_sampling(image, text_prompt, num_steps=100, lr=0.01):
    # Convert the input image to a tensor, ensure it's on the correct device, and requires gradient
    image_tensor = transforms.ToTensor()(image).unsqueeze(0)
    image_tensor = image_tensor.to(device)
    
    # Ensure that the tensor requires gradient
    image_tensor = Variable(image_tensor, requires_grad=True)
    
    # Optimization loop
    for step in range(num_steps):
        # Forward pass through the diffusion model with guiding text
        output = pipe(prompt=text_prompt, image=image_tensor)['images'][0]
        
        # Convert output to a tensor and ensure it's on the same device
        output_tensor = transforms.ToTensor()(output).unsqueeze(0)
        output_tensor = output_tensor.to(image_tensor.device)
        
        # Resize the output tensor to match the input tensor dimensions
        output_tensor_resized = F.interpolate(output_tensor, size=image_tensor.shape[2:], mode='bilinear', align_corners=False)
        
        # Compute the gradient of the loss with respect to the image
        loss = ((output_tensor_resized - image_tensor) ** 2).mean()
        loss.backward()
        
        # Update image tensor with the gradient
        with torch.no_grad():
            image_tensor -= lr * image_tensor.grad
        
        # Zero the gradients for the next iteration
        image_tensor.grad.zero_()

    # Convert the final tensor back to an image for display or further processing
    transformed_image = transforms.ToPILImage()(image_tensor.squeeze().cpu())
    return transformed_image


In [4]:
# Adaptation Process for Score Distillation Sampling (SDS) with Stable Diffusion V1.4
# Import necessary libraries
import torch
from diffusers import StableDiffusionPipeline
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms


import torch
from diffusers import StableDiffusionPipeline

model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"


pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
torch.cuda.empty_cache()
# Enable model CPU offloading if on GPU
pipe.enable_model_cpu_offload()
# pipe = pipe.to(device)

def display_image(image):
    # If the image is already a PIL image, just display it directly
    if isinstance(image, torch.Tensor):
        image_np = image.squeeze().permute(1, 2, 0).cpu().detach().numpy()
        plt.imshow(np.clip(image_np, 0, 1))
    else:
        plt.imshow(image)
    plt.axis('off')
    plt.show()

# Load a sample image (Here, a random image is generated for demonstration)
def generate_random_image(height, width):
    return np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)

# Run the basic SDS task
initial_image = generate_random_image(512, 512)
text_prompt = "A beautiful landscape painting of mountains during sunset"
transformed_image = score_distillation_sampling(initial_image, text_prompt)

# Display the final transformed image
display_image(transformed_image)

# # Bonus Test 1: Experiment with different guiding signals (e.g., different queries)
# alternative_prompt = "A futuristic cityscape at night"
# alternative_transformed_image = score_distillation_sampling(initial_image, alternative_prompt)
# display_image(alternative_transformed_image)

# # Bonus Test 2: Experiment with corrupted images (e.g., zero out some pixels)
# def corrupt_image(image, corruption_type="zero_out"):
#     corrupted_image = image.copy()
#     if corruption_type == "zero_out":
#         corrupted_image[::2, ::2] = 0
#     elif corruption_type == "color_jitter":
#         corrupted_image = (corrupted_image + np.random.randint(-30, 30, image.shape)) % 256
#     return corrupted_image

corrupted_image = corrupt_image(initial_image, "zero_out")
corrupted_transformed_image = score_distillation_sampling(corrupted_image, text_prompt)
display_image(corrupted_transformed_image)



Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00,  5.94it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 