In [1]:
import numpy as np
import pandas as pd
import os

import torch
import nltk
import pickle

from model.concept_property_model import ConceptPropertyModel
from utils.functions import create_model
from utils.functions import load_pretrained_model
from utils.functions import read_config
from utils.functions import mcrae_dataset_and_dataloader
from utils.functions import compute_scores

from sklearn.neighbors import NearestNeighbors
from collections import Counter

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# assert os.environ["CONDA_DEFAULT_ENV"] == "gvenv", "Activate 'gvenv' conda environment"

print (f"Device Name : {device}")
print (f"Conda Environment Name : {os.environ['CONDA_DEFAULT_ENV']}")

Device Name : cuda
Conda Environment Name : gvenv


In [2]:
mcrae_train_df = pd.read_csv("data/evaluation_data/extended_mcrae/train_mcrae.tsv", sep="\t", names=["concept", "property", "label"])
mcrae_test_df = pd.read_csv("data/evaluation_data/extended_mcrae/test_mcrae.tsv", sep="\t", names=["concept", "property", "label"])

print ("McRae Train Df size : ", mcrae_train_df.shape)
print (mcrae_train_df.head())

print ()

print ("McRae Test Df size : ", mcrae_test_df.shape)
print (mcrae_test_df)


McRae Train Df size :  (19258, 3)
      concept          property  label
0      wrench           squishy      0
1  teddy_bear           fragile      0
2  microscope  used by children      1
3      shovel  used for killing      0
4        wall    found on walls      1

McRae Test Df size :  (4813, 3)
         concept          property  label
0         onions    used for music      0
1          pizza             shiny      0
2     motorcycle   eaten in summer      0
3       sailboat              cold      0
4           lime      worn on feet      0
...          ...               ...    ...
4808   snowboard              tall      0
4809        veil           squishy      0
4810    sailboat      light weight      0
4811  skateboard  used for cooking      0
4812    mandarin             a toy      0

[4813 rows x 3 columns]


In [3]:

train_con_file = "/home/amitgajbhiye/cardiff_work/dot_product_model_nn_analysis/mcrae_train_test_embeddings/mcrae_bert_base_train_cons_embeds.pkl"
train_prop_file = "/home/amitgajbhiye/cardiff_work/dot_product_model_nn_analysis/mcrae_train_test_embeddings/mcrae_bert_base_train_prop_embeds.pkl"

test_con_file = "/home/amitgajbhiye/cardiff_work/dot_product_model_nn_analysis/mcrae_train_test_embeddings/mcrae_bert_base_test_cons_embeds.pkl"
test_prop_file = "/home/amitgajbhiye/cardiff_work/dot_product_model_nn_analysis/mcrae_train_test_embeddings/mcrae_bert_base_test_prop_embeds.pkl"


with open(train_con_file, "rb") as train_con_emb, \
    open(train_prop_file, "rb") as train_prop_emb, \
    open(test_con_file, "rb") as test_con_emb, \
    open(test_prop_file, "rb") as test_prop_emb:
    
    train_con_emb = pickle.load(train_con_emb)
    train_prop_emb = pickle.load(train_prop_emb)
    
    test_con_emb = pickle.load(test_con_emb)
    test_prop_emb = pickle.load(test_prop_emb)


# hawk_train_con_file = "data/evaluation_data/nn_analysis/mcrae_train_concept_embedding.pkl"
# hawk_train_prop_file = "data/evaluation_data/nn_analysis/mcrae_train_properties_embedding.pkl"

# hawk_test_con_file = "data/evaluation_data/nn_analysis/mcrae_test_concept_embedding.pkl"
# hawk_test_prop_file = "data/evaluation_data/nn_analysis/mcrae_test_properties_embedding.pkl"

# with open(hawk_train_con_file, "rb") as train_con_emb, \
#     open(hawk_train_prop_file, "rb") as train_prop_emb, \
#     open(hawk_test_con_file, "rb") as test_con_emb, \
#     open(hawk_test_prop_file, "rb") as test_prop_emb:
    
#     train_con_emb = pickle.load(train_con_emb)
#     train_prop_emb = pickle.load(train_prop_emb)
    
#     test_con_emb = pickle.load(test_con_emb)
#     test_prop_emb = pickle.load(test_prop_emb)


In [4]:
print (train_con_emb.keys())
print (train_prop_emb.keys())

# print ("Train Concepts :", len(train_con_emb.get("name_list_con")))
# print ("Train Concepts :", train_con_emb.get("name_list_con"))

print ("Train Properties :", len(train_prop_emb.get("name_list_prop")))
print ("Train Properties :", train_prop_emb.get("name_list_prop"))

print ()

print (test_con_emb.keys())
print (test_prop_emb.keys())

# print ("Test Concepts :", len(test_con_emb.get("name_list_con")))
# print ("Test Concepts :", test_con_emb.get("name_list_con"))

