In [None]:
import os
import re
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
import random
import string
from pickle import dump, load
import matplotlib.pyplot as plt
import textwrap
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import (Input, Dense, Reshape, Embedding, Concatenate,
                                     Dropout, LayerNormalization, GlobalAveragePooling1D , Layer, MultiHeadAttention)
from tensorflow.keras.utils import to_categorical, Sequence, pad_sequences,get_file
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.layers import Lambda
import tensorflow_hub as hub
from tensorflow.keras.applications.resnet import ResNet50, preprocess_input
from tensorflow.keras.preprocessing.sequence import pad_sequences
from keras.utils import plot_model
from tensorflow.keras.layers import add
from tensorflow.keras.models import Model, load_model
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from Custom_layer_model import Transformer_decoder, PositionalEmbedding, Masked_Loss, Expand_Dimension

In [None]:
image_path = "flickr30k/flickr30k_images"
caption = pd.read_csv("flickr30k/captions.txt")
vit_url = "https://tfhub.dev/sayakpaul/vit_b16_fe/1"
vit_model = hub.KerasLayer(vit_url, trainable=False, input_shape=(224, 224, 3))

In [None]:
def clean_caption(caption):
    cap = caption.lower().strip()
    if cap.startswith("startseq") and cap.endswith("endseq"):
        mid = cap[len("startseq"):-len("endseq")].strip()
    else:
        mid = cap
    mid = re.sub(r'[^a-z\s]', '', mid)
    mid = re.sub(r'\s+', ' ', mid).strip()
    return f"startseq {mid} endseq"

In [None]:
caption['comment'] = caption['comment'].apply(clean_caption)

## Tokenization

In [None]:
captions_list = caption['comment'].tolist()
tokenizer = Tokenizer()
tokenizer.fit_on_texts(captions_list)
with open("tokenizer.pkl", "wb") as f:
    pickle.dump(tokenizer, f)

In [None]:
tokenizer = load(open("tokenizer.pkl","rb"))

In [None]:
vocab_size = len(tokenizer.word_index) + 1
max_length = min(40,max(len(c.split()) for c in captions_list))

print(f"Vocab size: {vocab_size} \n Max Caption length: {max_length}")

In [None]:
image_ids = caption['image_name'].unique().tolist()
train_ids, test_ids = train_test_split(image_ids, test_size=0.1, random_state=42)
train_ids, val_ids = train_test_split(train_ids, test_size=0.15, random_state=42)

train_df = caption[caption['image_name'].isin(train_ids)].reset_index(drop=True)
val_df   = caption[caption['image_name'].isin(val_ids)].reset_index(drop=True)
test_df  = caption[caption['image_name'].isin(test_ids)].reset_index(drop=True)

##  Feature Extraction

In [None]:
def extract_features(df, image_dir):
    features = {}
    for img_name in tqdm(df['image_name'].unique()):
        img_path = os.path.join(image_dir, img_name)
        img = keras.preprocessing.image.load_img(img_path, target_size=(224,224))
        img = keras.preprocessing.image.img_to_array(img) / 255.0
        img = np.expand_dims(img, 0)
        feat = vit_model(img)  # (1, 768)
        features[img_name] = feat.numpy()[0]  # (768,)
    return features

In [None]:
features = extract_features(
     caption,
     image_dir = image_path
)

dump(features, open("features.pkl", "wb"))


In [None]:
features = load(open("features.pkl","rb"))
for k in features:
    features[k] = np.squeeze(features[k])

## DataGenerator

In [None]:
class CustomDataGenerator(keras.utils.Sequence):
    def __init__(self, df, features, tokenizer, max_length, batch_size=32, shuffle=True,**kwargs):
        super().__init__(**kwargs)
        self.df = df.copy().reset_index(drop=True)
        self.features = features
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(self.df))
        self.on_epoch_end()
    
    def __len__(self):
        return int(np.ceil(len(self.df) / self.batch_size))
    
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indexes)
    
    def __getitem__(self, idx):
        batch_indices = self.indexes[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch = self.df.iloc[batch_indices]
        
        x_img, x_seq, y = [], [], []
        for _, row in batch.iterrows():
            img_feature = self.features[row['image_name']]  
            seq = self.tokenizer.texts_to_sequences([row['comment']])[0]
            seq = seq[:self.max_length]  
            seq_pad = pad_sequences([seq], maxlen=self.max_length, padding='post')[0]
            
            x_img.append(img_feature)
            # caption input
            x_seq.append(seq_pad[:-1])
            # target output
            y.append(seq_pad[1:])      
        
        x_img = np.array(x_img, dtype=np.float32)
        x_seq = np.array(x_seq, dtype=np.int32)
        y     = np.array(y,     dtype=np.int32)
        return (x_img, x_seq), y




In [None]:
train_generator = CustomDataGenerator(train_df, features, tokenizer, max_length, batch_size=32, shuffle=True)
val_generator   = CustomDataGenerator(val_df, features, tokenizer, max_length, batch_size=32, shuffle=False)

## Model

In [None]:
def define_model(
    vocab_size,
    max_length,
    emb_dimension=256,
    ff_dimension=512,
    num_heads=8,
    num_layers=4,
    
):
    img_input = Input(shape=(768,), name="image_features")
    cap_input = Input(shape=(max_length-1,), name="caption_input")
    
    img_emb = Expand_Dimension(name="image_context")(img_input)
    decoder = Transformer_decoder(embed_dim=embed_dimension, ff_dim=ff_dimension, num_heads=num_heads,
                                 vocab_size=vocab_size, max_len=max_length-1,
                                 num_layers=num_layers, rate=0.1)
    
    outputs = decoder(cap_input, img_emb) 
    
    model = Model(inputs=[img_input, cap_input], outputs=outputs)
    model.summary()
    return model
  


In [None]:
model = define_model(
    vocab_size=vocab_size,
    max_length=max_length,
)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=3e-4), loss=Masked_Loss)

