In [1]:
import torch
import torch.nn as nn
from torchvision.transforms import Normalize
import torch.nn.functional as F
import open_clip
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD

from helper import count_parameters

In [None]:
class CLIP(nn.Module):
    def __init__(self, name='ViT-L/14', pretrained='openai'):
        super().__init__()
        self.model = open_clip.create_model(name, pretrained=pretrained)
        self.model = self.model.eval().requires_grad_(False)
        self.img_resolution = self.model.visual.image_size[0]
        self.norm = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)
        self.im_dim = self.txt_dim = self.model.ln_final.normalized_shape[0]

    @property
    def device(self) -> torch.device:
        return next(self.model.parameters()).device
    
    def encode_images(self, images: torch.Tensor, div255: bool = False) -> torch.Tensor :
        if div255: images = images.to(torch.float32) / 255
        images = F.interpolate(images, size = self.img_resolution, mode = "bicubic", align_corners=False)
        images = self.norm(images)
        image_features = self.model.encode_image(images)
        image_features = F.normalize(image_features, dim = -1)
        return image_features

    def encode_texts(self, texts: list[str]) -> torch.Tensor:
        text = open_clip.tokenize(texts).to(self.device)
        text_features = self.model.encode_text(text)
        text_features = F.normalize(text_features, dim=-1)
        return text_features
    
    def forward(self, images: torch.Tensor, texts: list[str], div255: bool = False) -> torch.Tensor:
        assert len(images) == len(texts)
        image_features = self.encode_images(images, div255=div255)
        text_features = self.encode_text(texts)
        joint_features = torch.cat([image_features, text_features], 1)
        return joint_features

In [3]:
clip = CLIP()
image_features = clip.encode_images(torch.rand(2, 3, 224, 224))
image_features



tensor([[-0.0132,  0.0293,  0.0270,  ...,  0.0142, -0.0109, -0.0343],
        [-0.0096,  0.0299,  0.0264,  ...,  0.0153, -0.0137, -0.0324]])

In [5]:
text_features = clip.encode_text(["hi"])
text_features.shape

  return torch._native_multi_head_attention(


torch.Size([1, 768])

In [11]:
clip.model.visual.image_size

(224, 224)

In [6]:
joint_features.shape

torch.Size([2, 1536])

In [7]:
count_parameters(clip)
print("---")
del clip.model.visual
count_parameters(clip)

Total parameters: 427,616,513
Trainable parameters: 0
---
Total parameters: 123,650,305
Trainable parameters: 0


In [32]:
model = open_clip.create_model("ViT-L/14", pretrained='openai')
model = model.eval().requires_grad_(False)
img_resolution = model.visual.image_size[0]




In [33]:
img_resolution

224

In [29]:
next(model.parameters()).device

device(type='cpu')

In [25]:
model.ln_final.normalized_shape[0]

768