In [None]:
import torch
import torchvision
import clip

from pathlib import Path

In [None]:
class MemeFolder:
    """Takes an image folder and a CLIP model and calculates the encodings for each image"""
    
    def __init__(self, folder_str, clip_model="ViT-B/32", clear_cache=False):
        self.clear_cache = clear_cache
        self.path = Path(folder_str)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model, self.preprocess = clip.load(clip_model, device=self.device)
        self.logit_scale = self.model.logit_scale.exp()
        
        self.names, self.features = self.process_images(self.path)
        
    def process_images(self, path):
        """Calculate image encodings from folder. 
        TODO: Fix recursive loading to include self folder
        """
        savefile = path/'memery.pt'
        if self.clear_cache == True:
            savefile.unlink() # remove savefile if need be
        # load or generate the encodings 🗜️
        # currently this just checks to see if there's a savefile, not if anything has changed since save time
        if savefile.exists():
            save_dict = torch.load(savefile)
            image_names = [k for k in save_dict.keys()]
            image_features = torch.stack([v for v in save_dict.values()]).to(self.device)
        else:
            image_features = torch.tensor(()).to(device)
            with torch.no_grad():
                imagefiles=torchvision.datasets.ImageFolder(root=path, transform=preprocess)
                img_loader=torch.utils.data.DataLoader(imagefiles, batch_size=128, shuffle=False, num_workers=4)
                for images, labels in tqdm(img_loader):
                    batch_features = model.encode_image(images)
                    image_features = torch.cat((image_features, batch_features)).to(self.device)

            image_names = [Path(f[0]) for f in imagefiles.imgs]
            self.save(image_names, image_features)
                           
        return(image_names, image_features)
    
    def save(self, filenames, enc_tensors):
        """Saves a dictionary of filenames and encoding tensors"""
        save_dict = {str(k):v for k, v in zip(filenames, enc_tensors)}
        torch.save(save_dict, self.savefile)
    
    def predict_from_text(self, query):
        """Tokenize the text query and compare to each image. Returns a sorted dictionary of names
        and scores
        """
        with torch.no_grad():
            text = clip.tokenize(query).to(self.device)
            text_features = self.model.encode_text(text)

            # normalize features
            self.features = self.features / self.features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            # cosine similarity as logits
            logits_per_image = self.logit_scale * self.features @ text_features.float().t()

        scores = {self.names[i]: logit for i, logit in enumerate(logits_per_image)}
        top_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
        return(top_scores)

In [None]:
small = MemeFolder('images')

In [None]:
cats = small.predict_from_text('cat')

In [None]:
cats[0]

('images/memes/wgglo1jpy4l61.jpg', tensor([27.2811], device='cuda:0'))