In [2]:
import os
import torch
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from torchvision.models import *
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from PIL import Image
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
torch.manual_seed(123)

<torch._C.Generator at 0x116742e90>

# Объединяем две модальности: ViT и GPT-2

## Вспомогательные функции и классы

In [16]:
def preprocess_image(image_path):
    transform = Compose([
        Resize((224, 224)),
        ToTensor(),
        Normalize(
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225),
        )
    ])
    img = Image.open(image_path).convert('RGB')
    img = transform(img)
    return img

In [17]:
def train(model, dataloader, optimizer, loss_fn, num_epochs=5, device="mps"):
    model.vit.to(device)
    model.gpt2.to(device)
    model.adapter.to(device)
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0
        tqdm_iterator = tqdm(dataloader)
        tqdm_iterator.set_description(f"Epoch {epoch+1}/{num_epochs} ")
        for batch in tqdm_iterator:
            images, captions = batch
            images = images.to(device)
            captions = captions.to(device)

            outputs = model(images, captions)
            # убираем последний токен из предсказаний
            logits = outputs.logits[:, :- 1, :]

            # Compute loss
            loss = loss_fn(
                logits.contiguous().view(-1, logits.size(-1)),
                captions.contiguous().view(-1)
            )

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            tqdm_iterator.set_postfix(loss=loss.item())

In [18]:
class Flickr8kDataset(Dataset):
    def __init__(
            self,
            tokenizer,
            data_dir="./flickr8k",
            captions_file="captions.txt",
            max_len=50,
            ):

        self.data_dir = data_dir
        self.tokenizer = tokenizer
        self.image_dir = os.path.join(data_dir, "Images")
        self.max_len = max_len

        self.images = []
        self.captions = []
        captions_file = pd.read_csv(os.path.join(data_dir, captions_file))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_name, caption = self.images[idx], self.captions[idx]
        image_path = os.path.join(self.image_dir, image_name)
        # Load and preprocess image
        image = preprocess_image(image_path)
        input_ids = self.tokenizer(
            caption,
            return_tensors='pt',
            padding="max_length",
            truncation=True,
            max_length=self.max_len
            ).input_ids.squeeze(0)
        return image, input_ids


In [19]:
image_path = './exotic-shorthair.jpg'
image = preprocess_image(image_path).unsqueeze(0)
text = 'Describe this image: '

## Кодировщик и декодировщик

Попробуем просто объединить ViT и GPT-2.

In [20]:
class VLM(nn.Module):
    def __init__(self):
        super().__init__()
        # определим vit
        self.vit = vit_b_32(weights=ViT_B_32_Weights.IMAGENET1K_V1)
        # уберем голову классификации для получения эмбеддингов
        self.vit.heads = nn.Identity()

        # определим токенизатор
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # определим ллм
        self.gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
        self.max_new = 50

        # определим адаптер
        self.adapter = nn.Identity()

    def forward(self, image, input_ids):
        # получим эмбеддинги изображения
        visual_embeddings = self.vit(image) # [batch, vit.embed_dim]
        # пропустим через адаптер
        visual_embeddings = self.adapter(visual_embeddings)  # [batch, gpt.embed_dim]
        # токены текста преобразуем их в эмбеддинги
        text_embeddings = self.gpt2.transformer.wte(input_ids) 
        combined_embeddings = torch.cat([
            visual_embeddings.unsqueeze(1),  # [batch, 1, emb_dim]
            text_embeddings,  # [batch, seq_len, emb_dim]
            ], dim=1)

        # маска внимания необходима для генерации выхода gpt
        attn_mask = torch.cat([
            torch.ones((image.shape[0], 1), device=image.device), # и не забываем про изображение
            (input_ids != self.tokenizer.pad_token_id).long(),  # обращаем внимание на токены слов
        ], dim=1)

        # получим выход
        outputs = self.gpt2(
            inputs_embeds=combined_embeddings,
            attention_mask=attn_mask,
        )
        return outputs

    def generate_caption(self, image, text):
        # получим эмбеддинги изображения
        visual_embeddings = self.vit(image) # [batch, vit.embed_dim]
        # пропустим через адаптер
        visual_embeddings = self.adapter(visual_embeddings)  # [batch, gpt.embed_dim]

        # получим токены текста и преобразуем их в эмбеддинги
        input_ids = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True).input_ids
        text_embeddings = self.gpt2.transformer.wte(input_ids) 
        combined_embeddings = torch.cat([
            visual_embeddings.unsqueeze(1),  # [batch, 1, emb_dim]
            text_embeddings,  # [batch, seq_len, emb_dim]
            ], dim=1)

        # маска внимания необходима для генерации выхода gpt
        attn_mask = torch.cat([
            torch.ones((image.shape[0], 1), device=image.device), # и не забываем про изображение
            (input_ids != self.tokenizer.pad_token_id).long(),  # обращаем внимание на токены слов
        ], dim=1)

        # получим выход
        outputs = self.gpt2.generate(
            inputs_embeds=combined_embeddings,
            attention_mask=attn_mask,
            max_new_tokens=self.max_new,
            pad_token_id=self.tokenizer.eos_token_id
        )
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    def freeze_vit(self):
        for param in self.vit.parameters():
            param.requires_grad = False
    
    def freeze_gpt(self):
        for param in self.gpt2.parameters():
            param.requires_grad = False


