In [None]:
!pip install clip-by-openai > /dev/null

In [None]:
########## Imports ##########

import os 
import datetime 
import time

from PIL import Image

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchmetrics
import pytorch_lightning as pl

import clip

import warnings
warnings.filterwarnings('ignore')


if torch.cuda.is_available():  
    print('Wohooo, GPU found!!')
    dev = "cuda:0" 
else:  
    dev = "cpu"  
    
device = torch.device(dev)

torch.manual_seed(0)

# Loading in CLIP

In [None]:
CLIP, preprocess = clip.load("RN50", device=device, jit=False)
#CLIP, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
########## Decoder ##########

tokenizer_filepath = '../input/clip-backend-resources/bpe_simple_vocab_16e6.txt'
import gzip
import html
import os
from functools import lru_cache

import ftfy
import regex as re

@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()
            


def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text



class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = tokenizer_filepath):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = open(bpe_path).read().split('\n')
        merges = merges[1:49152-256-2+1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v+'</w>' for v in vocab]
        for merge in merges:
            vocab.append(''.join(merge))
        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
        self.sot_token = self.encoder['<|startoftext|>']
        self.eot_token = self.encoder['<|endoftext|>']
        

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
        pairs = get_pairs(word)

        if not pairs:
            return token+'</w>'

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word
    
    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
        return text
    
    
    def padded_decode(self, tokens):
        
        length = (tokens[0] == self.eot_token).nonzero(as_tuple=True)[0]
 
        tokens = tokens[:,0:length][0][1:].cpu().numpy()
                
        text = self.decode(tokens)
        return text.rstrip()
        

decoder = SimpleTokenizer()

# CLIP in a Pytorch Lightning Module

In [None]:
########## Making a PyTorch Lightning Module ##########

class CLIP_ZEROSHOT(pl.LightningModule):
    def __init__(self, classes = None, CLIP = CLIP, decoder = decoder ):
        super(CLIP_ZEROSHOT, self).__init__()
        
        self.CLIP = CLIP
        
        self.decoder = decoder
        self.classes = classes
        
        if self.classes != None:
            
            self.acc = torchmetrics.Accuracy()
            self.classes_dict = {}
            for i, class_label in enumerate(classes):
                self.classes_dict[decoder.padded_decode(class_label.unsqueeze(0))] = i  
            self.test_cm = ConfusionMatrix(num_classes = len(self.classes)) #for plotting after testing epochs

    
    def encode_image(self, images):
        with torch.no_grad():
            image_features = self.CLIP.encode_image(images)
            
        return image_features     
    
    def encode_text(self, texts):
        with torch.no_grad():
            text_features = self.CLIP.encode_text(texts)
            
        return text_features  
    
    def forward(self, images):
        with torch.no_grad():
            image_features = self.CLIP.encode_image(images)
            text_features = self.CLIP.encode_text(self.classes.cuda())

            logits_per_image, logits_per_text = self.CLIP(images, self.classes.cuda())
            probs = logits_per_image.softmax(dim=-1)
        predicted_labels_indexes = torch.argmax(probs, dim = -1)    
        return predicted_labels_indexes   
    
    def test_step(self, batch, batch_idx):
        images = batch['image']
        labels = batch['label']
        
        predicted_label_indexes = self.forward(images)
        
        predicted_labels = torch.zeros_like(labels).cuda()
        
        for i in range(len(predicted_label_indexes)):
            predicted_labels[i] = self.classes[predicted_label_indexes[i]]

        decoded_predicted_labels = torch.zeros(len(labels), dtype = torch.int)
        decoded_labels = torch.zeros(len(labels), dtype = torch.int)
        
        for i in range(len(predicted_labels)):
            decoded_predicted_labels[i] = self.classes_dict[self.decoder.padded_decode(predicted_labels[i])]
            decoded_labels[i] = self.classes_dict[self.decoder.padded_decode(labels[i])]
     
    
        test_acc = self.acc(decoded_predicted_labels, decoded_labels)
        self.test_cm(decoded_predicted_labels.cuda(), decoded_labels.cuda()) #for plotting after testing epochs        
        self.log('Accuracy', test_acc, prog_bar = True, on_epoch=True, sync_dist=True)
        
        
        
model = CLIP_ZEROSHOT()

# Prompt engineering

