In [1]:
from snorkel.labeling import labeling_function
import json
import os
import numpy as np

In [24]:
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score
import numpy as np

def calculate_metrics(y_true, y_pred, abstain_class=-1):
    # Filter out samples where prediction is -1
    valid_indices = y_pred != abstain_class
    y_true_filtered = y_true[valid_indices]
    y_pred_filtered = y_pred[valid_indices]

    # Compute metrics
    precision = precision_score(y_true_filtered, y_pred_filtered, average='macro')
    recall = recall_score(y_true_filtered, y_pred_filtered, average='macro')
    f1 = f1_score(y_true_filtered, y_pred_filtered, average='macro')
    accuracy = accuracy_score(y_true_filtered, y_pred_filtered)

    return {
        'Precision': precision,
        'Recall': recall,
        'F1 Score': f1,
        'Accuracy': accuracy
    }

In [17]:
POSITIVE = 1
NEGATIVE = 0
ABSTAIN = -1

@labeling_function()
def llava_7b(image_name):
    root_path = '../prompting_framework/prompting_results/aircraft/'
    llava_7b_results = 'aircraft-llava-7b.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data[image_name]["label"] if data[image_name]["label"] is not None else -1

# @labeling_function()
# def llava_34b(image_name):
#     root_path = '../prompting_framework/prompting_results/aircraft/'
#     llava_7b_results = 'oxford-llava_34b.json'
#     path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
#     with open(path_to_llava_7b_results, 'r') as file:
#         data = json.load(file)

#     return data.get(image_name, -1)

@labeling_function()
def llava_13b(image_name):
    root_path = '../prompting_framework/prompting_results/aircraft/'
    llava_7b_results = 'aircraft-llava-13b.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data[image_name]["label"] if data[image_name]["label"] is not None else -1

@labeling_function()
def bakllava(image_name):
    root_path = '../prompting_framework/prompting_results/aircraft/'
    llava_7b_results = 'aircraft-bakllava.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data[image_name]["label"] if data[image_name]["label"] is not None else -1

@labeling_function()
def llava_llama3(image_name):
    root_path = '../prompting_framework/prompting_results/aircraft/'
    llava_7b_results = 'aircraft-llava-llama3.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data[image_name]["label"] if data[image_name]["label"] is not None else -1


@labeling_function()
def moondream(image_name):
    root_path = '../prompting_framework/prompting_results/aircraft/'
    llava_7b_results = 'aircraft-moondream.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data[image_name]["label"] if data[image_name]["label"] is not None else -1


@labeling_function()
def llava_phi3(image_name):
    root_path = '../prompting_framework/prompting_results/aircraft/'
    llava_7b_results = 'aircraft-llava-phi3.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data[image_name]["label"] if data[image_name]["label"] is not None else -1


@labeling_function()
def llama3_2_vision(image_name):
    root_path = '../prompting_framework/prompting_results/aircraft/'
    llava_7b_results = 'aircraft-llama3.2-vision-11b.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data[image_name]["label"] if data[image_name]["label"] is not None else -1

In [18]:
train_data_json_path = '../prompting_framework/prompting_results/aircraft/aircraft-llava-phi3-train-raw_info.json'
val_data_json_path = '../prompting_framework/prompting_results/aircraft/aircraft-llava-phi3-val-raw_info.json'
test_data_json_path = '../prompting_framework/prompting_results/aircraft/aircraft-llava-phi3-test-raw_info.json'

with open(train_data_json_path, 'r') as file:
    train_data = json.load(file)

with open(val_data_json_path, 'r') as file:
    val_data = json.load(file)

with open(test_data_json_path, 'r') as file:
    test_data = json.load(file)

# Extract and pad image names, ensuring they are 5 digits long before the '.png'
train_image_names = []
for item in train_data:
    train_image_names.append(item)

val_image_names = []
Y_val = []
for item in val_data:
    val_image_names.append(item)
    Y_val.append(val_data[item]["label"])

test_image_names = []
Y_test = []
for item in test_data:
    test_image_names.append(item)
    Y_test.append(test_data[item]["label"])

# with open(dev_data_json_path, 'r') as file:
#     dev_data = json.load(file)
    
# dev_image_names = []
# Y_dev = []
# for item in dev_data:
#     Y_dev.append(dev_data[item])
#     dev_image_names.append(item)

print(f"There are {len(train_image_names)} images in the Train set.")
print(f"There are {len(val_image_names)} images in the val set.")
print(f"There are {len(test_image_names)} images in the test set.")


