In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import open_clip
import torch
from PIL import Image

model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',device="cuda")
tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-bigG-14-laion2B-39B-b160k')

model.eval()

# ELSA dataset

In [None]:
img = Image.open("../../img/cat.png")

image = preprocess(img).unsqueeze(0).cuda()
text = tokenizer(["a diagram", "a dog", "a cat"]).cuda()

display(img)

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]

In [None]:
img = Image.open("../../img/cat.jpg")

image = preprocess(img).unsqueeze(0).cuda()
text = tokenizer(["a diagram", "a dog", "a cat"]).cuda()

display(img)

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]

## Loading the dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset("elsaEU/ELSA_D3",split="train",streaming=True)

## Converting the dataset in pytorch format

In [None]:
for e in dataset:
    img = preprocess(e["image_gen0"]).unsqueeze(0).cuda()
    img_features = model.encode_image(img)
    display(img_features.shape)
    display(img_features.flatten().shape)
    break


In [None]:
ds = dataset.with_format("torch")

## Example of one sample from the dataset

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import requests
import sys
sys.path.append("../tools/")
from utils import plot_tensor

for e in ds:
    display(e)

    # Real image
    plt.imshow(np.asarray(Image.open(requests.get(e["url"], stream=True).raw)))
    plt.xticks([])
    plt.yticks([])
    plt.title("Real image")
    plt.show()

    # Generated images
    fig, axs = plt.subplots(2,2)
    base_name = "image_gen"
    for i in range(4):
        fig.sca(axs.flatten()[i])
        plot_tensor(e["image_gen" + str(i)])
    plt.suptitle("Generated images")
    plt.tight_layout()
    break