print ("Test Properties :", len(test_prop_emb.get("name_list_prop")))
print ("Test Properties :", test_prop_emb.get("name_list_prop"))


dict_keys(['name_list_con', 'untransformed_con_emb', 'transformed_con_emb'])
dict_keys(['name_list_prop', 'untransformed_prop_emb', 'transformed_prop_emb'])
Train Properties : 50
Train Properties : ['squishy', 'fragile', 'used by children', 'used for killing', 'found on walls', 'expensive', 'requires gasoline', 'shiny', 'has peel', 'dangerous', 'eaten in summer', 'man made', 'used for eating', 'slimy', 'tall', 'hard', 'fast', 'a toy', 'worn for warmth', 'swims', 'words on it', 'produces noise', 'an animal', 'fun', 'hot', 'heavy', 'large', 'cold', 'pairs', 'electrical', 'smelly', 'used for music', 'has shelves', 'used for holding things', 'lives in water', 'wet', 'loud', 'flies', 'used for cooking', 'sharp', 'unhealthy', 'hand held', 'used for cleaning', 'used for transportation', 'a tool', 'decorative', 'smooth', 'worn on feet', 'light weight', 'edible']

dict_keys(['name_list_con', 'untransformed_con_emb', 'transformed_con_emb'])
dict_keys(['name_list_prop', 'untransformed_prop_emb', 

In [5]:
inter = set(train_prop_emb.get("name_list_prop")).intersection(set(test_prop_emb.get("name_list_prop")))
print (len(inter))
print (inter)

50
{'dangerous', 'flies', 'squishy', 'unhealthy', 'used for cleaning', 'fast', 'used for music', 'a toy', 'cold', 'has peel', 'used for killing', 'decorative', 'light weight', 'an animal', 'eaten in summer', 'shiny', 'used for holding things', 'hot', 'used for eating', 'produces noise', 'requires gasoline', 'electrical', 'words on it', 'smooth', 'loud', 'used for transportation', 'hand held', 'fragile', 'found on walls', 'pairs', 'lives in water', 'wet', 'smelly', 'tall', 'sharp', 'man made', 'expensive', 'swims', 'has shelves', 'edible', 'worn for warmth', 'large', 'heavy', 'worn on feet', 'slimy', 'used by children', 'a tool', 'used for cooking', 'fun', 'hard'}


In [6]:
print (len(train_con_emb.get("transformed_con_emb")))
print (len(test_con_emb.get("transformed_con_emb")))

411
103


In [7]:
# Learning Nearest Neighbours
num_nearest_neighbours = 5

train_con_nbrs = NearestNeighbors(n_neighbors=num_nearest_neighbours, algorithm='brute').fit(np.array(train_con_emb.get("transformed_con_emb")))

con_test_distances, con_test_indices = train_con_nbrs.kneighbors(np.array(test_con_emb.get("transformed_con_emb")))

In [8]:
train_con_emb.keys()

dict_keys(['name_list_con', 'untransformed_con_emb', 'transformed_con_emb'])

In [9]:
# print (con_test_indices.shape)
# print (con_test_indices)

In [10]:
train_cons_similar_to_test = {}

for idx, con in zip(con_test_indices, test_con_emb.get("name_list_con")):    
    # print (f"Test Concept : {con} : {[train_con_emb.get('name_list_con') [con_id] for con_id in idx]}\n", flush=True)
    
    train_cons_similar_to_test[con] = [train_con_emb.get('name_list_con') [con_id] for con_id in idx]
    

In [11]:
print (len(train_cons_similar_to_test.keys()))
print (train_cons_similar_to_test.keys())

103
dict_keys(['onions', 'pizza', 'motorcycle', 'sailboat', 'lime', 'cannon', 'eagle', 'mandarin', 'spoon', 'oven', 'buckle', 'comb', 'worm', 'jet', 'cart', 'beaver', 'magazine', 'nylons', 'helicopter', 'crowbar', 'cushion', 'tack', 'elk', 'cow', 'book', 'inn', 'toy', 'deer', 'pearl', 'crab', 'baseball_glove', 'pepper', 'flamingo', 'cheese', 'hawk', 'octopus', 'brush', 'platypus', 'whale', 'saxophone', 'elephant', 'bucket', 'axe', 'housefly', 'banjo', 'lobster', 'bicycle', 'iguana', 'cauliflower', 'stone', 'chair', 'hut', 'train', 'sofa', 'mittens', 'cucumber', 'ox', 'tomato', 'armour', 'skateboard', 'sardine', 'dress', 'raisin', 'laptop', 'seaweed', 'corn', 'rocker', 'vest', 'pie', 'cabbage', 'backpack', 'doll', 'rope', 'peach', 'zebra', 'pistol', 'jacket', 'sink', 'screws', 'gown', 'gate', 'wand', 'sandwich', 'ostrich', 'submarine', 'peg', 'house', 'pyramid', 'plate', 'rake', 'building', 'bathtub', 'shrimp', 'telephone', 'cupboard', 'lantern', 'chicken', 'rock', 'snowboard', 'pen', '

In [12]:
preds = []

for index, row in mcrae_test_df.iterrows():
    print ("Index :", index)
    test_con, test_prop, test_label = row["concept"], row["property"], row["label"]
    train_similar_concepts = train_cons_similar_to_test.get(test_con)
    assert train_similar_concepts is not None, "No Train Similar Concepts for the Test Concept"
    
    # print ("Test Data :", index, test_con, test_prop, test_label)
    # print ("Concepts Similar to test concept in Train File")
    # print (train_similar_concepts)
    
    # print ("Positive properties of similar concepts in train file :") 
    
    positive_prop = []
    
    for train_con in train_similar_concepts:
        positive_property = mcrae_train_df.loc[(mcrae_train_df["concept"] == train_con) & (mcrae_train_df["label"] == 1)]["property"].tolist()
        positive_prop.extend(positive_property)
        # print (train_con, ":",  positive_property)
    
    # print ("positive_prop")
    # print (positive_prop)
    
    prop_dict  = dict(Counter(positive_prop))
    max_prop_count = max(prop_dict.values())
    
    # print (prop_dict)
    # print (max_prop_count)
    
    prop_with_max_count = [] 
    for prop, count in prop_dict.items():
        # print (prop, count)
        
        if count == max_prop_count:
            prop_with_max_count.append(prop)
    
    # print (prop_with_max_count)
    
    # print ("Test Prop :", test_prop)
    # print ("prop_with_max_count :", prop_with_max_count)
    # print ("test_prop in prop_with_max_count :", test_prop in prop_with_max_count)
    
    if test_prop in prop_with_max_count:
        test_pred = 1
    else:
        test_pred = 0
    
    preds.append(test_pred)
    # print ("test_pred :", test_pred)
    
    # print()
    

Index : 0
Index : 1
Index : 2
Index : 3
Index : 4
Index : 5
Index : 6
Index : 7
Index : 8
Index : 9
Index : 10
Index : 11
Index : 12
Index : 13
Index : 14
Index : 15
Index : 16
Index : 17
Index : 18
Index : 19
Index : 20
Index : 21
Index : 22
Index : 23
Index : 24
Index : 25
Index : 26
Index : 27
Index : 28
Index : 29
Index : 30
Index : 31
Index : 32
Index : 33
Index : 34
Index : 35
Index : 36
Index : 37
Index : 38
Index : 39
Index : 40
Index : 41
Index : 42
Index : 43
Index : 44
Index : 45
Index : 46
Index : 47
Index : 48
Index : 49
Index : 50
Index : 51
Index : 52
Index : 53
Index : 54
Index : 55
Index : 56
Index : 57
Index : 58
Index : 59
Index : 60
Index : 61
Index : 62
Index : 63
Index : 64
Index : 65
Index : 66
Index : 67
Index : 68
Index : 69
Index : 70
Index : 71
Index : 72
Index : 73
Index : 74
Index : 75
Index : 76
Index : 77
Index : 78
Index : 79
Index : 80
Index : 81
Index : 82
Index : 83
Index : 84
Index : 85
Index : 86
Index : 87
Index : 88
Index : 89
Index : 90
Index : 9

In [13]:
gold_labels = mcrae_test_df["label"].values

In [14]:
gold_labels

array([0, 0, 0, ..., 0, 0, 0])

In [15]:
assert len(gold_labels) == len(np.array(preds))

In [16]:
print (Counter(preds))
print (Counter(gold_labels))

Counter({0: 4565, 1: 248})
Counter({0: 3954, 1: 859})


In [17]:
results = compute_scores(gold_labels, preds)

In [18]:
print ()
print ("Concept Split")
print (f"NN Classifier with pretrained BERT Base Embedding pretrained on MSCG+PREFIX+GKB Data")
print (f"Nearest Neighbours Considered : {num_nearest_neighbours}")
print ()

for key, value in results.items():
    print (key, value)


Concept Split
NN Classifier with pretrained BERT Base Embedding pretrained on MSCG+PREFIX+GKB Data
Nearest Neighbours Considered : 5

binary_f1 0.439
micro_f1 0.871
macro_f1 0.6831
weighted_f1 0.84
accuracy 0.871
classification report               precision    recall  f1-score   support

           0       0.87      1.00      0.93      3954
           1       0.98      0.28      0.44       859

    accuracy                           0.87      4813
   macro avg       0.92      0.64      0.68      4813
weighted avg       0.89      0.87      0.84      4813

confusion matrix [[3949    5]
 [ 616  243]]
