In [2]:
import ollama 
import os
from tqdm import tqdm
import signal
import random
import numpy as np
import json
import numpy as np
from PIL import Image
import pandas as pd
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score


In [3]:
base_path = '/root/home/data/hateful_memes/'
images_path = os.path.join(base_path, "img")

list_of_image_names = os.listdir(images_path)

In [7]:
def check_yes_no(text):
    # Strip any leading/trailing whitespace and convert to lowercase
    text = text.strip().lower()

    # Check if the text starts with 'yes' or 'no'
    if text.startswith("yes"):
        return 1
    elif text.startswith("no"):
        return 0
    else:
        return None  
    
class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException


def read_jsonl_file(file_path):
    data = []
    with open(file_path, 'r') as f:
        for line in f:
            entry = json.loads(line.strip())
            data.append(entry)
    return data

def load_image_and_label(entry, img_base_path):
    img_path = f"{img_base_path}/{entry['img']}"
    try:
        img = Image.open(img_path)
        label = entry['label']
        return img, label
    except FileNotFoundError:
        print(f"Image {img_path} not found.")
        return None, None


def load_dev_file(input_file):
    dev_data = {}
    with open(input_file, 'r') as infile:
        for line in infile:
            # Load each JSON line as a dictionary
            entry = json.loads(line.strip())
        
            # Use the image path as the key and the label as the value
            dev_data[entry["img"]] = entry["label"]
    return dev_data

dev_file = os.path.join(base_path, "dev.jsonl")
dev_data = load_dev_file(dev_file)

In [5]:
list_of_models = ['llava:7b', 
                  'llava:13b',
                  'llava:34b',
                  'llava-llama3',
                  'bakllava',
                  'moondream',
                  'minicpm-v',
                  'llava-phi3']

ollama.pull('llava:7b') #pull the desired model

{'status': 'success'}

In [11]:
# Split dev_file
hateful_data = {}
nonhateful_data = {}

for key in dev_data :
    if dev_data[key] == 1 :
        hateful_data[key] = dev_data[key]
    else  :
        nonhateful_data[key] = dev_data[key]
hateful_images = list(hateful_data.keys())
nonhateful_images = list(nonhateful_data.keys())


In [36]:
options= {  # new
            "seed": 123,
            "temperature": 0,
            "num_ctx": 2048, # must be set, otherwise slightly random output
        }
    
llava_7b_labels = {}

np.random.seed(0)

#for image_name in tqdm(list_of_image_names):
for key, value in dev_data.items():
    print(f"Image: {key}, Label: {value}")
    
    random_image_index = np.random.randint(0, len(dev_data))
    random_image_path = os.path.join(base_path, images_list[random_image_index])
    
    image_path = os.path.join(base_path, key)
    
    prompt_1 = "Is this an offensive meme? Please answer with YES or NO. DO NOT mention the reason: "
    prompt_2 = "How about this one? (Yes or No)"
#     prompt = "Is there ? Please answer with YES or NO. DO NOT mention the reason: "
#     prompt = "Is this somehow an offensive meme? Please answer with YES or NO: "
#     prompt = "describe this image: "
    
    response = ollama.generate(model='llava:7b', prompt=prompt_1, images=[random_image_path], options=options)

    response_2 = ollama.generate(model='llava:7b', prompt=prompt_2, images=[image_path], options=options, context=response['context'])
    
#     response = ollama.chat(model='llava:34b', messages=[
#       {
#         'role': 'user',
#         'content': prompt,
#       }, 
#     ], options= {  # new
#             "seed": 123,
#             "temperature": 0,
#             "num_ctx": 2048, # must be set, otherwise slightly random output
#         })
    
#     print(prompt)
    print("------------------------------------------------------")
    print(image_name)
    # print(response_2['response'])
    
    label = check_yes_no(response_2['response'])
    
#     print(label)
    
    llava_7b_labels[image_name] = label
    
    # break

Image: img/08291.png, Label: 1
------------------------------------------------------
01247.png
Image: img/46971.png, Label: 1
------------------------------------------------------
01247.png
Image: img/03745.png, Label: 1
------------------------------------------------------
01247.png
Image: img/83745.png, Label: 1
------------------------------------------------------
01247.png
Image: img/80243.png, Label: 1
------------------------------------------------------
01247.png
Image: img/05279.png, Label: 1
------------------------------------------------------
01247.png
Image: img/01796.png, Label: 1
------------------------------------------------------
01247.png
Image: img/53046.png, Label: 1
------------------------------------------------------
01247.png
Image: img/82301.png, Label: 1
------------------------------------------------------
01247.png
Image: img/31752.png, Label: 1
------------------------------------------------------
01247.png
Image: img/27635.png, Label: 1
---------

In [12]:
print(random_image_path)

/root/home/data/hateful_memes/27816.png


In [None]:
# Experiment 3 | two prompts

options= {  # new
            "seed": 123,
            "temperature": 0,
            "num_ctx": 2048, # must be set, otherwise slightly random output
        }
    
labels_2context_dict = {}

np.random.seed(0)
for model in list_of_models :
    for entry in tqdm(dev_data):
        rand_hateful_index = np.random.randint(0, len(gt_hateful))
        rand_not_hateful_index = np.random.randint(0, len(gt_not_hateful))
        
        hateful_path = os.path.join(base_path, "img", gt_hateful[rand_hateful_index])
        not_hateful_path = os.path.join(base_path, "img", gt_not_hateful[rand_not_hateful_index])
        
        image_path = os.path.join(base_path, entry['img'])
        
        prompt_1 = "This is a offensive meme."
        prompt_2 = "This is not a offensive meme. "
        prompt_3 = "Based on the two previous prompts. Is this an offensive meme? answer either yes or no?. "
    
        
        response_1 = ollama.generate(model=model, prompt=prompt_1, images=[hateful_path], options=options)
        response_2 = ollama.generate(model=model, prompt=prompt_2, images=[not_hateful_path], options=options, context=response_1['context'])
        response_3 = ollama.generate(model=model, prompt=prompt_3, images=[image_path], options=options, context=response_2['context'])
    
        label_2context = check_yes_no(response_3['response'])
    
        image_name =  os.path.basename(entry['img'])
        print(label_2context, '--', entry['label'], '--', entry['img'], '--', image_name)
    
        labels_2context_dict[image_name] = label_2context
        break

In [14]:

def compute_metrics(gt, predictions, name, output_file="metrics.csv"):
    # Ensure we only evaluate on common keys
    common_keys = set(gt.keys()).intersection(predictions.keys())
    
    # Extract lists of labels based on the common keys, filtering out None values
    y_true = []
    y_pred = []
    for key in common_keys:
        pred_label = predictions[key]
        if pred_label is not None:
            y_true.append(gt[key])
            y_pred.append(pred_label)
    
    # Check if there are valid entries left after filtering
    if y_true and y_pred:
        # Calculate metrics
        precision = precision_score(y_true, y_pred)
        recall = recall_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred)
        accuracy = accuracy_score(y_true, y_pred)
        
        # Create a DataFrame to store the results
        metrics_df = pd.DataFrame({
            "Model": [name],
            "Precision": [precision],
            "Recall": [recall],
            "F1 Score": [f1],
            "Accuracy": [accuracy]
        })
        
        # Display the table
        print(metrics_df)
        
        # Save to a file (append if file already exists)
        with open(output_file, "a") as f:
            metrics_df.to_csv(f, index=False, header=f.tell()==0)
    else:
        print(f"No valid entries to compute metrics for {name}")

# Example usage
# compute_metrics(gt_dict, predictions_dict, 'Model Metrics')
