In [1]:
import torch
from model import *
from get_loader import get_loader
import torchvision.transforms as transforms
from PIL import Image

In [13]:
test_image_path = "archive/images"
test_caption_path = "archive/captions.txt"

In [14]:
transform = transforms.Compose(
        [
            transforms.Resize((356, 356)),
            transforms.RandomCrop((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

In [15]:
test_loader, dataset = get_loader(
        root_folder=test_image_path,
        annotation_file=test_caption_path,
        transform=transform,
        num_workers=6,
    )

In [16]:
device = "cuda"
embedding_size = 256
hidden_size = 256
vocabulary_size = 2994
num_layers = 1
model = CNNtoRNNTranslator(embedding_size, hidden_size, vocabulary_size, num_layers).to(device)

In [17]:
checkpoint = torch.load("my_checkpoint.pth.tar")
model.load_state_dict(checkpoint["state_dict"])

<All keys matched successfully>

In [18]:
model.eval()

CNNtoRNNTranslator(
  (encoderCNN): EncoderCNN(
    (inception): Inception3(
      (Conv2d_1a_3x3): BasicConv2d(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (Conv2d_2a_3x3): BasicConv2d(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (Conv2d_2b_3x3): BasicConv2d(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (Conv2d_3b_1x1): BasicConv2d(
        (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=0.001, momentum=0.1

In [35]:
test_img1 = transform(Image.open("test/images/4.jpg").convert("RGB")).unsqueeze(0)

In [36]:
print(model.image_caption(test_img1.to(device), dataset.vocabulary))

['<SOS>', 'a', 'man', 'in', 'a', 'red', 'shirt', 'and', 'a', 'woman', 'in', 'a', 'white', 'shirt', 'and', 'sunglasses', '.', '<EOS>']


In [37]:
def load_model_to_captioning(train_image_path="archive/images",
                             train_caption_path="archive/captions.txt",
                             checkpoint_path="my_checkpoint.pth.tar"):
    
    transform = transforms.Compose(
        [
            transforms.Resize((356, 356)),
            transforms.RandomCrop((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
    
    test_loader, dataset = get_loader(
        root_folder=train_image_path,
        annotation_file=train_caption_path,
        transform=transform,
        num_workers=6,
    )
    
    device = "cuda"
    embedding_size = 256
    hidden_size = 256
    vocabulary_size = 2994
    num_layers = 1
    model = CNNtoRNNTranslator(embedding_size, hidden_size, vocabulary_size, num_layers).to(device)
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["state_dict"])
    
    return model, dataset.vocabulary

In [69]:
def create_caption(image_path, model, vocabulary):
    image = transform(Image.open(image_path).convert("RGB")).unsqueeze(0)
    word_list = model.image_caption(test_img1.to(device), vocabulary)
    word_list = [word for word in word_list if not (word in [*vocabulary.stoi][:4])]
    caption = ' '.join([str(elem) for elem in word_list])
    return caption

In [70]:
create_caption("test/images/4.jpg", model, dataset.vocabulary)

'a man in a red shirt and a woman in a white shirt and sunglasses .'

['<PAD>', '<SOS>', '<EOS>', '<UNK>']