In [1]:
from build_model import EncoderDecoder
from load_data import CaptionsLoader
from torchvision import transforms
from torch import optim
from torch import nn
import torch
import os

In [2]:
BASE_DIR = "flickr-image-dataset"
IMAGES_DIR = os.path.join("flickr30k_images", "flickr30k_images", "flickr30k_images")
CAPTIONS_DIR = "flickr30k_images"
CAPTIONS_FILE = "results.csv"
LOAD_MODEL_PATH = 'image_captioning_model.pth.tar'

In [5]:

    captions_path = os.path.join(BASE_DIR, CAPTIONS_DIR, CAPTIONS_FILE)
    images_path = os.path.join(BASE_DIR, IMAGES_DIR)
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ]
    )
    cl = CaptionsLoader(captions_path, images_path, transform,
                        batch_size=32, num_workers=2, shuffle=True,
                        pin_memory=True, max_unk_freq=2)
    loader, dataset = cl.get_loader()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = EncoderDecoder(
        embedding_size=256,
        train_all=False,
        num_lstms=2,
        hidden_size=256,
        vocab_size=len(dataset.tokenizer),
        index_to_string=dataset.tokenizer.convert_idx_str
    ).to(device)

    lr = 2e-4
    optimizer = optim.Adam(model.parameters(), lr)
    criterion = nn.CrossEntropyLoss(ignore_index=dataset.tokenizer.convert_str_idx['<PAD>'])
    
    checkpoint = torch.load(LOAD_MODEL_PATH, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

In [6]:
from PIL import Image
image = transform(Image.open('boys_football.jpg').convert("RGB")).unsqueeze(0).to(device)

In [7]:
model.eval()

EncoderDecoder(
  (encoder): ImageEncoder(
    (image_model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True

In [8]:
model.predict(image, 50)

'a young boy in a red uniform is running on a field.'

baseball

In [9]:
image = Image.open('134206.jpg').convert("RGB")
image=transform(image).unsqueeze(0).to(device)

In [10]:
model.predict(image, 50)

'a pitcher is throwing a pitch at a baseball game.'

children painting at a table

In [11]:
image = Image.open('438106.jpg').convert("RGB")
image=transform(image).unsqueeze(0).to(device)

In [12]:
model.predict(image, 50)

'a woman in a white shirt is sitting on a bench.'

In [13]:
image = Image.open('elephant.jpg').convert("RGB")
image=transform(image).unsqueeze(0).to(device)

In [14]:
model.predict(image, 50)

'a dog is running through a field of grass.'

In [15]:
image = Image.open('women_soldiers.jpeg').convert("RGB")
image=transform(image).unsqueeze(0).to(device)

In [16]:
model.predict(image, 50)

'a man in a black shirt is standing in front of a crowd.'

In [17]:
for name in ['bus.png', 'child.jpg', 'dog.jpg', 'horse.png', 'boat.png']:
    image = Image.open(name).convert("RGB")
    image=transform(image).unsqueeze(0).to(device)
    print(name)
    print(model.predict(image, 50))
    print()

bus.png
a man in a blue shirt is standing in front of a bus stop.

child.jpg
a young boy in a blue shirt is running through a field.

dog.jpg
a dog is running on the beach.

horse.png
a man and a woman are walking on a path in the desert.

boat.png
a man is fishing off of a dock into a lake.