In [None]:
########## Prompt Engine Development Cell ##########

def control_engine(text):
    return text
    
    
def basic_prompt_engine(text):
    return f'{text}, a type of animal' 


from transformers import pipeline

classifier_zero_shot = pipeline("zero-shot-classification")

def object_category_engine(text, classifier = classifier_zero_shot):
    labels = ['animal', 'food', 'fruit', 'car', 'boat', 'airplane', 'appliance', 'electronic', 'accessory', 'furniture', 'kitchen', 'cutlery', 'crockery', 'person', 'fish', 'instrument', 'tool', 'sports equipment', 'vehicle', 'holy place', 'power tool']
    out = classifier(text, labels)
    category = out['labels'][np.argmax(out['scores'])]
    
    return f'{text}, a type of {category}'

# Experiments

## Plotting Functions

In [None]:
########## Plotting a similarity matrix ##########

# Requires as many labels as images but can have more labels than images.

def plot_matrix(images, texts, model = model, decoder = decoder):
    texts = texts.squeeze(1)
    
    image_features = model.encode_image(images.cuda())
    
    text_features = model.encode_text(texts.cuda())
    
    similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
    
    count = len(texts)
    
    decoded_texts = [decoder.padded_decode(texts[i].unsqueeze(0)) for i in range(len(texts))]
        
    plt.figure(figsize=(20, 14))
    plt.imshow(np.around(similarity).astype(int)) 
    
    plt.yticks(range(count), decoded_texts, fontsize=18)
    plt.xticks([])
    
    for i, image in enumerate(images):
        plt.imshow(image.permute(1, 2, 0), extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
    
    for x in range(similarity.shape[1]):
        for y in range(similarity.shape[0]):
            plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)

    for side in ["left", "top", "right", "bottom"]:
        plt.gca().spines[side].set_visible(False)

    plt.xlim([-0.5, count - 0.5])
    plt.ylim([count + 0.5, -2])

    plt.title("Cosine similarity between text and image features", size=20)
    plt.show()

In [None]:
########## Plotting similarity between an image and multiple labels ##########

def plot_similarity_scores(images, texts, model = model, decoder = decoder, top_k = 5, padding = 15.0):
    if len(texts) < top_k:
        top_k = len(texts)
    
    texts = texts.squeeze(1)
    decoded_texts = [decoder.padded_decode(texts[i].unsqueeze(0)) for i in range(len(texts))]
    
    image_features = model.encode_image(images.cuda())
    
    text_features = model.encode_text(texts.cuda())
    
    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    top_probs, top_labels = text_probs.topk(top_k, dim=-1)
    
    top_probs = top_probs.cpu().numpy()
    top_labels = top_labels.cpu().numpy()
    
    fig, ax = plt.subplots(len(images), 2, figsize=(15, 15))
    fig.tight_layout(pad=padding)
    
    if len(images) > 1:
        for i, image in enumerate(images):
            ax[i][0].imshow(image.permute(1, 2, 0))

            y = np.arange(top_probs.shape[-1])
            ax[i][1].barh(y, top_probs[i])
            plt.gca().invert_yaxis()
            plt.gca().set_axisbelow(True)
            ax[i][1].set_yticks(y)
            ax[i][1].set_yticklabels([decoded_texts[index] for index in top_labels[i]])

            ax[i][1].set_xlabel("probability")

        plt.show()
    else:
        for i, image in enumerate(images):
            ax[0].imshow(image.permute(1, 2, 0))
            y = np.arange(top_probs.shape[-1])[::-1]
            ax[1].barh(y, top_probs[i])
            ax[1].set_yticks(y)
            ax[1].set_yticklabels([decoded_texts[index] for index in top_labels[i].numpy()])
            ax[1].set_xlabel("probability")
            asp = np.diff(ax[1].get_xlim())[0] / np.diff(ax[1].get_ylim())[0]
            ax[1].set_aspect(asp)

        plt.show()

In [None]:
########## Functions to produce examples ##########

def image_example_producer(images):
    # takes an array of images and returns a tensor containing the images
    
    for i in range(len(images)):
        images[i] = preprocess(images[i])
    
    dims = tuple([i for i in images[0].shape])
    dims = (len(images),) + dims
    
    
    image_tensor = torch.zeros(dims)
    for i in range(len(images)):
        image_tensor[i] = images[i]
        
    return image_tensor
    
    
