<a href="https://colab.research.google.com/github/Taedriel/ZSL-v2/blob/wordEmbedding/Hierarchical_clustering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install orange3 python-louvain networkx dendropy biopython scikit-bio --quiet --upgrade

In [4]:
import numpy as np
import torch
import community.community_louvain as community
import dendropy
# if import error, launch import a second time, and it will be fine
from Orange.clustering.hierarchical import Tree, ClusterData, SingletonData, dist_matrix_linkage, tree_from_linkage, data_clustering, leaves, WEIGHTED, dist_matrix_clustering
from Orange.data import Table, Domain
from Orange.distance.distance import Cosine
from Orange.widgets.unsupervised.owhierarchicalclustering import clusters_at_height
from Orange.misc.distmatrix import DistMatrix
from Bio import Phylo
from io import StringIO
from sklearn.metrics.pairwise import cosine_similarity,cosine_distances
from scipy.cluster.hierarchy import dendrogram
from itertools import chain
from collections import Counter
from Orange.data.variable import StringVariable
from skbio import DistanceMatrix
from skbio.tree import nj
from typing import Dict, Tuple, List, Callable
from tqdm import tqdm
import sklearn.preprocessing as pp
from scipy import sparse

from sklearn.metrics import pairwise_distances
from scipy.spatial.distance import cosine

In [5]:
class EmbeddingsLoader:

    """class that load an embeddings file to perform operation on it. Base class
     for multiple operations such as matrix similarity operations.
     """

    def __init__(self, filename : str):

        self.file = filename
        self.embeddings = {}

        self._load_file()

    def _load_file(self):
        try:
            with open(self.file, "r") as f:
                lines = f.readlines()
                
            for line in lines[1:]:
                data = line.split(",")
                self.embeddings[data[0]] = torch.FloatTensor(list(map(float, data[1:])))

        except IOError as e:
            raise IOError(f"No file {self.file}")

# New Solver

In [76]:
CLUSTER_THRESOLD = 0.85
GROUP_BY = "first superclass"
MYSTERY = "TOGUESS"
SIM_THRESOLD = 0.3

def left_join(complete_table, supp_info_table, key: str = "embeddings") -> Table:
    """add all <b> metas </b> column from supp_info_table to complete_table using key as joint
    """
    assert key in list(map(lambda x : x.name, supp_info_table.domain.metas)), "embeddings name not present in additional data"
    # assert len(complete_table) == len(supp_info_table), "table don't contain the same number of line"
    print(len(complete_table), len(supp_info_table))
    name_supp_data = [i.name for i in chain(supp_info_table.domain.metas, 
                                            supp_info_table.domain.variables, 
                                            supp_info_table.attributes) if i.name != key]
                                            
    supp_list_list = [[] for i in range(len(name_supp_data))]

    for s in complete_table:
        done = False
        for d in supp_info_table:
            if s[key] == d[key]:
                for i, name in enumerate(name_supp_data):
                    supp_list_list[i].append(d[name])
                done = True
                break
        if not done:
            for i, name in enumerate(name_supp_data):
                supp_list_list[i].append("?")

    for i, name in enumerate(name_supp_data):
        # print(f"adding {name}")
        complete_table = complete_table.add_column(StringVariable(name), supp_list_list[i])

    return complete_table

def parent_of_mystery(cluster, mystery_index):
    res = None
    for branch in cluster.branches:
        if branch.is_leaf:
            if branch.value.index == mystery_index:
                return cluster
        else: 
            res = parent_of_mystery(branch, mystery_index)
            if res is not None:
                return res
    
def first_child(root):
    if root.is_leaf:
        return root
    else:
        return first_child(root.branches[0])


def closest_to(cluster, mystery_index):
    if len(cluster.branches) == 1:
        return None

    next = False
    for i, branch in enumerate(cluster.branches):
        if next:
            return first_child(branch)

        if branch.is_leaf:
            if branch.value.index == mystery_index:
                if i == 0:
                    next = True
                else:
                    return first_child(cluster.branches[i-1])

def add_to_list(cluster, list_to_add_to):
    """ decompose a cluster tree by adding the index of all children in the list
    """
    if cluster.is_leaf:
        list_to_add_to.append(cluster.value.index)

    for i, branch in enumerate(cluster.branches):
        add_to_list(branch, list_to_add_to)

