In [None]:
!pip install transformers

In [1]:
from transformers import AutoTokenizer, BertForMaskedLM, CLIPModel, AutoProcessor, CLIPTokenizer, GPT2LMHeadModel
from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel

# for image captioning
# load a fine-tuned image captioning model and corresponding tokenizer and image processor
ic_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
ic_tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
ic_image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

# for CLIP
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")


Some weights of the model checkpoint at nlpconnect/vit-gpt2-image-captioning were not used when initializing VisionEncoderDecoderModel: ['decoder.transformer.h.2.crossattention.masked_bias', 'decoder.transformer.h.0.crossattention.bias', 'decoder.transformer.h.6.attn.masked_bias', 'decoder.transformer.h.2.attn.bias', 'decoder.transformer.h.7.crossattention.masked_bias', 'decoder.transformer.h.3.crossattention.masked_bias', 'decoder.transformer.h.6.attn.bias', 'decoder.transformer.h.8.attn.bias', 'decoder.transformer.h.4.crossattention.masked_bias', 'decoder.transformer.h.3.attn.bias', 'decoder.transformer.h.5.attn.bias', 'decoder.transformer.h.5.crossattention.masked_bias', 'decoder.transformer.h.1.attn.masked_bias', 'decoder.transformer.h.10.crossattention.masked_bias', 'decoder.transformer.h.1.crossattention.bias', 'decoder.transformer.h.0.attn.masked_bias', 'decoder.transformer.h.9.crossattention.masked_bias', 'decoder.transformer.h.7.attn.bias', 'decoder.transformer.h.1.crossattent

In [2]:
import requests
from PIL import Image

url = "https://farm8.staticflickr.com/7020/6810252887_01e3d8e4e6_z.jpg"
image = Image.open(requests.get(url, stream=True).raw)

In [9]:
import torch.nn as nn


def compute_loss(model, pixel_values, labels, sample_weights):
    print("compute_loss", flush=True)
    output = model(pixel_values=pixel_values, labels=labels)

    criterion = nn.CrossEntropyLoss(reduction='none')

    print(output.logits.size())
    loss = criterion(
        output.logits[:, :-1, :].reshape(-1, 50257),
        labels[:, 1:].reshape(-1))

    loss = loss.reshape(output.logits.size(dim=0), -1)
    sample_weights = sample_weights.squeeze(0).unsqueeze(1).repeat(1, output.logits.size(dim=1) - 1)

    print(loss.size(), sample_weights.size(), flush=True)
    loss = loss * sample_weights
    # loss.mean().backward()

    return loss.mean()


def compute_rl_loss(model, image_processor, tokenizer, img, gt_labels, reward_fct):
    pixel_values = image_processor(img, return_tensors="pt").pixel_values

    generated_ids = model.generate(
        pixel_values,
        max_new_tokens=40,
        num_beams=5,
        no_repeat_ngram_size=2,
        early_stopping=True,
        num_return_sequences=3,
    )
    generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    print("Generated texts")
    print(generated_ids.size())
    print(generated_texts, flush=True)

    score = reward_fct(generated_texts, generated_ids, gt_labels, img)  # return tensor
    mu = score.mean()
    sample_weights = score - mu
    print(score)
    return compute_loss(model, pixel_values, generated_ids, sample_weights)

In [4]:
import torch
from PIL import Image


def compute_image_representation_from_image_instance(clip_processor, clip_model, image):
    inputs = clip_processor(images=image, return_tensors="pt")
    pixel_values = inputs['pixel_values']
    visual_outputs = clip_model.vision_model(pixel_values=pixel_values)
    image_embeds = visual_outputs[1]
    image_embeds = clip_model.visual_projection(image_embeds)  # [1 x embed_dim]
    image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
    return image_embeds


def compute_image_text_similarity_via_embeddings(image_embeds, text_embeds, clip_model):
    text_embeds = text_embeds.view(image_embeds.shape[0], -1, text_embeds.shape[-1])
    text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
    image_embeds = image_embeds.unsqueeze(-1)
    logit_scale = clip_model.logit_scale.exp()
    logits_per_text = torch.matmul(text_embeds, image_embeds) * logit_scale
    logits_per_image = logits_per_text.squeeze(-1)
    return logits_per_image.softmax(dim=1)  # , logits_per_image/logit_scale # batch x len(text_list)


def compute_text_representation(text_list, clip_model, clip_tokenizer):
    # text_list: a list of text
    text_inputs = clip_tokenizer(text_list, padding=True, return_tensors="pt",
                                 max_length=clip_tokenizer.max_len_single_sentence + 2, truncation=True)
    # self.tokenizer.max_len_single_sentence + 2 = 77
    input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask']

    text_outputs = clip_model.text_model(
        input_ids=input_ids,
        attention_mask=attention_mask
    )
    text_embeds = text_outputs[1]
    text_embeds = clip_model.text_projection(text_embeds)
    return text_embeds


def compute_image_text_similarity_via_raw_text(image_embeds, text_list, clip_model, clip_tokenizer):
    text_embeds = compute_text_representation(text_list, clip_model, clip_tokenizer)
    return compute_image_text_similarity_via_embeddings(image_embeds, text_embeds, clip_model)

In [5]:
def reward_clip(clip_model, clip_processor, clip_tokenizer):
    def _reward_fct(generated_texts, generated_ids, gt_labels, img):
        image_embeds = compute_image_representation_from_image_instance(clip_processor, clip_model, img)

        return compute_image_text_similarity_via_raw_text(image_embeds, generated_texts, clip_model, clip_tokenizer)

    return _reward_fct



In [6]:
gt_labels = ic_tokenizer("a kitten plays with a colorful ball with a dog next to it.", return_tensors="pt")["input_ids"]

In [10]:
loss = compute_rl_loss(
    ic_model, ic_image_processor, ic_tokenizer,
    image, gt_labels,
    reward_clip(clip_model, clip_processor, clip_tokenizer))

Generated texts
torch.Size([3, 13])
['a dog and a cat playing with a ball ', 'a dog and a cat playing with a toy ', 'a dog and a cat are playing with a ball ']
tensor([[0.2021, 0.0723, 0.7257]], grad_fn=<SoftmaxBackward0>)
compute_loss
torch.Size([3, 13, 50257])
torch.Size([3, 12]) torch.Size([3, 12])


In [8]:
loss

tensor(0.1059, grad_fn=<MeanBackward0>)