# Artwork Captioning: Demo

In [None]:
! pip install -qq transformers datasets torch torchvision evaluate rouge_score onedrivedownloader

In [None]:
from onedrivedownloader import download

ln = "https://unibari-my.sharepoint.com/:u:/g/personal/n_fanelli10_studenti_uniba_it/EXeIINJMf65PqLelHAsvhvcBtOSCrRdnCRO2LGGDVE08Gw?e=cs8crd"
download(ln, filename="file.zip", unzip=True)

## Imports

In [None]:
from transformers import AutoProcessor, AutoModelForCausalLM, MarianMTModel, MarianTokenizer
import torch
import truecase
import gradio as gr
import requests
from PIL import Image
import os
import sys
from pathlib import Path
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTModel
from functools import partial
from pathlib import Path

import numpy as np
from joblib import load
from transformers import ViTImageProcessor

## Set Models

In [None]:
class ViTForMultiClassification(nn.Module):
    def __init__(
        self,
        multiclass_classifications: dict[str, int],
        multilabel_classifications: dict[str, int],
        dropout_rate: float = 0.0,
    ):
        """Initialize a ViTForMultiClassification model for multi-classification and multi-label classification.

        Args:
            multiclass_classifications (dict[str, int]): dictionary of multiclass classification with feature names and number of classes
            multilabel_classifications (dict[str, int]): dictionary of multilabel classification with feature names and number of classes
            multiclass_class_weights (dict[str, torch.Tensor]): dictionary of weights for each class in each multiclass classification
        """
        super(ViTForMultiClassification, self).__init__()
        self.multiclass_classifications = multiclass_classifications
        self.multilabel_classifications = multilabel_classifications
        self.n_classifications = len(multiclass_classifications) + len(
            multilabel_classifications
        )

        # initialize ViT model
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k", add_pooling_layer=False)

        # add final dropout
        self.dropout = nn.Dropout(dropout_rate)

        # initialize classification heads
        if self.multiclass_classifications:
            self.multiclass_fcs = nn.ModuleList(
                [
                    nn.Linear(self.vit.config.hidden_size, num_classes)
                    for num_classes in multiclass_classifications.values()
                ]
            )

        if self.multilabel_classifications:
            self.multilabel_fcs = nn.ModuleList(
                [
                    nn.Linear(self.vit.config.hidden_size, num_classes)
                    for num_classes in multilabel_classifications.values()
                ]
            )

        # loss weights (if multitask learning)
        if self.n_classifications > 1:
            self.log_vars = nn.Parameter(
                torch.zeros(
                    self.n_classifications, dtype=torch.float32, requires_grad=True
                )
            )
        else:
            self.log_vars = None

    def freeze_base_model(self, freeze: bool):
        """Toggle freeze/unfreeze of the ViT model.

        Args:
            freeze (bool): freeze or unfreeze
        """
        for param in self.vit.parameters():
            param.requires_grad = not freeze

    def freeze_log_vars(self, freeze: bool):
        """Toggle freeze/unfreeze of the log vars.

        Args:
            freeze (bool): freeze or unfreeze
        """
        if freeze and self.log_vars is not None:
            frozen_log_vars = nn.Parameter(self.log_vars.clone().detach(), requires_grad=False)
            self.log_vars = frozen_log_vars

    def forward(
        self, pixel_values
    ):
        """Forward pass for ViTForMultiClassification model.

        Args:
            pixel_values (torch.Tensor): pixel values of the images
        """
        x = self.vit(pixel_values=pixel_values).last_hidden_state[:, 0]
        x = self.dropout(x)
        logits = None

        if self.multiclass_classifications:
            multiclass_logits = tuple(fc(x) for fc in self.multiclass_fcs)
            logits = multiclass_logits
        if self.multilabel_classifications:
            multilabel_logits = tuple(fc(x) for fc in self.multilabel_fcs)
            if logits is not None:
                logits = logits + multilabel_logits
            else:
                logits = multilabel_logits

        logits_dict = {
            feature: logits[i]
            for i, feature in enumerate(
                list(self.multiclass_classifications.keys())
                + list(self.multilabel_classifications.keys())
            )
        }
        return logits_dict