def compute(lst, tips):
    # return max(lst,key=lst.count)
    weighted_lst = {elem: 0 for elem in set(lst)}
    for elem in lst:
        if elem in tips:
            weighted_lst[elem] += 3
        else:
            weighted_lst[elem] += 1

    most_common = sorted([(key, elem) for key, elem in weighted_lst.items()], key = lambda x : x[1], reverse = True)
    # print(most_common)
    return most_common

def Orange_tree_to_newick(root):

    if root.is_leaf:
        return str(root.value.index)  +  ":" + str(root.value.height)

    concat = "("
    for branch in root.branches:

        concat += Orange_tree_to_newick(branch) + ","

    concat = concat[:-1] + ")" +  ":" + str(root.value.height)
    return concat

def biotree_to_Orange_tree(tree):
    def recur_parse(root, acc, depth):
        if root.is_terminal():
            val = root.name
            leaf = Tree(SingletonData(range = range(len(acc), len(acc)+1), 
                                 height= 0.0, 
                                 index = val), ())
            acc.append(leaf)
            return leaf

        else:
            list_cla = []
            for cla in root:
                sub_tree = recur_parse(cla, acc, depth + (root.branch_length or 0))
                list_cla.append(sub_tree)
            node = Tree(ClusterData(range = range(list_cla[0]._Tree__value.range.start, list_cla[-1]._Tree__value.range.stop),
                               height = root.branch_length or 0), tuple(list_cla))
            return node

    orange_tree = recur_parse(tree.root, [], 0)
    return orange_tree

def reroot_tree(tree, void_index, format):
    newick_tree = Orange_tree_to_newick(tree)
    tree = Phylo.read(StringIO(newick_tree), "newick")
    tree.root_with_outgroup(tree.root.find_clades("NULL"))

    return biotree_to_Orange_tree(tree)

def sim2dist(mat : List[List[float]], func : Callable[[float], float] \
             = lambda x: 1 - x, hollow : bool = True) -> List[List[float]]:
    """ map the function func to each elements in the matrix

    apply the lambda function func to each element of the matrix. if hollow is set 
    to True, set the diagonal of the matrix to 0.
    Args:
        mat (List[List[float]]) : a matrix of number
        func (Callable[[float], float]) : a simple function to apply to each elem of the matrix
        hollow (bool) : whether to consider the diagonal of the matrix or not
    
    """
    inv_data = [[0 for i in range(len(mat[0]))] for j in range(len(mat))]

    for i, elem in enumerate(mat):
        for j, case in enumerate(elem):
            if i == j and hollow: 
                inv_data[i][j] = 0
            else:
                inv_data[i][j] = func(case)
                
    return inv_data

def njt(table, key : str):

    embeddings = {}
    for line in table:
        embeddings[str(line[key].value)] = list(line.attributes()) 
        

    ids = list(map(lambda x : x.replace(" ", "_").replace("-", "_").replace("'", "_"), embeddings.keys()))
    data = np.array([item for item in embeddings.values()])
    
    cos_A = sim2dist(1-pairwise_distances(data, metric="cosine"))

    def constructor(x):
        biotree = Phylo.read( StringIO(x), "newick")
        # print(biotree)
        return biotree_to_Orange_tree(biotree)

    dm = DistanceMatrix(cos_A, ids)
    tree = nj(dm, result_constructor = constructor)

    return tree



def clusterize(table : Table, thresold, key = "embeddings") -> Table:
    """clusterize a Oranga Table based on the height of THRESOLD
    """
    null = Table.from_numpy(table.domain, [np.array([1 for i in range(len(table.domain.attributes))])], Y = None, metas = np.char.asarray([["NULL", "?", "?"]]))
    table = Table.concatenate([table, null])

    for i in table[-1::-1]:
        if i[key] == MYSTERY:
            mystery_index = table.index(i)
            break
    for i in table[-1::-1]:
        if i[key] == "NULL":
            void_index = table.index(i)
            break

    # root = data_clustering(table, distance=Cosine, linkage=WEIGHTED)
    root = njt(table, key)
    root = reroot_tree(root, void_index, (len(table[0])))

    parent_cluster = parent_of_mystery(root, MYSTERY)
    if thresold is None:
        thresold = min(parent_cluster.value.height + 0.05, 1)

    cluster_tree = clusters_at_height(root, thresold)

    list_cluster = {}
    closest = None
    mystery_len_cluster = -1
    for i, cluster in enumerate(cluster_tree):
        cluster_name     = 'C' + str(i) 

        current = []
        add_to_list(cluster, current)
        if MYSTERY in current: 
            mystery_len_cluster = len(current)
            closest = closest_to(parent_cluster, MYSTERY)

        for item_index in current:
            list_cluster[item_index] = cluster_name
        # print(cluster_name, list(map(lambda x: table[x]["embeddings"].value, current)))

    table = table.add_column(StringVariable("Cluster"), [i for i in list_cluster])

    print(closest.value.index)

    return table, closest.value.index, thresold, i

