In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from torchvision import models, datasets, transforms

import numpy as np
import time
import sys
import os
import json
import tqdm
from PIL import Image

from models import utils, caption
from configuration import Config

In [2]:
MAX_DIM = 224
config = Config()

class p2Data(Dataset):
    def __init__(self, fnames, transform=None):
        self.transform = transform
        self.fnames = fnames
        self.file_list = [file for file in os.listdir(fnames) if file.endswith('.jpg')]
        self.file_list.sort()
        self.num_samples = len(self.file_list)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        fname = self.file_list[idx]
        filepath = os.path.join(self.fnames, fname)
        img = Image.open(filepath)
        img = self.transform(img)
        return img, fname
    
class to_dim:
    def __init__(self):
        self.dim = 3

    def __call__(self, x):
        if x.shape[0] == 1:
            x = x.repeat(3,1,1)
        return x

def create_caption_and_mask(start_token, max_length):
    caption_template = torch.zeros((imgs.shape[0], max_length), dtype=torch.long)
    # print('caption_template = ', caption_template)
    mask_template = torch.ones((imgs.shape[0], max_length), dtype=torch.bool)
    # print('mask_template = ', mask_template)
    caption_template[:, 0] = start_token
    mask_template[:, 0] = False

    return caption_template, mask_template

In [3]:
test_transform = transforms.Compose([
    transforms.ToTensor(),
    to_dim(),    
    transforms.Resize((MAX_DIM, MAX_DIM)),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [4]:
test_image_dir = 'D:/NTU/DLCV/hw3/hw3_data/p2_data/images/val'
test_set = p2Data(test_image_dir, transform=test_transform)
test_dataloader = DataLoader(test_set, batch_size=32, shuffle=False)
device = torch.device(config.device)
print(f'Initializing Device: {device}')

Initializing Device: cuda


In [5]:
seed = config.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

In [6]:
model, criterion = caption.build_model(config)
model.to(device)
model.eval()



Caption(
  (backbone): Joiner(
    (0): Backbone(
      (body): IntermediateLayerGetter(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): FrozenBatchNorm2d()
        (relu): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): Bottleneck(
            (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn1): FrozenBatchNorm2d()
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): FrozenBatchNorm2d()
            (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn3): FrozenBatchNorm2d()
            (relu): ReLU(inplace=True)
            (downsample): Sequential(
              (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): FrozenBatchNorm2d()
            )
 

In [7]:
config.checkpoint

'./checkpoint.pth'

In [8]:
if os.path.exists(config.checkpoint):
    print("Loading Checkpoint...")
    # checkpoint = torch.load(config.checkpoint, map_location='cpu')
    checkpoint = torch.load(config.checkpoint)
    model.load_state_dict(checkpoint['model'])

Loading Checkpoint...


In [9]:
print(f"Valid: {len(test_set)}")

Valid: 1789


In [10]:
tokenizer = Tokenizer.from_file("D:/NTU/DLCV/hw3/hw3_data/caption_tokenizer.json")
total = len(test_dataloader)

In [11]:
start_token = 2
end_token = 3

In [12]:
max_len = 60
result_dict = {}
with tqdm.tqdm(total=total) as pbar:
    with torch.no_grad():
        for k, (imgs, fnames) in enumerate(test_dataloader):
            imgs = imgs.to(device)
    #         print(imgs.shape)

            cap, cap_mask = create_caption_and_mask(start_token, config.max_position_embeddings)
            cap = cap.to(device)
            cap_mask = cap_mask.to(device)
    #         print(cap.shape)

            for i in range(max_len):
                predictions = model(imgs, cap, cap_mask)[:, i, :]
                predicted_id = torch.argmax(predictions, axis=-1)
                for j in range(imgs.shape[0]):
                    if predicted_id[j] != 3:
                        cap[j, i + 1] = predicted_id[j]
                        cap_mask[j, i + 1] = False

            for r in range(imgs.shape[0]):
                s = tokenizer.decode(cap[r].tolist(), skip_special_tokens=True).capitalize().split('.')[0]
                s = s[:-1]+'.'
                name = fnames[r][:-4]
                print(f'{name}: {s}')
                result_dict[name] = s
            pbar.update(1)

  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
  2%|█▍                                                                                 | 1/56 [00:23<21:32, 23.50s/it]

000000000368: A group of people playing frisbee in a field.
000000000620: A plate with a sandwich and a fork and a fork.
000000001548: A man is standing on a surfboard in the water.
000000001999: A cat is sitting on a bed with a blanket.
000000002982: A train is on a track near a train station.
000000003461: A man and a woman are walking down a sidewalk.
000000003771: A sheep is eating grass in a field.
000000003999: A cat is laying on a bed in a room.
000000004956: An elephant is standing in the grass with a tree in the background.
000000005418: A group of giraffes standing in a field.
000000005434: A dog is sitting on a bench in front of a fence.
000000005757: A train is parked on the side of a road.
000000005811: A red and white bus is parked on the side of the road.
000000006393: A man in a black shirt and a black shirt is standing on a toilet.
000000006789: A large boat is parked on the side of a road.
000000007201: A red and white photo of a red fire hydrant.
000000008320: A grou

  2%|█▍                                                                                 | 1/56 [00:28<25:56, 28.30s/it]


KeyboardInterrupt: 

In [None]:
json_object = json.dumps(result_dict, indent=4)
with open("p2_output.json", "w") as outfile:
    outfile.write(json_object)

In [14]:
result_dict

{'000000000368': 'A man is playing frisbee in a field.',
 '000000000620': 'A man in a white shirt is standing in front of a building.',
 '000000001548': 'A man riding a wave on a surfboard in the snow.',
 '000000001999': 'A cat is sitting on a wooden bench.',
 '000000002982': 'A large clock tower with a clock on it.',
 '000000003461': 'A man is riding a skateboard on a street.',
 '000000003771': 'A dog is running in the grass.',
 '000000003999': 'A cat sitting on top of a wooden table.',
 '000000004956': 'A black dog is standing on a rock.',
 '000000005418': 'A giraffe standing in a field with a tree in the background.',
 '000000005434': 'A large clock on a pole in a room.',
 '000000005757': 'A train is parked on the tracks in front of a building.',
 '000000005811': 'A red bus is parked on the side of a street.',
 '000000006393': 'A bathroom with a toilet and a toilet with a toilet.',
 '000000006789': 'A large plane is on a runway in the water.',
 '000000007201': 'A man is riding a wav

In [15]:
len(result_dict)

1789