In [None]:
import os
import sys
sys.path.append("..")

import torch
from tokenizers import Tokenizer
from safetensors import safe_open

from model.siglip import SiglipModel, SiglipConfig

In [None]:
config: SiglipConfig = {
    "text_config": {
    "hidden_size": 1152,
    "intermediate_size": 4304,
    "num_attention_heads": 16,
    "num_hidden_layers": 27,
    "vocab_size": 32000,
    "max_position_embeddings": 64,
    "attention_dropout": 0.0,
    "layer_norm_eps": 1e-6,
  },
  "vision_config": {
    "hidden_size": 1152,
    "image_size": 384,
    "intermediate_size": 4304,
    "num_attention_heads": 16,
    "num_hidden_layers": 27,
    "patch_size": 14,
    "num_channels": 3,
    "attention_dropout": 0.0,
    "layer_norm_eps": 1e-6,
  }
}

In [None]:
model = SiglipModel(config=config)
tokenizer = Tokenizer.from_file(os.path.join('../weights/siglip/', 'tokenizer.json'))
tokenizer.enable_padding(pad_id=1, length=config['text_config']['max_position_embeddings'])

In [None]:
def get_state_dict_from_safetensors(path: str | list[str], device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.bfloat16) -> dict:
    state_dict = {}
    if isinstance(path, str): path = [path]
    if path:
        d = device.type if device.type == 'cpu' else device.index
        for p in path:
            with safe_open(p, framework="pt", device=d) as f:
                for k in f.keys(): state_dict[k] = f.get_tensor(k).to(dtype=dtype)
    else: print("No weights found.")
    return state_dict

sd = get_state_dict_from_safetensors('../weights/siglip/model.safetensors')
model.load_state_dict(sd)

In [None]:
import torchvision.transforms as transforms
from PIL import Image
from io import BytesIO
from typing import Union
import requests

def preprocess_image(image_input: Union[str, bytes, Image.Image], image_size: int = 384):
    if isinstance(image_input, Image.Image): image = image_input
    elif isinstance(image_input, bytes): image = Image.open(BytesIO(image_input))
    elif image_input.startswith('http'): image = Image.open(requests.get(image_input, stream=True).raw)
    else: image = Image.open(image_input)
    
    if image.mode == 'RGBA':
        image = image.convert('RGB')

    num_channels = len(image.getbands())
    
    normalize_transform = transforms.Compose([
        transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5] * num_channels, std=[0.5] * num_channels)
    ])
    
    tensor_image = normalize_transform(image)
    tensor_image = tensor_image.unsqueeze(0)
    return tensor_image

In [None]:
texts = ["a photo of earth from moon", "a photo of 2 people on moon", "2 people sitting on moon"]
inputs = {
    'input_ids': torch.tensor([t.ids for t in tokenizer.encode_batch(texts,)]),
    'pixel_values': preprocess_image("/home/andrew264/Downloads/Screenshot_23.png"),
    }

with torch.no_grad():
    outputs = model(**inputs)

logits = outputs[0]
probs = torch.sigmoid(logits)
for i, text in enumerate(texts):
    print(f"{probs[0][i]:.1%} that image is '{text}'")

In [None]:
from datasets import load_dataset
config: SiglipConfig = {
        "text_config": {
        "hidden_size": 768,
        "intermediate_size": 3072,
        "num_attention_heads": 12,
        "num_hidden_layers": 12,
        "vocab_size": 10,
        "max_position_embeddings": 1,
        "attention_dropout": 0.0,
        "layer_norm_eps": 1e-6,
        },
    "vision_config": {
        "image_size": 224,
        "hidden_size": 768,
        "intermediate_size": 3072,
        "num_attention_heads": 12,
        "num_hidden_layers": 12,
        "patch_size": 16,
        "num_channels": 1,
        "attention_dropout": 0.0,
        "layer_norm_eps": 1e-6,
        }
    }
model = SiglipModel(config=config)
model.load_state_dict(torch.load('../weights/siglip/siglip.pt', map_location='cpu', weights_only=True))
dataset = load_dataset("ylecun/mnist", split='test')


In [None]:
dataset[0]['image']

In [None]:
inputs = {
    'input_ids': torch.tensor([[i] for i in range(10)]),
    'pixel_values': preprocess_image(dataset[0]['image'], image_size=28),
    }

with torch.no_grad():
    outputs = model(**inputs)

logits = outputs[0]
probs = torch.sigmoid(logits)
for i, text in enumerate(range(10)):
    print(f"{probs[0][i]:.1%} that image is '{text}'")