In [1]:
import json
import os
from transformers import BeitFeatureExtractor, BeitForMaskedImageModeling
from PIL import Image
import requests
import torch
from torch import nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer


import torchvision.datasets as dset
import torchvision.transforms as transforms

In [2]:
captions = dset.CocoCaptions(root='data/coco/train2014',
                              annFile = 'data/coco/annotations/captions_train2014.json',
                              transform=transforms.ToTensor())

loading annotations into memory...
Done (t=0.36s)
creating index...
index created!


In [3]:
n_tokens = 20 

#Not contained in the __init__.py so I needed to recreate it here
class BeitPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layernorm = (
            nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None)

    def forward(self, hidden_states):
        if self.layernorm is not None:
            patch_tokens = hidden_states[:, 1:, :]
            pooled_output = self.layernorm(patch_tokens.mean(1))
        else:
            pooled_output = hidden_states[:, 0]
        return pooled_output

#Set up the model and the pool layer 
#c denotes the config string 
class BeitForCaptioning(nn.Module): 
    def __init__(self, c):
        super().__init__()
        self.feature_extractor = BeitFeatureExtractor.from_pretrained(c)
        self.model = BeitForMaskedImageModeling.from_pretrained(c)
        self.pooler = BeitPooler(self.model.config)
        self.model = self.model.beit
        self.model.pooler = self.pooler
    def forward(self, image_batch):
        inputs = self.feature_extractor(images=image_batch, return_tensors="pt").to('cuda')
        return self.model(**inputs, output_hidden_states=True, return_dict=True).pooler_output

In [4]:
class SoftEmbedding(nn.Module):
    def __init__(self, 
                wte: nn.Embedding,
                n_tokens: int = 10, 
                random_range: float = 0.5,
                initialize_from_vocab: bool = True):
        """appends learned embedding to 
        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
        """
        super(SoftEmbedding, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
                                                                               n_tokens, 
                                                                               random_range, 
                                                                               initialize_from_vocab))
    def initialize_embedding(self, 
                             wte: nn.Embedding,
                             n_tokens: int = 10, 
                             random_range: float = 0.5, 
                             initialize_from_vocab: bool = True):
        """initializes learned embedding
        Args:
            same as __init__
        Returns:
            torch.float: initialized using original schemes
        """
        if initialize_from_vocab:
            return self.wte.weight[:n_tokens].clone().detach()
        return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)
            
    def forward(self, tokens, extra_token_embeddings=None):
        """run forward pass
        Args:
            tokens (torch.long): input tokens before encoding
            extra_token_embeddings (torch.float): Used for image or video captioning
        Returns:
            torch.float: encoding of text concatenated with learned task specifc embedding
        """
        print(tokens.shape)
        input_embedding = self.wte(tokens[:, self.n_tokens:])
        learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
        if extra_token_embeddings is None:
            return torch.cat([learned_embedding, input_embedding], 1)
        else:
            print(learned_embedding.shape)
            print(extra_token_embeddings.shape)
            print(input_embedding.shape)
            return torch.cat([learned_embedding, extra_token_embeddings.unsqueeze(1), input_embedding], 1)


In [5]:
#TODO: Add caption token BEFORE adding soft prompting

