In [6]:
from snorkel.labeling import labeling_function
import json
import os
import numpy as np

In [7]:
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score
import numpy as np

def calculate_metrics(y_true, y_pred, abstain_class=-1):
    # Filter out samples where prediction is -1
    valid_indices = y_pred != abstain_class
    y_true_filtered = y_true[valid_indices]
    y_pred_filtered = y_pred[valid_indices]

    # Compute metrics
    precision = precision_score(y_true_filtered, y_pred_filtered, average='macro')
    recall = recall_score(y_true_filtered, y_pred_filtered, average='macro')
    f1 = f1_score(y_true_filtered, y_pred_filtered, average='macro')
    accuracy = accuracy_score(y_true_filtered, y_pred_filtered)

    return {
        'Precision': precision,
        'Recall': recall,
        'F1 Score': f1,
        'Accuracy': accuracy
    }

In [8]:
POSITIVE = 1
NEGATIVE = 0
ABSTAIN = -1

@labeling_function()
def llava_7b(image_name):
    root_path = '../prompting_framework/prompting_results/traffic/'
    llava_7b_results = 'traffic-llava-7b.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data[image_name]["label"] if data[image_name]["label"] is not None else -1

# @labeling_function()
# def llava_34b(image_name):
#     root_path = '../prompting_framework/prompting_results/aircraft/'
#     llava_7b_results = 'oxford-llava_34b.json'
#     path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
#     with open(path_to_llava_7b_results, 'r') as file:
#         data = json.load(file)

#     return data.get(image_name, -1)

@labeling_function()
def llava_13b(image_name):
    root_path = '../prompting_framework/prompting_results/traffic/'
    llava_7b_results = 'traffic-llava-13b.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data[image_name]["label"] if data[image_name]["label"] is not None else -1

@labeling_function()
def bakllava(image_name):
    root_path = '../prompting_framework/prompting_results/traffic/'
    llava_7b_results = 'traffic-bakllava.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data[image_name]["label"] if data[image_name]["label"] is not None else -1

@labeling_function()
def llava_llama3(image_name):
    root_path = '../prompting_framework/prompting_results/traffic/'
    llava_7b_results = 'traffic-llava-llama3.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data[image_name]["label"] if data[image_name]["label"] is not None else -1


# @labeling_function()
# def moondream(image_name):
#     root_path = '../prompting_framework/prompting_results/traffic/'
#     llava_7b_results = 'aircraft-moondream.json'
#     path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
#     with open(path_to_llava_7b_results, 'r') as file:
#         data = json.load(file)

#     return data[image_name]["label"] if data[image_name]["label"] is not None else -1


@labeling_function()
def llava_phi3(image_name):
    root_path = '../prompting_framework/prompting_results/traffic/'
    llava_7b_results = 'traffic-llava-phi3.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data[image_name]["label"] if data[image_name]["label"] is not None else -1


@labeling_function()
def llama3_2_vision(image_name):
    root_path = '../prompting_framework/prompting_results/traffic/'
    llava_7b_results = 'traffic-llama3.2-vision.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data[image_name]["label"] if data[image_name]["label"] is not None else -1

In [9]:
train_data_json_path = '../prompting_framework/prompting_results/traffic/traffic-llama3.2-vision 11b-train-raw_info.json'
test_data_json_path = '../prompting_framework/prompting_results/traffic/traffic-llama3.2-vision 11b-test-raw_info.json'

with open(train_data_json_path, 'r') as file:
    train_data = json.load(file)


with open(test_data_json_path, 'r') as file:
    test_data = json.load(file)

# Extract and pad image names, ensuring they are 5 digits long before the '.png'
train_image_names = []
for item in train_data:
    train_image_names.append(item)


test_image_names = []
Y_test = []
for item in test_data:
    test_image_names.append(item)
    Y_test.append(test_data[item]["label"])

# with open(dev_data_json_path, 'r') as file:
#     dev_data = json.load(file)
    
# dev_image_names = []
# Y_dev = []
# for item in dev_data:
#     Y_dev.append(dev_data[item])
#     dev_image_names.append(item)

print(f"There are {len(train_image_names)} images in the Train set.")

print(f"There are {len(test_image_names)} images in the test set.")


There are 39209 images in the Train set.
There are 12630 images in the test set.


In [11]:
llama3_2_vision(test_image_names[0])

30

In [12]:
from snorkel.labeling import LFApplier

list_of_all_the_models = [
        'llava_13b',
       'llava_7b',
       'llava_llama3',
       'bakllava',
       'llama3_2_vision',
       'llava_phi3'
       ]

lfs = [llava_13b,
       llava_7b,
       llava_llama3,
       bakllava,
       llama3_2_vision,
       llava_phi3
       ]

applier = LFApplier(lfs)

In [None]:
from snorkel.labeling import LFAnalysis

L_test = applier.apply(test_image_names)
L_train = applier.apply(train_image_names)

12630it [49:46,  4.23it/s]
22710it [1:30:53,  4.09it/s]