In [3]:
import numpy as np

In [4]:
def BuildTreeMetric(XX,L,KC):
    # XX: considered as a 3-D list of supports for N empirical measures;
    # L:the predefined maximum height level;
    # KC: the number of clusters;
    MAXNUM_TREE = KC**(L+1)
    # MAXNUM_TREE is the maximum number of nodes
    KCarray = [KC ** i for i in range(1, L+1)]
    # KCArray: the array of maximum number of clusters in each level
    
    N = len(XX)
    # N: the number of empirical measures
    dim = len(XX[0][0])
    # dim: the dimension of empirical measures
    
    sIDArray = np.zeros(N)
    eIDArray = np.zeros(N)
    # sIDArray: the start index of each empirical measure
    # eIDArray: the end index of each empirical measure
    nSupports = 0
    # nSupports: the number of supports

    for ii in range(N):
        sIDArray[ii] = nSupports+1
        nSupports += len(XX[ii])
        eIDArray[ii] = nSupports
    
    allXX = np.zeros((nSupports, dim)) 
    # allXX: all supports
    for ii in range(N):
        allXX[sIDArray[ii]:eIDArray[ii],:] = XX[ii]
    
    nXX = nSupports
    # nXX: the number of supports

    kcCenterPP = np.mean(allXX, axis=0)
    numPP = 1
    idzzPP = np.zeros(N)
    Tree = [np.zeros((KCarray[i],KCarray[i])) for i in range(L+1)]

    # Initialization
    TM = {'nVertices':0, 'Vertex_ParentId':np.zeros(MAXNUM_TREE), 'Vertex_ChildId': [[] for _ in range(MAXNUM_TREE)], 
          'Vertex_Pos':np.zeros((MAXNUM_TREE,dim)), 'Vertex_EdgeIdPath': [[] for _ in range(MAXNUM_TREE)],
          'Edge_LowNode': np.zeros(MAXNUM_TREE), 'Edge_HighNode': np.zeros(MAXNUM_TREE), 'Edge_Weight': np.zeros(MAXNUM_TREE),
          'Level_sID': np.zeros(L+1), 'Level_eID': np.zeros(L+1), 'LevelIdArray': []}
    
    # Add root to the tree
    TM['nVertices'] = 1
    TM['Vertex_Pos'][0,:] = kcCenterPP
    TM['Level_sID'][0] = 1
    TM['Level_eID'][0] = 1

    for idLL in range(L):
        idZZLL = np.zeros(nXX)
        # cluster id for each support

        kcCenterLL = np.zeros(dim, KCarray(idLL))
        # number of clusters at each level
        nkcCenterLL = 0
        # the real number of centroids at idLL level
        
        TM['Level_sID'][idLL+1] = TM['Level_sID'][idLL] + 1
        TM['Level_eID'][idLL+1] = TM['Level_eID'][idLL]

        for idCCPP in range(numPP): # idCCPP: parent cluster index
            idVertexPP = TM['Level_sID'][idLL] + idCCPP 
            # index of parent vertex
            
            if idLL == 0:
                idZZ_idCCPP = np.arange(1, nXX+1)
                # if the level is 0, there is only 1 cluster
            else:
                idZZ_idCCPP = np.where(idZZLL == idCCPP)[0]+1
            
            if (len(idZZ_idCCPP) > 0):
                allZZ_idCCPP = allXX[idZZ_idCCPP,:] # supports of corresponding cluster

                rKCLL_idCCPP,_,idZZLL_idCCPP,kcCenterLL_idCCPP,_ = figtreeKCenterClustering(dim,len(idZZ_idCCPP),allZZ_idCCPP.T, KC)
                # rKCLL_idCCPP: real # of clusters
                # idZZLL_idCCPP: children level clusters' ID
                # kcCenterLL_idCCPP: children level centroids

                wLL_idCCPP = np.sqrt(np.sum((kcCenterLL_idCCPP-kcCenterLL_idCCPP[:, idCCPP][:, np.newaxis] ** 2), axis=0))

                setID_0 = np.where(wLL_idCCPP == 0)[0]
                # setID_0: indices of 0-length-edge cluster

                if len(setID_0):
                    kcCenterLL_idCCPP = np.delete(kcCenterLL_idCCPP, setID_0, axis=1)
                    wLL_idCCPP = np.delete(wLL_idCCPP, setID_0)

                    # Now set 0-length-edge cluster to -1
                    clusterID_zerolength = setID_0 - 1
                    allID_zero = []
                    for iiCC_zero in range(len(clusterID_zerolength)):
                        tmp = np.where(idZZLL_idCCPP == clusterID_zerolength[iiCC_zero])[0]
                        allID_zero = np.concatenate((allID_zero,tmp), axis=0)
                    idZZLL_idCCPP[allID_zero] = -1
                
                # Relabel those clusters with numerical edges
                clusterID_nonzero = np.arange(rKCLL_idCCPP)
                clusterID_nonzero = clusterID_nonzero[~np.isin(clusterID_nonzero, setID_0)]
                # Relabel the rest clusters 
                for iiCC_nonzero in range(len(clusterID_nonzero)):
                    if clusterID_nonzero[iiCC_nonzero] != iiCC_nonzero:
                        idZZLL_idCCPP[idZZLL_idCCPP == clusterID_nonzero[iiCC_nonzero]] = iiCC_nonzero
                rKCLL_idCCPP -= len(setID_0)
                
                idZZLL[idZZ_idCCPP-1] = nkcCenterLL + idZZLL_idCCPP
                #idZZLL: cluster id for each cluster 
                if len(setID_0):
                    idZZLL[idZZ_idCCPP[allID_zero]-1] = -1
                
                if rKCLL_idCCPP:
                    kcCenterLL[:, nkcCenterLL:nkcCenterLL+rKCLL_idCCPP] = kcCenterLL_idCCPP

                    idNewVertices = np.arange(TM['Level_eID'][idLL+1]+1,TM['Level_eID'][idLL+1]+rKCLL_idCCPP+1)

                    TM['nVertices'] += rKCLL_idCCPP
                    TM['Vertex_ParentId'][idNewVertices-1] = idVertexPP
                    TM['Vertex_ChildId'][idVertexPP] = idNewVertices.tolist()
                    TM['Vertex_Pos'][idNewVertices-1, :] = kcCenterLL_idCCPP.T
                    idNewEdges = idNewVertices-1
                    TM['Edge_LowNode'][idNewEdges] = idVertexPP
                    TM['Edge_HighNode'][idNewEdges] = idNewVertices
                    TM['Edge_Weight'][idNewEdges] = wLL_idCCPP

                    for ii in range(rKCLL_idCCPP):
                        TM['Vertex_EdgeIdPath'][idNewVertices[ii]-1] = TM['Vertex_EdgeIdPath'][idVertexPP-1] + [idNewEdges[ii]]
                        # add path from the roots to each node
                    TM['Level_eID'][idLL+1] += rKCLL_idCCPP
            
        idZZPP = idZZLL
        # next level's clusters
        kcCenterPP = kcCenterLL[:, :nkcCenterLL]
        numPP = nkcCenterLL
        # next level's # of clusters
    
    TM['LeavesIDArray'] = np.arange(TM['Level_sID'][L], TM['Level_eID'][L] + 1)
    TM['Vertex_ParentId'] = TM['Vertex_ParentId'][:TM['nVertices']]
    TM['Vertex_ChildId'] = TM['Vertex_ChildId'][:TM['nVertices']]
    TM['Vertex_Pos'] = TM['Vertex_Pos'][:TM['nVertices'], :]
    TM['Vertex_EdgeIdPath'] = TM['Vertex_EdgeIdPath'][:TM['nVertices']]
    TM['Edge_LowNode'] = TM['Edge_LowNode'][:TM['nVertices'] - 1]
    TM['Edge_HighNode'] = TM['Edge_HighNode'][:TM['nVertices'] - 1]
    TM['Edge_Weight'] = TM['Edge_Weight'][:TM['nVertices'] - 1]

    return TM