# Neural Image Caption
Original paper: https://arxiv.org/pdf/1411.4555.pdf

Not exactly the same but implementing something similar with more up to date tools

Basic idea: Input Image -> CNN -> Image Embedding -> LSTM -> Image Caption

In [62]:
from transformers import MobileViTImageProcessor, MobileViTForImageClassification, AutoModel
import torch
from torchvision.transforms import v2
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from PIL import Image
import requests
import matplotlib.pyplot as plt
%matplotlib inline

# Fixes matplotlib crashing jupyter kernel issue (should find out what this actually does, something to do with OpenMP?)
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"



In [2]:
# Read images into array

# Image paths
image_paths = ["cat1.png", "cat2.jpg", "river.jpg"]

# Open the images
images = [Image.open(path).resize((256,256)) for path in image_paths]

# Convert the PIL images to NumPy arrays
image_arrays = [np.array(image) for image in images]

# Combine the arrays into a single array (stack vertically)
combined_array = np.stack(image_arrays, axis=0)    
print(combined_array.shape)

(3, 256, 256, 3)


In [3]:
# Some unrelated notes on hugging face that's good to remember
#
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# image = Image.open(requests.get(url, stream=True).raw)
#
# MobileViTForImageClassification vs. AutoModel
# For our use case we need to obtain image embeddings so using AutoModel makes more sense as it outputs the dense 
# representations of the images and not the logits, which are what MobileViTForImageClassification would have provided
#
# This would be for actual image classification:
#
# model = MobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small")
# outputs = model(**inputs, output_hidden_states=True)
# logits = outputs.logits
#
# # model predicts one of the 1000 ImageNet classes
# predicted_class_idx = logits.argmax(-1).item()
# print("Predicted class:", model.config.id2label[predicted_class_idx])

# Pulling pre-trained image classification model from hugging face
feature_extractor = MobileViTImageProcessor.from_pretrained("apple/mobilevit-xx-small")
model = AutoModel.from_pretrained("apple/mobilevit-xx-small")

# Preprocess batch of images
# Images can be PIL image, numpy array, or torch tensor individually or in list form
inputs = feature_extractor(images=combined_array, return_tensors="pt")
outputs = model(**inputs)
out = outputs.pooler_output # backup plan: last_hidden_state for embeddings instead
print(out.shape)

torch.Size([3, 320])


In [6]:
# Testing embeddings
with torch.no_grad():
    sim1 = torch.nn.functional.cosine_similarity(out[0], out[2], dim=0)
    sim2 = torch.nn.functional.cosine_similarity(out[0], out[1], dim=0)
    print("Cat1 vs. River similarity:", sim1)
    print("Cat1 vs. Cat2 similarity:", sim2)

Cat1 vs. River similarity: tensor(-0.0756)
Cat1 vs. Cat2 similarity: tensor(0.5102)


In [17]:
class ImageEncoder(torch.nn.Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        self.model = AutoModel.from_pretrained("apple/mobilevit-xx-small")
    
    def forward(self, images):
        output = self.model(**images)
        
        # Pull embeddings
        embeddings = output.pooler_output
        
        return embeddings
    
encoder = ImageEncoder()
encoder.train()
features = encoder(inputs)
print(features.shape)

In [43]:
### Preparing input data ###
df = pd.read_csv("captions.txt")
df['caption'] = df['caption'].str.lower()
df['caption'] = df['caption'].str.replace(r"[^a-zA-Z0-9-' ]", '', regex=True)

In [18]:
# Create dataset class
class FlickrDataset(Dataset):
    
    def __init__(self, image_dir, dataframe, transform):
        
        self.image_dir = image_dir
        self.dataframe = dataframe
        self.transform = transform
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image_file = row[0]
        caption = row[1]
        image = Image.open(os.path.join(self.image_dir, image_file)).convert('RGB')
        image = self.transform(image)
            
        return image, caption
    
    def __len__(self):
        return len(self.dataframe)

torch.Size([3, 320])


In [64]:
# For later
transform = v2.Compose([ 
    v2.PILToTensor() 
])