In [None]:
MULTICLASS_FEATURES = ("artist", "style", "genre")
MULTILABEL_FEATURES = ("tags", "media")


def get_multiclassification_dicts():
    """Get multiclassification dicts.

    Returns:
        multiclass_classifications (dict): dict with number of classes for each classification feature
    """
    multiclass_classifications = {}
    multilabel_classifications = {}

    for feature in MULTICLASS_FEATURES:
        ordinal_encoder = load(Path("data") / "ordinal_encoders" / f"{feature}.joblib")
        multiclass_classifications[feature] = len(ordinal_encoder.categories_[0])

    for feature in MULTILABEL_FEATURES:
        multilabel_binarizer = load(
            Path("sklearn_encoders") / "multilabel_binarizers" / f"{feature}.joblib"
        )
        multilabel_classifications[feature] = len(multilabel_binarizer.classes_)

    return multiclass_classifications, multilabel_classifications


class ViTForMultiClassificationPredictor:
    def __init__(self, model_path, device, batch_size=1):
        self.multiclassification_dicts = get_multiclassification_dicts()
        self.model = self.load_model(model_path, device)

        self.ordinal_encoders = {}
        self.multilabel_binarizers = {}
        for feature in self.multiclassification_dicts[0].keys():
            ordinal_encoder = load(
                Path("sklearn_encoders") / "ordinal_encoders" / f"{feature}.joblib"
            )
            self.ordinal_encoders[feature] = ordinal_encoder
        for feature in self.multiclassification_dicts[1].keys():
            multilabel_binarizer = load(
                Path("sklearn_encoders") / "multilabel_binarizers" / f"{feature}.joblib"
            )
            self.multilabel_binarizers[feature] = multilabel_binarizer

        self.processor = ViTImageProcessor.from_pretrained(
            "google/vit-base-patch16-224-in21k"
        )

        self.device = device
        self.batch_size = batch_size

    def load_model(self, model_path, device):
        model = ViTForMultiClassification(*self.multiclassification_dicts)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        model = model.to(device)
        model.train(False)
        return model

    @torch.no_grad()
    def predict(self, image):
        pixel_values = self.processor(images=image, return_tensors="pt")[
            "pixel_values"
        ].to(self.device)
        outputs = self.model(pixel_values)

        predicted_classes = {}
        predicted_labels = {}
        for feature, output in outputs.items():
            if feature in self.multiclassification_dicts[0]:
                predicted_class = torch.argmax(output, dim=1).item()
                predicted_class = self.ordinal_encoders[feature].inverse_transform(
                    [[predicted_class]]
                )[0]
                predicted_classes[feature] = predicted_class
            elif feature in self.multiclassification_dicts[1]:
                predicted_labels_example = torch.where(output > 0, 1, 0).cpu().numpy()
                predicted_labels_example = self.multilabel_binarizers[
                    feature
                ].inverse_transform(predicted_labels_example)[0]
                predicted_labels[feature] = predicted_labels_example

        return predicted_classes, predicted_labels

## Config

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "microsoft/git-base"
PROCESSOR = AutoProcessor.from_pretrained(MODEL_NAME)
MODEL = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
OUTPUT_DIR = "tutorial"

MULTICLASSIFICATION_MODEL = ViTForMultiClassificationPredictor(
    "models/model-20230513_121917-35.pt",
    DEVICE
)

checkpoint = torch.load("models/git_base_gs.pt")
MODEL.load_state_dict(checkpoint["model_state_dict"])
MODEL.to(DEVICE)
MODEL.train(False)

# load a translation model (en-it) from hf
TRANSLATION_MODEL = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-it")
TRANSLATION_MODEL.to(DEVICE)
TRANSLATION_MODEL.train(False)
TRANSLATION_TOKENIZER = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-it")

