In [33]:
import os
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from torchvision.datasets import CIFAR100
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

In [5]:
class VisionTransformer(nn.Module):
    def __init__(self, embed_dim, image_resolution, vision_layers, vision_width, vision_heads, 
                 vision_patch_size):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=vision_width, kernel_size=vision_patch_size,
                               stride=vision_patch_size, bias=False)
        scale = vision_width ** -0.5
        self.cls = nn.Parameter(scale * torch.randn(vision_width))
        self.pos_embedding = nn.Parameter(scale * torch.randn((image_resolution // vision_patch_size) ** 2 + 1, vision_width))
        self.ln1 = nn.LayerNorm(vision_width)
        self.transformer = Transformer(vision_layers, vision_width, vision_heads)
        self.ln2 = nn.LayerNorm(vision_width)
        self.projection = nn.Parameter(scale * torch.randn(vision_width, embed_dim))

    def forward(self, x):
        x = self.conv1(x)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = x.permute(0, 2, 1)
        x = torch.cat([self.cls.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
        x = x + self.pos_embedding.to(x.dtype)
        x = self.ln1(x)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        x = self.ln2(x)
        if self.projection:
            x = x @ self.projection
        return x

In [6]:
class AttentionBlock(nn.Module):
    def __init__(self, width, heads, attention_mask=None):
        super().__init__()
        # GPT-2 Attention block
        # Masked MultiHeadAttnetion -> LayerNorm -> FF -> LayerNorm
        self.attn = nn.MultiHeadAttention(width, head)
        self.ln1 = nn.LayerNorm(width)
        self.mlp = nn.Sequential([nn.Linear(width, width*4), nn.GELU(),
                                  nn.Linear(width*4, width)])
        self.ln2 = nn.LayerNorm(width)
        self.attention_mask = attention_mask
    
    def attention(self, x):
        attn_mask = self.attention_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]

    def forward(self, x):
        x = self.attention(self.ln1(x))
        return self.mlp(self.ln2(x))

In [7]:

class Transformer(nn.Module):
    def __init__(self, transformer_layers, transformer_width, transformer_heads, attention_mask=None):
        super().__init__()
        self.blocks = nn.Sequential([AttentionBlock(transformer_width, transformer_heads, attention_mask) 
                                     for _ in range(transformer_layers)])
        self.width = transformer_width
        self.layers = transformer_layers
        
    def forward(self, x):
        return self.blocks(x)

In [9]:
class CLIP(nn.Module):
    def __init__(self, embed_dim, image_resolution, vision_layers, vision_width, 
                 vision_patch_size, context_length, vocab_size, transformer_width,
                 transformer_heads, transformer_layers):
        super().__init__()
        vision_heads = vision_width // 64
        self.vision_encoder = VisionTransformer(embed_dim, image_resolution, vision_layers, vision_width,
                                                vision_heads, vision_patch_size)
        #self.text_encoder = Transformer(transformer_layers, transformer_width, transformer_heads)
        self.text_encoder = nn.Sequential([AttentionBlock(transformer_width, transformer_heads, attention_mask) 
                                     for _ in range(transformer_layers)])

        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.pos_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln = nn.LayerNorm(transformer_width)

        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.initialize_parameters()

    def initialize_parameters(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.pos_embedding, std=0.01)

        proj_std = (self.text_encoder.width ** -0.5) * ((2 * self.text_encoder.layers) ** -0.5)
        attn_std = self.text_encoder.width ** -0.5
        fc_std = (2 * self.text_encoder.width) ** -0.5
        for block in self.text_encoder.resblocks:
            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

        if self.text_projection is not None:
            nn.init.normal_(self.text_projection, std=self.text_encoder.width ** -0.5)

    def build_attention_mask(self):
        mask = torch.empty(self.context_length, self.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1) # Returns the upper triangular part of a matrix the other elements of the result tensor out are set to 0.
        return mask
    
    def encode_text(self, x):
        x = x + self.pos_embedding(x).type(self.dtype)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = self.permute(1, 0, 2)
        x = self.ln(x).type(self.dtype)
        x = x[torch.arange(x.shape[0]), x.argmax(dim=-1)] @ self.text_projection

    def encode_image(self, x):
        return self.vision_encoder(x)
    
    def forward(self, text, img):
        vision_output = self.vision_encoder(img)
        text_output = self.text_encoder(text)
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text

In [25]:
!wget -c https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt

--2022-12-20 17:26:47--  https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt
Resolving openaipublic.azureedge.net (openaipublic.azureedge.net)... 13.107.246.69, 13.107.213.69, 2620:1ec:bdf::69, ...
Connecting to openaipublic.azureedge.net (openaipublic.azureedge.net)|13.107.246.69|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 353976522 (338M) [application/octet-stream]
Saving to: ‘ViT-B-32.pt’


2022-12-20 17:26:49 (221 MB/s) - ‘ViT-B-32.pt’ saved [353976522/353976522]



In [50]:
!rm ViT-B-2.pt

In [51]:
!ls -l

total 4
drwxr-xr-x 1 root root 4096 Dec 16 21:15 sample_data


In [None]:

# https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/clip.py
MODELS = {
    "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
    "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
    "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
    "ViT-B/32_jit": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
    "ViT-B/32": "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"
}

def _convert_image_to_rgb(image):
    return image.convert("RGB")

def load_params(param_path, device):
    def _transform(n_px):
        return Compose([
            Resize(n_px, interpolation=BICUBIC),
            CenterCrop(n_px),
            _convert_image_to_rgb,
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])
    
    import urllib
    if not os.path.exists("ViT-B-2.pt"):
        with urllib.request.urlopen(param_path) as source, open("ViT-B-2.pt", "wb") as output:
            with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
                while True:
                    buffer = source.read(8192)
                    if not buffer:
                        break

                    output.write(buffer)
                    loop.update(len(buffer))

    with open("ViT-B-2.pt", 'rb') as opened_file:
        checkpoint = torch.load(opened_file, map_location="cpu")
        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        else:
            state_dict = checkpoint
        if next(iter(state_dict.items()))[0].startswith('module'):
            state_dict = {k[7:]: v for k, v in state_dict.items()}

    print(type(state_dict))
    raise Exception("here")
    vision_width = state_dict["visual.conv1.weight"].shape[0]
    vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
    vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
    grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
    image_resolution = vision_patch_size * 32

    embed_dim = state_dict["text_projection"].shape[1]
    context_length = state_dict["positional_embedding"].shape[0]
    vocab_size = state_dict["token_embedding.weight"].shape[0]
    transformer_width = state_dict["ln_final.weight"].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
    
    model = CLIP(embed_dim,
        image_resolution, vision_layers, vision_width, vision_patch_size,
        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers)

    for key in ["input_resolution", "context_length", "vocab_size"]:
        if key in state_dict:
            del state_dict[key]
    
    model.load_state_dict(state_dict)
    model = model.eval()
    
    return model, _transform(model.visual.input_resolution)


# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = load_params(MODELS['ViT-B/32'], device)

In [None]:
# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")