There are 2367 images in the Train set.
There are 2365 images in the val set.
There are 2368 images in the test set.


In [19]:
llama3_2_vision(train_image_names[0])

4

In [35]:
from snorkel.labeling import LFApplier

list_of_all_the_models = [
        'llava_13b',
       'llava_7b',
       'llava_llama3',
       'moondream',
       'bakllava',
       'llama3_2_vision',
       'llava_phi3'
       ]

lfs = [llava_13b,
       llava_7b,
       llava_llama3,
       moondream,
       llama3_2_vision,
       llava_phi3
       ]

applier = LFApplier(lfs)

In [36]:
from snorkel.labeling import LFAnalysis

L_test = applier.apply(test_image_names)
L_val = applier.apply(val_image_names)
L_train = applier.apply(train_image_names)

2368it [00:53, 44.61it/s]
2365it [00:52, 44.97it/s]
2367it [00:52, 44.88it/s]


In [37]:
Y_val = np.array(Y_val)
LFAnalysis(L_val, lfs).lf_summary(Y_val)

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
llava_13b,0,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14...",1.0,1.0,0.97759,651,1714,0.275264
llava_7b,1,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14...",1.0,1.0,0.97759,438,1927,0.185201
llava_llama3,2,"[0, 1, 3, 4, 6, 7, 8, 9, 10, 13, 15, 17, 18, 2...",1.0,1.0,0.97759,661,1704,0.279493
moondream,3,"[0, 1, 4, 6, 8, 15, 16, 18, 22, 23, 24, 25, 26...",0.991121,0.991121,0.970402,269,2075,0.114761
llama3_2_vision,4,"[2, 4, 5, 8, 9, 10, 11, 12, 19, 20, 21, 26]",1.0,1.0,0.97759,865,1500,0.365751
llava_phi3,5,"[0, 1, 4, 5, 6, 8, 9, 10, 11, 12, 14, 15, 16, ...",1.0,1.0,0.97759,686,1679,0.290063


In [38]:
Y_test = np.array(Y_test)
LFAnalysis(L_test, lfs).lf_summary(Y_test)

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
llava_13b,0,"[0, 1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 1...",1.0,1.0,0.984797,639,1729,0.269848
llava_7b,1,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1.0,1.0,0.984797,429,1939,0.181166
llava_llama3,2,"[0, 1, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 1...",1.0,1.0,0.984797,657,1711,0.277449
moondream,3,"[0, 1, 4, 6, 8, 13, 15, 16, 18, 19, 22, 23, 24...",0.994088,0.994088,0.97973,287,2067,0.12192
llama3_2_vision,4,"[3, 4, 5, 8, 9, 10, 11, 12, 16, 20, 21, 26]",1.0,1.0,0.984797,869,1499,0.366976
llava_phi3,5,"[0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15...",1.0,1.0,0.984797,676,1692,0.285473


# Majority Vote

In [39]:
def majority_vote_exclude_negative(labels):
    result = np.empty(labels.shape[0], dtype=labels.dtype)
    for i, row in enumerate(labels):
        # Get the unique values and their counts
        values, counts = np.unique(row, return_counts=True)
        # Sort both arrays by counts in descending order
        sorted_indices = np.argsort(-counts)
        values_sorted = values[sorted_indices]
        counts_sorted = counts[sorted_indices]
        
        # Exclude -1 from the majority vote
        if values_sorted[0] == -1:
            if len(values_sorted) > 1:
                result[i] = values_sorted[1]  # Use the second most frequent if -1 is the most frequent
            else:
                result[i] = -1  # If -1 is the only class, we have no choice but to use it
        else:
            result[i] = values_sorted[0]  # Most frequent non-negative value

    return result
    
def majority_vote(labels):
    # Assuming the labels are categorical and using mode to find the most frequent label
    from scipy.stats import mode
    # Using mode along axis=1 to find the most common element across columns
    modes = mode(labels, axis=1)
    # modes.mode contains the most common values, reshaping to (500,) for a clean 1D array output
    return modes.mode.reshape(-1)

# Applying the majority vote function
majority_labels_val = majority_vote(L_val)
majority_labels_exclude_negative_val = majority_vote_exclude_negative(L_val)

majority_labels_test = majority_vote(L_test)
majority_labels_exclude_negative_test = majority_vote_exclude_negative(L_test)


In [40]:
metrics = calculate_metrics(Y_val, majority_labels_exclude_negative_val)
for metric, value in metrics.items():
    print(f"{metric}: {value}")

