In [1]:
import torch
import torch.nn as nn

In [2]:
# Device configuration
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available() # For macOS
    else "cpu"
)

print(f"Using {device}")

Using cuda


In [3]:
class PatchEmbedding(nn.Module):
    """Patch the image (needs to be square) and performs a linear projection of the patchs see : """
    def __init__(self, img_size, patch_size, in_channels=3, embedding_dim=512):
        super().__init__()
        self.img_size = img_size
        self.n_patches = (self.img_size // patch_size) ** 2
        self.proj_layer = nn.Conv2d(in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        """x : [n_batches, in_channels, img_size, img_size]
            output : [n_batches, embedding_dim, n_batches]
        """

        x = self.proj_layer(x) #[n_bathces, embedding_dim, sqrt(n_patches), sqrt(n_pathces)]
        x = x.flatten(2) #[n_batches, embedding_dim, n_patches]

        return x

class EncoderBlock(nn.Module):

    def __init__(self, dim, n_heads, mlp_ratio=4, p_dropout=0.5):
        super(EncoderBlock, self).__init__()

        self.dim = dim
        self.n_heads = n_heads
        self.p_dropout = p_dropout
        self.mlp_ratio = mlp_ratio
        self.norm = nn.LayerNorm(self.latent_size)
        self.attention = nn.MultiheadAttention(self.dim, self.n_heads, dropout=self.p_dropout)
        self.MLP = nn.Sequential(
            nn.Linear(self.dim, self.dim * mlp_ratio),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.dim * mlp_ratio, self.dim),
            nn.Dropout(self.dropout)
        )

    def forward(self, x):
        """
        x : [n_samples, n_patches + 1, embedding_dim]
        output : [n_samples, n_patches + 1, embedding_dim]
        """
        first_norm = self.norm(x)
        attention_out = self.attention(first_norm, first_norm, first_norm)
        first_added = attention_out + x
        second_norm = self.norm(first_added)
        mlp_out = self.MLP(second_norm)
        output = mlp_out + first_added

        return output

class ViT(nn.Module):
    def __init__(self,img_size, patch_size=9, in_channels=3, embedding_dim=512, depth=6, n_heads=6, mlp_ratio=4, p_dropout=0.5):
        super().__init__()

        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embedding_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embedding_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embedding.n_patches, embedding_dim))

        
        self.encoder_blocks = nn.ModuleList([ EncoderBlock(embedding_dim, n_heads, mlp_ratio, p_dropout) for _ in range(depth)])

    
    def forward(self, x):
        """
        x : [n_samples, in_channels, img_size, img_size]
        output : [n_samples, 1, embedding_dim]
        """
        n_samples = x.shape[0]
        x = self.patch_embedding(x)

        cls_token = self.cls_token.expand(n_samples, -1, -1) #[n_samples, 1, embedding_dim]
        x = torch.cat((cls_token, x), dim=1) #[n_samples, 1 + n_pathces, embedding_dim]

        x = x + self.pos_embed #[n_samples, 1 + n_pathces, embedding_dim]

        for enc_block in self.encoder_blocks:
            x = enc_block(x)

        return x #[n_samples, 1 + n_patches, embedding_dim] Only extract the token embedding

In [4]:
!pip install transformers



In [5]:
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor

image_encoder_model = "google/vit-base-patch16-224-in21k"
text_decoder_model = "gpt2"


model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(image_encoder_model, text_decoder_model)


#Get feature extractor from yhe image encoder model
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)

#Get tokenizer from the text decoder model
tokenizer = AutoTokenizer.from_pretrained(text_decoder_model)
tokenizer.pad_token = tokenizer.eos_token


model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.5.ln_cross_attn.bias', 'h.2.crossattention.q_attn.weight', 'h.5.crossattention.c_proj.bias', 'h.5.crossattention.q_attn.bias', 'h.9.crossattention.c_attn.bias', 'h.6.crossattention.c_attn.bias', 'h.8.crossattention.c_proj.weight', 'h.6.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.bias', 'h.10.ln_cross_attn.bias', 'h.8.ln_cross_attn.weight', 'h.1.crossattention.c_attn.weight', 'h.4.crossattention.c_attn.weight', 'h.6.crossattention.q_attn.weight', 'h.11.ln_cross_attn.weight', 'h.10.crossattention.q_attn.weight', 'h.11.crossattention.c_proj.bias', 'h.3.crossattention.c_attn.bias', 'h.11.crossattention.c_proj.weight', 'h.8.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.9.crossattention.c_proj.weight', 'h.3.crossattention.c_proj.bias', 'h.4.crossattention.c_proj.bias', 'h.11.ln_cross_attn.bias', 'h.5.crossattention.c_attn.bias', 'h.8.crossat

In [6]:
out_dir = "model"
model.save_pretrained(out_dir)
feature_extractor.save_pretrained(out_dir)
tokenizer.save_pretrained(out_dir)



('model/tokenizer_config.json',
 'model/special_tokens_map.json',
 'model/vocab.json',
 'model/merges.txt',
 'model/added_tokens.json',
 'model/tokenizer.json')

In [7]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from PIL import Image

In [8]:
directory = 'Images'

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [9]:
# Load captions from the text file
with open(os.path.join('./', 'captions.txt'), 'r') as f:
    next(f)
    captions_doc = f.read()

In [10]:
# Create mapping of image to captions
mapping = {}
img_ids = []

for line in tqdm(captions_doc.split('\n')):
    tokens = line.split(',')
    if len(line) < 2:
        continue
    image_id, caption = tokens[0], tokens[1:]
    image_id = image_id.split('.')[0]
    caption = " ".join(caption)
    if image_id not in mapping:
        mapping[image_id] = []
    mapping[image_id].append(caption)
    img_ids.append(image_id)

  0%|          | 0/40456 [00:00<?, ?it/s]

In [11]:
# text preprocessing step
def tokenization_caption(img_id, max_caption_length):
    """Run tokenization on captions."""
    labels = tokenizer(mapping[img_id],
                      truncation=True,
                      padding="max_length", 
                      max_length=max_caption_length).input_ids

    return labels

def feature_extraction(img_id):
    image = Image.open(img_id + 'jpg')

    encoder_input = feature_extractor(images=transform(image), return_tensor="np")

    return encoder_input.pixel_values


def get_model_input(img_id, max_caption_length):

    model_input = {}

    model_input['labels'] = tokenization_caption(img_id, max_caption_length)

    model_input['pixel_values'] = feature_extraction(img_id)

    return model_input

In [12]:
class CaptionDataset(Dataset):
    def __init__(self, data_keys, max_caption):
        self.data_keys = data_keys
        self.max_caption = max_caption

    def __len__(self):
        return len(self.data_keys)
    
    def __getitem__(self, index):
        key = self.data_keys[index]

        features = feature_extraction(key)
        labels = tokenization_caption(key, self.max_caption)

        return features, labels

In [13]:
max_length = 35
split = int(len(img_ids) * 0.75)
train_ids = img_ids[:split]
test_ids = img_ids[split:]

In [14]:
batch_size = 32
train_dataset = CaptionDataset(train_ids, max_length)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

test_dataset = CaptionDataset(test_ids, max_length)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [15]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    output_dir="./image-captioning-output",
)
