<a href="https://colab.research.google.com/github/Taedriel/ZSL-v2/blob/wordEmbedding/Hierarchical_clustering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install orange3 python-louvain networkx --quiet --upgrade

In [15]:
import numpy as np
import torch
import community.community_louvain as community
from Orange.clustering.hierarchical import dist_matrix_linkage, tree_from_linkage, data_clustering, leaves, WEIGHTED
from Orange.data import Table, Domain
from Orange.distance.distance import Cosine
from Orange.widgets.unsupervised.owhierarchicalclustering import clusters_at_height
from scipy.cluster.hierarchy import dendrogram
from itertools import chain
from collections import Counter
from Orange.data.variable import StringVariable

In [16]:
class EmbeddingsLoader:

    """class that load an embeddings file to perform operation on it. Base class
     for multiple operations such as matrix similarity operations.
     """

    def __init__(self, filename : str):

        self.file = filename
        self.embeddings = {}

        self._load_file()

    def _load_file(self):
        try:
            with open(self.file, "r") as f:
                lines = f.readlines()
                
            for line in lines[1:]:
                data = line.split(",")
                self.embeddings[data[0]] = torch.FloatTensor(list(map(float, data[1:])))

        except IOError as e:
            raise IOError(f"No file {self.file}")


# New Solver

In [60]:
CLUSTER_THRESOLD = 0.85
GROUP_BY = "first superclass"
MYSTERY = "TOGUESS"
SIM_THRESOLD = 0.3

def left_join(complete_table, supp_info_table, key: str = "embeddings") -> Table:
    """add all <b> metas </b> column from supp_info_table to complete_table using key as joint
    """
    assert key in list(map(lambda x : x.name, supp_info_table.domain.metas)), "embeddings name not present in additional data"
    # assert len(complete_table) == len(supp_info_table), "table don't contain the same number of line"

    name_supp_data = [i.name for i in chain(supp_info_table.domain.metas, 
                                            supp_info_table.domain.variables, 
                                            supp_info_table.attributes) if i.name != key]
                                            
    supp_list_list = [[] for i in range(len(name_supp_data))]

    for s in complete_table:
        done = False
        for d in supp_info_table:
            if s[key] == d[key]:
                for i, name in enumerate(name_supp_data):
                    supp_list_list[i].append(d[name])
                done = True
                break
        if not done:
            for i, name in enumerate(name_supp_data):
                supp_list_list[i].append("?")

    for i, name in enumerate(name_supp_data):
        # print(f"adding {name}")
        complete_table = complete_table.add_column(StringVariable(name), supp_list_list[i])

    return complete_table

def parent_of_mystery(cluster, mystery_index):
    res = None
    for branch in cluster.branches:
        if branch.is_leaf:
            if branch.value.index == mystery_index:
                return cluster
        else: 
            res = parent_of_mystery(branch, mystery_index)
            if res is not None:
                return res
    
def first_child(root):
    if root.is_leaf:
        return root
    else:
        return first_child(root.branches[0])


def closest_to(cluster, mystery_index):
    if len(cluster.branches) == 1:
        return None

    next = False
    for i, branch in enumerate(cluster.branches):
        if next:
            return first_child(branch)

        if branch.is_leaf:
            if branch.value.index == mystery_index:
                if i == 0:
                    next = True
                else:
                    return first_child(cluster.branches[i-1])

def add_to_list(cluster, list_to_add_to):
    """ decompose a cluster tree by adding the index of all children in the list
    """
    if cluster.is_leaf:
        list_to_add_to.append(cluster.value.index)

    for i, branch in enumerate(cluster.branches):
        add_to_list(branch, list_to_add_to)

