In [9]:
import json
import os

import torch
from PIL import Image

import cloob.clip as clip
import cloob.zeroshot_data as zeroshot_data
from cloob.clip import _transform
from cloob.model import CLIPGeneral

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device is ", device)

Device is  cpu


In [8]:
# Load CLOOB
checkpoint_path = './checkpoints/cloob_rn50_yfcc_epoch_28.pt'
configs_path = './cloob/model_configs/'

checkpoint = torch.load(checkpoint_path, map_location=device)
model_config_file = os.path.join(configs_path, checkpoint['model_config_file'])

print('Loading model from', model_config_file)
assert os.path.exists(model_config_file)
with open(model_config_file, 'r') as f:
    model_info = json.load(f)
model = CLIPGeneral(**model_info)
preprocess = _transform(model.visual.input_resolution, is_train=False)

if not torch.cuda.is_available():
    model.float()
else:
    model.to(device)

sd = checkpoint["state_dict"]
sd = {k[len('module.'):]: v for k, v in sd.items()}
if 'logit_scale_hopfield' in sd:
    sd.pop('logit_scale_hopfield', None)
model.load_state_dict(sd)
model.eval()

Loading model from ./cloob/model_configs/RN50.json


CLIPGeneral(
  (visual): ModifiedResNet(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (relu): ReLU(inplace=True)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    

In [53]:
image_paths = [
    "./data/gtsrb/test/025/10553.png",
    "./data/gtsrb/test/009/08086.png",
    "./data/gtsrb/test/014/00621.png",
    "./data/gtsrb/test/001/11095.png",
    "./data/gtsrb/test/022/03702.png",
    "./data/gtsrb/test/018/10780.png",
    "./data/gtsrb/test/013/05310.png",
    "./data/gtsrb/test/004/02686.png",
]
classnames = zeroshot_data.gtsrb_classnames
prompt_templates = zeroshot_data.gtsrb_templates

image_path = image_paths[2]
classname = classnames[14]

# Image processing
with torch.no_grad(), Image.open(image_path) as im:
    image = preprocess(im).to(device)
    image_embedding = model.encode_image(image.unsqueeze(0))
    image_embedding /= image_embedding.norm(dim=-1, keepdim=True)

# Text processing
with torch.no_grad():
    texts = [template(classname) for template in prompt_templates]
    texts = clip.tokenize(texts).to(device)
    class_embeddings = model.encode_text(texts)
    class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
    class_embedding = class_embeddings.mean(dim=0)
    class_embedding /= class_embedding.norm()

similarity = (image_embedding @ class_embedding) * 30  # we scale by 30 just so the difference is more stark
similarity.numpy()[0]

5.4933786