## Utils

In [None]:
# load all paths from tutorial folder
paths = [str(os.path.join("data/images", f)) for f in os.listdir("data/images") if f.endswith('.jpg')]

In [None]:
def capitalize_artist(s):
    return " ".join([word.capitalize() for word in s.split("-")])

In [None]:
def multiclassification_prediction_to_caption(prediction):
    caption = ""
    multiclass_preds = prediction[0]
    multilabel_preds = prediction[1]

    if multiclass_preds["artist"][0] != "other":
        artist_pred = capitalize_artist(multiclass_preds["artist"][0])
        caption += f"The artwork could be attributed to {artist_pred}, in the {multiclass_preds['genre'][0].capitalize()} genre, showcasing the {multiclass_preds['style'][0].capitalize()} style."
    else:
        caption += f"The artwork could be attributed to an unknown artist, in the {multiclass_preds['genre'][0].capitalize()} genre, showcasing the {multiclass_preds['style'][0].capitalize()} style."

    if multilabel_preds["tags"]:
        caption += f" It is associated with the following concepts: {', '.join(multilabel_preds['tags'])}."

    if multilabel_preds["media"]:
        caption += f" It is presented in the medium of {', '.join(multilabel_preds['media'])}."
    return caption.strip()


## Demo

In [None]:
# create a gradio interface with image input and text output
def captioning_pipeline(image, temperature, num_beams, min_length, do_sample):
    multiclassification_preds = MULTICLASSIFICATION_MODEL.predict(image)
    caption = multiclassification_prediction_to_caption(multiclassification_preds)
    
    pixel_values = PROCESSOR(images=image, return_tensors="pt").pixel_values.to(DEVICE)
    generated_ids = MODEL.generate(pixel_values=pixel_values, min_length=min_length, max_length=100, num_beams=num_beams, no_repeat_ngram_size=2, do_sample=do_sample, temperature=float(temperature))
    generated_caption = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
    generated_caption = truecase.get_true_case("It shows " + generated_caption)
    caption += f" The artwork depicts{generated_caption[8:]}."

    # translate caption to italian
    with torch.no_grad():
        translated_caption = TRANSLATION_MODEL.generate(**TRANSLATION_TOKENIZER(caption, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE))
        translated_caption = TRANSLATION_TOKENIZER.batch_decode(translated_caption, skip_special_tokens=True)[0]
        caption_it = f"{translated_caption[0].upper()}{translated_caption[1:]}"

    return caption, caption_it


In [None]:
# set theme to gstaff/whiteboard
with gr.Blocks() as demo:
    # set title and header
    demo.title = "Artwork Captioning"
    gr.Markdown("# Artwork Captioning")
    with gr.Row():
        with gr.Column(scale=1):
            image = gr.components.Image()
            temperature = gr.components.Slider(
                minimum=0.0,
                maximum=2.5,
                step=0.1,
                value=1.0,
                label="Temperature (works only if sampling is activated)",
            )
            num_beams = gr.components.Slider(
                minimum=1, maximum=5, step=1, value=4, label="Number of beams"
            )
            min_length = gr.components.Slider(
                minimum=5, maximum=30, step=1, value=10, label="Minimum length"
            )
            do_sample = gr.components.Checkbox(label="Sampling")
        with gr.Column(scale=2):
            caption = gr.components.Textbox(
                label="English caption"
            )
            caption_it = gr.components.Textbox(
                label="Italian caption"
            )
            # increase font size
            caption.fontsize = 100
            caption_it.fontsize = 100
    examples = gr.Examples(
        [[path] for path in paths],
        inputs=[image],
    )
    btn = gr.Button("Run")
    btn.click(
        captioning_pipeline,
        inputs=[image, temperature, num_beams, min_length, do_sample],
        outputs=[caption, caption_it],
    )

demo.launch()

In [None]:
MODEL.device