def clusterize(table : Table, thresold, key = "embeddings") -> Table:
    """clusterize a Oranga Table based on the height of THRESOLD
    """
    for i in table[-1::-1]:
        if i[key] == MYSTERY:
            mystery_index = table.index(i)
            break

    root = data_clustering(table, distance=Cosine, linkage=WEIGHTED)
    parent_cluster = parent_of_mystery(root, mystery_index)
    if thresold is None:
        thresold = min(parent_cluster.value.height + 0.001, 1)

    cluster_tree = clusters_at_height(root, thresold)

    list_cluster = {}
    closest = None
    mystery_len_cluster = -1
    for i, cluster in enumerate(cluster_tree):
        cluster_name     = 'C' + str(i) 

        current = []
        add_to_list(cluster, current)
        if mystery_index in current: 
            mystery_len_cluster = len(current)
            closest = closest_to(parent_cluster, mystery_index)

        for item_index in current:
            list_cluster[item_index] = cluster_name
        # print(cluster_name, list(map(lambda x: table[x]["embeddings"].value, current)))

    # print(f"last cluster: {i} ({mystery_len_cluster})")
    table = table.add_column(StringVariable("Cluster"), [list_cluster[i] for i in range(len(table))])

    return table, closest.value.index, thresold, i

def compute(lst):
    # return max(lst,key=lst.count)
    counter = Counter(lst)
    return counter.most_common(len(lst))

def one_pass(table, toguess_table, keep_cluster_line : bool = False, cluster_thresold : float = CLUSTER_THRESOLD, sim_thresold : float = SIM_THRESOLD):
    assert GROUP_BY in list(map(lambda x: x.name, chain(table.domain.metas, 
                                                        table.domain.variables, 
                                                        table.domain.attributes))), "Group by not in the Table !"
    
    format = (len(table), "x", len(table.domain.attributes))
    table, closest, thresold, nb_cluster = clusterize(table, cluster_thresold)
    #===========================================================================
    # Cluster split
    toguess_cluster = [d["Cluster"] for d in table if d["embeddings"] == MYSTERY][0]

    in_cluster_table  = Table.from_list(table.domain, [d for d in table if d["Cluster"].value == toguess_cluster])
    out_cluster_table = Table.from_list(table.domain, [d for d in table if d["Cluster"].value != toguess_cluster])
    #===========================================================================
    # Group by computation
    main_superclass_count_list = compute([row[GROUP_BY].value for row in in_cluster_table])
    #equality case with "?", take the second
    ind = 1 if main_superclass_count_list[0][0] == "?" and len(main_superclass_count_list) > 1 else 0

    main_superclass = main_superclass_count_list[ind][0]
    main_superclass_table = Table.from_list(superclass_embeddings.domain, 
                                            [i for i in superclass_embeddings if i["embeddings"] == main_superclass])
    if len(main_superclass_table) == 0: return [], "cluster is empty"

    main_superclass_table = Table.concatenate([in_cluster_table, Table.from_table(out_cluster_table.domain, 
                                                                                  main_superclass_table)])
    #===========================================================================
    # thresold computation
    to_copy_row_instance = [d for d in main_superclass_table if d["embeddings"] == MYSTERY][0]
    to_copy = list(to_copy_row_instance.attributes())

    to_compare_row_instance = [d for d in main_superclass_table if d["Cluster"] == "?"][0]
    to_compare = list(to_compare_row_instance.attributes())

    dead_row = [k for k, (i, j) in enumerate(zip(to_copy, to_compare)) if abs(i - j) <= sim_thresold]
    #===========================================================================
    # reconstruct the table filtering dead row and cluster. Remove used cluster row if 
    # keep_cluster_line is set to False
    new_domain = Domain(attributes = [i for i in out_cluster_table.domain.attributes if int(i.name) not in dead_row], 
                        metas      = [i for i in out_cluster_table.domain.metas if i.name != "Cluster"])

    # do the same on the data
    data_attr, data_meta = [], []
    whole_data = list(out_cluster_table) + list(toguess_table)
    if keep_cluster_line: whole_data += list(in_cluster_table)

    for rowinstance in whole_data:
        data_attr.append([rowinstance[k] for k, i in enumerate(out_cluster_table.domain.attributes) if int(i.name) not in dead_row])
        data_meta.append([rowinstance.metas[k] for k, i in enumerate(out_cluster_table.domain.metas) if i.name != "Cluster"])

    return Table.from_numpy(new_domain, X = data_attr, metas = data_meta), \
        { "cluster" : {
                "name" : main_superclass,
                "size" : len(in_cluster_table) - 1,
                "thresold": thresold,
                "closest_to_myster" : table[closest]["embeddings"].value if closest is not True else None
            },
            "number_of_cluster" : nb_cluster,
           "format_at_beginning": format,
           "keep_cluster_line"  : keep_cluster_line,
           "sim_thresold"       : sim_thresold,
           "removed_col"        : len(dead_row) 
        }