In [21]:
vlm = VLM()

In [22]:
caption = vlm.generate_caption(image, text)
print("Generated caption:", caption)

Generated caption:  "The first time I saw the first of the new, new, new, new, new, new, new, new, new, new, new, new, new, new, new, new, new, new, new, new


## Кодировщик, адаптер и декодировщик

### Линейный адаптер

Попробуем "подружить" ViT и GPT-2. Для этого попробуем дообучить дополнительный слой, который будет проецировать эмбеддинг изображения в пространство, понятное языковой модели.

In [23]:
vlm_lin = VLM()

vlm_lin.adapter = nn.Sequential(
    nn.Linear(vlm_lin.vit.hidden_dim, vlm_lin.vit.hidden_dim),
    nn.LeakyReLU(),
    nn.Linear(vlm_lin.vit.hidden_dim, vlm_lin.vit.hidden_dim),
    nn.LeakyReLU(),
    nn.Linear(vlm_lin.vit.hidden_dim, vlm_lin.gpt2.config.n_embd),
)

In [24]:
caption = vlm_lin.generate_caption(image, text)
print("Generated caption:", caption)

Generated caption:  "I was in the middle of a long day in the middle of the night when I saw a man in a white robe walking down the street. I was so shocked that I thought he was a man. I thought he was a man.


In [None]:
dataset = Flickr8kDataset(
    tokenizer=vlm_lin.tokenizer,
    )
optimizer = torch.optim.Adam(vlm_lin.adapter.parameters(), lr=1e-4)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
loss = nn.CrossEntropyLoss(ignore_index=vlm_lin.tokenizer.pad_token_id)

In [12]:
vlm_lin.freeze_gpt()
vlm_lin.freeze_vit()

train(
    vlm_lin,
    dataloader,
    optimizer,
    loss,
    num_epochs=5,
    device="cuda",
)

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Epoch 1/5 : 100%|██████████| 5057/5057 [07:29<00:00, 11.25it/s, loss=5.91]
Epoch 2/5 : 100%|██████████| 5057/5057 [07:36<00:00, 11.09it/s, loss=5.15]
Epoch 3/5 : 100%|██████████| 5057/5057 [07:40<00:00, 10.98it/s, loss=5.69]
Epoch 4/5 : 100%|██████████| 5057/5057 [07:40<00:00, 10.98it/s, loss=4.85]
Epoch 5/5 : 100%|██████████| 5057/5057 [07:43<00:00, 10.91it/s, loss=5.31]


In [None]:
vlm_lin.load_state_dict(torch.load('./vlm_lin.pt', map_location='cpu'))

  vlm_lin.load_state_dict(torch.load('./vlm_lin.pt'))


<All keys matched successfully>

In [34]:
vlm_lin.to("cpu")
caption = vlm_lin.generate_caption(image, text)
print("Generated caption:", caption)

Generated caption:  a man on a bicycle . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .


### Адаптер с Attention

Давайте немного усложним адаптер, добавив туда Attention. По сути, сделаем мини-блок трансформера.

In [35]:
class AttentionAdapter(nn.Module):
    def __init__(self, vit_dim, gpt_dim, num_heads, dropout=0.1):
        super().__init__()

        self.proj = nn.Linear(vit_dim, gpt_dim)
        # в торче уже есть написанный за вас mha
        self.cross_attention = nn.MultiheadAttention(embed_dim=gpt_dim, num_heads=num_heads, dropout=dropout)
        
        self.ffn = nn.Sequential(
            nn.Linear(gpt_dim, gpt_dim * 4),
            nn.ReLU(),
            nn.Linear(gpt_dim * 4, gpt_dim)
        )

        self.norm1 = nn.LayerNorm(gpt_dim)
        self.norm2 = nn.LayerNorm(gpt_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, visual_embeddings, text_embeddings, attention_mask=None):
        visual_embeddings = self.proj(visual_embeddings)

        # переставляем местами размерности для подачи в слой внимания
        visual_embeddings = visual_embeddings.permute(1, 0, 2)  # [1, batch_size, gpt_dim]
        text_embeddings = text_embeddings.permute(1, 0, 2)  # [seq_len, batch_size, gpt_dim]

        # используем cross-attention: в качестве query подаем эмбеддинги текста,
        # в качестве key и value подаем эмбеддинги изображения
        attn_output, _ = self.cross_attention(
            query=text_embeddings,  # [seq_len, batch_size, gpt_dim]
            key=visual_embeddings,    # [1, batch_size, gpt_dim]
            value=visual_embeddings,  # [1, batch_size, gpt_dim]
            key_padding_mask=None
        )
        text_embeddings = self.norm1(text_embeddings + self.dropout(attn_output))
        ffn_output = self.ffn(text_embeddings)
        fused_embeddings = self.norm2(text_embeddings + self.dropout(ffn_output))

        return fused_embeddings.permute(1, 0, 2)  # [batch_size, seq_len, gpt_dim]


