In [1]:
%matplotlib inline
import numpy as np
import json
import matplotlib.pyplot as plt
# from skbio.stats.distance import mantel
from itertools import combinations
import tikzplotlib

plt.style.use('dark_background')

BASE_DIR = '/Users/dfilipiak/repositories/publikacje/PUB-2021-OneShotKG'
PROJECT_DIR = f'{BASE_DIR}/KGTN2'
FIGURES_DIR = f'{BASE_DIR}/figures'

In [2]:
distances_hierarchy = np.load(f'{PROJECT_DIR}/KnowledgeGraphMatrix/HierarchyGraph.npy')
distances_glove = np.load(f'{PROJECT_DIR}/KnowledgeGraphMatrix/SemanticGraph.npy')
distances_wikidata = np.load(f'{PROJECT_DIR}/KnowledgeGraphMatrix/WikidataGraph.npy').astype(np.float64)

distances_dict = {
    'hierachy': distances_hierarchy,
    'glove': distances_glove,
    'wiki': distances_wikidata,
}

In [3]:
print(f" & Min & Avg & Max & Std \\\\")
for k, v in distances_dict.items():
    print(f"{k:10} & {v.min():3.2f} & {v.mean():3.2f} & {v.max():3.2f} & {v.std():3.2f} \\\\")

 & Min & Avg & Max & Std \\
hierachy   & 0.00 & 9.76 & 10.00 & 1.20 \\
glove      & 0.00 & 8.52 & 14.31 & 1.29 \\
wiki       & 0.00 & 5.82 & 12.73 & 1.32 \\


In [4]:
NUM_VALID=196+300+311
coefficient_glove = .4
coefficient_hierarchy = .5

kg_ratio = 100
use_all_base = True
label_idx_file = f"{PROJECT_DIR}/DataSplit/KGTN/label_idx.json"
testsetup = 1

def get_ignore_ind(testsetup = 1, use_all_base = use_all_base, label_idx_file = label_idx_file):
    with open(label_idx_file, 'r') as f:
        lowshotmeta = json.load(f)
    novel_classes = lowshotmeta['novel_classes_1']
    novel2_classes = lowshotmeta['novel_classes_2']
    base_classes = lowshotmeta['base_classes_1']
    base2_classes = lowshotmeta['base_classes_2']
    if testsetup:
        if use_all_base:
            ignore_ind = novel_classes 
            valid_nodes = novel2_classes + base2_classes + base_classes
        else:
            ignore_ind = novel_classes + base_classes
            valid_nodes = novel2_classes + base2_classes
    else:
        if use_all_base:
            ignore_ind = novel2_classes
            valid_nodes = novel_classes + base2_classes + base_classes
        else:
            ignore_ind = novel2_classes + base2_classes
            valid_nodes = novel_classes + base_classes
    return ignore_ind, valid_nodes


def process_semantic(mat, coefficient, kg_ratio):
    ignore_ind, _ = get_ignore_ind()
    num_classes = mat.shape[0]
    mat[range(num_classes), range(num_classes)] = 999
    min_mat = np.min(mat, 1) 
    mat = mat - min_mat.reshape(-1, 1) + 1

    in_matrix = coefficient ** (mat - 1)
    in_matrix[:, ignore_ind] = 0
    in_matrix[ignore_ind, :] = 0
    in_matrix[range(num_classes), range(num_classes)] = 2
        
    # in the ascent order
    topk = int(NUM_VALID * kg_ratio / 100)
    max_ = -np.sort(-in_matrix, 1)
    edge = max_[:, topk].reshape(-1, 1)
    in_matrix[in_matrix < edge] = 0

    return in_matrix


def process_wordnet(mat, coefficient, kg_ratio):
    ignore_ind, _ = get_ignore_ind()
    num_classes = mat.shape[0]
    in_matrix = coefficient ** (mat - 1)
    in_matrix[range(num_classes), range(num_classes)] = 2
    in_matrix[ignore_ind, :] = 0
    in_matrix[:, ignore_ind] = 0
    # in the ascent order
    topk = int(NUM_VALID * kg_ratio / 100)
    max_ = -np.sort(-in_matrix, 1)
    edge = max_[:, topk].reshape(-1, 1)
    in_matrix[in_matrix < edge] = 0
    return in_matrix

distances_dict_processed = {
    'hierachy': process_wordnet(np.copy(distances_hierarchy), coefficient=coefficient_hierarchy, kg_ratio=kg_ratio),
    'glove': process_semantic(np.copy(distances_glove), coefficient=coefficient_glove, kg_ratio=kg_ratio),
    'wiki': process_semantic(np.copy(distances_wikidata), coefficient=.32, kg_ratio=kg_ratio),
}


In [5]:
print(f" & Min & Avg & Max & Std \\\\")
for k, v in distances_dict_processed.items():
    print(f"{k:10} & {v.min():3.2f} & {v.mean():3.2f} & {v.max():3.2f} & {v.std():3.2f} \\\\")

 & Min & Avg & Max & Std \\
hierachy   & 0.00 & 0.01 & 2.00 & 0.07 \\
glove      & 0.00 & 0.05 & 2.00 & 0.11 \\
wiki       & 0.00 & 0.08 & 2.00 & 0.14 \\


In [8]:
distances_dict_processed['glove'].shape

(1000, 1000)