## Pre-requisited

As part of this blog post, I assume that the reader understands the following: 

1. ResNet Architecture 
2. Vision Trasnsformer

## Introduction

As part of this blog post we will be uncovering the inner workings of CLIP - [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) (@clip) by looking at it's PyTorch implementation. For a gentle introduction to CLIP, please refer to [part-1](https://amaarora.github.io/posts/2023-03-06_Understanding_CLIP.html) of the blog.

## CLIP Architecture

![Summary of CLIP approach](../images/clip.png){#fig-clip}

From @fig-clip, we can see that we have a text encoder and image encoder. These encoders are responsible for taking in the image and the text and and converting them to an embedding space. 

The image encoder encodes images to embeddings $I_1, I_2, I_2 ... I_N$, and the text encoder encodes respective image captions to $T_1, T_2, T_3 ... T_N$.

But, how do these text and image encoders look like? Let's start with the image encoder. 

## Image Encoder

*We consider two different architectures for the image encoder. For the first, we use ResNet-50 (@resnet) as the base architecture for the image encoder due to its widespread adoption and proven performance. We make several modifications to the original version using the ResNetD improvements from @bag_of_tricks and the antialiased rect-2 blur pooling from @blurpool. We also replace the global average pooling layer with an attention pooling mechanism. The attention pooling is implemented as a single layer of “transformer-style” multi-head QKV attention where the query is conditioned on the global average-pooled representation of the image. For the second architecture, we experiment with the recently introduced Vision Transformer (ViT) (@vit). We closely follow their implementation with only the minor modification of adding an additional layer normalization to the combined patch and position embeddings before the transformer and use a slightly different initialization scheme.*

### Modified ResNet

Let's start with the first architecture. 

*For the first, we use ResNet-50 (@resnet) as the base architecture for the image encoder due to its widespread adoption and proven performance. We make several modifications to the original version using the ResNetD improvements from @bag_of_tricks and the antialiased rect-2 blur pooling from @blurpool. We also replace the global average pooling layer with an attention pooling mechanism. The attention pooling is implemented as a single layer of “transformer-style” multi-head QKV attention where the query is conditioned on the global average-pooled representation of the image.*

There are 3 major changes as mentioned to the ResNet architecture in CLIP: 

- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool

#### ResNet stem

Let's look at all of them one by one in code. First, we start with *There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.*

-- Add image here

In the vanilla ResNet architecture, the stem consists of a 7x7 stride-2 convolution. This is what the stem looks like in the vanilla ResNet architecture. 

```python
class VanillaResNet:
    def __init__(...):
        self.stem = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
```

However, in the paper @bag_of_tricks, where at the time, the authors raised *ResNet-50’s top-1 validation accuracy from 75.3% to 79.29% on ImageNet*. From the paper, one of the tweaks used in the architecture: 

*A 7 × 7 convolution is 5.4 times more expensive than a 3 × 3 convolution. So this tweak replacing the 7 × 7 convolution in the input stem with three conservative 3 × 3 convolutions.*

-- Add image here

In code this looks like: 

```python 
class ModifiedResNet:
    def __init__(...):
        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.act1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.act2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.act3 = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(2)


    def stem(self, x):
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x)))
        x = self.act3(self.bn3(self.conv3(x)))
        x = self.avgpool(x)
        return x
    
    def forward(self, x):
        x = self.stem(x)
```    

#### Blur Pool

The next change is to use `BlurPooling` - *Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1*.

In this section I will introduce BlurPooling and share how it is implemented in the `ModifiedResNet` architecture. 

From the research paper, 

*Modern convolutional networks are not shiftinvariant, as small input shifts or translations can cause drastic changes in the output. Commonly used downsampling methods, such as max-pooling, strided-convolution, and averagepooling, ignore the sampling theorem. The wellknown signal processing fix is anti-aliasing by low-pass filtering before downsampling.*

#### Final pooling layer

The last change in the network architecture is to use QKV attention instead of an average pool. *We also replace the global average pooling layer with an attention pooling mechanism. The attention pooling is implemented as a single layer of “transformer-style” multi-head QKV attention where the query is conditioned on the global average-pooled representation of the image.*

In [2]:
import torch
import torch.nn as nn

class AttentionPool2d(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x, key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0.,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        )

        return x[0]

### Modified ViT

From the paper:

*For the second architecture, we experiment with the recently introduced Vision Transformer (ViT) (@vit). We closely follow their implementation with only the minor modification of adding an additional layer normalization to the combined patch and position embeddings before the transformer and use a slightly different initialization scheme.*

Since the architecture is very similar to vanilla Vision Transformer, with a very minor change of adding LayerNorm after combining Patch embeddings and positional embeddings, I will not be covering the architecture in detail in this blog post. 

For reference to ViT, please refer to my previous blog post that covers the architecture in detail with PyTorch code implementation - [Vision Transformer](https://amaarora.github.io/posts/2021-01-18-ViT.html)

## Text Encoder

In the previous section, we covered the two types of Image Encoders used in CLIP. As mentioned in part-1 of the blog series on CLIP, the image encoders are actually architecture agnostic - that is, any standard architecture can be used to extract embeddings from images. 

In this section, let's look at the text encoder.