In [1]:
import torch
import argparse
import torch.optim as optim
from model import CNNtoRNN
from get_loader import get_loader
import torchvision.transforms as transforms
from PIL import Image

In [2]:
transform = transforms.Compose(
    [
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    step = checkpoint["step"]
    return step

In [3]:
train_loader, dataset = get_loader(
    root_folder="data/images",
    annotation_file="data/captions.txt",
    transform=transform,
    num_workers=2,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
model = CNNtoRNN(256, 256, 2994, 1).cuda()
optimizer = optim.Adam(model.parameters(), lr=3e-4)
checkpoint = load_checkpoint(torch.load("checkpoint/model_checkpoint.pth.tar"), model, optimizer)
model.eval()

=> Loading checkpoint


CNNtoRNN(
  (encoderCNN): EncoderCNN(
    (inception): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU

In [42]:
test_img1 = transform(Image.open("test_images/img/dog.jpg").convert("RGB")).unsqueeze(0)
print("Example 1 CORRECT: a white dog is standing in the field")
print("Example 1 OUTPUT: "+ " ".join(model.caption_image(test_img1.to(device),dataset.vocab)))

Example 1 CORRECT: a white dog is standing in the field
Example 1 OUTPUT: <SOS> a white dog is running through a field . <EOS>


In [10]:
from utils import print_examples

In [11]:
print_examples(model, device, dataset)

Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a white dog is running through a field . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a little girl in a pink shirt is swinging on a swing . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a red jacket and a white helmet is riding a bicycle in a city street . <EOS>
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man is standing on a rock in the ocean . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> a man is standing on a dirt road with his dog . <EOS>


In [38]:
parser = argparse.ArgumentParser()
parser.add_argument("--image_path", type=str, default = "test_images/img/dog.jpg")
parser.add_argument("--user_caption", type=str, default = "a white dog is standing in the field")
user_args = parser.parse_args()

usage: ipykernel_launcher.py [-h] [--image-path IMAGE_PATH]
                             [--user-caption USER_CAPTION]
ipykernel_launcher.py: error: unrecognized arguments: -f /home/sajid/.local/share/jupyter/runtime/kernel-3789cf1d-59d5-4b7c-8545-624d760953ec.json


SystemExit: 2

In [None]:
test_img1 = transform(Image.open(user_args.image_path).convert("RGB")).unsqueeze(0)
correct_caption = user_args.user_caption
print("CORRECT USER CAPTION: "+correct_caption)
print("PREDICTED IMAGE OUTPUT: "+ " ".join(model.caption_image(test_img1.to(device),dataset.vocab)))