# Import libraries

In [None]:
!pip install bitsandbytes>=0.39.0 accelerate>=0.20.0
!pip install transformers

In [None]:
import os
import pandas as pd
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration

# Instantiate model

In [None]:
# Instantiate BLIP-2 model. via
# https://huggingface.co/Salesforce/blip2-flan-t5-xxl
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xxl", device_map="auto", load_in_8bit=True)

# Mount drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create directory to store inferences

In [None]:
os.makedirs('/content/drive/MyDrive/stance_detection_datasets/inferences', exist_ok=True)

# Import datasets

In [None]:
constraint22_dataset_uspolitics_test = pd.read_csv('/content/drive/MyDrive/stance_detection_datasets/constraint22_dataset_uspolitics/constraint22_dataset_uspolitics_test.csv')
constrain22_dataset_covid19_test = pd.read_csv('/content/drive/MyDrive/stance_detection_datasets/constrain22_dataset_covid19/constrain22_dataset_covid19_test.csv')
DISARM_test_all = pd.read_csv('/content/drive/MyDrive/stance_detection_datasets/DISARM/DISARM_test_all.csv')
total_defense_memes = pd.read_csv('/content/drive/MyDrive/stance_detection_datasets/total_defense_memes/total_defense_memes.csv')

# Balanced sampling

In [None]:
constraint22_dataset_uspolitics_test = constraint22_dataset_uspolitics_test.groupby('role').sample(n=250, random_state=1).reset_index(drop=True)
constrain22_dataset_covid19_test = constrain22_dataset_covid19_test.groupby('role').sample(n=190, random_state=1).reset_index(drop=True)

# Select unique images from **total_defense_memes**

In [None]:
total_defense_memes = total_defense_memes[['image']].drop_duplicates().reset_index(drop=True)

# Set prompt and define a function to call the model

In [None]:
# Use BLIP-2 for the inference. via
# https://huggingface.co/Salesforce/blip2-flan-t5-xxl
def get_caption(image, entity):
    prompt = f'What does the meme show, in particular text and entities such as {entity}? Describe in english:'
    raw_image = Image.open(image).convert('RGB')
    inputs = processor(raw_image, prompt, return_tensors="pt").to("cuda")
    out = model.generate(**inputs, max_new_tokens=30)
    return processor.decode(out[0], skip_special_tokens=True)

In [None]:
# Use BLIP-2 for the inference. via
# https://huggingface.co/Salesforce/blip2-flan-t5-xxl
def get_caption_TDEF(image):
    prompt = 'What does the meme show, in particular text and entities? Describe in english:'
    raw_image = Image.open(image).convert('RGB')
    inputs = processor(raw_image, prompt, return_tensors="pt").to("cuda")
    out = model.generate(**inputs, max_new_tokens=30)
    return processor.decode(out[0], skip_special_tokens=True)

# Call the `get_caption` and `get_caption_TDEF` functions and save inferences

In [None]:
uspolitics_test_images = constraint22_dataset_uspolitics_test['image'].values
uspolitics_test_entities = constraint22_dataset_uspolitics_test['entity'].values
constraint22_dataset_uspolitics_test['caption'] = [get_caption(image, entity) for image, entity in zip(uspolitics_test_images, uspolitics_test_entities)]
constraint22_dataset_uspolitics_test.to_csv('/content/drive/MyDrive/stance_detection_datasets/inferences/constraint22_dataset_uspolitics_test_captioned_BLIP-2.csv', index=False)

In [None]:
covid19_test_images = constrain22_dataset_covid19_test['image'].values
covid19_test_entities = constrain22_dataset_covid19_test['entity'].values
constrain22_dataset_covid19_test['caption'] = [get_caption(image, entity) for image, entity in zip(covid19_test_images, covid19_test_entities)]
constrain22_dataset_covid19_test.to_csv('/content/drive/MyDrive/stance_detection_datasets/inferences/constrain22_dataset_covid19_test_captioned_BLIP-2.csv', index=False)

In [None]:
DISARM_test_all_images = DISARM_test_all['image'].values
DISARM_test_all_entities = DISARM_test_all['target'].values
DISARM_test_all['caption'] = [get_caption(image, entity) for image, entity in zip(DISARM_test_all_images, DISARM_test_all_entities)]
DISARM_test_all.to_csv('/content/drive/MyDrive/stance_detection_datasets/inferences/DISARM_test_all_captioned_BLIP-2.csv', index=False)

In [None]:
total_defense_memes_images = total_defense_memes['image'].values
total_defense_memes['caption'] = [get_caption_TDEF(image) for image in total_defense_memes_images]
total_defense_memes.to_csv('/content/drive/MyDrive/stance_detection_datasets/inferences/total_defense_memes_captioned_BLIP-2.csv', index=False)