# Assignment 02: Validation of CLIP

#### Note: Pre-Trained model and Dataset are too big to upload

In [1]:
import clip
import os
from torch import nn
import numpy as np
import torch
import torch.nn.functional as nnf
import sys
from sentence_transformers import SentenceTransformer, util
from typing import Tuple, List, Union, Optional
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup, VisionEncoderDecoderModel, GPT2TokenizerFast, ViTImageProcessor
from tqdm import tqdm, trange
import skimage.io as io
import PIL.Image
from IPython.display import Image
import pandas as pd
from PIL import Image
import io
import urllib.parse as parse
import matplotlib.pyplot as plt
import requests

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Basic Setting for Clip
N = type(None)
V = np.array
ARRAY = np.ndarray
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
VS = Union[Tuple[V, ...], List[V]]
VN = Union[V, N]
VNS = Union[VS, N]
T = torch.Tensor
TS = Union[Tuple[T, ...], List[T]]
TN = Optional[T]
TNS = Union[Tuple[TN, ...], List[TN]]
TSN = Optional[TS]
TA = Union[T, ARRAY]

model_path = os.path.join('pretrained_models', 'model_weights.pt')

In [3]:
class MLP(nn.Module):

    def forward(self, x: T) -> T:
        return self.model(x)

    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) -1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)


class ClipCaptionModel(nn.Module):
    def get_dummy_token(self, batch_size: int, device: D) -> T:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

    def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
        return out

    def __init__(self, prefix_length: int, prefix_size: int = 512):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length
        self.gpt = GPT2LMHeadModel.from_pretrained("./gpt2_tokenizer")
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        if prefix_length > 10:  # not enough memory
            self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
        else:
            self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))

class ClipCaptionPrefix(ClipCaptionModel):

    def parameters(self, recurse: bool = True):
        return self.clip_project.parameters()
    
    def train(self, mode: bool = True):
        super(ClipCaptionPrefix, self).train(mode)
        self.gpt.eval()
        return self

In [4]:
def generate2(
        model,
        tokenizer,
        tokens=None,
        prompt=None,
        embed=None,
        entry_count=1,
        entry_length=67,  # maximum number of words
        top_p=0.8,
        temperature=1.,
        stop_token: str = '.',
):
    model.eval()
    generated_num = 0
    generated_list = []
    stop_token_index = tokenizer.encode(stop_token)[0]
    filter_value = -float("Inf")
    device = next(model.parameters()).device

    with torch.no_grad():

        for entry_idx in trange(entry_count):
            if embed is not None:
                generated = embed
            else:
                if tokens is None:
                    tokens = torch.tensor(tokenizer.encode(prompt))
                    tokens = tokens.unsqueeze(0).to(device)

                generated = model.gpt.transformer.wte(tokens)

            for i in range(entry_length):

                outputs = model.gpt(inputs_embeds=generated)
                logits = outputs.logits
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                                                    ..., :-1
                                                    ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value
                next_token = torch.argmax(logits, -1).unsqueeze(0)
                next_token_embed = model.gpt.transformer.wte(next_token)
                if tokens is None:
                    tokens = next_token
                else:
                    tokens = torch.cat((tokens, next_token), dim=1)
                generated = torch.cat((generated, next_token_embed), dim=1)
                if stop_token_index == next_token.item():
                    break

            output_list = list(tokens.squeeze().cpu().numpy())
            output_text = tokenizer.decode(output_list)
            generated_list.append(output_text)

    return generated_list[0]

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained("./gpt2_tokenizer")

In [7]:
prefix_length = 10

model = ClipCaptionModel(prefix_length)

print(model_path)
model.load_state_dict(torch.load(model_path, map_location=CPU),strict=False) 

model = model.eval() 
model = model.to(device)

pretrained_models/model_weights.pt


In [8]:
# load a fine-tuned image captioning model and corresponding tokenizer and image processor
finetuned_model = VisionEncoderDecoderModel.from_pretrained("./ViT-GPT")
finetuned_tokenizer = GPT2TokenizerFast.from_pretrained("./ViT-GPT")
finetuned_image_processor = ViTImageProcessor.from_pretrained("./ViT-GPT")

In [9]:
# a function to perform inference
def get_caption(model, image_processor, tokenizer, image):
    # preprocess the image
    img = image_processor(image, return_tensors="pt").to(device)
    # generate the caption (using greedy decoding by default)
    output = model.generate(**img)
    # decode the output
    caption = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    return caption

In [10]:
df = pd.read_parquet('COCO.parquet')
sentence_model = SentenceTransformer('./sentence_transformer')
image_size = len(df)
image_data = df['image']
label_data = df['sentences_raw']
baseline_scores = []
clip_scores = []

In [None]:
for i in range(len(df)):
    image = Image.open(io.BytesIO(image_data[i]['bytes']))
    label = label_data[i]
    use_beam_search = False
    if image.mode == 'L':
        image = image.convert('RGB')
    #if i <= 10:
        #display(image)
    caption_baseline = get_caption(finetuned_model.to(device), finetuned_image_processor, finetuned_tokenizer, image)
    image = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
        prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
    caption_clip = generate2(model, tokenizer, embed=prefix_embed)
    
    sentences1 = [caption_baseline]
    sentences2 = [caption_clip]
    
    embeddings1 = sentence_model.encode(sentences1, convert_to_tensor=True)
    embeddings2 = sentence_model.encode(sentences2, convert_to_tensor=True)
    
    baseline_cosine_scores = 0
    clip_cosine_scores = 0
    for j in range(len(label)):
        sentences3 = [label[j]]
        embeddings3 = sentence_model.encode(sentences3, convert_to_tensor=True)
        baseline_cosine_scores += util.cos_sim(embeddings1, embeddings3)
        clip_cosine_scores += util.cos_sim(embeddings2, embeddings3)
    
    baseline_cosine_scores /= len(label)
    clip_cosine_scores /= len(label)
    
    baseline_scores.append(baseline_cosine_scores)
    clip_scores.append(clip_cosine_scores)
    
   # if i <= 10:
        #print('\n')
        #print("caption_baseline: ", caption_baseline)
        #print("caption_clip: ", caption_clip)
        #print("Label: ", label)
        
baseline_value = sum(baseline_scores) / len(baseline_scores)
clip_value = sum(clip_scores) / len(clip_scores)

In [None]:
print("Baseline_value: ", baseline_value)
print("Clip_value: ", clip_value)

In [23]:
len(baseline_scores)

6455

In [26]:
clip_scores = [tensor.item() for tensor in clip_scores]

In [43]:
baseline_scores = [tensor.item() for tensor in baseline_scores]

In [45]:
clip_scores.index(sorted(set(clip_scores))[0])

2732

In [46]:
clip_scores.index(sorted(set(clip_scores))[1])

2725

In [47]:
baseline_scores.index(sorted(set(baseline_scores))[0])

1057

In [48]:
baseline_scores.index(sorted(set(baseline_scores))[1])

3312