In [1]:

from model.ImageCaptioningModel_transformer import ImageCaptionModel
from model.VIT import VIT
import os.path

from processing.image_processing import process_image
from settings import BASE_DIR
from train.helper import load_checkpoint

model_path = os.path.join(BASE_DIR, "train", "model","cnn_transformer", "best_model.pth")
model, _, _ = load_checkpoint(ImageCaptionModel, model_path, 0.0001)
model


ImageCaptionModel(
  (cnn_model): CNNEncoder(
    (backbone): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=

In [3]:
from settings import DEVICE
import torch

vocab = model.vocab
print(vocab.vocab_size)
def generate_caption(model, image, vocab, max_len=30):
    model.eval()
    with torch.no_grad():
        image = image.to(DEVICE).unsqueeze(0)  # (1, 3, 224, 224)

        # Encode ảnh
        encoder_output = model.encoder_vit(image)

        # Khởi tạo caption đầu vào với <START>
        input_ids = [vocab.w2i["<START>"]]

        for _ in range(max_len):
            input_tensor = torch.tensor([input_ids], device=DEVICE)


            # Dự đoán từ tiếp theo
            output = model.decoder_vit(input_tensor, encoder_output)
            next_token_logits = output[0, -1, :]
            next_token = torch.argmax(next_token_logits).item()

            if next_token == vocab.w2i["<END>"]:
                break

            input_ids.append(next_token)

        # Chuyển token ID -> từ
        caption = [vocab.i2w[idx] for idx in input_ids[1:]]
        return " ".join(caption)
from PIL import Image
from processing.image_processing import process_image
image = Image.open("E:\\python_prj\\dataset\\flickr30k\\flickr30k-images\\1007129816.jpg").convert("RGB")
image.show()
processed_image = process_image(image)
caption = model.generate_caption(processed_image)
print(caption)

7472
a man in a blue hat and a hat is holding a hat


In [5]:
predictions = []
references = []
import pandas as pd
test_df = pd.read_csv(os.path.join(BASE_DIR, "dataset" , "flickr30k", "test.csv"))
import time
start_time = time.time()
for index, row in test_df.iterrows():
    image = Image.open(row["image_path"]).convert("RGB")
    processed_image = process_image(image)
    caption = model.generate_caption(processed_image)
    predictions.append(caption)
    references.append(row["caption"])
end_time = time.time()
print(f"Time taken to generate caption: {(end_time - start_time)/len(predictions)} seconds")



Time taken to generate caption: 0.08689800748825073 seconds


In [6]:
import evaluate

bleu = evaluate.load("bleu")
# === Tính toán ===
bleu_result = bleu.compute(predictions=predictions, references=references)


# === In kết quả ===
print("BLEU:", bleu_result)



BLEU: {'bleu': 0.05825142298260339, 'precisions': [0.3062422853112326, 0.09437246180622703, 0.03585955898094626, 0.015775593382881804], 'brevity_penalty': 0.9160754964746836, 'length_ratio': 0.9194079214020525, 'translation_length': 56710, 'reference_length': 61681}


In [7]:
meteor = evaluate.load("meteor")
meteor_result = meteor.compute(predictions=predictions, references=references)
print("METEOR:", meteor_result)

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\huynh\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\huynh\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\huynh\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


METEOR: {'meteor': 0.21028330381881552}


In [8]:
rouge = evaluate.load("rouge")
rouge_result = rouge.compute(predictions=predictions, references=references)
print("ROUGE:", rouge_result)

ROUGE: {'rouge1': 0.29220539708642773, 'rouge2': 0.08895517227881751, 'rougeL': 0.269766067777997, 'rougeLsum': 0.26985112498941766}