def text_example_producer(texts):
    # takes an array of text labels and returns a tensor containing the tokenized text labels
    
    text_lengths = []
    for i in range(len(texts)):    
        texts[i] =  clip.tokenize(texts[i])
    
    dims = tuple([i for i in texts[0].shape])
    dims = (len(texts),) + dims
    
    
    texts_tensor = torch.zeros(dims, dtype=torch.long)
    
    for i in range(len(texts)):
        texts_tensor[i] = texts[i]
    
    return texts_tensor
        
        
def read_image_folder(filepath):
    paths = os.listdir(filepath)
    images = []
    for i in paths:
        images.append(Image.open(os.path.join(filepath, i)))
        
    return images    
    
def random_sample(list_of_images, n):
    indexes = []
    while len(indexes) < n:
        rand = np.random.randint(0, len(list_of_images))
        if rand not in indexes:
            indexes.append(rand)
        else:
            continue
        
        
    sample = []
    for i in indexes:
        sample.append(list_of_images[i])
        
    return sample    
        
        
np.random.seed(42)
        

## Reading in random images and storing random noise texts¶


In [None]:
random_noise_images = read_image_folder('../input/small-clip-comparison-dataset/Noise and Random')


random_noise_texts = ['Abstract Painting, a type of art', 'Bicycle, a type of vehicle', 'Chael Sonnen, a type of bad guy', 'Empire State Building, a type of building', 'Rubiks Cube, a type of toy', 'fhasfdgbeasd, a type of noise']


## Fruit Example

In [None]:
fruit_images = read_image_folder('../input/small-clip-comparison-dataset/Fruit')

fruit_images = image_example_producer(fruit_images + random_sample(random_noise_images, 2))


fruit_texts = ['Avocado', 'Avocado, a type of fruit', 'orange', 'Orange, a type of fruit', 'Apple', 'Apple, a type of fruit', 'Banana', 'Banana, a type of fruit']
fruit_texts_tensor = text_example_producer(fruit_texts + random_noise_texts)
plot_similarity_scores(fruit_images, fruit_texts_tensor, padding = 15)
plot_matrix(fruit_images, fruit_texts_tensor)

## Orange Example

In [None]:
orange_images = read_image_folder('../input/small-clip-comparison-dataset/Oranges')

orange_images = image_example_producer(orange_images + random_sample(random_noise_images, 2))

orange_texts = ['Orange', 'Orange, a type of fruit']

orange_texts_tensor = text_example_producer(orange_texts + random_noise_texts)
plot_similarity_scores(orange_images, orange_texts_tensor, padding = 2)
plot_matrix(orange_images, orange_texts_tensor)

## Satellite Images Example

In [None]:
satellite_images = read_image_folder('../input/small-clip-comparison-dataset/Satelite')

satelite_images = image_example_producer(satellite_images)

satellite_texts = ['Picture of a buckingham palace', 'Satellite picture of a buckingham palace']
satellite_texts_tensor = text_example_producer(satellite_texts)

plot_similarity_scores(satelite_images, satellite_texts_tensor)
plot_matrix(satelite_images, satellite_texts_tensor)

## Vehicle Example 

In [None]:
vehicle_images = read_image_folder('../input/small-clip-comparison-dataset/Vehicles')

vehicle_images = image_example_producer(vehicle_images + random_sample(random_noise_images, 2))

boat_texts = ['Yacht', 'Yacht, a type of vehicle', 'Yacht, a type of boat', 'boat', 'rowboat', 'rowboat, a type of boat', 'rowboat, a type of vehicle' ]
car_texts = ['Ferrari', 'Ferrari, a type of car', 'Ferrari, a type of vehicle', 'Jetta', 'Jetta, a type of car', 'Jetta, a type of vehicle']
other_texts = ['Boeing 747', 'Boeing 747, a type of vehicle', 'Boeing 747, a type of airplane', 'Tank', 'Tank, a type of vehicle']

vehicle_texts = car_texts + boat_texts + other_texts

vehicle_texts_tensor = text_example_producer(vehicle_texts)

plot_similarity_scores(vehicle_images, vehicle_texts_tensor)
plot_matrix(vehicle_images, vehicle_texts_tensor)

## Computer and Laptop Example


