# Evaluate Similarity Grouping


In [1]:
import pandas as pd
import numpy as np

from tqdm import tqdm

from models import ModelMgr
from models.embedding.SentenceTransformer import SentenceTransformerEmbeddingModel
from models.semantic_validation import LLaMAValidationModel

from db.operators import Dummy, Select
from db.criteria import SoftEqual
from db.structure import Column, Constant
import kagglehub

from evaluation.util import calculate_metrics, calc_bleu

import time

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Nico\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Nico\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [2]:
stem = SentenceTransformerEmbeddingModel(ModelMgr())
lsv = LLaMAValidationModel(ModelMgr())
# lsv = DeepSeekValidationModel(m)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
ZERO_SHOT_SYSTEM_PROMPT = "You are a validator. Respond with \"no\" and \"yes\" only!"
ZERO_SHOT_PROMPTING_TEMPLATE = 'Does "{}" describes "{}"'

In [4]:
path = kagglehub.dataset_download("uciml/zoo-animal-classification")



In [5]:
classes = pd.read_csv(f"{path}/class.csv", index_col=0)[["Class_Type"]]
classes

Unnamed: 0_level_0,Class_Type
Class_Number,Unnamed: 1_level_1
1,Mammal
2,Bird
3,Reptile
4,Fish
5,Amphibian
6,Bug
7,Invertebrate


In [6]:
transforms = {"is": "is not", "has": "has no", "does": "does not", "lays": "does not lays", "gives": "gives no"}
transform_cols = {"hair": "has", "feathers": "has", "eggs": "lays", "milk": "gives", "airborne": "is", "aquatic": "is", "predator": "is", "toothed": "is", "backbone": "has", "breathes": "does", "venomous": "is", "fins": "has", "legs": "has", "tail": "has", "domestic": "is", "catsize": "is" }

columns = ["name"] + [col for col in transform_cols]

print(columns)

zoo = pd.read_csv(f"{path}/zoo.csv", index_col=0)
for col in transform_cols:
    t = transform_cols[col]
    zoo[col] = zoo[col].apply(lambda x: f'{t if x else transforms[t]} {col}')
zoo

['name', 'hair', 'feathers', 'eggs', 'milk', 'airborne', 'aquatic', 'predator', 'toothed', 'backbone', 'breathes', 'venomous', 'fins', 'legs', 'tail', 'domestic', 'catsize']


Unnamed: 0_level_0,hair,feathers,eggs,milk,airborne,aquatic,predator,toothed,backbone,breathes,venomous,fins,legs,tail,domestic,catsize,class_type
animal_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
aardvark,has hair,has no feathers,does not lays eggs,gives milk,is not airborne,is not aquatic,is predator,is toothed,has backbone,does breathes,is not venomous,has no fins,has legs,has no tail,is not domestic,is catsize,1
antelope,has hair,has no feathers,does not lays eggs,gives milk,is not airborne,is not aquatic,is not predator,is toothed,has backbone,does breathes,is not venomous,has no fins,has legs,has tail,is not domestic,is catsize,1
bass,has no hair,has no feathers,lays eggs,gives no milk,is not airborne,is aquatic,is predator,is toothed,has backbone,does not breathes,is not venomous,has fins,has no legs,has tail,is not domestic,is not catsize,4
bear,has hair,has no feathers,does not lays eggs,gives milk,is not airborne,is not aquatic,is predator,is toothed,has backbone,does breathes,is not venomous,has no fins,has legs,has no tail,is not domestic,is catsize,1
boar,has hair,has no feathers,does not lays eggs,gives milk,is not airborne,is not aquatic,is predator,is toothed,has backbone,does breathes,is not venomous,has no fins,has legs,has tail,is not domestic,is catsize,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
wallaby,has hair,has no feathers,does not lays eggs,gives milk,is not airborne,is not aquatic,is not predator,is toothed,has backbone,does breathes,is not venomous,has no fins,has legs,has tail,is not domestic,is catsize,1
wasp,has hair,has no feathers,lays eggs,gives no milk,is airborne,is not aquatic,is not predator,is not toothed,has no backbone,does breathes,is venomous,has no fins,has legs,has no tail,is not domestic,is not catsize,6
wolf,has hair,has no feathers,does not lays eggs,gives milk,is not airborne,is not aquatic,is predator,is toothed,has backbone,does breathes,is not venomous,has no fins,has legs,has tail,is not domestic,is catsize,1
worm,has no hair,has no feathers,lays eggs,gives no milk,is not airborne,is not aquatic,is not predator,is not toothed,has no backbone,does breathes,is not venomous,has no fins,has no legs,has no tail,is not domestic,is not catsize,7


In [7]:
gt = {tuple([key] + [x[col] for col in columns[1:]] + [classes.loc[x["class_type"]].values[0]]) for key, x in zoo.iterrows()}
print(str(gt)[0:500], "...", len(gt))

