In [None]:
import open_clip
from collections import OrderedDict
import torch

model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='datacomp_xl_s13b_b90k')

In [None]:
print(model.logit_scale.exp())

In [None]:
print(preprocess)

In [None]:
print(model.state_dict()["visual.proj"].shape)

In [None]:
new_state_dict = OrderedDict()
for k, v in model.state_dict().items():
    if 'visual.' in k:
        new_state_dict[k[7:]] = v
torch.save(new_state_dict, '{openclip_path}/datacomp_xl_s13b_b90k/vit_l14.pth')

In [None]:
new_state_dict = OrderedDict()
for k, v in model.state_dict().items():
    if not 'visual.' in k and not 'logit_scale' in k:
        new_state_dict[k] = v
torch.save(new_state_dict, '{openclip_path}/datacomp_xl_s13b_b90k/vit_l14_text.pth')

In [None]:
# Test the extracted models

from models.backbones.clip.clip_vision import clip_joint_l14
from models.backbones.clip.clip_text import clip_text_l14

visual_encoder = clip_joint_l14(num_frames=1)
text_encoder = clip_text_l14(context_length=32)

In [None]:
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import decord
import numpy as np

def load_video(path):
    video_reader = decord.VideoReader(path, num_threads=1, ctx=decord.cpu(0))
    decord.bridge.set_bridge('torch')
    video_len = len(video_reader)
    video = video_reader.get_batch(np.linspace(0, video_len - 1, 1).astype(np.int)).byte()
    video = video.permute(0, 3, 1, 2)

    return video

# This transform follows vindlu
type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
normalize = transforms.Normalize(mean, std)
transform = transforms.Compose(
        [
            transforms.Resize(
                (224, 224),
                interpolation=InterpolationMode.BICUBIC,
            ),
            type_transform,
            normalize,
        ]
)

video = load_video('examples/yoga.mp4')
video = transform(video)

In [None]:
import torch.nn.functional as F

visual_embedding = visual_encoder(video.permute(1, 0, 2, 3).unsqueeze(0))
text = text_encoder.tokenize(['a woman doing yoga', 'a woman doing yoga on the roof', 'a man doing yoga', 'a person doing yoga', 'a dog running on the grass', 'a cat sitting on the sofa'])
text_embedding = text_encoder(text)

visual_embedding = F.normalize(visual_embedding, dim=-1, p=2)
text_embedding = F.normalize(text_embedding, dim=-1, p=2)

In [None]:
print(visual_embedding.shape, text_embedding.shape)
print((100 * visual_embedding @ text_embedding.T).softmax(-1))