In [None]:
import torch
from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
from huggingface_hub import login
import requests
from PIL import Image
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, GPT2TokenizerFast
from tqdm import tqdm
from IPython.display import display

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(199)
np.random.seed(199)

In [None]:
dataset_train = load_dataset("martinsinnona/visdecode", split = "train")
dataset_test = load_dataset("martinsinnona/visdecode", split = "test")

In [None]:
print(dataset_train, dataset_test)

In [None]:
# Loading a fine-tuned image captioning Transformer Model

# ViT Encoder - Decoder Model
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)

# Corresponding ViT Tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

# Image processor
image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

In [None]:
# Accesssing images from the web
import urllib.parse as parse
import os
# Verify url
def check_url(string):
    try:
        result = parse.urlparse(string)
        return all([result.scheme, result.netloc, result.path])
    except:
        return False

# Load an image
def load_image(image_path):
    if check_url(image_path):
        return Image.open(requests.get(image_path, stream=True).raw)
    elif os.path.exists(image_path):
        return Image.open(image_path)

In [None]:
# Image inference
def get_caption_from_url(model, image_processor, tokenizer, image_path):
    
    image = load_image(image_path)
    return get_caption(model, image_processor, tokenizer, image)
    

def get_caption(model, image_processor, tokenizer, image):
    
    img = image_processor(image, return_tensors="pt").to(device)
    
    # Generating captions
    output = model.generate(**img)
    #print("tokens:",output)

    # decode the output
    caption = tokenizer.batch_decode(output, skip_special_tokens=True)[0]

    return caption

In [None]:
# Loading URLs
url = "https://images.pexels.com/photos/101667/pexels-photo-101667.jpeg?auto=compress&cs=tinysrgb&w=600"
# Display Image
display(load_image(url))

# Display Caption
get_caption_from_url(model, image_processor, tokenizer, url)

In [None]:
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-5)

In [None]:
epochs = 10
losses = []
batch_loss = 0

model.to(device)
model.train()

for epoch in range(epochs):
    for index, image in enumerate(dataset_train):
        
        if index % 50 == 0: print(index, " /", len(dataset_train))

        # Preprocessing the Image
        
        pixels = image_processor(image['image'].convert("RGB"), return_tensors="pt").to(device)
        
        stop = image['text'].find("</field>", 37)
        target_text = image['text'][64+7:stop] + "<|endoftext|>"
        
        target_sequence = tokenizer(target_text, return_tensors="pt", padding=True).input_ids.to(device)
        
        # Generating captions
        output = model(pixel_values=pixels['pixel_values'], labels=target_sequence)

        # Compute the loss
        loss = output.loss
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        
        batch_loss += loss.cpu().detach().numpy().item()
        
    batch_loss = batch_loss / len(dataset_train)
    
    print("Epoch: ", epoch, " | batch mean loss:", batch_loss)
    losses.append(batch_loss)
    
    batch_loss = 0
    
plt.plot(losses)

In [None]:
for index, data in enumerate(dataset_test):
    print(get_caption(model, image_processor, tokenizer, data['image'].convert("RGB")))

In [None]:
print(get_caption(model, image_processor, tokenizer, dataset_train[0]['image'].convert("RGB")))
dataset_train[0]['image']