# Contrastive Language-Image Pretraining (CLIP)

Today, we will delve into the intersection of computer vision and natural language processing, which has recently gained significant importance. The development of robust vision and language models has been crucial in enabling this intersection.

CLIP is a noteworthy example of such a model. It can comprehend natural language descriptions of images and generate image embeddings catering to a range of downstream tasks. The model has been trained on a vast dataset of images and their corresponding captions, which has facilitated it to acquire a comprehensive representation of both images and language.

## Setup

Let's start with the installation of the necessary packages.

In [None]:
!pip install -q ftfy regex tqdm scikit-learn scikit-image
!pip install -q git+https://github.com/openai/CLIP.git

We can now import all the pertinent packages.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from PIL import Image
import skimage
import torch
import torch.nn as nn
import torchvision
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from clip import clip

# from collections import OrderedDict

# 
# import IPython.display

%matplotlib inline
%config InlineBackend.figure_format = "retina"

First, we can visualise the available pre-trained models...

In [None]:
clip.available_models()

...and load one of them. For this, we can use the `clip.load()` function and specify the name of the model we want to use.

In [None]:
model, preprocess = clip.load("RN50")
model = model.cuda().eval()

By inspecting some variables, we can better understand the structure of the architecture we loaded.

In [None]:
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{sum(p.numel() for p in model.parameters()):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

## Zero-shot evaluation

CLIP's extraordinary feature is its zero-shot capability, which enables it to perform tasks without any additional training or fine-tuning on specific datasets.

This ability is made possible by CLIP's pre-training on a large corpus of text and images, which empowers it to acquire a comprehensive and flexible representation of language and visual concepts. Consequently, CLIP can perform an extensive range of tasks without any additional training, making it a powerful tool for various applications in NLP and computer vision.

To assess this capability, we need to prepare some test samples. Earlier, by employing the `clip.load()` function, we obtained two outputs—the pre-trained CLIP model and a preprocess object. On examining the object, we can infer that it contains a list of transformations to apply on input images to make them compatible with the model.

In [None]:
preprocess

Therefore, we can use the `preprocess` transform to convert `PIL.Image`s into tensors that are compatible with CLIP.

However, for the textual counterpart, we cannot apply the same transformation. Instead, we can use the `clip.tokenize()` function, which accepts textual input and converts it into an integer tensor.

In [None]:
clip.tokenize("tokenize me!")

We will now examine CLIP's performance on images and textual descriptions.

Next, we can define a list of images that we plan to test and also visualize them along with their textual descriptions.

In [None]:
# Create a description for some images
DESCRIPTIONS = {
    "page": "a page of text about segmentation",
    "chelsea": "a facial photo of a tabby cat",
    "astronaut": "a portrait of an astronaut with the American flag",
    "rocket": "a rocket standing on a launchpad",
    "motorcycle_right": "a red motorcycle standing in a garage",
    "camera": "a person looking at a camera on a tripod",
    "horse": "a black-and-white silhouette of a horse",
    "coffee": "a cup of coffee on a saucer"
}

def get_data():
    images = []
    texts = []
    
    # Get all the filenames in the data directory
    data_dir = Path(skimage.data_dir)
    filenames = [
        filename for filename in data_dir.glob('*')
            if filename.suffix in {'.png', '.jpg'}
        ]
    
    for filename in filenames:
        # Skip images we do not care about
        name = filename.stem
        if name not in DESCRIPTIONS:
            continue

        images.append(filename)
        texts.append(DESCRIPTIONS[name])

    return images, texts

def visualise_data(images_path, texts):
    plt.figure(figsize=(16, 5))
    
    for i, (image_path, text) in enumerate(zip(images_path, texts)):
        # Load and visualize the image along with its text description
        image = Image.open(image_path).convert("RGB")
        
        plt.subplot(2, 4, i+1)
        plt.imshow(image)
        plt.title(text)
        plt.xticks([])
        plt.yticks([])
        plt.tight_layout()

images_path, texts = get_data()
visualise_data(images_path, texts)

To obtain the features of the images, we need to preprocess them using the methods that we discussed previously.

In [None]:
def encode_data(images_path, texts):
    # Preprocess the images to transform from filenames to images to tensors
    images = [preprocess(Image.open(image_path)) for image_path in images_path]

    # Preprocess the texts to transform from text to tensors
    images = torch.tensor(np.stack(images)).cuda()
    text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()
    
    # Encode the inputs
    with torch.no_grad():
        images_z = model.encode_image(images).float()
        texts_z = model.encode_text(text_tokens).float()

    return images_z, texts_z

We can now evaluate the similarity between the set of images and the set of textual descriptions that we created. We can expect that the similarity will be higher for the pairs that we plotted earlier.

Initially, **it is vital** to normalize the features to make them "compatible" and facilitate reasoning in the feature space.

In [None]:
def cosine_similarity(images_z, texts_z):
    # Normalise the image and the text
    images_z /= images_z.norm(dim=-1, keepdim=True)
    texts_z /= texts_z.norm(dim=-1, keepdim=True)
    
    # Evaluate the cosine similarity between the sets of features
    similarity = (images_z @ texts_z.T)

    return similarity.cpu()

images_path, texts = get_data()
images_z, texts_z = encode_data(images_path, texts)
similarity = cosine_similarity(images_z, texts_z)
print(similarity)

As anticipated, the similarity is higher on the diagonal, i.e., for the pairs that we defined earlier. We can also visualize these values to make it more explicit.

In [None]:
def visualise_similarity(similarity, images_path, texts):
    # Flip similarity (just for visualization)
    similarity = similarity.permute(1, 0)
    
    similarity = similarity.numpy()
    count = len(texts)

    plt.figure(figsize=(18, 12))
    plt.imshow(similarity, vmin=0.1, vmax=0.3)
    plt.yticks(range(count), texts, fontsize=18)
    plt.xticks([])
    
    for i, image_path in enumerate(images_path):
        image = Image.open(image_path).convert("RGB")
        plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
        
    # Print the scores
    for x in range(similarity.shape[1]):
        for y in range(similarity.shape[0]):
            plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
    
    # Update spines
    for side in ["left", "top", "right", "bottom"]:
        plt.gca().spines[side].set_visible(False)
    
    # Change plot limits
    plt.xlim([-0.5, count - 0.5])
    plt.ylim([count + 0.5, -2])
    
    # Set title
    plt.title("Cosine similarity between text and image features", size=20)

visualise_similarity(similarity, images_path, texts)

It is evident that these scores provide valuable information, which could also be used for a classification task. For instance, consider the scenario where instead of descriptions such as "a cup of coffee on a saucer," we had descriptions such as "a photo of a mug," "a photo of a horse," and "a photo of a person." These three descriptions can effectively enable us to identify whether the image contains a mug, a horse, or a person.

This highlights the potent zero-shot capabilities of CLIP, as it enables us to classify any image by simply creating a list of objects of interest and converting them into textual descriptions or prompts. By subsequently comparing these descriptions with the image and collecting similarity scores, we can directly classify an image as a mug if the score for the mug is the highest in the description set.

Let's explore this further with an example on CIFAR10.

In [None]:
DATASETS = {
    "mnist": torchvision.datasets.MNIST,
    "cifar10": torchvision.datasets.CIFAR10,
}

def embed_dataset_classnames(dataset_name):
    # Create the list of descriptions and tokenize them
    dataset = DATASETS[dataset_name]("./data", transform=preprocess, download=True, train=False)
    classnames = dataset.classes
    descriptions = [f"a photo of a {label}." for label in classnames]
    text_tokens = clip.tokenize(descriptions).cuda()
    
    # Get the normalized textual features
    with torch.no_grad():
        texts_z = model.encode_text(text_tokens).float()
        texts_z /= texts_z.norm(dim=-1, keepdim=True)
    
    return classnames, descriptions, texts_z

def visualise_probabilities(images_path, classnames, texts_p, k=5):
    topk_p, topk_labels = texts_p.cpu().topk(k, dim=-1)
    plt.figure(figsize=(12, 12))
    
    for i, image_path in enumerate(images_path):
        # Read the image
        image = Image.open(image_path).convert("RGB")
    
        # Visualise the image
        plt.subplot(4, 4, 2 * i + 1)
        plt.imshow(image)
        plt.axis("off")
        
        # Visualise the probabilities for the image
        plt.subplot(4, 4, 2 * i + 2)
        y = np.arange(topk_p.shape[-1])
        plt.grid()
        plt.barh(y, topk_p[i])
        plt.gca().invert_yaxis()
        plt.gca().set_axisbelow(True)
        plt.yticks(y, [classnames[index] for index in topk_labels[i].numpy()])
        plt.xlabel("probability")

    plt.subplots_adjust(wspace=0.5, hspace=0.3)
    plt.show()

# Get the text descriptions and their embeddings
classes, texts, texts_z = embed_dataset_classnames("cifar10")
print(f"Classes: {classes}")
print(f"Prompts (text): {texts}")
print(f"Prompts (embedded): {texts_z.shape}")

# Evaluate the softmax from the cosine similarities
similarity = cosine_similarity(images_z, texts_z)
texts_p = (100 * similarity).softmax(dim=-1)

# Visualise the similarity
visualise_similarity(similarity, images_path, texts)

# Visualise the top-5 predictions
# visualise_probabilities(images_path, texts, texts_p, k=5)

It is noteworthy that the list of descriptions functions as if they were a classifier. Until now, we have observed and trained neural networks that can generally be broken down into an encoder network and a classifier network. However, as we saw in the last lab, if we have to adapt a pre-trained neural network for another task, we have to drop the classifier and retrain it from scratch.

But, by design, CLIP has no classifier and can classify by merely changing the list of textual descriptions we use. Therefore, the list of textual descriptions can also be considered a flexible classifier.

## Some considerations

Thus far, we employed textual descriptions to classify images. However, the opposite can also be accomplished. What we did so far was compare an image with a list of textual descriptions and select the most similar text for that image. But, we can also do the exact opposite.

This type of application may be useful in other scenarios, but it is merely an idea of what you can achieve by reasoning with the components you possess. Another example is that you can compare the similarity between images and images or between text and text to comprehend their degree of similarity.

In [None]:
# Compare how text is similar to other text
similarity = cosine_similarity(texts_z, texts_z)
print(similarity)

# Compare how images are similar one with the other
similarity = cosine_similarity(images_z, images_z)
print(similarity)

Furthermore, until now, we considered only a single "prompt" for each class name (i.e., "a photo of a [CLS]"), but we can use multiple templates to generate more textual features. For instance, in CLIP [1], they define a list of 80 templates for the ImageNet dataset, demonstrating that using the mean representation of the textual features leads to improved performance.

In [None]:
def embed_dataset_classnames(dataset_name, templates=["a photo of a {}."]):
    # Create the list of descriptions and tokenize them
    dataset = DATASETS[dataset_name]("./data", transform=preprocess, download=True, train=False)
    classnames = dataset.classes
    
    texts_z_views = []
    for template in templates:
        descriptions = [template.format(c) for c in classnames]
        text_tokens = clip.tokenize(descriptions).cuda()
        
        # Get the normalized textual features
        with torch.no_grad():
            texts_z = model.encode_text(text_tokens).float()
            texts_z /= texts_z.norm(dim=-1, keepdim=True)
            texts_z_views.append(texts_z)

    # Evaluate the mean representation
    texts_z = torch.stack(texts_z_views).mean(dim=0)

    # Renormalise
    texts_z /= texts_z.norm(dim=-1, keepdim=True)
    
    return classnames, texts_z

# Get the text descriptions and their embeddings
texts, texts_z = embed_dataset_classnames(
  "cifar10",
  templates=["a photo of a {}", "a low-res picture of a {}"]
)

# Evaluate the softmax from the cosine similarities
similarity = cosine_similarity(images_z, texts_z)
texts_p = (100 * similarity).softmax(dim=-1)

# Visualise the top-5 predictions
visualise_probabilities(images_path, texts, texts_p, k=5)

# Finetune CLIP

Although CLIP lacks a linear layer for classifying samples, there are various ways to adapt CLIP for downstream tasks. While it's generally not recommended to completely fine-tune the entire network (as we did in the last lab with the AlexNet examples), there are several methods to adapt CLIP for transfer learning.

One straightforward solution is to add a linear layer on top of the visual features of CLIP to classify a specific dataset.

Let's consider an example with MNIST. To do this, we will reuse code from previous lab lessons.

## Setup

First, we need to modify the get_data function from the old labs to support different datasets and custom transforms (since CLIP has its own).

In [None]:
def get_data(dataset_name, batch_size=64, transform=None, test_batch_size=256):
    dataset = DATASETS[dataset_name]
    
    if not transform:
        # Convert the PIL images to Tensors
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
    
    # Load data
    full_training_data = dataset('./data', train=True, transform=transform, download=True)
    test_data = dataset('./data', train=False, transform=transform, download=True)

    # Create train and validation splits
    num_samples = len(full_training_data)
    training_samples = int(num_samples * 0.5 + 1)
    validation_samples = num_samples - training_samples
    
    training_data, validation_data = torch.utils.data.random_split(full_training_data, [training_samples, validation_samples])
    
    # Initialize dataloaders
    train_loader = torch.utils.data.DataLoader(training_data, batch_size, shuffle=True, num_workers=8)
    val_loader = torch.utils.data.DataLoader(validation_data, test_batch_size, shuffle=False, num_workers=8)
    test_loader = torch.utils.data.DataLoader(test_data, test_batch_size, shuffle=False, num_workers=8)

    return train_loader, val_loader, test_loader

Next, we write a custom `test_step` function to evaluate how well zero-shot CLIP performs

In [None]:
def test_step_zero_shot_clip(net, data_loader, texts_z, device='cuda'):
    samples = 0.0
    cumulative_accuracy = 0.0

    # Set the network to evaluation mode
    net.eval()

    with torch.no_grad():
        # Iterate over the test set
        for batch_idx, (inputs, targets) in tqdm(enumerate(data_loader), total=len(data_loader), position=0, leave=True):
            # Load data into GPU
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # Forward pass
            images_z = model.encode_image(inputs).float()
            outputs = (100 * images_z @ texts_z.T).softmax(dim=-1)

            # Fetch prediction and loss value
            samples += inputs.shape[0]
            _, predicted = outputs.max(1)

            # Compute accuracy
            cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_accuracy / samples * 100

And we evaluate the performance of zero-shot CLIP on a dataset.

In [None]:
dataset_name = "cifar10"

_, _, test_loader = get_data(dataset_name, transform=preprocess, batch_size=128)
texts, texts_z = embed_dataset_classnames(dataset_name)
test_accuracy = test_step_zero_shot_clip(model, test_loader, texts_z)

print(f"Test accuracy {test_accuracy:.2f}")

Now, let's attempt to add a linear layer on top of the visual encoder of CLIP and see if we can enhance performance.

The initial step is to create a custom neural network that builds upon the visual encoder.

In [None]:
class OurCLIP(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        model, _ = clip.load("RN50")

        self.encoder = model.visual.float()
        self.classifier = nn.Linear(1024, num_classes)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.classifier(x)
        
        return x

Next, we can define the other components of training, such as the optimizer, the loss function, and the training and testing steps.

In [None]:
def get_optimizer(model, lr, wd, momentum):
    optimizer = torch.optim.SGD([
        {"params": model.classifier.parameters(), "lr": lr}
    ], lr=lr / 10, weight_decay=wd, momentum=momentum)

    return optimizer

In [None]:
def get_cost_function():
    cost_function = torch.nn.CrossEntropyLoss()
    return cost_function

In [None]:
def training_step(net, data_loader, optimizer, cost_function, device="cuda"):
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0
    
    # Set the network to training mode
    net.train()
    
    # Iterate over the training set
    pbar = tqdm(data_loader, desc="Training", position=0, leave=True, total=len(data_loader))
    for batch_idx, (inputs, targets) in enumerate(data_loader):
        # Load data into GPU
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # Forward pass
        outputs = net(inputs)
        
        # Loss computation
        loss = cost_function(outputs, targets)
        
        # Backward pass
        loss.backward()
        
        # Parameters update
        optimizer.step()
        
        # Gradients reset
        optimizer.zero_grad()
        
        # Fetch prediction and loss value
        samples += inputs.shape[0]
        cumulative_loss += loss.item()
        _, predicted = outputs.max(dim=1) # max() returns (maximum_value, index_of_maximum_value)
        
        # Compute training accuracy
        cumulative_accuracy += predicted.eq(targets).sum().item()

        pbar.set_postfix(train_loss=loss.item(), train_acc=cumulative_accuracy / samples * 100)
        pbar.update(1)
    
    return cumulative_loss / samples, cumulative_accuracy / samples * 100

def test_step(net, data_loader, cost_function, device="cuda"):
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0
    
    # Set the network to evaluation mode
    net.eval()
    
    # Disable gradient computation (we are only testing, we do not want our model to be modified in this step!)
    pbar = tqdm(data_loader, desc="Testing", position=0, leave=True, total=len(data_loader))
    with torch.no_grad():
        # Iterate over the test set
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            # Load data into GPU
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # Forward pass
            outputs = net(inputs)
            
            # Loss computation
            loss = cost_function(outputs, targets)
            
            # Fetch prediction and loss value
            samples += inputs.shape[0]
            cumulative_loss += loss.item() # Note: the .item() is needed to extract scalars from tensors
            _, predicted = outputs.max(1)
            
            # Compute accuracy
            cumulative_accuracy += predicted.eq(targets).sum().item()

            pbar.set_postfix(test_acc=cumulative_accuracy / samples * 100)
            pbar.update(1)

    return cumulative_loss / samples, cumulative_accuracy / samples * 100

## Put it all together!

We need a compact procedure to apply all the components and functions defined so far into the actual optimization procedure. In particular, we want our model to iterate over training step and test step for multiple epochs, tracking the partial results.

In [None]:
def log_values(writer, step, loss, accuracy, prefix):
    writer.add_scalar(f"{prefix}/loss", loss, step)
    writer.add_scalar(f"{prefix}/accuracy", accuracy, step)

def main(
    dataset_name="cifar10",
    batch_size=16,
    num_classes=10,
    device='cuda:0',
    learning_rate=0.002,
    weight_decay=0.0005,
    momentum=0.9,
    epochs=2,
    run_name="exp1",
):
    # Create a logger for the experiment
    writer = SummaryWriter(log_dir=f"runs/{run_name}")
    
    # Get dataloaders
    train_loader, val_loader, test_loader = get_data(dataset_name, transform=preprocess, batch_size=batch_size)
    
    # Instantiate the network and move it to the chosen device (GPU)
    net = OurCLIP(num_classes=num_classes).to(device)

    print(f"Total parameters: {sum(p.numel() for p in net.parameters()):,}")
    print(f"Total trainable parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad):,}")

    # Instantiate the optimizer
    optimizer = get_optimizer(net, learning_rate, weight_decay, momentum)
    
    # Define the cost function
    cost_function = get_cost_function()
    
    # Computes evaluation results before training
    print("Before training:")
    train_loss, train_accuracy = test_step(net, train_loader, cost_function)
    val_loss, val_accuracy = test_step(net, val_loader, cost_function)
    test_loss, test_accuracy = test_step(net, test_loader, cost_function)
    
    # Log to TensorBoard
    log_values(writer, -1, train_loss, train_accuracy, "train")
    log_values(writer, -1, val_loss, val_accuracy, "validation")
    log_values(writer, -1, test_loss, test_accuracy, "test")
    
    print(f"\tTraining loss {train_loss:.5f}, Training accuracy {train_accuracy:.2f}")
    print(f"\tValidation loss {val_loss:.5f}, Validation accuracy {val_accuracy:.2f}")
    print(f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")
    
    # For each epoch, train the network and then compute evaluation results
    for e in range(epochs):
        train_loss, train_accuracy = training_step(net, train_loader, optimizer, cost_function)
        val_loss, val_accuracy = test_step(net, val_loader, cost_function)

        log_values(writer, e, train_loss, train_accuracy, "train")
        log_values(writer, e, val_loss, val_accuracy, "validation")

    # Compute final evaluation results
    print("After training:")
    train_loss, train_accuracy = test_step(net, train_loader, cost_function)
    val_loss, val_accuracy = test_step(net, val_loader, cost_function)
    test_loss, test_accuracy = test_step(net, test_loader, cost_function)
    
    log_values(writer, epochs, train_loss, train_accuracy, "train")
    log_values(writer, epochs, val_loss, val_accuracy, "validation")
    log_values(writer, epochs, test_loss, test_accuracy, "test")
    print(f"\tTraining loss {train_loss:.5f}, Training accuracy {train_accuracy:.2f}")
    print(f"\tValidation loss {val_loss:.5f}, Validation accuracy {val_accuracy:.2f}")
    print(f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")
    
    # closes the logger
    writer.close()

## Run it!

In [None]:
main()

# Context Optimization (CoOp)
We will now see how to implement CoOp and use it to improve the performance of CLIP on downstream datasets.
CoOp aims to perform prompt tuning by learning prompts while keeping CLIP frozen. To do so, we need to make some slight changes to the CLIP pipeline, and we will introduce some new components to perform prompt tuning.

NOTE: the following is a simpler version of the original code of CoOp available [here](https://github.com/KaiyangZhou/CoOp/tree/main).

In [None]:
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()

## The text encoder
While we don't want to finetune the text encoder, we still need to process the "CoOp" prompts (i.e., those we want to learn) through the encoder. Since CLIP is not designed for this, we will wrap the original text encoder of CLIP into our own custom implementation.

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding
        x = x.permute(1, 0, 2)  # [batch_size, n_ctx, transformer.width] -> [n_ctx, batch_size, transformer.width]
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # [n_ctx, batch_size, transformer.width] -> [batch_size, n_ctx, transformer.width]
        x = self.ln_final(x)

        # Take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x

Then, we define our `PromptLearner`, a class that holds the learnable prompts.

In [None]:
class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        n_cls = len(classnames)
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution

        # Use given words to initialize context vectors
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(clip_model.token_embedding.weight.device)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            if csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim)

            torch.nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f"Initial context: '{prompt_prefix}'")
        print(f"Number of context words (tokens): {n_ctx}")

        # These are the `prompts` we want to optimize
        self.ctx = nn.Parameter(ctx_vectors)

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        # print("+++")
        # print("Prompts:")
        # for p in prompts:
        #     print(p)
        # print("+++")

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(clip_model.token_embedding.weight.device)

        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
        self.name_lens = name_lens
        self.class_token_position = class_token_position

    def forward(self):
        prefix = self.token_prefix
        suffix = self.token_suffix
        ctx = self.ctx
        
        # If CoOp, expand the ctx for all classes
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
        
        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts

We redefine our CLIP model with the updated TextEncoder and the new PromptLearner

In [None]:
class OurCLIP(nn.Module):
    def __init__(self, classnames, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        clip_model, _ = clip.load("RN50")
        # clip_model = clip_model.cpu()
        clip_model = clip_model.float()
        
        self.prompt_learner = PromptLearner(clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=csc)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale

    def forward(self, image):
        image_features = self.image_encoder(image)

        prompts = self.prompt_learner()
        tokenized_prompts = self.tokenized_prompts
        text_features = self.text_encoder(prompts, tokenized_prompts)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits

In [None]:
def get_optimizer(model, lr, wd, momentum):
    optimizer = torch.optim.SGD([
        {"params": model.parameters()}
    ], lr=lr, weight_decay=wd, momentum=momentum)

    return optimizer

def main_coop(
    dataset_name="cifar10",
    batch_size=16,
    num_classes=10,
    device="cuda:0",
    learning_rate=0.002,
    weight_decay=0.0005,
    momentum=0.9,
    epochs=2,
    run_name="exp1",
    n_ctx=4,
    ctx_init="",
    class_token_position="end",
    csc=False,
):
    # Create a logger for the experiment
    writer = SummaryWriter(log_dir=f"runs/{run_name}")
    
    # Get dataloaders
    train_loader, val_loader, test_loader = get_data(dataset_name, transform=preprocess, batch_size=batch_size)
    classnames, _ = embed_dataset_classnames(dataset_name)
    
    # Instantiate the network and move it to the chosen device (GPU)
    net = OurCLIP(
        classnames=classnames, n_ctx=n_ctx, ctx_init=ctx_init, class_token_position=class_token_position, csc=csc
    ).to(device)

    print("Turning off gradients in both the image and the text encoder")
    for name, param in net.named_parameters():
        if "prompt_learner" not in name:
            param.requires_grad_(False)
    
    print(f"Total parameters: {sum(p.numel() for p in net.parameters()):,}")
    print(f"Total trainable parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad):,}")
    
    # Instantiate the optimizer
    optimizer = get_optimizer(net, learning_rate, weight_decay, momentum)
    
    # Define the cost function
    cost_function = get_cost_function()
    
    # Computes evaluation results before training
    print("Before training:")
    train_loss, train_accuracy = test_step(net, train_loader, cost_function)
    val_loss, val_accuracy = test_step(net, val_loader, cost_function)
    test_loss, test_accuracy = test_step(net, test_loader, cost_function)
    
    # Log to TensorBoard
    log_values(writer, -1, train_loss, train_accuracy, "train")
    log_values(writer, -1, val_loss, val_accuracy, "validation")
    log_values(writer, -1, test_loss, test_accuracy, "test")
    
    print(f"\tTraining loss {train_loss:.5f}, Training accuracy {train_accuracy:.2f}")
    print(f"\tValidation loss {val_loss:.5f}, Validation accuracy {val_accuracy:.2f}")
    print(f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")
    
    # For each epoch, train the network and then compute evaluation results
    for e in range(epochs):
        train_loss, train_accuracy = training_step(net, train_loader, optimizer, cost_function)
        val_loss, val_accuracy = test_step(net, val_loader, cost_function)

        log_values(writer, e, train_loss, train_accuracy, "train")
        log_values(writer, e, val_loss, val_accuracy, "validation")

    # Compute final evaluation results
    print("After training:")
    train_loss, train_accuracy = test_step(net, train_loader, cost_function)
    val_loss, val_accuracy = test_step(net, val_loader, cost_function)
    test_loss, test_accuracy = test_step(net, test_loader, cost_function)
    
    log_values(writer, epochs, train_loss, train_accuracy, "train")
    log_values(writer, epochs, val_loss, val_accuracy, "validation")
    log_values(writer, epochs, test_loss, test_accuracy, "test")
    print(f"\tTraining loss {train_loss:.5f}, Training accuracy {train_accuracy:.2f}")
    print(f"\tValidation loss {val_loss:.5f}, Validation accuracy {val_accuracy:.2f}")
    print(f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")
    
    # Closes the logger
    writer.close()

In [None]:
main_coop()

# References

1. Radford, Alec, et al. "Learning transferable visual models from natural language supervision." International conference on machine learning. PMLR, 2021.