# Import libraries

In [None]:
# Pre-requisites. via
# https://github.com/haotian-liu/LLaVA
!git clone https://github.com/haotian-liu/LLaVA.git
%cd LLaVA
!conda create -n llava python=3.10 -y
!conda activate llava
!pip install --upgrade pip
!pip install fastapi kaleido python-multipart typing-extensions==4.5.0 uvicorn
!pip install -e .

In [None]:
import os
import pandas as pd
# Import libraries. via
# https://medium.com/supportvectors/comparing-image-to-text-models-blip-blip2-and-llava-ebd66d2eb6d8
import torch
from llava.serve.cli import load_image
from llava.constants import IMAGE_TOKEN_INDEX
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

# Instantiate model

In [None]:
# Instantiate the model. via
# https://github.com/haotian-liu/LLaVA
# https://medium.com/supportvectors/comparing-image-to-text-models-blip-blip2-and-llava-ebd66d2eb6d8
model_path = 'liuhaotian/llava-v1.5-13b'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_path),
    device=device
)

# Mount drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create directory to store inferences

In [None]:
os.makedirs('/content/drive/MyDrive/stance_detection_datasets/inferences', exist_ok=True)

# Import datasets

In [None]:
constraint22_dataset_uspolitics_test = pd.read_csv('/content/drive/MyDrive/stance_detection_datasets/constraint22_dataset_uspolitics/constraint22_dataset_uspolitics_test.csv')
constrain22_dataset_covid19_test = pd.read_csv('/content/drive/MyDrive/stance_detection_datasets/constrain22_dataset_covid19/constrain22_dataset_covid19_test.csv')
DISARM_test_all = pd.read_csv('/content/drive/MyDrive/stance_detection_datasets/DISARM/DISARM_test_all.csv')
total_defense_memes = pd.read_csv('/content/drive/MyDrive/stance_detection_datasets/total_defense_memes/total_defense_memes.csv')

# Balanced sampling

In [None]:
constraint22_dataset_uspolitics_test = constraint22_dataset_uspolitics_test.groupby('role').sample(n=250, random_state=1).reset_index(drop=True)
constrain22_dataset_covid19_test = constrain22_dataset_covid19_test.groupby('role').sample(n=190, random_state=1).reset_index(drop=True)

# Select unique images from **total_defense_memes**

In [None]:
total_defense_memes = total_defense_memes[['image']].drop_duplicates().reset_index(drop=True)

# Set prompt and define a function to call the model

In [None]:
# Use llava-v1.5-13b for the inference. via
# https://github.com/haotian-liu/LLaVA/blob/main/llava/serve/cli.py
# https://medium.com/supportvectors/comparing-image-to-text-models-blip-blip2-and-llava-ebd66d2eb6d8
def get_caption(image, entity):
    image = load_image(image)
    image_tensor = image_processor([image], return_tensors='pt')['pixel_values']
    image_tensor = image_tensor.to(model.device, dtype=torch.float16)
    prompt = f'USER: <image>\nWhat does the meme show, in particular text and entities such as {entity}? Describe in english:\nASSISTANT:'
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
    stopping_criteria = KeywordsStoppingCriteria(["</s>"], tokenizer, input_ids)
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=True,
            temperature=0.2,
            max_new_tokens=512,
            use_cache=True,
            stopping_criteria=[stopping_criteria])
    output = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
    return output[:-4]

In [None]:
# Use llava-v1.5-13b for the inference. via
# https://github.com/haotian-liu/LLaVA/blob/main/llava/serve/cli.py
# https://medium.com/supportvectors/comparing-image-to-text-models-blip-blip2-and-llava-ebd66d2eb6d8
def get_caption_TDEF(image):
    image = load_image(image)
    image_tensor = image_processor([image], return_tensors='pt')['pixel_values']
    image_tensor = image_tensor.to(model.device, dtype=torch.float16)
    prompt = 'USER: <image>\nWhat does the meme show, in particular text and entities? Describe in english:\nASSISTANT:'
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
    stopping_criteria = KeywordsStoppingCriteria(["</s>"], tokenizer, input_ids)
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=True,
            temperature=0.2,
            max_new_tokens=512,
            use_cache=True,
            stopping_criteria=[stopping_criteria])
    output = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
    return output[:-4]

# Call the `get_caption` and `get_caption_TDEF` functions and save inferences

In [None]:
uspolitics_test_images = constraint22_dataset_uspolitics_test['image'].values
uspolitics_test_entities = constraint22_dataset_uspolitics_test['entity'].values
constraint22_dataset_uspolitics_test['caption'] = [get_caption(image, entity) for image, entity in zip(uspolitics_test_images, uspolitics_test_entities)]
constraint22_dataset_uspolitics_test.to_csv('/content/drive/MyDrive/stance_detection_datasets/inferences/constraint22_dataset_uspolitics_test_captioned_llava-v1.5-13b.csv', index=False)

In [None]:
covid19_test_images = constrain22_dataset_covid19_test['image'].values
covid19_test_entities = constrain22_dataset_covid19_test['entity'].values
constrain22_dataset_covid19_test['caption'] = [get_caption(image, entity) for image, entity in zip(covid19_test_images, covid19_test_entities)]
constrain22_dataset_covid19_test.to_csv('/content/drive/MyDrive/stance_detection_datasets/inferences/constrain22_dataset_covid19_test_captioned_llava-v1.5-13b.csv', index=False)

In [None]:
DISARM_test_all_images = DISARM_test_all['image'].values
DISARM_test_all_entities = DISARM_test_all['target'].values
DISARM_test_all['caption'] = [get_caption(image, entity) for image, entity in zip(DISARM_test_all_images, DISARM_test_all_entities)]
DISARM_test_all.to_csv('/content/drive/MyDrive/stance_detection_datasets/inferences/DISARM_test_all_captioned_llava-v1.5-13b.csv', index=False)

In [None]:
total_defense_memes_images = total_defense_memes['image'].values
total_defense_memes['caption'] = [get_caption_TDEF(image) for image in total_defense_memes_images]
total_defense_memes.to_csv('/content/drive/MyDrive/stance_detection_datasets/inferences/total_defense_memes_captioned_llava-v1.5-13b.csv', index=False)