from typing import List
def one_pass(table, toguess_table, keep_cluster_line : bool = False, cluster_thresold : float = CLUSTER_THRESOLD, sim_thresold : float = SIM_THRESOLD, tips : List[str] = []):
    assert GROUP_BY in list(map(lambda x: x.name, chain(table.domain.metas, 
                                                        table.domain.variables, 
                                                        table.domain.attributes))), "Group by not in the Table !"
    supp_data = {
        "sim_thresold"       : sim_thresold,
        "keep_cluster_line"  : keep_cluster_line,
    }

    supp_data["format_at_beginning"] = (len(table), "x", len(table.domain.attributes))
    table, closest, thresold, nb_cluster = clusterize(table, cluster_thresold)
    supp_data["cluster_thresold"]    = thresold
    supp_data["closest_to_myster"]   = closest # table[]["embeddings"].value if closest is not True else None,
    supp_data["number_of_cluster"]   = nb_cluster
    #===========================================================================
    # Cluster split
    toguess_cluster = [d["Cluster"] for d in table if d["embeddings"] == MYSTERY][0]

    in_cluster_table  = Table.from_list(table.domain, [d for d in table if d["Cluster"].value == toguess_cluster])
    out_cluster_table = Table.from_list(table.domain, [d for d in table if d["Cluster"].value != toguess_cluster])
    #===========================================================================
    # Group by computation
    supp_data["cluster_size"] = len(in_cluster_table)
    if len(in_cluster_table) <= 1: return [], supp_data
    
    main_superclass_count_list = compute([row[GROUP_BY].value for row in in_cluster_table], tips)
    #equality case with "?", take the second
    ind = 1 if main_superclass_count_list[0][0] == "?" and len(main_superclass_count_list) > 1 else 0
    main_superclass = main_superclass_count_list[ind][0]
    supp_data["cluster_name"] = main_superclass

    nb_dimension = len(list(in_cluster_table.domain.attributes))
    average_cluster = Table.from_list(in_cluster_table.domain, [
        [sum([line[i] for line in in_cluster_table]) / nb_dimension  for i in in_cluster_table.domain.attributes] + ["cluster_average"]
    ])

    # main_superclass_table = Table.from_list(superclass_embeddings.domain, [i for i in superclass_embeddings if i["embeddings"] == main_superclass])
    main_superclass_table = Table.concatenate([in_cluster_table, Table.from_table(in_cluster_table.domain, average_cluster)])
    #===========================================================================
    # thresold computation
    to_copy_row_instance = [d for d in main_superclass_table if d["embeddings"] == MYSTERY][0]
    to_copy = list(to_copy_row_instance.attributes())

    to_compare_row_instance = [d for d in main_superclass_table if d["Cluster"] == "?"][0]
    to_compare = list(to_compare_row_instance.attributes())

    dead_row = [k for k, (i, j) in enumerate(zip(to_copy, to_compare)) if abs(i - j) <= sim_thresold]
    supp_data["removed_col"] = len(dead_row) 
    #===========================================================================
    # reconstruct the table filtering dead row and cluster. Remove used cluster row if 
    # keep_cluster_line is set to False
    new_domain = Domain(attributes = [i for i in out_cluster_table.domain.attributes], # if int(i.name) not in dead_row], 
                        metas      = [i for i in out_cluster_table.domain.metas if i.name != "Cluster"])

    # do the same on the data
    data_attr, data_meta = [], []
    whole_data = list(out_cluster_table) + list(toguess_table)
    if keep_cluster_line: whole_data += list(in_cluster_table)

    for rowinstance in whole_data:
        # data_attr.append([rowinstance[k] for k, i in enumerate(out_cluster_table.domain.attributes) if int(i.name) not in dead_row])
        data_attr.append([rowinstance[k] for k, i in enumerate(out_cluster_table.domain.attributes)])
        data_meta.append([rowinstance.metas[k] for k, i in enumerate(out_cluster_table.domain.metas) if i.name != "Cluster"])


    return Table.from_numpy(new_domain, X = data_attr, metas = data_meta), supp_data


