In [9]:
from transformers import BeitFeatureExtractor, BeitForMaskedImageModeling
from PIL import Image
import requests
from torch import nn

from transformers import GPT2LMHeadModel, GPT2Tokenizer
n_tokens = 20 

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

#Not contained in the __init__.py so I needed to recreate it here
class BeitPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layernorm = (
            nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None)

    def forward(self, hidden_states):
        if self.layernorm is not None:
            patch_tokens = hidden_states[:, 1:, :]
            pooled_output = self.layernorm(patch_tokens.mean(1))
        else:
            pooled_output = hidden_states[:, 0]
        return pooled_output

#Set up the model and the pool layer 
#c denotes the config string 
class BeitForCaptioning(nn.Module): 
    def __init__(self, c):
        super().__init__()
        self.feature_extractor = BeitFeatureExtractor.from_pretrained(c)
        self.model = BeitForMaskedImageModeling.from_pretrained(c)
        self.pooler = BeitPooler(self.model.config)
        self.model = self.model.beit
        self.model.pooler = self.pooler
    def forward(self, image_batch):
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        return self.model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
    
image_model = BeitForCaptioning("microsoft/beit-large-patch16-224-pt22k")
print(image_model(image).shape)

torch.Size([1, 1024])


In [2]:
lm_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
lm_model = GPT2LMHeadModel.from_pretrained("gpt2")

In [3]:
class SoftEmbedding(nn.Module):
    def __init__(self, 
                wte: nn.Embedding,
                n_tokens: int = 10, 
                random_range: float = 0.5,
                initialize_from_vocab: bool = True):
        """appends learned embedding to 
        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
        """
        super(SoftEmbedding, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
                                                                               n_tokens, 
                                                                               random_range, 
                                                                               initialize_from_vocab))
    def initialize_embedding(self, 
                             wte: nn.Embedding,
                             n_tokens: int = 10, 
                             random_range: float = 0.5, 
                             initialize_from_vocab: bool = True):
        """initializes learned embedding
        Args:
            same as __init__
        Returns:
            torch.float: initialized using original schemes
        """
        if initialize_from_vocab:
            return self.wte.weight[:n_tokens].clone().detach()
        return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)
            
    def forward(self, tokens, extra_token_embeddings=None):
        """run forward pass
        Args:
            tokens (torch.long): input tokens before encoding
            extra_token_embeddings (torch.float): Used for image or video captioning
        Returns:
            torch.float: encoding of text concatenated with learned task specifc embedding
        """
        input_embedding = self.wte(tokens[:, self.n_tokens:])
        learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
        if extra_token_embeddings is None:
            return torch.cat([learned_embedding, input_embedding], 1)
        else:
            return torch.cat([learned_embedding, extra_token_embeddings, input_embedding], 1)


In [4]:
s_wte = SoftEmbedding(lm_model.get_input_embeddings(), 
                      n_tokens=n_tokens, 
                      initialize_from_vocab=True)

In [7]:
input_image = image
pooled_embedding = image_model(input_image)
print(pooled_embedding.shape)


#TODO: Add caption token BEFORE adding soft prompting

class captioning_model(nn.Module,
                       img_model = "microsoft/beit-large-patch16-224-pt22k", lm = "gpt2",
                       n_tokens = 20):
    def __init__():
        self.embed = BeitForCaptioning(img_model)
        self.lm_tokenizer = GPT2Tokenizer.from_pretrained(lm)
        self.lm_model = GPT2LMHeadModel.from_pretrained(lm)

        self.s_wte = SoftEmbedding(self.lm_model.get_input_embeddings(), 
                      n_tokens=n_tokens, 
                      initialize_from_vocab=True)
        
        #Freeze the language model. We only want to tune the image model and the prefix.
        for param in self.lm_model.parameters():
            param.requires_grad = False
        self.loss_start = n_tokens+1
            
    #Captions effectively acts as "labels"
    """
    Args:
        images (list PIL): List of images
        captions (list str): List of strings representing captions. Prepend [CAP] token. Can eventually be expanded
        to multiple kinds of captions.
    """
    
    def foward(self, images, captions=None):
        labels = None
        input_ids = None
        if captions is not None:
            toks = self.lm_tokenizer(captions, return_tensors="pt").cuda()
            input_ids = toks['input_ids']
            #We do not want to perform loss on the prefix or image toks
            #The 1 at the beginning signifies a batch size of one. Might need to be more careful when
            #we add batching
            labels = torch.cat([torch.full((1,self.loss_start), -100), input_ids], 1).cuda()
        else:
            #If we arent using labels, just append a caption token.
            start_string = ['[CAP]'] * len(images)
            toks = self.lm_tokenizer(captions, return_tensors="pt").cuda()
            input_ids = toks['input_ids']
            
            
            
        #1024 dim representation per image, patch size 16x16
        #~ bs x hdim
        image_embeddings = self.embed(images)
        inputs_embeds = self.s_wte(input_ids, image_embeddings)
            
        
        return self.lm_model(inputs_embeds=inputs_embeddings, labels=labels)
    
     def generate(self, images):       
        labels = None
        input_ids = None
        if captions is not None:
            toks = self.lm_tokenizer(captions, return_tensors="pt").cuda()
            input_ids = toks['input_ids']
            #We do not want to perform loss on the prefix or image toks
            #The 1 at the beginning signifies a batch size of one. Might need to be more careful when
            #we add batching
            labels = torch.cat([torch.full((1,self.loss_start), -100), input_ids], 1).cuda()
        else:
            #If we arent using labels, just append a caption token.
            start_string = ['[CAP]'] * len(images)
            toks = self.lm_tokenizer(captions, return_tensors="pt").cuda()
            input_ids = toks['input_ids']
            
            
            
        #1024 dim representation per image, patch size 16x16
        #~ bs x hdim
        image_embeddings = self.embed(images)
        inputs_embeds = self.s_wte(input_ids, image_embeddings)
            
        return self.lm_model.generate(inputs_embeds=inputs_embeds)

torch.Size([1, 1024])
