In [2]:
import pandas as pd
import numpy as np
from annoy import AnnoyIndex 

from transformers import AutoTokenizer, AutoModel
import torch

# Load once (outside the function if embedding multiple times)
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased")
model.eval()  # optional, ensures no dropout during inference
# device = "cuda" if torch.cuda.is_available() else "mps"
# model.to(device)

  from .autonotebook import tqdm as notebook_tqdm


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(31090, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [3]:
def get_scibert_embedding(text: str, pooling: str = "mean") -> torch.Tensor:
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
        last_hidden_state = outputs.last_hidden_state  # [1, seq_len, 768]
    
    if pooling == "mean":
        embedding = last_hidden_state.mean(dim=1)  # Average pooling
    elif pooling == "cls":
        embedding = last_hidden_state[:, 0, :]  # [CLS] token
    else:
        raise ValueError("Pooling must be 'mean' or 'cls'")
    return embedding[0]  # shape: [1, 768]


In [65]:
import json

with open("/Users/kaushalpatil/Development/USC MS CSAI Program/Applied NLP/Automated KG Gen/My Paper/relations_extraction (1).json", 'r') as file:
    data = json.load(file)

entities = []
relations = []
for item in data['relations']:
    entities.append(item['head'])
    entities.append(item['tail'])
    relations.append(item['relation'])
    # tail = item['tail']
    # print(f"Head: {head}, Relation: {relation}, Tail: {tail}")


In [66]:
embeddings_entities_combined_dict = {}
embeddings_relations_combined_dict = {}

for i in entities:
    embeddings_entities_combined_dict[i] = get_scibert_embedding(i, pooling="mean")
    
for i in relations:
    embeddings_relations_combined_dict[i] = get_scibert_embedding(i, pooling="mean")

In [67]:
len(embeddings_entities_combined_dict), len(embeddings_relations_combined_dict)

(162, 49)

In [68]:
alternate_entities_keys = {}
start = 1
for i in embeddings_entities_combined_dict:
    alternate_entities_keys[i] = start
    start += 1

reverse_entities_alternate_keys = {}

for i, j in alternate_entities_keys.items():
    reverse_entities_alternate_keys[j] = i
    
alternate_relations_keys = {}
start = 1
for i in embeddings_relations_combined_dict:
    alternate_relations_keys[i] = start
    start += 1

reverse_relations_alternate_keys = {}

for i, j in alternate_relations_keys.items():
    reverse_relations_alternate_keys[j] = i

In [69]:
len(alternate_entities_keys), len(alternate_relations_keys), len(reverse_entities_alternate_keys), len(reverse_relations_alternate_keys)

(162, 49, 162, 49)

In [70]:
f = 768 # Number of Dimensions
t = AnnoyIndex(f)
for i, j in embeddings_entities_combined_dict.items():
    t.add_item(alternate_entities_keys[i], j)
    
t.build(f)
t.save('entity-search-tree.ann')

f = 768 # Number of Dimensions
t = AnnoyIndex(f)
for i, j in embeddings_relations_combined_dict.items():
    t.add_item(alternate_relations_keys[i], j)
    
t.build(f)
t.save('relation-search-tree.ann')

  t = AnnoyIndex(f)
  t = AnnoyIndex(f)


True

In [71]:
entity_search_space = AnnoyIndex(768)
entity_search_space.load('./entity-search-tree.ann')

relation_search_space = AnnoyIndex(768)
relation_search_space.load('./relation-search-tree.ann')

  entity_search_space = AnnoyIndex(768)
  relation_search_space = AnnoyIndex(768)


True

In [72]:
def entity_text_space_search(query: str, num : int = 10):
    query_vector = get_scibert_embedding(query) 
    ans = entity_search_space.get_nns_by_vector(query_vector, num)
    return ans

def relation_text_space_search(query: str, num : int = 10):
    query_vector = get_scibert_embedding(query) 
    ans = relation_search_space.get_nns_by_vector(query_vector, num)
    return ans

In [73]:
final_ls_entities, final_ls_relations = set(), set()

In [75]:
import torch.nn.functional as F

def cosine_similarity(a, b):
    return F.cosine_similarity(embeddings_entities_combined_dict[a].unsqueeze(0), embeddings_entities_combined_dict[b].unsqueeze(0)).item()

In [76]:
blacklist_entities = set()

In [77]:
for entity in entities:
    # final_ls_entities.add(entity)
    knn_entities = [reverse_entities_alternate_keys[i] for i in entity_text_space_search(entity, 10)]
    for i in knn_entities:
        if cosine_similarity(entity, i) < 0.9:
            print(cosine_similarity(entity, i), f"{i}")
            if i not in blacklist_entities:
                final_ls_entities.add(i)
        else:
            blacklist_entities.add(i)

0.8790421485900879 Matthew W. McKown
0.8605492115020752 Bernie R. Tershy
0.6735193729400635 Conservation Metrics, Inc.
0.6374217867851257 conservation
0.6353024244308472 panels
0.627414882183075 Ashy Storm-petrel
0.6247120499610901 384-dimensional
0.6214278936386108 brown tree snake
0.6190373301506042 audio
0.757946252822876 conservation
0.7451828718185425 conservation measures program
0.7426853179931641 International Union for Conservation
0.7247264385223389 Queensland
0.7131509184837341 ecosystems
0.7069091796875 global conservation
0.7018937468528748 Matthew W. McKown
0.698165237903595 Convention for Global Biodiversity
0.6968064904212952 Great Barrier Reef Marine Park
0.8790421485900879 David J. Klein
0.8673216104507446 Bernie R. Tershy
0.7018937468528748 Conservation Metrics, Inc.
0.661203145980835 conservation
0.6463519930839539 U.S. military
0.6393545269966125 panels
0.6384798288345337 Queensland
0.627934992313385 Ashy Storm-petrel
0.6240606307983398 computational neuroscience
0

In [78]:
import torch.nn.functional as F

def relations_cosine_similarity(a, b):
    return F.cosine_similarity(embeddings_relations_combined_dict[a].unsqueeze(0), embeddings_relations_combined_dict[b].unsqueeze(0)).item()

In [79]:
blacklist_relations = set()

In [80]:
for relation in relations:
    # final_ls_entities.add(entity)
    knn_relations = [reverse_relations_alternate_keys[i] for i in relation_text_space_search(relation, 10)]
    # print(knn_relations)
    for i in knn_relations:
        if relations_cosine_similarity(relation, i) < 0.9:
            print(relations_cosine_similarity(relation, i), f"{i}")
            if i not in blacklist_relations:
                final_ls_relations.add(i)
        else:
            blacklist_relations.add(i)

0.8677421808242798 related_to
0.8637931942939758 member_of
0.8577030897140503 according_to
0.8526221513748169 trained_to
0.8496032953262329 candidate_for
0.8487582206726074 trained_for
0.842886209487915 introduced_by
0.8403897285461426 lead_to
0.8369935154914856 stored_at
0.8677421808242798 related_to
0.8637931942939758 member_of
0.8577030897140503 according_to
0.8526221513748169 trained_to
0.8496032953262329 candidate_for
0.8487582206726074 trained_for
0.842886209487915 introduced_by
0.8403897285461426 lead_to
0.8369935154914856 stored_at
0.8677421808242798 related_to
0.8637931942939758 member_of
0.8577030897140503 according_to
0.8526221513748169 trained_to
0.8496032953262329 candidate_for
0.8487582206726074 trained_for
0.842886209487915 introduced_by
0.8403897285461426 lead_to
0.8369935154914856 stored_at
0.8924670219421387 used_by
0.8822272419929504 measured_by
0.8671614527702332 used_for
0.8667038083076477 candidate_for
0.862706184387207 trained_for
0.858871579170227 introduced_by


In [82]:
len(final_ls_entities)

131

In [83]:
final_entities, final_relations, final_triplets = [], [], []

In [84]:
for i in final_ls_entities:
    final_entities.append({"Name": i})
    
for i in final_ls_relations:
    final_relations.append({"Name": i})
    
for i in data['relations']:
    final_triplets.append([i['head'], i['relation'],  i['tail']])

In [88]:
final_triplets

[['David J. Klein', 'affiliated_with', 'Conservation Metrics, Inc.'],
 ['Matthew W. McKown', 'affiliated_with', 'Conservation Metrics, Inc.'],
 ['Bernie R. Tershy', 'affiliated_with', 'Conservation Metrics, Inc.'],
 ['Conservation Metrics, Inc.', 'located_in', 'Santa Cruz, CA 95060'],
 ['deep learning', 'used_by', 'system'],
 ['system', 'monitors', 'endangered species'],
 ['system', 'monitors', 'ecosystems'],
 ['microphones', 'component_of', 'system'],
 ['cameras', 'component_of', 'system'],
 ['nature', 'component_of', 'Ecosystem services'],
 ['Convention for Global Biodiversity', 'affiliated_with', 'UN'],
 ['conservation', 'produces', 'conservation outcomes'],
 ['wildlife monitoring techniques',
  'component_of',
  'conservation measures program'],
 ['storms', 'component_of', 'natural systems'],
 ['droughts', 'component_of', 'natural systems'],
 ['diseases', 'component_of', 'natural systems'],
 ['biological surveys', 'part_of', 'monitoring programs'],
 ['conservation actions', 'part_o

In [99]:
len(final_triplets)

132

In [94]:
rag_triplets = [
        ["Ecosystem services", "are", "the contribution of nature to human well-being"]
        ,["sensor networks", "used_for", "conservation"]
        ,["analysts", "analyzed_by", "sensor data"]
        ,["sensor data", "produced_by", "sensors"]
        ,["sensors", "include", "microphones"]
        ,["sensors", "include", "cameras"]
        ,["cameras", "type_of", "visual"]
        ,["cameras", "type_of", "thermal"]
        ,["cameras", "type_of", "hyperspectral"]
        ,["data", "flows_to", "data center"]
        ,["data", "analyzed_by", "algorithms"]
        ,["Artificial Intelligence", "used_for", "biodiversity monitoring"]
        , ["machine learning", "used_for", "biodiversity monitoring"]
        , ["deep learning", "used_for", "biodiversity monitoring"]
        , ["Deep Learning", "subfield_of", "Artificial Intelligence"]
        , ["Deep Learning", "subfield_of", "machine learning"]
        , ["sensor networks", "to_collect", "sensor data"]
        , ["mobile phones", "gather", "sensor data"]
        , ["data", "stored_at", "Amazon Web Services"]
        , ["data", "stored_at", "third-party data centers"]
        , ["satellites", "to_collect", "images"]
        , ["satellites", "to_collect", "data"]
        , ["un-manned aerial vehicles", "to_collect", "images"]
        , ["observers", "to_collect", "data"]
        , ["AI algorithm", "to_collect", "sensor data"]
        , ["audio", "analyzed_by", "spectrograms"]
        , ["brown tree snake", "native_to", "Papua New Guinea"]
        , ["brown tree snake", "introduced_by", "U.S. military"]
        , ["Guam", "part_of", "Micronesia"]
        , ["Puffinus bryani", "is_a", "species"]
        , ["Ashy Storm-petrel", "is_a", "species"]
        , ["Conservation Metrics, Inc.", "developed", "biodiversity monitoring application"]
        , ["data clustering", "used_for", "view and select data samples"]
        , ["data clustering", "example_of", "t-SNE"]
        , ["DL algorithms", "trained_to", "classify events of interest"]
        , ["DL algorithms", "type_of", "CNNs"]
        , ["DL algorithms", "type_of", "DNNs"]
        , ["environmental sensor data", "measured_by", "temperature"]
        , ["global positioning systems", "used_to", "improve wildlife monitoring techniques"]
        , ["technology", "used_to", "improve conservation monitoring"]
        , ["technology", "used_for", "data analysis"]
    ]

In [95]:
len(final_triplets)

132

In [96]:
len(rag_triplets)

41

In [92]:
triplets = [
    [
      "David J. Klein",
      "affiliated_with",
      "Conservation Metrics, Inc."
    ],
    [
      "Matthew W. McKown",
      "affiliated_with",
      "Conservation Metrics, Inc."
    ],
    [
      "Bernie R. Tershy",
      "affiliated_with",
      "Conservation Metrics, Inc."
    ],
    [
      "Conservation Metrics, Inc.",
      "located_in",
      "site"
    ],
    [
      "DL",
      "used_by",
      "models"
    ],
    [
      "models",
      "monitors",
      "species"
    ],
    [
      "models",
      "monitors",
      "ecosystems"
    ],
    [
      "audio",
      "component_of",
      "sensors"
    ],
    [
      "cameras",
      "component_of",
      "sensors"
    ],
    [
      "natural systems",
      "component_of",
      "Ecosystem services"
    ],
    [
      "Convention for Global Biodiversity",
      "affiliated_with",
      "UN"
    ],
    [
      "conservation",
      "produces",
      "conservation outcomes"
    ],
    [
      "wildlife monitoring techniques",
      "component_of",
      "conservation measures program"
    ],
    [
      "natural systems",
      "component_of",
      "natural systems"
    ],
    [
      "droughts",
      "component_of",
      "natural systems"
    ],
    [
      "natural systems",
      "component_of",
      "natural systems"
    ],
    [
      "biological surveys",
      "part_of",
      "monitoring programs"
    ],
    [
      "conservation actions",
      "part_of",
      "monitoring programs"
    ],
    [
      "conservation actions",
      "measured_by",
      "data"
    ],
    [
      "wildlife monitoring techniques",
      "improve",
      "inference"
    ],
    [
      "wildlife monitoring techniques",
      "drive",
      "adaptive management"
    ],
    [
      "adaptive management",
      "of",
      "conservation projects"
    ],
    [
      "biodiversity monitoring",
      "involves",
      "sending observers"
    ],
    [
      "survey sites",
      "to_collect",
      "data"
    ],
    [
      "traditional surveys",
      "to_meet",
      "global conservation"
    ],
    [
      "human observers",
      "lead_to",
      "negative ecological impacts"
    ],
    [
      "sensors",
      "component_of",
      "monitoring"
    ],
    [
      "audio",
      "part_of",
      "sensors"
    ],
    [
      "cameras",
      "part_of",
      "sensors"
    ],
    [
      "gradients",
      "part_of",
      "sensors"
    ],
    [
      "visual",
      "type_of",
      "cameras"
    ],
    [
      "thermal",
      "type_of",
      "cameras"
    ],
    [
      "hyperspectral",
      "type_of",
      "cameras"
    ],
    [
      "New York City",
      "located_in",
      "New York City"
    ],
    [
      "technology",
      "component_of",
      "technology"
    ],
    [
      "global positioning systems",
      "component_of",
      "technology"
    ],
    [
      "cellular networks",
      "component_of",
      "technology"
    ],
    [
      "satellites",
      "component_of",
      "technology"
    ],
    [
      "airplanes",
      "component_of",
      "technology"
    ],
    [
      "un-manned aerial vehicles",
      "component_of",
      "technology"
    ],
    [
      "DL",
      "used_for",
      "biodiversity"
    ],
    [
      "sensor data",
      "measured_by",
      "sensors"
    ],
    [
      "audio",
      "part_of",
      "datasets"
    ],
    [
      "image",
      "part_of",
      "datasets"
    ],
    [
      "temperature",
      "part_of",
      "environmental sensor data"
    ],
    [
      "site",
      "part_of",
      "monitoring project"
    ],
    [
      "region",
      "part_of",
      "monitoring project"
    ],
    [
      "cloud computing and storage",
      "component_of",
      "internet"
    ],
    [
      "internet",
      "depends_on",
      "cloud computing and storage"
    ],
    [
      "mobile phones",
      "used_for",
      "gather sensor data"
    ],
    [
      "sensors",
      "component_of",
      "communication network"
    ],
    [
      "data",
      "flows_to",
      "base stations"
    ],
    [
      "base stations",
      "transmits_via",
      "internet"
    ],
    [
      "base stations",
      "transmits_via",
      "satellite"
    ],
    [
      "base stations",
      "transmits_via",
      "microwave"
    ],
    [
      "base stations",
      "transmits_via",
      "cellular networks"
    ],
    [
      "data",
      "stored_at",
      "data center"
    ],
    [
      "data center",
      "uses",
      "Spark"
    ],
    [
      "Amazon Web Services",
      "is_a",
      "third-party data centers"
    ],
    [
      "DL",
      "component_of",
      "machine learning"
    ],
    [
      "models",
      "produced_by",
      "labeled datasets"
    ],
    [
      "models",
      "trained_to",
      "classify events of interest"
    ],
    [
      "data",
      "according_to",
      "date ranges"
    ],
    [
      "data",
      "according_to",
      "time of day"
    ],
    [
      "data",
      "according_to",
      "site location"
    ],
    [
      "audio signals",
      "frequency_ranges",
      "data"
    ],
    [
      "models",
      "used_for",
      "species"
    ],
    [
      "analysts",
      "uses",
      "exploration tools"
    ],
    [
      "sensor data",
      "analyzed_by",
      "analysts"
    ],
    [
      "exploration tools",
      "visualizes",
      "data"
    ],
    [
      "t-SNE",
      "used_for",
      "data clustering"
    ],
    [
      "t-SNE",
      "used_for",
      "view and select data samples"
    ],
    [
      "feature vector",
      "is",
      "384-dimensional"
    ],
    [
      "models",
      "classify",
      "species"
    ],
    [
      "models",
      "classify",
      "events"
    ],
    [
      "keyboard shortcuts",
      "uses",
      "panels"
    ],
    [
      "spectrograms",
      "represents",
      "audio"
    ],
    [
      "Artificial Intelligence",
      "abbreviation_of",
      "AI"
    ],
    [
      "machine learning",
      "component_of",
      "Artificial Intelligence"
    ],
    [
      "AI",
      "uses",
      "data"
    ],
    [
      "machine learning",
      "uses",
      "data"
    ],
    [
      "sensor networks",
      "produces",
      "petabytes"
    ],
    [
      "sensor networks",
      "produces",
      "exabytes"
    ],
    [
      "analysts",
      "uses",
      "keyboard shortcuts"
    ],
    [
      "Deep Learning",
      "abbreviation_of",
      "DL"
    ],
    [
      "DL",
      "subfield_of",
      "machine learning"
    ],
    [
      "DL",
      "grew_out_of",
      "representation learning"
    ],
    [
      "DL",
      "grew_out_of",
      "neural networks"
    ],
    [
      "DL",
      "grew_out_of",
      "computational neuroscience"
    ],
    [
      "edges",
      "component_of",
      "image"
    ],
    [
      "gradients",
      "component_of",
      "image"
    ],
    [
      "corners",
      "component_of",
      "image"
    ],
    [
      "textures",
      "component_of",
      "image"
    ],
    [
      "visual object recognition",
      "application_of",
      "DL"
    ],
    [
      "speech recognition",
      "application_of",
      "DL"
    ],
    [
      "genomics",
      "application_of",
      "DL"
    ],
    [
      "visual object recognition",
      "application_of",
      "AI algorithm"
    ],
    [
      "speech recognition",
      "application_of",
      "AI algorithm"
    ],
    [
      "genomics",
      "application_of",
      "AI algorithm"
    ],
    [
      "unconstrained image recognition",
      "example_of",
      "difficult problems"
    ],
    [
      "DL algorithm",
      "is_a",
      "ML algorithm"
    ],
    [
      "biodiversity monitoring application",
      "candidate_for",
      "DL"
    ],
    [
      "CNNs",
      "part_of",
      "DL"
    ],
    [
      "DNNs",
      "part_of",
      "DL"
    ],
    [
      "DL algorithms",
      "applied_to",
      "biodiversity monitoring"
    ],
    [
      "models",
      "trained_for",
      "species"
    ],
    [
      "species",
      "related_to",
      "species"
    ],
    [
      "audio",
      "component_of",
      "sensors"
    ],
    [
      "image",
      "component_of",
      "sensors"
    ],
    [
      "International Union for Conservation",
      "part_of",
      "RedList"
    ],
    [
      "Channel Islands National Park",
      "located_in",
      "Anacapa Island"
    ],
    [
      "Ashy Storm-petrel",
      "scientific_name",
      "Puffinus"
    ],
    [
      "Bryans Shearwater",
      "scientific_name",
      "Puffinus"
    ],
    [
      "Bryans Shearwater",
      "member_of",
      "Puffinus"
    ],
    [
      "Queensland",
      "located_in",
      "Australia"
    ],
    [
      "Great Barrier Reef Marine Park",
      "located_in",
      "Australia"
    ],
    [
      "California Coastal National Monument",
      "located_in",
      "U.S. Geological Survey"
    ],
    [
      "Island of Guam",
      "located_in",
      "Guam"
    ],
    [
      "brown tree snake",
      "scientific_name",
      "Boiga irregularis"
    ],
    [
      "brown tree snake",
      "native_to",
      "Papua New Guinea"
    ],
    [
      "brown tree snake",
      "introduced_by",
      "U.S. military"
    ],
    [
      "U.S. Geological Survey",
      "located_in",
      "Guam"
    ],
    [
      "summer of 2012",
      "precedes",
      "summer of 2014"
    ]]

In [100]:
len(triplets)

123

In [101]:
len(rag_triplets)

41