In [None]:
plot_model(
    model,
    to_file="caption_model.png",
    show_shapes=True,
    show_layer_names=True,
    expand_nested=True
)


## Training

In [None]:
check_point = ModelCheckpoint("caption_model.keras", monitor="val_loss",
                             save_best_only=True, verbose=1)
early_stopping = EarlyStopping(monitor="val_loss", patience=5,
                          restore_best_weights=True, verbose=1)
reducelr  = ReduceLROnPlateau(monitor="val_loss", factor=0.5,
                              patience=3, verbose=1)
callbacks = [
    check_point,
    reducelr,
    early_stopping

In [None]:
model.fit(
    train_generator,
    epochs=50,
    validation_data=val_generator,
    callbacks=callbacks,
    verbose=1
)


## Testing

In [None]:
model = load_model("caption_model",
                   custom_objects={"Transformer_decoder": Transformer_decoder,
                                   "Positional_Embedding": Positional_Embedding,
                                   "Masked_Loss": Masked_Loss})

In [None]:
def generate_caption(image_path, model, tokenizer, max_length):
    img = keras.preprocessing.image.load_img(image_path, target_size=(224,224))
    img = keras.preprocessing.image.img_to_array(img) / 255.0
    img = np.expand_dims(img, 0)
    feature = vit_model(img)[0].numpy()    
    caption = "startseq"
    for _ in range(max_length):
        seq = tokenizer.texts_to_sequences([caption])[0]
        seq = pad_sequences([seq], maxlen=max_length, padding='post')
        y_pred = model.predict([feature.reshape(1, -1), seq], verbose=0)
        next_index = np.argmax(y_pred[0, len(caption.split())-1])
        next_word = tokenizer.index_word.get(next_index, '')
        if next_word == '' or next_word == 'endseq':
            break
        caption += ' ' + next_word
    return caption

In [None]:
def show_prediction(
    image_path,
    model,
    tokenizer,
    max_length,
    dataframe,
    save_image = False
):
  
    image_id =os.path.basename(image_path)
    # print(image_id)

    # Generate caption 
    pred_caption = generate_caption(
        image_path,
        model,
        tokenizer,
        max_length
    )

    pred_caption = (
        pred_caption
        .replace("startseq", "")
        .replace("endseq", "")
        .strip()
    )

    true_caption = dataframe[dataframe["image_name"] == image_id]["comment"].iloc[1]
    true_caption = (
        true_caption
        .replace("startseq", "")
        .replace("endseq", "")
        .strip()
    )

    img = Image.open(image_path)

    plt.figure(figsize=(8, 8))
    plt.imshow(img)
    plt.axis("off")

    title_text = (
         "\n\nActual: \n"
        +"\n".join(textwrap.wrap(true_caption, 60))
        +"\n"
        +"Predicted:\n"
        + "\n".join(textwrap.wrap(pred_caption, 60))
    )

    plt.title(title_text, fontsize=11)
    if save_image:
        save_path = os.path.join(f"{image_id}_caption.png")
        plt.savefig(save_path, bbox_inches="tight", dpi=200)
        print(f"Image saved at: {save_path}")    
    plt.show()


In [None]:
image_url = "flickr30k_images/1007205537.jpg"
show_prediction(
    image_url,
    model,
    tokenizer,
    max_length - 1,
    test,
    save_image = True
)

In [None]:

def clean_tokens(tokens):
    return [
        t for t in tokens
        if t not in {"startseq", "endseq", "<pad>"}
    ]

test_images = test['image_name'].unique().tolist()
random.seed(42)
sample_images = random.sample(test_images, min(500, len(test_images)))

refs, hyps = [], []
smooth = SmoothingFunction().method1

for img_id in tqdm(sample_images):

    img_loc = "flickr30k_images" + "/"+img_id
    pred_caption = generate_caption(
        img_loc,
        model,
        tokenizer,
        max_length -1       
    )

    hyp_tokens = clean_tokens(pred_caption.split())
    hyps.append(hyp_tokens)

    true_caps = test_df[test['image_name'] == img_id]['comment'].tolist()
    ref_tokens = [clean_tokens(cap.split()) for cap in true_caps]

    refs.append(ref_tokens)

print("BLEU-1:", corpus_bleu(refs, hyps, weights=(1,0,0,0), smoothing_function=smooth))
print("BLEU-2:", corpus_bleu(refs, hyps, weights=(0.5,0.5,0,0), smoothing_function=smooth))
print("BLEU-3:", corpus_bleu(refs, hyps, weights=(0.33,0.33,0.33,0), smoothing_function=smooth))
print("BLEU-4:", corpus_bleu(refs, hyps, weights=(0.25,0.25,0.25,0.25), smoothing_function=smooth))