class captioning_model(nn.Module):
    def __init__(self, img_model = "microsoft/beit-large-patch16-224-pt22k", lm = "gpt2", n_tokens = 20):
        super().__init__()
        self.embed = BeitForCaptioning(img_model).to('cuda')
        self.lm_tokenizer = GPT2Tokenizer.from_pretrained(lm)
        self.lm_model = GPT2LMHeadModel.from_pretrained(lm).to('cuda')
        self.lm_tokenizer.pad_token = self.lm_tokenizer.eos_token
        
        self.s_wte = SoftEmbedding(self.lm_model.get_input_embeddings(), 
                      n_tokens=n_tokens, 
                      initialize_from_vocab=True).to('cuda')
        self.project = nn.Linear(1024, 768).to('cuda')
        #Freeze the language model. We only want to tune the image model and the prefix.
        for param in self.lm_model.parameters():
            param.requires_grad = False
        self.loss_start = n_tokens+1
            
    #Captions effectively acts as "labels"
    """
    Args:
        images (list PIL): List of images
        captions (list str): List of strings representing captions. Prepend [CAP] token. Can eventually be expanded
        to multiple kinds of captions.
    """
    
    def forward(self, images, captions=None):
        labels = None
        input_ids = None
        bs = 1
        if captions is not None:
            if type(captions) == list:
                bs = len(captions)
            print(bs)
            toks = self.lm_tokenizer(captions, padding=True, truncation=True, return_tensors="pt").to('cuda')
            input_ids = toks['input_ids']
            #We do not want to perform loss on the prefix or image toks
            #The 1 at the beginning signifies a batch size of one. Might need to be more careful when
            #we add batching
            labels = torch.cat([torch.full((bs, self.loss_start), -100).to("cuda"), input_ids], 1).cuda()
            input_ids = torch.cat([torch.full((bs, self.loss_start - 1), -100).to("cuda"), input_ids], 1).cuda()

        else:
            #If we arent using labels, just append a caption token.
            start_string = ['[CAP]'] * len(images)
            toks = self.lm_tokenizer(captions, padding=True, truncation=True, return_tensors="pt").to('cuda')
            input_ids = toks['input_ids']
            input_ids = torch.cat([torch.full((bs, self.loss_start - 1), -100).to("cuda"), input_ids], 1).cuda()

            
            
            
        #1024 dim representation per image, patch size 16x16
        #~ bs x hdim
        image_embeddings = self.project(self.embed(images))
        inputs_embeds = self.s_wte(input_ids, image_embeddings)
            
        print(inputs_embeds.shape)
        print(labels.shape)
        return self.lm_model(inputs_embeds=inputs_embeds, labels=labels)
    
    def generate(self, images):       
        labels = None
        input_ids = None
        if captions is not None:
            toks = self.lm_tokenizer(captions, return_tensors="pt").cuda()
            input_ids = toks['input_ids']
            #We do not want to perform loss on the prefix or image toks
            #The 1 at the beginning signifies a batch size of one. Might need to be more careful when
            #we add batching
            labels = torch.cat([torch.full((bs, self.loss_start), -100).to("cuda"), input_ids], 1).cuda()

            input_ids = torch.cat([torch.full((bs, self.loss_start - 1), -100).to("cuda"), input_ids], 1).cuda()
        else:
            #If we arent using labels, just append a caption token.
            start_string = ['[CAP]'] * len(images)
            toks = self.lm_tokenizer(captions, return_tensors="pt").cuda()
            input_ids = toks['input_ids']
            
            
            
        #1024 dim representation per image, patch size 16x16
        #~ bs x hdim
        image_embeddings = self.project(self.embed(images))
        inputs_embeds = self.s_wte(input_ids, image_embeddings)
            
        return self.lm_model.generate(inputs_embeds=inputs_embeds)

In [6]:
model = captioning_model()

In [7]:
from numpy.random import choice
from torch.optim import AdamW
from tqdm import tqdm
 
class CaptionChoice(Dataset):
    def __init__(self, data):
        self.data = data
        pass
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return item[0], choice(item[1])
    
    
    
data = CaptionChoice(captions)
train_dataloader = train_dataloader(data, batch_size=8)
optimizer = AdamW(model, lr=3e-5)



optimizer.zero_grad()
#model.lm_tokenizer(text, padding=True, truncation=True, return_tensors="pt")
loss = model(images=images, captions=text).loss()

loss.backwards()
optimizer.step()

2
torch.Size([2, 32])
torch.Size([2, 20, 768])
torch.Size([2, 768])
torch.Size([2, 12, 768])
torch.Size([2, 33, 768])
torch.Size([2, 33])


CausalLMOutputWithCrossAttentions(loss=tensor(7.1717, device='cuda:0', grad_fn=<NllLossBackward>), logits=tensor([[[ -37.7948,  -38.3075,  -40.5287,  ...,  -47.4307,  -46.8450,
           -36.8976],
         [ -69.7115,  -68.8415,  -68.4465,  ...,  -75.0614,  -74.2622,
           -64.0662],
         [ -50.3697,  -47.9626,  -51.2256,  ...,  -63.0675,  -60.9991,
           -52.0711],
         ...,
         [-107.9593, -108.6686, -113.3835,  ..., -115.2313, -113.6276,
          -110.0203],
         [ -75.9620,  -77.0478,  -81.2779,  ...,  -88.5497,  -88.9182,
           -80.3618],
         [-106.2581, -103.6757, -102.5214,  ..., -116.6622, -117.3206,
          -101.5299]],

        [[ -37.7948,  -38.3075,  -40.5287,  ...,  -47.4307,  -46.8450,
           -36.8976],
         [ -69.7115,  -68.8415,  -68.4465,  ...,  -75.0614,  -74.2622,
           -64.0662],
         [ -50.3697,  -47.9626,  -51.2256,  ...,  -63.0675,  -60.9991,
           -52.0711],
         ...,
         [ -97.8747,  -91.0