In [14]:
!git clone https://huggingface.co/datasets/uygarkurt/simple-image-captions

fatal: destination path 'simple-image-captions' already exists and is not an empty directory.


In [15]:
!pip install transformers
!pip install sentencepiece
!pip install protobuf



In [16]:
import base64
import io
import pandas as pd
from PIL import Image
import torchvision.transforms as transforms
import torch
import torch.nn as nn
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
from transformers import ViTConfig, ViTModel
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import random
import numpy as np


In [17]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
BATCH_SIZE = 16
N_HIDDEN_LAYERS = 16
MAX_LENGTH = 16
EVAL_INTERVAL = 10
LEARNING_RATE = 9e-4
EPOCHS = 6
N_EMBD = 128
N_HEAD = 8
N_LAYER = 8
DROPOUT = 0.4
IMG_SIZE = 96
PATCH_SIZE = 16
IMAGE_EMBED_DIM = 512
N_CHANNELS = 3
MAX_POSITION_EMBEDDINGS = 128

In [19]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [20]:
image_dir = '/mnt/d/Dysk/REPO/fine_tune_llm/simple-image-captions/'  # Directory containing images

def image_file_to_base64(image_filename):
    image_path = image_dir + image_filename
    with open(image_path, 'rb') as img_file:
        b64_str = base64.b64encode(img_file.read()).decode('utf-8')
    return b64_str

df = pd.read_csv(image_dir + 'inputs.csv', sep=";").dropna(axis=1, how="all")
df['b64string_images'] = df['file'].apply(image_file_to_base64)
df.head()

Unnamed: 0,file,caption,b64string_images
0,car.png,red car,dmVyc2lvbiBodHRwczovL2dpdC1sZnMuZ2l0aHViLmNvbS...
1,astronaut.png,astronaut in a white space suit,dmVyc2lvbiBodHRwczovL2dpdC1sZnMuZ2l0aHViLmNvbS...
2,tv.png,black television on a table,dmVyc2lvbiBodHRwczovL2dpdC1sZnMuZ2l0aHViLmNvbS...
3,horse.png,brown horse running,dmVyc2lvbiBodHRwczovL2dpdC1sZnMuZ2l0aHViLmNvbS...
4,wine.png,wine bottle,dmVyc2lvbiBodHRwczovL2dpdC1sZnMuZ2l0aHViLmNvbS...


In [21]:
config = ViTConfig(
    image_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    num_channels=N_CHANNELS,
    hidden_size=IMAGE_EMBED_DIM,
    num_attention_heads=N_HEAD,
    num_hidden_layers=N_HIDDEN_LAYERS,
    intermediate_size=4 * IMAGE_EMBED_DIM,
    hidden_dropout_prob=DROPOUT,
    attention_probs_dropout_prob=DROPOUT,
)

testvit = ViTModel(config)
vit_input = torch.zeros(BATCH_SIZE, N_CHANNELS, IMG_SIZE, IMG_SIZE)
testvit_out = testvit(vit_input).last_hidden_state[:, 0] # Get the [CLS] token representation
testvit_out.shape # (BATCH_SIZE, IMAGE_EMBED_DIM)

torch.Size([16, 512])