Precision: 0.038737883051907444
Recall: 0.04128589562544011
F1 Score: 0.029212981612122126
Accuracy: 0.3014799154334038


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [41]:
metrics = calculate_metrics(Y_test, majority_labels_exclude_negative_test)
for metric, value in metrics.items():
    print(f"{metric}: {value}")

Precision: 0.062406571539506885
Recall: 0.04929418057444379
F1 Score: 0.039311900141874065
Accuracy: 0.30532094594594594


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


# Snorkel Label Model

In [42]:
from snorkel.labeling.model import LabelModel

label_model = LabelModel(cardinality=30, verbose=False)
label_model.fit(L_train, n_epochs=5000, log_freq=500, seed=12345)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:05<00:00, 836.20epoch/s]


In [43]:
from snorkel.analysis import metric_score
from snorkel.utils import probs_to_preds

probs_val = label_model.predict_proba(L_val)
preds_val = probs_to_preds(probs_val)

probs_test = label_model.predict_proba(L_test)
preds_test = probs_to_preds(probs_test)

print("Validation:")
metrics = calculate_metrics(Y_val, preds_val)
for metric, value in metrics.items():
    print(f"{metric}: {value}")

print("Test:")
metrics = calculate_metrics(Y_test, preds_test)
for metric, value in metrics.items():
    print(f"{metric}: {value}")

Validation:
Precision: 0.10078015425144125
Recall: 0.054599945384629894
F1 Score: 0.05644668583644148
Accuracy: 0.21733615221987315
Test:
Precision: 0.059306018253837016
Recall: 0.05732430296328835
F1 Score: 0.05509352861126802
Accuracy: 0.20777027027027026


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


# Hyper Label Model

In [44]:
from hyperlm import HyperLabelModel
hlm = HyperLabelModel()

  checkpoint = torch.load(checkpoint_path, map_location=torch.device(self.device))


In [45]:
hyper_pred_val = hlm.infer(L_val[:,:])
hyper_pred_test = hlm.infer(L_test[:,:])
hyper_pred_train = hlm.infer(L_train)

print("Validation:")
metrics = calculate_metrics(Y_val, hyper_pred_val)
for metric, value in metrics.items():
    print(f"{metric}: {value}")

print("Test:")
metrics = calculate_metrics(Y_test, hyper_pred_test)
for metric, value in metrics.items():
    print(f"{metric}: {value}")

Validation:
Precision: 0.13212226112175626
Recall: 0.07916468727442784
F1 Score: 0.08939258076412132
Accuracy: 0.27568710359408033
Test:
Precision: 0.12510123602945414
Recall: 0.06955357425663299
F1 Score: 0.07978793487838003
Accuracy: 0.2690033783783784


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


# Fine-tuned Hyper Label Model on The Validation set

In [67]:
from hyperlm import HyperLabelModel
hlm = HyperLabelModel()

val_test_labels = np.concatenate((Y_val,Y_test))
L_val_test_agg = np.concatenate((L_val,L_test),axis=0)

val_indices = list(range(Y_val.shape[0]))

hyper_pred_val_test = hlm.infer(L_val_test_agg, y_indices=val_indices, y_vals=Y_val)

ft_hyper_pred_val = hyper_pred_val_test[:Y_val.shape[0]]
ft_hyper_pred_test = hyper_pred_val_test[Y_val.shape[0]:]

print("Validation:")
metrics = calculate_metrics(Y_val, ft_hyper_pred_val)
for metric, value in metrics.items():
    print(f"{metric}: {value}")

print("Test:")
metrics = calculate_metrics(Y_test, ft_hyper_pred_test)
for metric, value in metrics.items():
    print(f"{metric}: {value}")



  checkpoint = torch.load(checkpoint_path, map_location=torch.device(self.device))
  self.checkpoint = torch.load(


Validation:
Precision: 0.05136417662039886
Recall: 0.09770175527827062
F1 Score: 0.06387304061867036
Accuracy: 0.35095137420718814
Test:
Precision: 0.05076994330402457
Recall: 0.09945761435787251
F1 Score: 0.06405820829838553
Accuracy: 0.3450168918918919


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [51]:
val_test_labels.shape

(4733,)

In [52]:
Y_val.shape

(2365,)

In [64]:
hyper_pred_val_test[:Y_val.shape[0]].shape

(2365,)

  self.checkpoint = torch.load(


(4733, 6)