def standardize_first(table):
    values = table[0]
    mean = np.mean(values)
    std  = np.std(values)

    for v in range(len(values)):
        values[v] = (values[v] - mean) / std

    return Table.from_numpy(table.domain, [values], None, table.metas)

In [31]:
def solve_mystery(complete_table, mystery, cluster_thresold_lambda, sim_thresold_lambda, tips):

    toguess_table = Table.from_numpy(complete_table.domain, [np.array(mystery)], Y = None, metas = np.char.asarray([[MYSTERY, "?", "?"]]))
    # toguess_table = standardize_first(toguess_table)

    table = Table.concatenate([complete_table, toguess_table])
    old_table = table

    advancement = []
    for i in range(5):
        old_table = table
        table, data = one_pass(table, toguess_table,
                                      keep_cluster_line = False, 
                                      cluster_thresold  = cluster_thresold_lambda(i), 
                                      sim_thresold      = sim_thresold_lambda(i),
                                      tips = tips)
        advancement.append(data)

        if len(table) <= 1 or data["cluster_size"] < 10:
            break
    return advancement

            # print("no result, trying to upper cluster thresold")
            # current_cluster_thresold = 0.55 + 0.05
            # while current_cluster_thresold < 1 and len(current_table) == 0:
            #     current_table, data = one_pass(old_table, keep_cluster_line = False, cluster_thresold = current_cluster_thresold, sim_thresold = 0.3 + 0.05 * i)
            #     print(data)
            #     current_cluster_thresold += 0.05
            # if len(current_table) == 0:
            #     print("no suitable thresold...")
            #     break
            # print("find a suitable thresold, resuming")

In [8]:
generic_table = Table("/content/ResNet50-average.csv")
supp_info_table = Table("/content/class_map_imagenet.csv")
generic_table = left_join(generic_table, supp_info_table)

print(len(generic_table))

# superclass_embeddings = Table("/content/custom-wikipedia2vec-300_superclass.csv")

996 999
996


In [9]:
myster_file = EmbeddingsLoader("/content/mystery_CNN.csv")

In [77]:
def format_result(list_dict):
    superclass_list = []
    for dic in list_dict:
        if type(dic) == type(dict()) and "cluster_name" in dic.keys():
            superclass_list.append(f"{dic['cluster_name']}[{round(dic['cluster_size'] / len(generic_table) * 100, 1)}%]({dic['closest_to_myster'][0]})")
    
    return superclass_list

# cluster_thresold_lambda = lambda x : 0.30 + 0.20 * x
cluster_thresold_lambda     = lambda x : None
sim_thresold_lambda     = lambda x : 0.3

tips = [["bear"],
        ["bear"],
        ["bear"],
        ["monotreme"],
        ["monotreme"],
        ["monotreme"],
        ["cat"],
        ["rodent"],
        ["dog"],
        ["bear"],
        ["bear"],
        ["bear"],
        ["bear"],
        ["bear"],
        ["bear"]]

for k, (i, embeddings) in enumerate(myster_file.embeddings.items()):

    result = solve_mystery(generic_table, embeddings, cluster_thresold_lambda, sim_thresold_lambda, tips[k])
    print(f"{i}\t{' + '.join(format_result(result)): <80}\t\t{result}")


jinrikisha
bear1	                                                                                		[{'sim_thresold': 0.3, 'keep_cluster_line': False, 'format_at_beginning': (997, 'x', 2048), 'cluster_thresold': 0.09395700000000001, 'closest_to_myster': 'jinrikisha', 'number_of_cluster': 0, 'cluster_size': 1}]
schipperke
bear2	                                                                                		[{'sim_thresold': 0.3, 'keep_cluster_line': False, 'format_at_beginning': (997, 'x', 2048), 'cluster_thresold': 0.08696100000000001, 'closest_to_myster': 'schipperke', 'number_of_cluster': 0, 'cluster_size': 1}]
jinrikisha
bear3	                                                                                		[{'sim_thresold': 0.3, 'keep_cluster_line': False, 'format_at_beginning': (997, 'x', 2048), 'cluster_thresold': 0.10772999999999999, 'closest_to_myster': 'jinrikisha', 'number_of_cluster': 0, 'cluster_size': 1}]
platypus
platypus1	                                                