In [None]:
import os
import torch
from torchvision.datasets import CIFAR100
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

In [None]:
class ResidualAttentionBlock:
    def __init__(self):
        pass

    def forward(self, x):
        pass

In [None]:
class VisionTransformer:
    def __init__(self, ViTconfig):
        pass

    def forward(self, x):
        pass

In [None]:
class Transformer:
    def __init__(self, text_config):
        pass

    def forward(self, x):
        pass

In [None]:
class CLIP:
    def __init__(self, clip_config):
        self.vision_encoder = VisionTransformer()
        self.text_encoder = Transformer()

    def forward(self, text, img):
        vision_output = self.vision_encoder(img)
        text_output = self.text_encoder(text)
        return vision_output, text_output

In [None]:

# https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/clip.py
MODELS = {
    "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
    "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
    "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
}

def _convert_image_to_rgb(image):
    return image.convert("RGB")

def load_params(param_path, device):
    def _transform(n_px):
        return Compose([
            Resize(n_px, interpolation=BICUBIC),
            CenterCrop(n_px),
            _convert_image_to_rgb,
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])
    
    with open(param_path) as f:
        state_dict = torch.load(f, map_location="cpu")

    model = CLIP(embed_dim,
        image_resolution, vision_layers, vision_width, vision_patch_size,
        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers)

    for key in ["input_resolution", "context_length", "vocab_size"]:
        if key in state_dict:
            del state_dict[key]
    
    model.load_state_dict(state_dict)
    model = model.eval()
    
    return model, _transform(model.visual.input_resolution)


# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = load_params(MODELS['ViT-B/32'], device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")