In [22]:
class VisionLanguageModel(nn.Module):
    def __init__(
        self,
        n_embed,
        image_embed_dim,
        vocab_size,
        n_layer,
        n_head,
        img_size,
        patch_size,
        n_hidden_layers,
        dropout,
        pad_token_id,
        max_position_embeddings,
        n_channels,
    ):
        super().__init__()
        vit_config = ViTConfig(
            image_size=img_size,
            patch_size=patch_size,
            num_channels=n_channels,
            hidden_size=image_embed_dim,
            num_attention_heads=n_head,
            num_hidden_layers=n_hidden_layers,
            intermediate_size=4 * image_embed_dim,
            hidden_dropout_prob=dropout,
            attention_probs_dropout_prob=dropout,
        )
        self.vision_encoder = ViTModel(vit_config)
        self.image_projector = nn.Linear(image_embed_dim, n_embed)

        llama_config = LlamaConfig(
            vocab_size=vocab_size,
            hidden_size=n_embed,
            num_hidden_layers=n_layer,
            num_attention_heads=n_head,
            max_position_embeddings=max_position_embeddings,
            pad_token_id=int(pad_token_id),
        )
        self.llama = LlamaForCausalLM(llama_config)
        self.llama = self.llama.to(dtype=torch.bfloat16)  # Move Llama to bfloat16

    def forward(self, img_array, input_ids, targets=None):
        # img_array: [BATCH_SIZE, N_CHANNELS, IMG_SIZE, IMG_SIZE]
        # input_ids: [BATCH_SIZE, MAX_LENGTH]
        image_embeds = self.vision_encoder(img_array).last_hidden_state[:, 0]  # [BATCH_SIZE, IMAGE_EMBED_DIM]
        image_embeds_proj = self.image_projector(image_embeds).to(dtype=torch.bfloat16)  # [BATCH_SIZE, N_EMBED]
        image_embeds_proj = image_embeds_proj.unsqueeze(1) # [BATCH_SIZE, 1, N_EMBED]

        text_embeds = self.llama.model.embed_tokens(input_ids).to(dtype=torch.bfloat16)  # [BATCH_SIZE, MAX_LENGTH, N_EMBED]

        input_embeds = torch.cat([image_embeds_proj, text_embeds], dim=1)  # [BATCH_SIZE, MAX_LENGTH + 1, N_EMBED]

        attention_mask = torch.ones(input_embeds.shape[:2], dtype=torch.long, device=input_embeds.device) # [BATCH_SIZE, MAX_LENGTH + 1]

        if targets is not None:
            #target: [BATCH_SIZE, MAX_LENGTH]
            targets = torch.cat([torch.full((targets.size(0), 1), -100, dtype=targets.dtype, device=targets.device), targets], dim=1) # [BATCH_SIZE, MAX_LENGTH + 1]
            outputs = self.llama(
                inputs_embeds=input_embeds,
                attention_mask=attention_mask,
                labels=targets,
            )
            return outputs.logits, outputs.loss
        else:
            outputs = self.llama(
                inputs_embeds=input_embeds,
                attention_mask=attention_mask,
            )
            return outputs.logits

    @torch.no_grad()
    def generate(self, img_array, input_ids, max_new_tokens=20):
        # img_array: [BATCH_SIZE, N_CHANNELS, IMG_SIZE, IMG_SIZE]
        # input_ids: [BATCH_SIZE, MAX_LENGTH]
        image_embeds = self.vision_encoder(img_array).last_hidden_state[:, 0]
        image_embeds_proj = self.image_projector(image_embeds).unsqueeze(1).to(dtype=torch.bfloat16)

        input_embeds = self.llama.model.embed_tokens(input_ids).to(dtype=torch.bfloat16)
        inputs_embeds = torch.cat([image_embeds_proj, input_embeds], dim=1)
        attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device)
       
        generated = self.llama.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            pad_token_id=self.llama.config.pad_token_id,
            eos_token_id=self.llama.config.eos_token_id,
        )
        return generated

model = VisionLanguageModel(
        N_EMBD,
        IMAGE_EMBED_DIM,
        tokenizer.vocab_size,
        N_LAYER,
        N_HEAD,
        IMG_SIZE,
        PATCH_SIZE,
        N_HIDDEN_LAYERS,
        DROPOUT,
        tokenizer.pad_token_id,
        max_position_embeddings=MAX_POSITION_EMBEDDINGS,
        n_channels=N_CHANNELS,
)
model.to(device)

dummy_img = torch.randn(1, N_CHANNELS, IMG_SIZE, IMG_SIZE).to(device)
dummy_idx = torch.randint(0, tokenizer.vocab_size, (1, MAX_LENGTH)).to(device)
output = model(dummy_img, dummy_idx)
print(output.shape)

NameError: name 'MAX_LENGTH' is not defined