{('pussycat', 'has hair', 'has no feathers', 'does not lays eggs', 'gives milk', 'is not airborne', 'is not aquatic', 'is predator', 'is toothed', 'has backbone', 'does breathes', 'is not venomous', 'has no fins', 'has legs', 'has tail', 'is domestic', 'is catsize', 'Mammal'), ('mongoose', 'has hair', 'has no feathers', 'does not lays eggs', 'gives milk', 'is not airborne', 'is not aquatic', 'is predator', 'is toothed', 'has backbone', 'does breathes', 'is not venomous', 'has no fins', 'has legs ... 101


In [8]:
data = [x[:-1] for x in gt]
print(data[0])

('pussycat', 'has hair', 'has no feathers', 'does not lays eggs', 'gives milk', 'is not airborne', 'is not aquatic', 'is predator', 'is toothed', 'has backbone', 'does breathes', 'is not venomous', 'has no fins', 'has legs', 'has tail', 'is domestic', 'is catsize')


In [9]:
overall_result = {}

def evaluate(method, threshold, system_prompt=ZERO_SHOT_SYSTEM_PROMPT, prompt_template=ZERO_SHOT_PROMPTING_TEMPLATE):
    pred = []
    runtimes = []
    for animal_type in tqdm(classes.values):
        animal_type = animal_type[0]

        d = Dummy("animals", columns, data)
        s = Select(d, SoftEqual(columns, Constant(animal_type), method=method, em=stem, sv=lsv, threshold=threshold, zfs_system_prompt=system_prompt, zfs_prompt_template = prompt_template))

        tic = time.time()
        result = s.open().fetch_all()
        toc = time.time()

        pred.extend([tuple([x[col] for col in columns] + [animal_type]) for x in result])
        runtimes.append(toc - tic)

    scores = calculate_metrics(gt, set(pred), np.mean(runtimes))

    print(method, threshold, scores["F1 Score"])

    return scores, pred

In [10]:
for thresh in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
    res, _ = evaluate("threshold", thresh)
    overall_result[("threshold", thresh)] = res
    if res["Recall"] == 0.0:
        break

100%|██████████| 7/7 [00:05<00:00,  1.34it/s]


threshold 0.1 0.2506203473945409


100%|██████████| 7/7 [00:04<00:00,  1.46it/s]


threshold 0.2 0.2717086834733894


100%|██████████| 7/7 [00:04<00:00,  1.45it/s]


threshold 0.3 0.28483920367534454


100%|██████████| 7/7 [00:04<00:00,  1.44it/s]


threshold 0.4 0.386046511627907


100%|██████████| 7/7 [00:05<00:00,  1.28it/s]


threshold 0.5 0.32653061224489793


100%|██████████| 7/7 [00:11<00:00,  1.63s/it]

threshold 0.6 0





In [11]:
res, pred = evaluate("zero-few-shot", None)
overall_result[("zero-few-shot", None)] = res

100%|██████████| 7/7 [00:23<00:00,  3.36s/it]

zero-few-shot None 0.8699999999999999





In [12]:
for thresh in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
    res, _ = evaluate("both", thresh)
    overall_result[("both", thresh)] = res
    if res["Recall"] == 0.0:
        break

100%|██████████| 7/7 [00:26<00:00,  3.72s/it]


both 0.1 0.8699999999999999


100%|██████████| 7/7 [00:23<00:00,  3.31s/it]


both 0.2 0.8527918781725888


100%|██████████| 7/7 [00:21<00:00,  3.04s/it]


both 0.3 0.8350515463917527


100%|██████████| 7/7 [00:14<00:00,  2.10s/it]


both 0.4 0.7978723404255319


100%|██████████| 7/7 [00:06<00:00,  1.13it/s]


both 0.5 0.34374999999999994


100%|██████████| 7/7 [00:04<00:00,  1.47it/s]

both 0.6 0





In [13]:
for key in tqdm(overall_result):
    scores_bleu = calc_bleu(gt, overall_result[key]["pred"])
    for score_bleu in scores_bleu:
        overall_result[key][score_bleu] = scores_bleu[score_bleu]

100%|██████████| 13/13 [01:49<00:00,  8.42s/it]


In [14]:
df_results = pd.DataFrame([{"method": k[0], "threshold": k[1]} | v for k,v in overall_result.items()]).drop(columns=["pred"])
df_results

Unnamed: 0,method,threshold,Precision,Recall,F1 Score,tp,fn,fp,runtime,bleu1,bleu2,bleu3,bleu4
0,threshold,0.1,0.143262,1.0,0.25062,101,0,604,0.746588,1.0,1.0,1.0,1.0
1,threshold,0.2,0.158238,0.960396,0.271709,97,4,516,0.685437,0.999363,0.999358,0.999353,0.999347
2,threshold,0.3,0.168478,0.920792,0.284839,93,8,459,0.690415,0.998727,0.998717,0.998706,0.998695
3,threshold,0.4,0.25228,0.821782,0.386047,83,18,246,0.69533,0.997065,0.997041,0.997015,0.99699
4,threshold,0.5,0.521739,0.237624,0.326531,24,77,22,0.780139,0.981684,0.975662,0.970362,0.965376
5,threshold,0.6,0.0,0.0,0.0,0,101,0,1.627356,-1.0,-1.0,-1.0,-1.0
6,zero-few-shot,,0.878788,0.861386,0.87,87,14,12,3.35938,0.996871,0.995339,0.994164,0.993087
7,both,0.1,0.878788,0.861386,0.87,87,14,12,3.723182,0.996871,0.995339,0.994164,0.993087
8,both,0.2,0.875,0.831683,0.852792,84,17,12,3.312671,0.996399,0.994863,0.993685,0.992603
9,both,0.3,0.870968,0.80198,0.835052,81,20,12,3.041573,0.995922,0.994383,0.9932,0.992115


In [15]:
df_results.to_csv("results/Animals_mpnetBaseV2_LLama3B.csv")