In [1]:
import torch
from networks import SCLIPNN
import clip
from PIL import Image
from sentence_transformers import SentenceTransformer

We can now test what happens if you use a multilingual SBERT model and put in a Spanish sentence and map it to CLIP and then compare it against a image put into CLIP. Maybe you can first just test some examples using arbitrary example images. I.e. try a Spanish sentence with different images (some that match the sentence well, some that match it less well, some that don't match it at all).

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
spanish_sentences = ["un perro blanco", "un animal grande", "la muchacha de pelo azul"]
sbert_model = SentenceTransformer('distiluse-base-multilingual-cased-v1')

In [4]:
clip_model, preprocess = clip.load("ViT-B/32", device=device)
image = preprocess(Image.open("imgs/poodle.jpg")).unsqueeze(0).to(device)
text = clip.tokenize(spanish_sentences).to(device)

In [5]:
print("Text Before: Type: {}, Shape: {}, Text: {}".format(type(text), text.shape, text[0][:5]))

Text Before: Type: <class 'torch.Tensor'>, Shape: torch.Size([3, 77]), Text: tensor([49406,  2271,   703,  2795, 26801], device='cuda:0')


In [6]:
with torch.no_grad():
    image_features = clip_model.encode_image(image)
    sbert_features = torch.from_numpy(sbert_model.encode(spanish_sentences))
    clip_features = clip_model.encode_text(text)
    print("Text After: Type: {}, Shape: {}, Text: {}".format(type(text), text.shape, text[0][:5]))
    logits_per_image, logits_per_text = clip_model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

input_size = sbert_features.shape[1]
print(input_size)
PATH = "models/best_model.pt"
model = SCLIPNN(input_size,850)
model.load_state_dict(torch.load(PATH))
model.eval()
print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

Text After: Type: <class 'torch.Tensor'>, Shape: torch.Size([3, 77]), Text: tensor([49406,  2271,   703,  2795, 26801], device='cuda:0')
512
Label probs: [[0.967   0.02347 0.00963]]