In [None]:
class VLMAttn(nn.Module):
    def __init__(self, num_heads=8, dropout=0.1):
        super().__init__()
        self.vit = vit_b_32(weights=ViT_B_32_Weights.IMAGENET1K_V1)
        self.vit.heads = nn.Identity()

        self.gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
        self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.adapter = AttentionAdapter(
            vit_dim=self.vit.hidden_dim,
            gpt_dim=self.gpt2.config.n_embd,
            num_heads=num_heads,
            dropout=dropout
            )

    def forward(self, images, input_ids):
        visual_embeddings = self.vit(images).unsqueeze(1)  
        text_embeddings = self.gpt2.transformer.wte(input_ids) 

        fused_embeddings = self.adapter(visual_embeddings, text_embeddings)

        outputs = self.gpt2(inputs_embeds=fused_embeddings)
        return outputs
    
    def generate_caption(self, images, text, max_new_tokens=50, device="cpu"):

        visual_features = self.vit(images).unsqueeze(1)

        input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to(device) 
        text_embeddings = self.gpt2.transformer.wte(input_ids)
        fused_embeddings = self.adapter(visual_features, text_embeddings)

        attention_mask = torch.ones(fused_embeddings.size()[:-1], dtype=torch.long, device=device)
        outputs = self.gpt2.generate(
            inputs_embeds=fused_embeddings,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            pad_token_id=self.tokenizer.eos_token_id,
        )

        generated_caption = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return generated_caption
        
    def freeze_vit(self):
        for param in self.vit.parameters():
            param.requires_grad = False
    
    def freeze_gpt(self):
        for param in self.gpt2.parameters():
            param.requires_grad = False

In [37]:
vlm_attn = VLMAttn()

In [39]:
caption = vlm_attn.generate_caption(image, text)
print("Generated caption:", caption)

Generated caption: 







"

"
"

"


"


"



"



"


"


"


"





In [32]:
dataset = Flickr8kDataset(
    tokenizer=vlm_attn.tokenizer,
    )
optimizer = torch.optim.Adam(vlm_attn.adapter.parameters(), lr=1e-4)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
loss = nn.CrossEntropyLoss(ignore_index=vlm_attn.tokenizer.pad_token_id)

In [33]:
vlm_attn.freeze_gpt()
vlm_attn.freeze_vit()

train(
    vlm_attn,
    dataloader,
    optimizer,
    loss,
    num_epochs=5,
    device="cuda",
)

Epoch 1/5 : 100%|██████████| 5057/5057 [07:48<00:00, 10.79it/s, loss=2.89]
Epoch 2/5 : 100%|██████████| 5057/5057 [07:45<00:00, 10.87it/s, loss=3.54]
Epoch 3/5 : 100%|██████████| 5057/5057 [07:44<00:00, 10.90it/s, loss=3.18]
Epoch 4/5 : 100%|██████████| 5057/5057 [07:18<00:00, 11.53it/s, loss=3.4] 
Epoch 5/5 : 100%|██████████| 5057/5057 [07:21<00:00, 11.45it/s, loss=2.85]


In [41]:
vlm_attn.load_state_dict(torch.load('./vlm_attn.pt', map_location='cpu'))

  vlm_attn.load_state_dict(torch.load('./vlm_attn.pt', map_location='cpu'))


<All keys matched successfully>

In [43]:
vlm_attn.to("cpu")
caption = vlm_attn.generate_caption(image, text)
print("Generated caption:", caption)

Generated caption:  with a large number of people in the area.
The person who was in the area was taken to the hospital.
The person who was in the area was taken to the hospital.
The person who was in the area was taken to


## Обученные архитектуры

Посмотрим, что генерирует уже объединенные ViT и GPT-2 Воспользуемся оберткой `VisionEncoderDecoderModel` для ViT + GPT-2. Здесь авторы используют `Cross Attention`.

In [46]:
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "architectures": [
    "ViTModel"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.46.3"
}

Config of the decoder: <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'> is overwritten by shared decoder config: GPT2Config {
  "activation_function": "gelu_new",
  "add_cross_attention": true,
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "decoder_start_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_rang

In [47]:
def preprocess_image_tr(image_path):
    img = Image.open(image_path).convert('RGB')
    pixel_values = feature_extractor(images=[img], return_tensors="pt").pixel_values
    return pixel_values

In [50]:
def generate_caption(pixel_values):
    with torch.no_grad():
        generated_ids = model.generate(pixel_values, max_length=50)
        generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_text

In [None]:
pixel_values = preprocess_image_tr(image_path)
caption = generate_caption(pixel_values)
print("Generated caption:", caption)


Generated caption: a cat sitting on a blanket on a carpet 