In [None]:
pc_and_laptop_images = read_image_folder('../input/small-clip-comparison-dataset/Computers and Laptops')

pc_and_laptop_images = image_example_producer(pc_and_laptop_images + random_sample(random_noise_images, 2))

pc_and_laptop_texts = ['Macbook', 'Macbook, a type of laptop', 'Macbook, a type of computer', 'Lenovo Ideapad', 'Lenovo Ideapad, a type of laptop', 'Lenovo Ideapad, a type of computer', 'A desktop computer']
pc_and_laptop_texts_tensor = text_example_producer(pc_and_laptop_texts + random_noise_texts)

plot_similarity_scores(pc_and_laptop_images, pc_and_laptop_texts_tensor, padding = 5)
plot_matrix(pc_and_laptop_images, pc_and_laptop_texts_tensor)

## Accessory Example

In [None]:
accessory_images = read_image_folder('../input/small-clip-comparison-dataset/Accessory')

accessory_images = image_example_producer(accessory_images + random_sample(random_noise_images, 2))

accessory_texts = ['Backpack', 'Backpack, a type of accessory', 'waist bag', 'waist bag, a type of accessory', 'fanny pack', 'fanny pack, a type of accessory', 'belt bag', 'belt bag, a type of accessory', 'Watch', 'Watch, a type of accessory', 'Sunglasses', 'Sunglasses, a type of accessory', ]
accessory_texts_tensor = text_example_producer(accessory_texts + random_noise_texts)

plot_similarity_scores(accessory_images, accessory_texts_tensor, padding = 2)
plot_matrix(accessory_images, accessory_texts_tensor)

## Fannypack Example

In [None]:
fannypack_images = read_image_folder('../input/small-clip-comparison-dataset/FannyPack')

fannypack_images = image_example_producer(fannypack_images + random_sample(random_noise_images, 2))

fannypack_texts = ['waist bag', 'waist bag, a type of accessory', 'fanny pack', 'fanny pack, a type of accessory', 'belt bag', 'belt bag, a type of accessory', 'Zelda, a type of video game']
fannypack_texts_tensor = text_example_producer(fannypack_texts + random_noise_texts)

plot_similarity_scores(fannypack_images, fannypack_texts_tensor, padding = 2)
plot_matrix(fannypack_images, fannypack_texts_tensor)

## Food Example

In [None]:
food_images = read_image_folder('../input/small-clip-comparison-dataset/Food')

food_images = image_example_producer(food_images + random_sample(random_noise_images, 2))

food_texts = ['Candy Apple', 'Candy Apple, a type of food', 'Pizza', 'Pizza, a type of food', 'Cake', 'Cake, a type of food', 'Hotdog', 'Hotdog, a type of food']
food_texts_tensor = text_example_producer(food_texts + random_noise_texts)

plot_similarity_scores(food_images, food_texts_tensor, padding = 2)
plot_matrix(food_images, food_texts_tensor)

## People and Action Figures Example

In [None]:
people_and_action_figures_images = read_image_folder('../input/small-clip-comparison-dataset/People and action figures')

people_and_action_figures_images = image_example_producer(people_and_action_figures_images + random_sample(random_noise_images, 2))

people_and_action_figures_texts = ['Superman', 'Superman, an action figure', 'Ironman', 'Ironman, an action figure', 'Man', 'Man, a type of human', 'Woman', 'Woman, a type of human']
people_and_action_figures_texts_tensor = text_example_producer(people_and_action_figures_texts + random_noise_texts)

plot_similarity_scores(people_and_action_figures_images, people_and_action_figures_texts_tensor, padding = 2.0)
plot_matrix(people_and_action_figures_images, people_and_action_figures_texts_tensor)

## Weapons Example

In [None]:
weapons_images = read_image_folder('../input/small-clip-comparison-dataset/Weapons')

weapons_images = image_example_producer(weapons_images + random_sample(random_noise_images, 2))

weapons_texts = ['Pistol', 'Pistol, a type of weapon', 'Rifle', 'Rifle, a type of weapon', 'Nerf gun', 'Nerf gun, a type of toy']
weapons_texts_tensor = text_example_producer(weapons_texts + random_noise_texts)

plot_similarity_scores(weapons_images, weapons_texts_tensor, padding = 2.0)
plot_matrix(weapons_images, weapons_texts_tensor)