# Imports

In [1]:
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from glob import glob

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
from torch.utils.data import DataLoader


import cv2
from pathlib import Path
import random
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
captions_file = 'captions.txt'
images_folder = 'Images'

Plot parameters

In [3]:
plt.rcParams['font.size'] = 12

# Image Captioning project

Objective - Generate natural language descriptions for images

Type of models used:

    - ResNet/Alexnet/VGG16 for extracting the image features

    - ... for Text Encoding
    
    - LSTMs for text generation

Data: Flickr 8k Dataset

# Exploring the data

In [4]:
def extract_captions(image_name, captions_text):
    captions = []
    for line in captions_text.splitlines():
        parts = line.split(',')
        if len(parts) == 2 and parts[0] == image_name:
            captions.append(parts[1])

    return captions

def build_captions_dict(captions_text):
    captions_dict = {}
    for line in captions_text.splitlines():
        parts = line.split(',')
        if len(parts) == 2:
            if parts[0] not in captions_dict:
                captions_dict[parts[0]] = []
            captions_dict[parts[0]].append(parts[1])
    return captions_dict

In [8]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_images(folder_path, captions_dict, num_images_to_display=6, num_images_per_row=3, max_caption_length=40):
    file_names = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')]

    np.random.shuffle(file_names) 
    file_names = file_names[:num_images_to_display]  

    num_images = len(file_names)
    num_rows = (num_images + num_images_per_row - 1) // num_images_per_row

    fig, axes = plt.subplots(num_rows, num_images_per_row, figsize=(15, num_rows * 5))

    # If only one image is being displayed, axes won't be an array. We fix this here.
    if num_images == 1:
        axes = np.array([axes])
        
    axes = axes.flatten()  # Flatten the axes array for easier handling

    for i, file_name in enumerate(file_names):
        img_path = os.path.join(folder_path, file_name)
        img = plt.imread(img_path)
        ax = axes[i]
        ax.imshow(img)
        ax.axis('off')

        caption = captions_dict.get(file_name, ["No caption"])
        if len(caption[0]) > max_caption_length:
            caption_lines = [caption[0][0:max_caption_length], caption[0][max_caption_length:]]
            ax.text(0.5, -0.1, caption_lines[0] + '\n' + caption_lines[1], transform=ax.transAxes, fontsize=10, ha='center')
        else:
            ax.text(0.5, -0.1, caption[0], transform=ax.transAxes, fontsize=10, ha='center')

    for j in range(num_images, num_rows * num_images_per_row):
        fig.delaxes(axes[j])  # Remove the extra sub-plots

    plt.tight_layout()
    plt.show()


In [9]:
with open(captions_file, "r") as file:
    captions_text = file.read()
captions_dict = build_captions_dict(captions_text)

In [10]:
visualize_images(folder_path=images_folder, captions_dict=captions_dict, num_images_to_display=9, num_images_per_row=3)

# The notebook will consist of 3 main sections, similarly to the functionality of the code.

# ----------------------------------------------------------------
# 1. Feature extraction from images

We will start with applying some transformations to the input images.

Precomputed the Mean and Standard Deviation:

Mean: [0.45802852 0.4460975  0.40391668]

Standard Deviation: [0.24219123 0.2332004  0.23719894]

In [None]:
mean = [0.45802852, 0.4460975, 0.40391668]
std = [0.24219123, 0.2332004, 0.23719894]

preprocess = transforms.Compose([
    transforms.Resize((224, 224)), # Resize the image: ResNet model  - > (224,224,3)
    transforms.ToTensor(), # Img to Python Tensor
    transforms.Normalize(mean=mean, std=std), # image = (image - mean) / std
])

Visualising the applied transformation

In [None]:
file_names = [f for f in os.listdir(images_folder) if f.endswith('.jpg')]
random_file = random.choice(file_names) # Pick one image at random

image_path = os.path.join(images_folder, random_file)
original_image = Image.open(image_path).convert('RGB')

normalized_image = preprocess(original_image)
normalized_image = torch.clamp(normalized_image, 0, 1)

In [None]:
fig, axs = plt.subplots(1, 2)

# Original Image
axs[0].imshow(original_image)
axs[0].set_title("Original Image")

# Preprocessed Image
axs[1].imshow(normalized_image.permute(1, 2, 0))
axs[1].set_title("Processed Image")

for ax in axs:
    ax.axis('off')

plt.tight_layout()
plt.show()

### Extracting the features from the augmented images

Using pretrained weights for now seems like a good idea. Especially considering the fact that the weights are from ImageNet which is a large and diverse dataset that contains a wide variety of images.

In [None]:
# Load the feature extraction model

"""
"pretrained=True" argument - > Pre-trained weights for ResNet-18. 
 By default - > Weights trained on the ImageNet dataset.

"""
def load_model(model_name='resnet18'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    try:
        model_func = getattr(models, model_name)
        model = model_func(pretrained=True)
    except KeyError:
        raise ValueError(f'{model_name} is not a valid model name.')
    
    model = model.to(device)
    model.eval()
    
    return model

model = load_model('resnet18')

In [None]:
def extract_image_features(images_path, model_name, batch_size=64):
    device = torch.device("cuda" if torch.cuda.is_available() else "mps")

    image_files = os.listdir(images_path)
    dataloader = DataLoader(image_files, batch_size=batch_size, shuffle=False)

    features = {}
    for batch in tqdm(dataloader, total=len(image_files)//batch_size):
        batch_images = []
        for img_name in batch:
            if img_name.endswith('.jpg'):
                try:
                    img = Image.open(os.path.join(images_path, img_name)).convert('RGB')
                    img = preprocess(img).unsqueeze(0).to(device)
                    batch_images.append(img)
                except Exception as e:
                    print(f"Skipping image {img_name} due to error: {e}")
                    continue
        batch_images = torch.cat(batch_images, 0)
        
        with torch.no_grad():
            model = load_model(model_name)
            feature = model(batch_images)
        
        for img_name, feature in zip(batch, feature.cpu()):
            features[img_name] = torch.flatten(feature).numpy()
            
    return features

In [None]:
features = extract_image_features(images_folder, model_name = 'resnet18')
