# Exercise: CLIP-Based Image/Text Retrieval and Synthesis

In this exercise, we will build some intuitive understanding of how a general-purpose image/text correspondence models such as OpenAI's Contrastive Language-Image Pretraining (CLIP) model and diffusion models can be used for generative and creative purposes.

#### Learning goals:

* Use CLIP encoders/decoders to perform image retrieval and synthesis
* Examine the use of diffusion model for image generation

#### After you've read, run, and understood the provided code, perform these tasks:

* Easy: change the prompt for image synthesis until you get an interesting image.
    * You can share your image (as long as appropriate) on class Discord.
* Easy: using your own text samples, perform latent blending between two text samples and visualize the resulting image using image synthesis.
    * Suggestion: See for yourself that latent math like "king" - "man" + "woman" = "queen" holds
    * Suggestion: try blending a "content prompt" with a "style prompt", e.g. "an aerial view of a city" + "a Van Gogh painting"
    * Again, feel free to share your results!
* Medium: modify the code to perform image search (by text and image)
    * Download the COCO Val2017 dataset (http://images.cocodataset.org/zips/val2017.zip)
    * Find an image within the dataset, whose CLIP-embedding is closest to the following image (https://i.imgur.com/S0iR4Zr.png):
    <div>
    <img src="https://i.imgur.com/S0iR4Zr.png" width="300"/>
    </div>
    * Find an image within the dataset, whose CLIP-embedding is closest to the following text prompt:
```
man eating pie
```
* Easy (optional): if your image/text retrieval is too slow, pre-compute embeddings to accelerate things.
* Easy: using your own text samples, perform latent blending between two text samples and visualize the resulting image using image retrieval.
* Easy: using your own text samples, perform latent blending between two text samples and search for the best sample text describing the blended text among Flickr8k captions (https://www.dropbox.com/s/4dgs7e0r3ypaqus/just_captions.txt?dl=1).
* Medium: perform latent blending between two image samples from COCO Val2017. Search for the best caption describing the blended image among Flickr8k captions.
* Hard: Use a CLIP encoder to create a "style embedding" from an image or a text, and then perform a style transfer on any image of your choice.
* Hard: Fork the clipcoders repository (https://github.com/namheegordonkim/clipcoders) and add another decoder that performs supersampling diffusion from 256x256 images to 512x512 images. (See https://github.com/crowsonkb/v-diffusion-pytorch/issues/10)
* Project idea: Fine-tune CLIP with an image caption dataset of your choice. Visualize the fine-tuned results with image retrieval and synthesis (for synthesis, use a classifier-guided diffusion decoder).

## Dependencies

We will use a customized encoder/decoder objects to make our code easier to read. After understanding what's going on, you are most welcome to read the source code, as understanding the array/tensor manipulations underneath will be more useful for training and scaling up the operations.

In [None]:
# CLIP
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

# Custom Encoder-Decoder wrappers
!pip install git+https://github.com/namheegordonkim/clipcoders

# Demo images
!wget https://i.imgur.com/eB1otqV.png -O child.png
!wget https://i.imgur.com/YETS1O1.png -O two-dogs-0.png
!wget https://i.imgur.com/H4tTA0r.png -O two-dogs-1.png
!wget https://i.imgur.com/6mpZTzW.jpg -O cat.jpg
!wget https://i.imgur.com/S0iR4Zr.png -O cat2.png

As usual, we import necessary dependencies at the top of the notebook:

In [None]:
import clip
import torch
import numpy as np
from clipcoders.encoders import CLIPTextEncoder, CLIPImageEncoder
from clipcoders.decoders import ClassifierFreeGuidanceDecoder
from clipcoders.diffusion.models import get_model
from clipcoders.diffusion.utils import to_pil_image, from_pil_image
from PIL import Image
from torch.nn import functional as F
from torchvision.transforms import functional as TF
from IPython.core.display import display, HTML
from glob import glob
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)
np.random.seed(0)

### Load a Pre-Trained CLIP Model

In [None]:
clip_model = clip.load("ViT-B/16")[0]
clip_model.eval().requires_grad_(False).to(device)

### Make CLIP-Based Encoders

In [None]:
clip_text_encoder = CLIPTextEncoder(
    clip_model=clip_model,
    device=device
)
clip_image_encoder = CLIPImageEncoder(
    clip_model=clip_model,
    device=device,
    cutn=16,
    cut_pow=1
)

What does a CLIP-based text encoder do?

In [None]:
sample_text_embedding = clip_text_encoder.encode("here is a sample text.")

In [None]:
print(sample_text_embedding.shape)

What does a CLIP-based image encoder do?

In [None]:
sample_image = Image.open("cat.jpg")
sample_image_tensor = from_pil_image(sample_image).unsqueeze(0).to(device)
sample_image_embedding = clip_image_encoder.encode(sample_image_tensor)

In [None]:
print(sample_image_embedding.shape)

## Measuring Latent Similarity with Embeddings

Why would one want to use image and text encoders? Image-text correspondence models like CLIP establishes a latent space where images and texts can be compared in terms of semantic content. Below, we provide a simple example.

Consider the following three image-text pairs:

<div>
    <table class="tg">
    <thead>
      <tr>
        <th>Index</th>
        <th class="tg-0pky">Image</th>
        <th class="tg-0pky">Text</th>
      </tr>
    </thead>
    <tbody>
      <tr>
        <td>0</td>
        <td class="tg-fymr"><img src="https://i.imgur.com/YETS1O1.png" width="300"/></td>
        <td class="tg-fymr">Two dogs of different breeds looking at each other on the road.</td>
      </tr>
      <tr>
        <td>1</td>
        <td class="tg-0pky"><img src="https://i.imgur.com/H4tTA0r.png" width="300"/></td>
        <td class="tg-0pky">Two dogs playing together on a beach.</td>
      </tr>
      <tr>
        <td>2</td>
        <td class="tg-0pky"><img src="https://i.imgur.com/eB1otqV.png" width="300"/></td>
        <td class="tg-0pky">A child biting into a baked good.</td>
      </tr>
    </tbody>
    </table>
</div>

Images 0 and 1 are similar in concept, and image 2 is very different from the others. How does CLIP help us quantify this difference?

In [None]:
# A distance metric to be used in latent spaces
def spherical_dist(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2).mean()

# Sometimes, you will want group-to-group distances
def spherical_dists(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)

# To make images more uniform-sized
def resize_and_center_crop(image, size):
    fac = max(size[0] / image.size[0], size[1] / image.size[1])
    image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
    return TF.center_crop(image, size[::-1])

In [None]:
# To quantify the difference among pairs of images
# Prepare image tensors
two_dogs_0_url = "https://i.imgur.com/YETS1O1.png"
two_dogs_1_url = "https://i.imgur.com/H4tTA0r.png"
child_url = "https://i.imgur.com/eB1otqV.png"

two_dogs_0_filename = "two-dogs-0.png"
two_dogs_1_filename = "two-dogs-1.png"
child_filename = "child.png"

two_dogs_0_image = resize_and_center_crop(Image.open(two_dogs_0_filename).convert("RGB"), (640, 480))
two_dogs_0_image_tensor = from_pil_image(two_dogs_0_image).unsqueeze(0).to(device)
two_dogs_1_image = resize_and_center_crop(Image.open(two_dogs_1_filename).convert("RGB"), (640, 480))
two_dogs_1_image_tensor = from_pil_image(two_dogs_1_image).unsqueeze(0).to(device)
child_image = resize_and_center_crop(Image.open(child_filename).convert("RGB"), (640, 480))
child_image_tensor = from_pil_image(child_image).unsqueeze(0).to(device)

# Prepare image embeddings
two_dogs_0_image_embedding = clip_image_encoder.encode(two_dogs_0_image_tensor)
two_dogs_1_image_embedding = clip_image_encoder.encode(two_dogs_1_image_tensor)
child_image_embedding = clip_image_encoder.encode(child_image_tensor)

# Finally, compute distances between embeddings
display(HTML(
    f"""
    <table>
    <tr>
    <td><img src={two_dogs_0_url} width=200px /></td>
    <td><img src={two_dogs_0_url} width=200px /></td>
    <td>dist={spherical_dist(two_dogs_0_image_embedding, two_dogs_0_image_embedding):.3f}</td>
    </tr>
    
    <tr>
    <td><img src={two_dogs_0_url} width=200px /></td>
    <td><img src={two_dogs_1_url} width=200px /></td>
    <td>dist={spherical_dist(two_dogs_0_image_embedding, two_dogs_1_image_embedding):.3f}</td>
    </tr>
    
    <tr>
    <td><img src={two_dogs_0_url} width=200px /></td>
    <td><img src={child_url} width=200px /></td>
    <td>dist={spherical_dist(two_dogs_0_image_embedding, child_image_embedding):.3f}</td>
    </tr>
    
    <tr>
    <td><img src={two_dogs_1_url} width=200px /></td>
    <td><img src={child_url} width=200px /></td>
    <td>dist={spherical_dist(two_dogs_1_image_embedding, child_image_embedding):.3f}</td>
    </tr>
    </table>
    """
))

In [None]:
# To quantify the difference among pairs of texts
# Prepare texts
two_dogs_0_text = "Two dogs of different breeds looking at each other on the road."
two_dogs_1_text = "Two dogs playing together on a beach."
child_text = "A child biting into a baked good."

# Prepare text embeddings
two_dogs_0_text_embedding = clip_text_encoder.encode(two_dogs_0_text)
two_dogs_1_text_embedding = clip_text_encoder.encode(two_dogs_1_text)
child_text_embedding = clip_text_encoder.encode(child_text)

# Compute distance between embeddings
display(HTML(
    f"""
    <table>
    <tr>
    <td>{two_dogs_0_text}</td>
    <td>{two_dogs_0_text}</td>
    <td>dist={spherical_dist(two_dogs_0_text_embedding, two_dogs_0_text_embedding):.3f}</td>
    </tr>
    
    <tr>
    <td>{two_dogs_0_text}</td>
    <td>{two_dogs_1_text}</td>
    <td>dist={spherical_dist(two_dogs_0_text_embedding, two_dogs_1_text_embedding):.3f}</td>
    </tr>
    
    <tr>
    <td>{two_dogs_0_text}</td>
    <td>{child_text}</td>
    <td>dist={spherical_dist(two_dogs_0_text_embedding, child_text_embedding):.3f}</td>
    </tr>
    
    <tr>
    <td>{two_dogs_1_text}</td>
    <td>{child_text}</td>
    <td>dist={spherical_dist(two_dogs_1_text_embedding, child_text_embedding):.3f}</td>
    </tr>
    </table>
    """
))

In [None]:
# To compute distance between images and texts
display(HTML(
    f"""
    <table>
    <tr>
    <td><img src={two_dogs_0_url} width=200px /></td>
    <td>{two_dogs_0_text}</td>
    <td>dist={spherical_dist(two_dogs_0_image_embedding, two_dogs_0_text_embedding):.3f}</td>
    </tr>
    
    <tr>
    <td><img src={two_dogs_0_url} width=200px /></td>
    <td>{two_dogs_1_text}</td>
    <td>dist={spherical_dist(two_dogs_0_image_embedding, two_dogs_1_text_embedding):.3f}</td>
    </tr>
    
    <tr>
    <td><img src={two_dogs_0_url} width=200px /></td>
    <td>{child_text}</td>
    <td>dist={spherical_dist(two_dogs_0_image_embedding, child_text_embedding):.3f}</td>
    </tr>
    
    <tr>
    <td><img src={two_dogs_1_url} width=200px /></td>
    <td>{two_dogs_1_text}</td>
    <td>dist={spherical_dist(two_dogs_1_image_embedding, two_dogs_1_text_embedding):.3f}</td>
    </tr>
    
    <tr>
    <td><img src={two_dogs_1_url} width=200px /></td>
    <td>{two_dogs_0_text}</td>
    <td>dist={spherical_dist(two_dogs_1_image_embedding, two_dogs_0_text_embedding):.3f}</td>
    </tr>
    
    <tr>
    <td><img src={two_dogs_1_url} width=200px /></td>
    <td>{child_text}</td>
    <td>dist={spherical_dist(two_dogs_0_image_embedding, child_text_embedding):.3f}</td>
    </tr>
    </table>
    """
))

## CLIP-Based Image Retrieval from Text

Since CLIP's image embeddings and text embeddings can be used to compute similarities across these domains, we can do something akin to Google's image search, where a text input is matched with images. Generally, these steps are involved:

* Encode text
* Encode images
* Get distance between text and each image
* Return the image with the lowest distance

In [None]:
# This is our "database" of images
images = [two_dogs_0_image, two_dogs_1_image, child_image]

# A function to make it easier to do image retrieval
def get_image_from_prompt(prompt, images):
    # to prevent gradient computation from clogging memory
    with torch.no_grad():
        prompt_embeddig = clip_text_encoder.encode(prompt)
        image_embeddings = [clip_image_encoder.encode(from_pil_image(image).unsqueeze(0).to(device)) for image in images]
        dists = np.array([spherical_dist(prompt_embeddig, embedding).mean().cpu().detach().numpy() for embedding in image_embeddings])
    return images[np.argmin(dists)]

In [None]:
get_image_from_prompt("beach", images)

In [None]:
get_image_from_prompt("child", images)

In [None]:
get_image_from_prompt("fighting", images)

## Exercise: Image Retrieval from COCO Val2017 Dataset with Text


In [None]:
!wget http://images.cocodataset.org/zips/val2017.zip
!unzip val2017.zip

In [None]:
class MyImageDataset(Dataset):
    """
    Makes it easy to interact with the downloaded dataset
    """
    
    def __init__(self, root):
        self.filenames = sorted(glob(f"{root}/*.jpg"))
        
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, index):
        return resize_and_center_crop(Image.open(self.filenames[index]).convert("RGB"), (640, 480))
    

In [None]:
image_dataset = MyImageDataset("val2017")

In [None]:
# TODO

## Exercise: Image Retrieval from COCO Val2017 Dataset with Image

In [None]:
# TODO

## Exercise: Finding Image for Latent Blending of Texts

In [None]:
# TODO

## Exercise: Finding Text for Latent Blending of Texts

In [None]:
!wget https://www.dropbox.com/s/4dgs7e0r3ypaqus/just_captions.txt?dl=1 -O just_captions.txt

In [None]:
class MyTextDataset(Dataset):
    """
    Makes it easy to interact with the downloaded dataset
    """
    
    def __init__(self, file):
        with open(file, "r") as f:
            self.lines = f.read().split("\n")
        
    def __len__(self):
        return len(self.lines)
    
    def __getitem__(self, index):
        return self.lines[index]

In [None]:
text_dataset = MyTextDataset("just_captions.txt")

In [None]:
# TODO

## Exercise: Finding Text for Latent Blending of Images

In [None]:
# TODO

## CLIP-Based Image Synthesis 

While *retrieval* has to do with returning matching images that are actually in the dataset, *synthesis* is rendering the embedding into a brand-new image. In 2021, some critical breakthroughs were made with the development of CLIP and diffusion models. Below, we will use diffusion-based decoder to render a prompt into an image and appreciate the beauty and power of the AI-generated art.

### Load a Pre-Trained Diffusion Model (CLIP-Conditioned)

In [None]:
!wget https://v-diffusion.s3.us-west-2.amazonaws.com/cc12m_1_cfg.pth

In [None]:
diffusion_model = get_model("cc12m_1_cfg")()
state_dict = torch.load("cc12m_1_cfg.pth", map_location="cpu")
diffusion_model.load_state_dict(state_dict)
if device.type == 'cuda':
    diffusion_model = diffusion_model.half()
diffusion_model.eval().requires_grad_(False).to(device)

### Instantiate a Diffusion-Based Decoder

In [None]:
decoder = ClassifierFreeGuidanceDecoder(
    diffusion_model=diffusion_model,
    device=device,
    n_steps=500
)

In [None]:
# Sensitive to random seed
prompt = "Two dogs of different breeds looking at each other on the road."
prompt_embeddig = clip_text_encoder.encode(prompt)
image = decoder.decode(prompt_embeddig)
to_pil_image(image[0])

The results don't look so good. Why would that be the case? Some possible explanations:

1. The dataset used to pre-train the diffusion model did not have enough examples of dogs
2. Our prompt did not emphasize how realistic the dogs must look

Our particular CLIP-guided diffusion pipeline is not very suitable for realism. In December 2021, OpenAI released another CLIP-based image synthesis model named GLIDE (https://github.com/openai/glide-text2im), which addresses this bottleneck.

However, our image synthesis is very powerful when it comes to generating art. See below:

In [None]:
# Sensitive to random seed
prompt = "magical ruins at dawn. watercolor. by artists on artstation."
prompt_embeddig = clip_text_encoder.encode(prompt)
image = decoder.decode(prompt_embeddig)
to_pil_image(image[0])

## Exercise: Latent Blending of Prompts, Visualized via Image Synthesis

In [None]:
# TODO