In [73]:
from rich import print

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("distilgpt2", add_cross_attention = True)
print(model.config)


Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['transformer.h.5.crossattention.c_attn.weight', 'transformer.h.2.crossattention.q_attn.weight', 'transformer.h.4.ln_cross_attn.weight', 'transformer.h.5.crossattention.bias', 'transformer.h.3.ln_cross_attn.weight', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.5.ln_cross_attn.weight', 'transformer.h.3.crossattention.c_proj.weight', 'transformer.h.3.crossattention.q_attn.weight', 'transformer.h.4.crossattention.c_proj.weight', 'transformer.h.2.crossattention.masked_bias', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.4.crossattention.c_proj.bias', 'transformer.h.0.crossattention.masked_bias', 'transformer.h.3.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.3.crossattention.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.ln_cross_a

In [75]:
# Assume we have the following input sequence
sample_sequence = ["I love deep learning and transformers.", "Wow, look at that beautiful sunset.", "I wonder what is going to happen next."]

def generate_qa_tokens(input_sequence_list):
    question_tokens = []
    answer_tokens = []
    for input_sequence in input_sequence_list:
        question_token = input_sequence[:len(input_sequence) // 2]
        answer_token = input_sequence[len(input_sequence) // 2:]
        question_tokens.append(question_token)
        answer_tokens.append(answer_token)
    
    question_tokens = tokenizer(question_tokens, padding=True, truncation=True, return_tensors="pt")
    answer_tokens = tokenizer(answer_tokens, padding=True, truncation=True, return_tensors="pt")
    return question_tokens, answer_tokens
        
question_encoder_tokens, answer_encoder_tokens = generate_qa_tokens(sample_sequence)       
outputs = model(**question_encoder_tokens, encoder_hidden_states=torch.randn(3, 10, 768), labels=answer_encoder_tokens["input_ids"])
loss = outputs.loss
logits = outputs.logits

print(loss, logits.shape)

In [80]:
from typing import Dict, Optional

import torch.nn as nn
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


class SimpleVQATransformer(nn.Module):
    def __init__(
        self,
        image_encoder: nn.Module,
        image_encoder_transforms: nn.Module,
        image_encoder_num_features: int,
        text_encoder: nn.Module,
        text_encoder_num_features: int,
        text_encoder_transforms: nn.Module,
    ):
        super().__init__()

        self.image_encoder = image_encoder
        self.text_encoder = text_encoder

        self.image_encoder_num_features = image_encoder_num_features
        self.text_encoder_num_features = text_encoder_num_features
        
        self.image_encoder_transforms = image_encoder_transforms
        self.text_encoder_transforms = text_encoder_transforms

        self.text_decoder_tokenizer = AutoTokenizer.from_pretrained(
            "distilgpt2"
        )
        self.text_decoder_tokenizer.pad_token = tokenizer.eos_token
        
        
        self.text_decoder = AutoModelForCausalLM.from_pretrained("distilgpt2",  add_cross_attention=True)
        
        self.combine_embeddings_linear = nn.Linear(
            image_encoder_num_features + text_encoder_num_features,
            self.text_decoder.config.vocab_size,
        )

    def forward(
        self,
        input_dict: Optional[Dict] = None,
        image_encoder_tokens: Optional[torch.Tensor] = None,
        question_encoder_tokens: Optional[torch.Tensor] = None,
        question_decoder_tokens: Optional[torch.Tensor] = None,
        answer_decoder_tokens: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        if input_dict is not None:
            image_encoder_tokens = input_dict["image_tokens"]
            question_encoder_tokens = input_dict["question_tokens"]

        image_embeddings = self.image_encoder(image_encoder_tokens)["image_features"]

        question_text_embeddings = self.text_encoder(question_encoder_tokens)[
            "text_features"
        ]

        concat_embeddings = torch.cat(
            [image_embeddings, question_text_embeddings], dim=1
        )

        combine_embeddings = self.combine_embeddings_linear(concat_embeddings)
        combine_embeddings = combine_embeddings.unsqueeze(1)

        if answer_encoder_tokens is not None:
            return self.text_decoder(**question_decoder_tokens, encoder_hidden_states=combine_embeddings, labels=answer_decoder_tokens["input_ids"])
        else:
            return self.text_decoder(**question_decoder_tokens, encoder_hidden_states=combine_embeddings)

    def get_transforms(self):
        return {"text_decoder": lambda x: self.text_decoder_tokenizer(x), 
                "image_encoder": self.image_encoder_transforms, 
                "text_encoder": self.text_encoder_transforms}


In [81]:
from gate.models.backbones.clip import CLIPAdapter


backbone_model = CLIPAdapter(model_name="openai/clip-vit-base-patch16", pretrained=True)
clip_transforms = backbone_model.get_transforms()
vqa_model = SimpleVQATransformer(image_encoder=backbone_model.vision_model, 
                                 image_encoder_transforms=clip_transforms["image"], 
                                 image_encoder_num_features=512, 
                                 text_encoder=backbone_model.text_model, 
                                 text_encoder_transforms=clip_transforms["text"], 
                                 text_encoder_num_features=512)
vqa_transforms = vqa_model.get_transforms()

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['transformer.h.5.crossattention.c_attn.weight', 'transformer.h.2.crossattention.q_attn.weight', 'transformer.h.4.ln_cross_attn.weight', 'transformer.h.5.crossattention.bias', 'transformer.h.3.ln_cross_attn.weight', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.5.ln_cross_attn.weight', 'transformer.h.3.crossattention.c_proj.weight', 'transformer.h.3.crossattention.q_attn.weight', 'transformer.h.4.crossattention.c_proj.weight', 'transformer.h.2.crossattention.masked_bias', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.4.crossattention.c_proj.bias', 'transformer.h.0.crossattention.masked_bias', 'transformer.h.3.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.3.crossattention.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.ln_cross_a

In [88]:
from collections import defaultdict
from typing import Optional
from urllib.request import urlopen
import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPProcessor
from transformers.models.clip.modeling_clip import CLIPOutput
import timm
import PIL.Image as Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

vqa_model.to("cuda")
img = Image.open(
            urlopen(
                "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
            )
        )
images = [img, img, img]
questions = ["What is in the image?", "Who is in the image?", "What is the weather like?"]
answers = ["beignets", "a cat", "sunny"]

encoder_images = [vqa_transforms["image_encoder"](image) for image in images]
encoder_questions = [vqa_transforms["text_encoder"](question) for question in questions.copy()]
decoder_questions = [vqa_transforms["text_decoder"](question) for question in questions.copy()]
decoder_answers = [vqa_transforms["text_decoder"](answer) for answer in answers]

encoder_images = torch.stack(encoder_images).to("cuda")
encoder_questions = torch.stack(encoder_questions).to("cuda")
# decoder_questions = torch.stack(decoder_questions).to("cuda")
# decoder_answers = torch.stack(decoder_answers).to("cuda")

output = vqa_model(image_encoder_tokens=encoder_images, question_encoder_tokens=encoder_questions, question_decoder_tokens=decoder_questions, answer_decoder_tokens=decoder_answers)




: 

: 