# nlpconnect Baseline

In [36]:
import dataset as ds
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from torch.utils.data import DataLoader
from torchvision import transforms as tt
import numpy as np

In [37]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [38]:
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

In [39]:
scale = tt.Resize((336, 336))
tensor = tt.PILToTensor()
image_composed = tt.transforms.Compose([tensor])

test_set = ds.VisualWSDDataset(mode="test", image_transform=image_composed)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

In [40]:
max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

In [43]:
def test():
    
    with torch.no_grad():

        results = []

        for a,data in enumerate(test_loader):
            images = data["imgs"]
            text = data["label_context"][0]
            correct_idx = data["correct_idx"].item()


            print("----------------------------")
            print("batch: " + str(a+1) + "/" + str(len(test_loader)))
            print("label: " + str(text))
            print("correct index: " + str(correct_idx))

            images_two = []
            for img in images:
                images_two.append(tt.functional.to_pil_image(torch.squeeze(img), mode='RGB'))

            pixel_values = feature_extractor(images=images_two, return_tensors="pt").pixel_values
            pixel_values = pixel_values.to(device)

            output_ids = model.generate(pixel_values, **gen_kwargs)

            preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            preds = [pred.strip() for pred in preds]
            results.append(preds)

            for p in preds:
                print(p)
            
    return results

In [44]:
test()

----------------------------
batch: 1/463
label: football goal
correct index: 8
a field with a fence and a stadium
a man in a blue shirt with blue eyes
a man holding a soccer ball in his right hand
a man holding a tennis racquet in front of a building
a man in a suit and tie
a man kicking a soccer ball on a field
a man in a baseball uniform holding a baseball bat
a large stadium with a large number of fans on it
a grassy field with a soccer ball in it
a blue and white sign on a white wall
----------------------------
batch: 2/463
label: mustard seed
correct index: 0
a close up picture of some food on a table
a garden filled with lots of different types of vegetables
a row of yellow and white flowers in a field
a black and white bird sitting on top of a wood block
a yellow bird sitting on top of a branch
a single branch of a plant in the middle of a sunny day
a field with a barn and some trees
a spoon in a bowl with a spoon in it
a man riding on the back of a brown horse
a green plant g

KeyboardInterrupt: 