In [None]:
# pip install "git+https://github.com/openai/CLIP.git"
import clip
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
perceptor, preprocess = clip.load("ViT-B/32", device=device, jit=False)
! pip install ftfy regex
! wget https://openaipublic.azureedge.net/clip/bpe_simple_vocab_16e6.txt.gz -O bpe_simple_vocab_16e6.txt.gz

from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image

In [105]:
from pathlib import Path
from random import randint, choice


from typing import Union, List
from torch.utils.data import Dataset
from torchvision import transforms as T
from clip.simple_tokenizer import SimpleTokenizer
model, preprocess = clip.load("ViT-B/32", device='cuda', jit=False)

_tokenizer = SimpleTokenizer()
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
    """
    Returns the tokenized representation of given input string(s)
    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize
    context_length : int
        The context length to use; all CLIP models use 77 as the context length
    Returns
    -------
    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
    """
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result



class TextImageDataset(Dataset):
    def __init__(self,
                 folder,
                 text_len=77,
                 image_size=128,
                 truncate_captions=False,
                 resize_ratio=0.75,
                 shuffle=False
                 ):
        """
        @param folder: Folder containing images and text files matched by their paths' respective "stem"
        @param truncate_captions: Rather than throw an exception, captions which are too long will be truncated.
        """
        super().__init__()
        self.shuffle = shuffle
        path = Path(folder)

        text_files = [*path.glob('**/*.txt')]
        image_files = [
            *path.glob('**/*.png'), *path.glob('**/*.jpg'),
            *path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
        ]

        text_files = {text_file.stem: text_file for text_file in text_files}
        image_files = {image_file.stem: image_file for image_file in image_files}

        keys = (image_files.keys() & text_files.keys())

        self.keys = list(keys)
        self.text_files = {k: v for k, v in text_files.items() if k in keys}
        self.image_files = {k: v for k, v in image_files.items() if k in keys}
        self.text_len = text_len
        self.truncate_captions = truncate_captions
        self.resize_ratio = resize_ratio
        self.tokenizer = tokenizer
        clip_mean = [0.48145466, 0.4578275, 0.40821073]
        clip_std = [0.26862954, 0.26130258, 0.27577711]
        self.image_transform = T.Compose([
            T.Resize(image_size),
            T.CenterCrop((image_size, image_size)),
            T.Normalize(mean=clip_mean, std=clip_std),

        ])

    def __len__(self):
        return len(self.keys)

    def random_sample(self):
        return self.__getitem__(randint(0, self.__len__() - 1))

    def sequential_sample(self, ind):
        if ind >= self.__len__() - 1:
            return self.__getitem__(0)
        return self.__getitem__(ind + 1)

    def skip_sample(self, ind):
        if self.shuffle:
            return self.random_sample()
        return self.sequential_sample(ind=ind)

    def __getitem__(self, ind):
        key = self.keys[ind]

        text_file = self.text_files[key]
        image_file = self.image_files[key]

        descriptions = text_file.read_text().split('\n')
        descriptions = list(filter(lambda t: len(t) > 0, descriptions))
        try:
            description = choice(descriptions)
        except IndexError as zero_captions_in_file_ex:
            print(f"An exception occurred trying to load file {text_file}.")
            print(f"Skipping index {ind}")
            return self.skip_sample(ind)

        tokenized_text = tokenize(
            description,
            self.text_len,
        ).squeeze(0)
        try:
            image_tensor = T.ToTensor(self.image_transform(image_file))
        except (PIL.UnidentifiedImageError, OSError) as corrupt_image_exceptions:
            print(f"An exception occurred trying to load file {image_file}.")
            print(f"Skipping index {ind}")
            return self.skip_sample(ind)

        # Success
        return tokenized_text, image_tensor


In [106]:
# from clip import tokenizer

ds = TextImageDataset(
    "/home/samsepiol/Datasets/Previews/Previews10M/previews_000/",
    text_len=256,
    image_size=256,
    resize_ratio=0.8,
    truncate_captions=True,
    shuffle=False,
)
assert len(ds) > 0, 'dataset is empty'



In [107]:
from torch.utils.data import DataLoader
dl = DataLoader(ds, batch_size=1, shuffle=False, drop_last=True)


In [108]:
for i, (text, images) in enumerate(dl):
    with torch.no_grad():
        text, images = map(lambda t: t.cuda(), (text, images))
        image_features = ds.image_transform(model.encode_image(images[0]))
        text_features = model.encode_text(text)
    
        logits_per_image, logits_per_text = model(image_features, text_features).cuda()
        probs = logits_per_text.softmax(dim=-1).numpy()
        print(probs)

    

TypeError: img should be PIL Image. Got <class 'pathlib.PosixPath'>