# Neural Image Caption
Original paper: https://arxiv.org/pdf/1411.4555.pdf

Not exactly the same but implementing something similar with more up to date tools

Basic idea: Input Image -> CNN -> Image Embedding -> LSTM -> Image Caption

In [97]:
from transformers import MobileViTImageProcessor, MobileViTForImageClassification, AutoModel, AutoConfig
from sentence_transformers import SentenceTransformer
import torch
from torchvision.transforms import v2
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from PIL import Image
import requests
import matplotlib.pyplot as plt
%matplotlib inline

# Fixes matplotlib crashing jupyter kernel issue (should find out what this actually does, something to do with OpenMP?)
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [4]:
# Read images into array

# Image paths
image_paths = ["cat1.png", "cat2.jpg", "river.jpg"]

# Open the images
images = [Image.open(path).resize((256,256)) for path in image_paths]

# Convert the PIL images to NumPy arrays
image_arrays = [np.array(image) for image in images]

# Combine the arrays into a single array (stack vertically)
combined_array = np.stack(image_arrays, axis=0)    
print(combined_array.shape)

(3, 256, 256, 3)


In [121]:
# Some unrelated notes on hugging face that's good to remember
#
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# image = Image.open(requests.get(url, stream=True).raw)
#
# MobileViTForImageClassification vs. AutoModel
# For our use case we need to obtain image embeddings so using AutoModel makes more sense as it outputs the dense 
# representations of the images and not the logits, which are what MobileViTForImageClassification would have provided
#
# This would be for actual image classification:
#
# model = MobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small")
# outputs = model(**inputs, output_hidden_states=True)
# logits = outputs.logits
#
# # model predicts one of the 1000 ImageNet classes
# predicted_class_idx = logits.argmax(-1).item()
# print("Predicted class:", model.config.id2label[predicted_class_idx])

# Pulling pre-trained image classification model from hugging face
feature_extractor = MobileViTImageProcessor.from_pretrained("apple/mobilevit-xx-small")
model = AutoModel.from_pretrained("apple/mobilevit-xx-small")

# Preprocess batch of images
# Images can be PIL image, numpy array, or torch tensor individually or in list form
inputs = feature_extractor(images=combined_array, return_tensors="pt")
outputs = model(**inputs)
out = outputs.pooler_output # backup plan: last_hidden_state for embeddings instead
print(out.shape)

torch.Size([3, 320])


In [6]:
# Testing embeddings
with torch.no_grad():
    sim1 = torch.nn.functional.cosine_similarity(out[0], out[2], dim=0)
    sim2 = torch.nn.functional.cosine_similarity(out[0], out[1], dim=0)
    print("Cat1 vs. River similarity:", sim1)
    print("Cat1 vs. Cat2 similarity:", sim2)

Cat1 vs. River similarity: tensor(-0.0756)
Cat1 vs. Cat2 similarity: tensor(0.5102)


In [102]:
class ImageEncoder(torch.nn.Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        # Load pre-trained model
        self.model = AutoModel.from_pretrained("apple/mobilevit-xx-small")
        
        # Freeze layers of pre-trained model
        for param in self.model.parameters():
            param.requires_grad = False
        
        # 320 is the size the pooler_output that I'm using as the last layer of the pre-trained.
        # Passed through this liner layer to get to the same size as the word embeddings so I can pass it to the LSTM.
        # Future: Maybe look into using different layer from pre-trained model as the output
        self.linear = torch.nn.Linear(320, 384)
        
    def forward(self, images):
        output = self.model(**images)
        
        # Pull image embeddings
        embeddings = output.pooler_output
        
        # Pass image embeddings through linear layer to reach desired size of LSTM input
        image_embeddings = self.linear(embeddings)
        
        return image_embeddings


In [103]:
# ImageEncoder test
encoder = ImageEncoder()
encoder.train()
features = encoder(inputs)
print(features.shape)

torch.Size([3, 384])


In [39]:
# https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/build_vocab.py
import nltk
import pickle
from collections import Counter

class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

def build_vocab(dataframe, threshold=4):
    """Build a simple vocabulary wrapper."""
    counter = Counter()

    for caption in dataframe['caption']:
        tokens = nltk.tokenize.word_tokenize(caption)
        counter.update(tokens)

    # If the word frequency is less than 'threshold', then the word is discarded.
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Add the words to the vocabulary.
    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab


vocab = build_vocab(df, threshold=4)
vocab_path = "./vocab.pkl"
with open(vocab_path, 'wb') as f:
    pickle.dump(vocab, f)
print("Total vocabulary size: {}".format(len(vocab)))
print("Saved the vocabulary wrapper to '{}'".format(vocab_path))

Total vocabulary size: 3435
Saved the vocabulary wrapper to './vocab.pkl'


In [109]:
# Create dataset class
class FlickrDataset(Dataset):
    
    def __init__(self, image_dir, dataframe, vocab, transform):
        
        self.image_dir = image_dir
        self.dataframe = dataframe
        self.vocab = vocab
        self.transform = transform
    
    def __getitem__(self, idx):
        
        row = self.dataframe.iloc[idx]
        image_file = row.iloc[0]
        caption = row.iloc[1]
        
        # Convert caption (string) to word ids
        tokens = nltk.tokenize.word_tokenize(caption)
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        
        image = Image.open(os.path.join(self.image_dir, image_file)).convert('RGB')
        image = self.transform(image)
            
        return image, target
    
    def viewImage(self, idx):
    
        row = self.dataframe.iloc[idx]
        image_file = row.iloc[0]
        caption = row.iloc[1]
        image = Image.open(os.path.join(self.image_dir, image_file)).convert('RGB')
            
        return image, caption
    
    def __len__(self):
        return len(self.dataframe)
    

In [112]:
# FlickrDataset test
df = pd.read_csv("captions.txt")
df['caption'] = df['caption'].str.lower()
df['caption'] = df['caption'].str.replace(r"[^a-zA-Z0-9-' ]", '', regex=True)

transform = v2.Compose([ 
    v2.PILToTensor() 
])

# Define embedding model
sent_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-V2')

with open("./vocab.pkl", 'rb') as f:
    vocab = pickle.load(f)

data = FlickrDataset("./images", df, vocab, transform)

In [114]:
im, cap = data[0]
print(cap)

tensor([ 1.,  4.,  5.,  6.,  4.,  7.,  8.,  9., 10., 11.,  4., 12., 13., 14.,
         6., 15.,  3., 16.,  2.])


In [116]:
%%time
# Generate embeddings for vocab
sent_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-V2')
embeddings_lst = []

def embedding(texts):
    embs = sent_model.encode(texts)
    return embs

for i in range(len(vocab)):
    word = vocab.idx2word[i]
    embeddings_lst.append(embedding(word))
    if i % 500 == 0:
        print(i)

0
500
1000
1500
2000
2500
3000
CPU times: total: 1min 30s
Wall time: 45.7 s


In [118]:
np_arr = np.array(embeddings_lst)
final_embeddings = torch.Tensor(np_arr)
print(final_embeddings.shape)

torch.Size([3435, 384])


In [120]:
torch.save(final_embeddings, "embeddings_table.pt")

In [None]:
class Decoder(torch.nn.Module):
    def __init__(self, word_embeddings):
        super(Decoder, self).__init__()
        # Load pre-trained embeddings
        self.embedding = torch.nn.Embedding.from_pretrained(word_embeddings, freeze=True) # Future: Try unfreezing?
        
        
    def forward(self, images):
        