#Evauation Frame on Text-to-Image Model

This frame is used to evaluate text-to-image models. With Input Text and Generated Images, it will give a score between 0 ~ 5, 0 means the worst. This frame has three components, "Object Detection", "Image Captioning" and "Sentence Analysis". This notebook select "Mask R-CNN", "ViT-GPT2" and "LaBSE" as an example combination to build this frame in sequence.

##1. Dependencies
In this part we will first install all the requirements and import needed packages

In [None]:
# transformers and flax
!pip install transformers
!pip install -U sentence-transformers
!pip install --upgrade pip
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install flax
!pip install --upgrade git+https://github.com/google/flax.git
!apt-cache policy libcudnn8
!apt install --allow-change-held-packages libcudnn8=8.4.1.50-1+cuda11.6
!export PATH=/usr/local/cuda-11.4/bin${PATH:+:${PATH}}
!export LD_LIBRARY_PATH=/usr/local/cuda-11.4/lib64:$LD_LIBRARY_PATH
!export LD_LIBRARY_PATH=/usr/local/cuda-11.4/include:$LD_LIBRARY_PATH
!export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/extras/CUPTI/lib64

In [None]:
# Install detectron2 
!python -m pip install pyyaml==5.1
import sys, os, distutils.core
# Note: This is a faster way to install detectron2 in Colab, but it does not include all functionalities.
# See https://detectron2.readthedocs.io/tutorials/install.html for full installation instructions
!git clone 'https://github.com/facebookresearch/detectron2'
dist = distutils.core.run_setup("./detectron2/setup.py")
!python -m pip install {' '.join([f"'{x}'" for x in dist.install_requires])}
sys.path.insert(0, os.path.abspath('./detectron2'))

### After Installation, now it's time to do some setup

In [None]:
# Some basic setup:
# Setup detectron2 logger

import requests
import transformers
from PIL import Image
from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel
import torch
from sentence_transformers import SentenceTransformer


import torch, detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
from google.colab.patches import cv2_imshow

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

## 2. Process

### 2.1 Data setup

In [None]:
# Generated Input image from Text-to-Image model
input_image_path = ""

# Input Text Prompt
input_text = ""

### 2.2 Image captioning

In this part we will use the Hugging Face transformers' Vision-To-Text Encoder-Decoder framework with the encoder Vision Transformer and decoder GPT2, to generate image caption.

In [None]:
# Define image captioning model
loc = "ydshieh/vit-gpt2-coco-en"
feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
tokenizer = AutoTokenizer.from_pretrained(loc)
image_cap_model = FlaxVisionEncoderDecoderModel.from_pretrained(loc)

In [None]:
# The function of predict image caption
def generate_caption(image_path, model):
  with Image.open(image_path) as img:
    all_pixels = feature_extractor(images=img, return_tensors="np").pixel_values
  output_ids = model.generate(pixel_values, max_length=16, num_beams=4).sequences
  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
  preds = [pred.strip() for pred in preds]

  return preds[0]

image_caption = generate_caption(input_image_path, image_cap_model)

print("Image caption is: ",image_caption)

### 2.3 Object Detection

In this part we will use Mask R-CNN to detect exactly the amount of objects if input text specified, since the image caption may not contains this attribute and the framework needs double check.

First we create a detectron2 config and a detectron2 `DefaultPredictor` to run inference.

In [None]:
cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)

Then count the number of each class

In [None]:
def count_number_objects(img_path, predictor):
  img = cv2.imread(img_path)
  outputs = predictor(img)
  instance_dict={}
  for class_id in outputs["instances"].pred_classes:
    class_name = classes[int(class_id)]
    instance_dict[class_name]=instance_dict.get(class_name, 0)+1
  
  return instance_dict

instance_count = count_number_objects(input_image_path, predictor)

print("The number of each class is: ", instance_count)

### 2.4 Sentence analysis

In this part we will analyze the input sentence to get fully information

In [None]:
import spacy
from spacy import displacy

def get_word_dict(text):
  # load english language model
  nlp = spacy.load('en_core_web_sm',disable=['ner','textcat'])

  # create spacy 
  doc = nlp(text)

  word_dict={}
  for token in doc:
    # extract subject
    if (token.pos_=='NOUN'):
        num=1
        for child in token.children:
          if child.dep_=='nummod':
            num=int(child.text)
        word_dict[token.lemma_]=num

  return word_dict

word_dict = get_word_dict(input_text)

print("The number of object required: ", word_dict)

Then we also need to get the embedd vectors of both Input Text Prompt and Image Caption by using LaBSE

In [None]:
# Define the sentence embed model
sentence_model = SentenceTransformer('sentence-transformers/LaBSE')

# encode sentences to vectors
input_embed = sentence_model.encode(input_text)
caption_embed = sentence_model.encode(image_caption)

## 3. Score

Now base on all the information we have, we can compute the final score.

In [None]:
# evaluation on the number of objects
def score_objects_num(word_dict, instance_dict):
  class_scores=[]
  for class_name in instance_dict.keys():
    if class_name in word_dict:
      class_scores.append(1-abs(instance_dict[class_name]-word_dict[class_name])/word_dict[class_name])

  return np.mean(class_scores)


# evaluation on sentence vector simi
def score_sentence_similarity(input_embed, caption_embed):
  similarity = torch.cosine_similarity(torch.tensor(input_embed), torch.tensor(caption_embed), dim=0)
  
  return similarity

# map score to range (0,5)
def map(score):
  return round(score*5)

score_1 = score_objects_num(word_dict, instance_dict)
score_2 = score_sentence_similarity(input_embed, caption_embed)

score = map(np.mean(score_1, score_2))

print("The Final Score is: ", score)