In [1]:
import sys
sys.path.append('..')

import HTorch
import torch
import numpy as np
import math
from data_utils.data_handler import DataHandler
from data_utils.relations import Relations
import matplotlib.pyplot as plt
import time
from ConeModel import *
from collections import deque
import pandas as pd


In [2]:
cpu = torch.device("cpu"); gpu = torch.device(type='cuda', index=0)
device = cpu
torch.set_default_tensor_type('torch.DoubleTensor')

### Load data and model

In [3]:
import pandas as pd
from IPython.display import clear_output
import random
import networkx as nx


In [4]:
### load basic edges
file_path = '../data_utils/data/maxn/mammal_closure.tsv.train_0percent'
df_indexed = pd.read_csv(file_path, delimiter='\t', encoding='utf-8',header=None)
### load vocab
filename = '../data_utils/data/maxn/mammal_closure.tsv.vocab'
element_index = {}

with open(filename, 'r', encoding='utf-8') as f:
    for line in f:
        index, element = line.strip().split('\t')
        element_index[int(index)] = element

In [5]:
G_load = nx.DiGraph()
G_load.add_edges_from(df_indexed.to_numpy())

In [6]:
 # best radius from validation = raidus from training

size = 1179 

In [64]:
### Load model
# load_name = '../trained_models/mammal_with_vocab'
# hyp_cone = UmbralCone(source = 'infinity', radius = 0.05, size = size, dim = 2, curvature = -1)

# load_name = '../trained_models/mammal_with_vocab_epoch_400_model_umbral_90_dim5' 
# hyp_cone = UmbralCone(source = 'infinity', radius = 0.05, size = size, dim = 5, curvature = -1)

load_name = '../trained_models/mammal_with_vocab_epoch_400_model_penumbral_90_dim2_source_origin' 
hyp_cone = UmbralCone(source = 'origin', radius = 0.1, size = size, dim = 2, curvature = -1)


hyp_cone.to(device)
hyp_cone.load_state_dict(torch.load(load_name+".pth")["model_weight"])
vocab = torch.load(load_name+".pth")["vocab"]

In [65]:
inputs = np.arange(size)
paired = sorted(zip(vocab,inputs))
vocab_sorted,inputs_sorted = zip(*paired)
new_element_index = dict(zip(inputs_sorted, [element_index[int(word)] for word in vocab_sorted]))

In [66]:
vocab = torch.load(load_name+".pth")["vocab"]
vocab = [int(i) for i in vocab]

In [67]:
### dictionary: from vocab index to embedding index
vocab_idx = np.arange(size)
emb_idx = [vocab.index(x) for x in vocab_idx]
emb_from_vocab = dict(zip(vocab_idx, emb_idx))


###

In [68]:
### Fina all source nodes
def find_source_node(G):
    for node in G.nodes():
        if G.in_degree(node) == 0:
            print(node)

find_source_node(G_load)
new_element_index[emb_from_vocab[356]]

12
356
696
1005


'placental.n.01'

In [69]:
def find_nodes_n_edges_down(G, source, n):
    if n < 0:
        return []

    visited = set()
    queue = deque([(source, 0)])
    nodes_at_level_n = []

    while queue:
        current_node, depth = queue.popleft()

        if depth == n:
            nodes_at_level_n.append(current_node)
            print(current_node, new_element_index[emb_from_vocab[current_node]][:-5])
        elif depth < n:
            for neighbor in G.successors(current_node):
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append((neighbor, depth + 1))
                    
        
    return nodes_at_level_n

find_nodes_n_edges_down(G_load, 356, 1);




58 unguiculata
74 ungulata
127 bull
137 yearling
145 livestock
198 unguiculate
312 hyrax
365 tree_shrew
389 pachyderm
494 carnivore
517 plantigrade_mammal
530 aquatic_mammal
603 rodent
634 ungulate
664 doe
693 pangolin
710 aardvark
717 primate
737 digitigrade_mammal
754 bat
786 cow
809 proboscidean
833 lagomorph
955 flying_lemur
1003 fissipedia
1064 buck
1105 edentate
1106 insectivore


In [70]:
### Find all nodes reprensenting taxonomic orders by hand
labels = [494, 603, 634, 717, 754, 809, 833, 955, 1105, 1137]
# labels = [494, 603, 634, 717, 754, 809, 955, 1105, 1137]
for order in labels:
    # print(emb_from_vocab[order])
    print(order, new_element_index[emb_from_vocab[order]][:-5])


