In [1]:
! pip install -q -U trl transformers accelerate git+https://github.com/huggingface/peft.git --q
! pip install -q datasets bitsandbytes einops wandb --q
! pip install transformers==4.28.0 --q
! pip install --upgrade datasets transformers --q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m150.9/150.9 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m19.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m270.9/270.9 kB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.7/79.7 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m16.1 MB/s[

In [2]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

In [3]:
file_name = f'captions_val2017.json'

with open(file_name, 'r') as f:
    data = json.load(f)

def retrieve_caption_image(data, image_name):
    captions = []
    
    img_occurence = next((img for img in data['images'] if img['file_name'] == image_name), None)

    if img_occurence is not None:
        annotation_ids = [annotation['id'] for annotation in data['annotations'] if annotation['image_id'] == img_occurence['id']]
        captions = [annotation['caption'] for annotation in data['annotations'] if annotation['id'] in annotation_ids]

    return captions

In [4]:
class Projection_Model(nn.Module):
    def __init__(
        self, 
        input_hidden_size: int, 
        hidden_size: int, 
        num_layers: int, 
        width: int
    ):
        super(Projection_Model, self).__init__()
        self.layers = nn.ModuleList()

        for _ in range(width):
            layer = [nn.Linear(input_hidden_size, hidden_size)]

            for _ in range(1, num_layers):
                layer.append(nn.GELU())
                layer.append(nn.Linear(hidden_size, hidden_size))

            self.layers.append(nn.Sequential(*layer))

    def forward(self, x):
        return torch.cat([layer(x) for layer in self.layers], dim=-2)


def build_model(
    input_hidden_size: int, 
    hidden_size: int, 
    num_layers: int, 
    num_tokens: int
):
    return Projection_Model(
        input_hidden_size, 
        hidden_size, 
        num_layers, 
        num_tokens
    )

In [5]:
class Image_Projection_Phi(nn.Module):
    def __init__(
        self,
        clip_embeddings : int = 512,
        token_embeddings : int = 2560,
        projection_tokens : int = 4,
        projection_layers : int = 4
    ):
        super().__init__()
        model_name = "microsoft/phi-2"

        self.projection_tokens = projection_tokens
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.vocab_size = len(tokenizer)
        self.tokenizer.pad_token = tokenizer.eos_token
        self.phi2Model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to("cuda")
        self.token_embedding = self.phi2Model.get_submodule('model.embed_tokens')

        self.layer1 = build_model(
            clip_embeddings, token_embeddings, 1, self.projection_tokens
        ).to("cuda")

        for name, param in self.phi2Model.named_parameters():
                param.requires_grad = False


    def forward(self, x, captions):
        x = self.layer1(x)
        caption_token_embeddings = self.token_embedding(captions)
        inputs = torch.concat((x, caption_token_embeddings), axis=-2)
        outputs = self.phi2Model(inputs_embeds=inputs)
        predictions = self.generate_text_from_embeddings(outputs.logits)

        loss = F.cross_entropy(
            outputs.logits[:, self.projection_tokens:, :].reshape(-1, outputs.logits.size(-1)), captions.reshape(-1)
        )

        return loss, predictions
    

    def generate_text_from_embeddings(self, logits):
        probabilities = logits.softmax(dim=2)  # Softmax along the last dimension
        predicted_indices = torch.argmax(probabilities, dim=2) # Find the index of the class with highest probability
        predicted_texts = [self.tokenizer.decode(sequence) for sequence in predicted_indices] # Decode each sequence

        return predicted_texts

In [6]:
class ImageEmbeddingDataset(Dataset):

    def __init__(self, imageIDs, imageEmbeddings, annotations_data, tokenizer):
        self.imageIDs = imageIDs
        self.imageEmbeddings = imageEmbeddings
        self.annotations_data = annotations_data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.imageIDs)

    def __getitem__(self, index):
        max_len = 256

        image_id = self.imageIDs[index]
        image_embedding = self.imageEmbeddings[index]

        image_captions = retrieve_caption_image(self.annotations_data, image_id)

        if len(image_captions) == 0:
            captions = ""
        else:
            captions = image_captions[0]

        caption_tokens = tokenizer.encode(captions, add_special_tokens=True)
        padded_caption_tokens = caption_tokens + [tokenizer.pad_token_id] * (max_len - len(caption_tokens))

        return {
            'image_id': image_id,
            'image_embedding': image_embedding,
            'caption_tokens': torch.tensor(padded_caption_tokens),
            'captions': captions
        }

In [7]:
model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

image_embedding_dict = torch.load("img_embeddings.pth")

dataset = ImageEmbeddingDataset(
    list(image_embedding_dict.keys()),
    list(image_embedding_dict.values()),
    data,
    tokenizer
)

dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/7.34k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/1.08k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [8]:
for batch in dataloader:
        captions = batch['caption_tokens'] # considering only 1 caption to verify shape
        break

print(captions.shape)

torch.Size([1, 256])


In [9]:
model = Image_Projection_Phi()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/861 [00:00<?, ?B/s]

configuration_phi.py:   0%|          | 0.00/9.26k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:
- configuration_phi.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_phi.py:   0%|          | 0.00/62.7k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:
- modeling_phi.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors.index.json:   0%|          | 0.00/35.7k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/564M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/74.0 [00:00<?, ?B/s]

In [10]:
num_epochs = 15
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

for epoch in range(num_epochs):
    model.train()

    for batch in dataloader:
        captions = batch['caption_tokens'].to('cuda')
        embeddings = batch['image_embedding'].to('cuda')
        loss, predictions = model(embeddings, captions)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch : {epoch + 1}/{num_epochs}\nLoss  : {loss.item()}")
    print("Caption    : ", batch['captions'])
    print("Prediction : ", [text.rstrip('\n') for text in predictions])
    print("=" * 30)

Epoch : 1/15
Loss  : 7.348170757293701
Caption    :  ['A tray holding a sandwich and cappuccino, next to the pastry.']
Prediction :  [',-.s, of the few and appuccino. with to a coffee case\nThe\nTheTheTheThe\n\nThe\n\nThe\n\nTheThe\n\n\n\n\n\n\nTheThe\n\n\n\n\n\n\nThe']
Epoch : 2/15
Loss  : 7.897108554840088
Caption    :  ['two zebras are standing together in the woods']
Prediction :  [',isea,,bras, two in, a middle.The']
Epoch : 3/15
Loss  : 6.406453609466553
Caption    :  ['A woman in a hat sitting next to luggage.']
Prediction :  [",,es,,'s the red and on to a,\nThe"]
Epoch : 4/15
Loss  : 6.641895294189453
Caption    :  ['Two Zebras grazing together in a grassy area.']
Prediction :  [',-es,,ekan, on, a fieldland field.\nThe']
Epoch : 5/15
Loss  : 6.686727523803711
Caption    :  ['A man prepares to cross the street at a crosswalk']
Prediction :  [',ing.,, with to be a street. a crosswalk.TheThe']
Epoch : 6/15
Loss  : 6.352148532867432
Caption    :  ['Two Zebras grazing together in a 