def standardize_first(table):
    values = table[0]
    mean = np.mean(values)
    std  = np.std(values)

    for v in range(len(values)):
        values[v] = (values[v] - mean) / std

    return Table.from_numpy(table.domain, [values], None, table.metas)

In [8]:
generic_table = Table("/content/imagenet-wikipedia2vec-300.csv")
supp_info_table = Table("/content/class_map_imagenet.csv")
generic_table = left_join(generic_table, supp_info_table)

superclass_embeddings = Table("/content/custom-wikipedia2vec-300_superclass.csv")

In [61]:
def solve_mystery(complete_table, mystery, cluster_thresold_lambda, sim_thresold_lambda):

    toguess_table = Table.from_numpy(complete_table.domain, [np.array(mystery)], Y = None, metas = np.char.asarray([[MYSTERY, "?", "?"]]))
    toguess_table = standardize_first(toguess_table)

    table = Table.concatenate([complete_table, toguess_table])
    old_table = table

    advancement = []
    for i in range(5):
        old_table = table
        table, data = one_pass(table, toguess_table,
                                      keep_cluster_line = False, 
                                      cluster_thresold  = cluster_thresold_lambda(i), 
                                      sim_thresold      = sim_thresold_lambda(i))
        advancement.append(data)

        if len(table) <= 1:
            break
    return advancement

            # print("no result, trying to upper cluster thresold")
            # current_cluster_thresold = 0.55 + 0.05
            # while current_cluster_thresold < 1 and len(current_table) == 0:
            #     current_table, data = one_pass(old_table, keep_cluster_line = False, cluster_thresold = current_cluster_thresold, sim_thresold = 0.3 + 0.05 * i)
            #     print(data)
            #     current_cluster_thresold += 0.05
            # if len(current_table) == 0:
            #     print("no suitable thresold...")
            #     break
            # print("find a suitable thresold, resuming")

In [19]:
myster_file = EmbeddingsLoader("/content/mystery.csv")

In [62]:
def format_result(list_dict):
    superclass_list = []
    for dic in list_dict:
        if type(dic) == type(dict()):
            superclass_list.append((dic["cluster"]["name"], dic["cluster"]["closest_to_myster"], dic["cluster"]["thresold"]))
    
    return superclass_list

# cluster_thresold_lambda = lambda x : 0.65 + 0.05 * x
cluster_thresold_lambda     = lambda x : None
sim_thresold_lambda     = lambda x : 0.1 + 0.05 * x

for i, embeddings in myster_file.embeddings.items():

    result = solve_mystery(generic_table, embeddings, cluster_thresold_lambda, sim_thresold_lambda)
    print(i, *result, sep = ",")


0,{'cluster': {'name': 'dog', 'size': 1, 'thresold': 0.3481095649353513, 'closest_to_myster': 'affenpinscher'}, 'number_of_cluster': 936, 'format_at_beginning': (1000, 'x', 300), 'keep_cluster_line': False, 'sim_thresold': 0.1, 'removed_col': 22},{'cluster': {'name': 'dog', 'size': 1, 'thresold': 0.30940756879306053, 'closest_to_myster': 'otterhound'}, 'number_of_cluster': 957, 'format_at_beginning': (999, 'x', 278), 'keep_cluster_line': False, 'sim_thresold': 0.15000000000000002, 'removed_col': 9},{'cluster': {'name': 'dog', 'size': 997, 'thresold': 0.9317622545442632, 'closest_to_myster': 'carbonara'}, 'number_of_cluster': 0, 'format_at_beginning': (998, 'x', 269), 'keep_cluster_line': False, 'sim_thresold': 0.2, 'removed_col': 39}
1,{'cluster': {'name': 'dog', 'size': 6, 'thresold': 0.30113508623620966, 'closest_to_myster': 'curly-coated retriever'}, 'number_of_cluster': 966, 'format_at_beginning': (1000, 'x', 300), 'keep_cluster_line': False, 'sim_thresold': 0.1, 'removed_col': 32}