494 carnivore
603 rodent
634 ungulate
717 primate
754 bat
809 proboscidean
833 lagomorph
955 flying_lemur
1105 edentate
1137 cetacean


In [71]:
### Find all leaf nodes under each order
def find_leaf_nodes_in_orders(G, orders):
    leaves_in_orders = {order: [] for order in orders}
    
    for order in orders:
        # Check if the order node is in the graph
        if order in G:
            # Use BFS to find all leaf nodes under this order
            visited = set()
            queue = [order]
            while queue:
                current_node = queue.pop(0)
                visited.add(current_node)
                children = list(G.successors(current_node))
                if not children:  # If no children, it's a leaf node
                    if current_node != order:  # Exclude the order node itself
                        leaves_in_orders[order].append(current_node)
                else:
                    # Add children to queue if not visited
                    for child in children:
                        if child not in visited:
                            queue.append(child)

    return leaves_in_orders

items = find_leaf_nodes_in_orders(G_load, labels)

In [72]:
for label in labels:
    print(f"Leaves under order {new_element_index[emb_from_vocab[order]][:-5]}:")
    leaves_nodes = items[label]
    print(len(leaves_nodes))
    # for leaf in leaves_nodes:
        # print(new_element_index[emb_from_vocab[leaf]][:-5])


Leaves under order cetacean:
282
Leaves under order cetacean:
104
Leaves under order cetacean:
244
Leaves under order cetacean:
78
Leaves under order cetacean:
32
Leaves under order cetacean:
8
Leaves under order cetacean:
18
Leaves under order cetacean:
1
Leaves under order cetacean:
16
Leaves under order cetacean:
24


In [73]:
### Create item-label dataset
def create_dataset_from_dict(orders_dict):
    data = []
    for order, leaf_nodes in orders_dict.items():
        for node in leaf_nodes:
            data.append({'item': node, 'label': order})
    
    df = pd.DataFrame(data)
    return df

dataset_df = create_dataset_from_dict(items)

In [74]:
dataset_df

Unnamed: 0,item,label
0,1035,494
1,538,494
2,514,494
3,645,494
4,518,494
...,...,...
802,665,1137
803,1170,1137
804,155,1137
805,480,1137


### Distance-based label prediction

In [88]:
unique_labels = list(set(dataset_df['label']))

# Convert item and label names to embedding indices
item_indices = [emb_from_vocab[item] for item in dataset_df['item']]
label_indices = [emb_from_vocab[label] for label in unique_labels]

# # Extract embeddings
item_emb = hyp_cone.emb(torch.tensor(item_indices))
label_embs = hyp_cone.emb(torch.tensor(label_indices))

# Expand item and label embeddings to batch compute pairwise distances
expanded_item_emb = item_emb.unsqueeze(1).expand(-1, len(unique_labels), -1)  # Shape: [num_items, num_labels, emb_dim]
expanded_label_embs = label_embs.unsqueeze(0).expand(len(dataset_df), -1, -1)  # Shape: [num_items, num_labels, emb_dim]

# Calculate pairwise hyperbolic distances
# distances = expanded_item_emb.Hdist(expanded_label_embs)

# Calculate pairwise Euclidean distances in the first (n-1) dimensions of the half-space
differences = expanded_item_emb - expanded_label_embs
# differences = differences[...,:-1]
squared_differences = differences ** 2
summed_squares = torch.sum(squared_differences, dim=2)
distances = torch.sqrt(summed_squares)

# Calculate softmax based on distances
probabilities = torch.softmax(-distances, dim=1) 

# Get the predicted label (the one with the highest probability)
predicted_label_idx = torch.argmax(probabilities, dim=1)
predicted_labels = [unique_labels[idx] for idx in predicted_label_idx]



In [89]:
# Calculate prediction accuracy
def calculate_accuracy(predictions, true_labels):
    correct_predictions = sum(p == t for p, t in zip(predictions, true_labels))
    accuracy = correct_predictions / len(predictions)
    return accuracy

# Extract true labels from the dataset
true_labels = dataset_df['label'].tolist()

# Calculate accuracy
accuracy = calculate_accuracy(predicted_labels, true_labels)
print(f"Accuracy: {accuracy:.4f}")

Accuracy: 0.1512
