In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class Node:
    def __init__(self, subset, q):        
        self.subset = subset
        self.q = q
        self.cost = self.distCalc().sum()
        self.weight = len(self.subset)
        
        self.leftChild = None
        self.rightChild = None
        
    def distCalc(self):
        distances = ((self.subset - self.q[:, np.newaxis])**2).sum(axis=2)
        return distances
        
    def subsetFinder(self):
        dist = self.distCalc()
        max_dist_index = np.unravel_index(np.argmax(dist, axis=None),
                                          dist.shape)[1]
        q2 = self.subset[max_dist_index][np.newaxis, :]
        grou = np.vstack((self.q,q2))
        grou_dist = ((self.subset - grou[:, np.newaxis])**2).sum(axis=2)
        labels = np.argmin(grou_dist, axis=0)
        subset1 = self.subset[labels == 0]
        subset2 = self.subset[labels == 1]
        return (subset1,self.q),(subset2,q2)

In [3]:
class CoreSetTree:
    def __init__(self, X, m):
        self.wholeSet = X
        self.m = m   
        self.coreset = np.array([])

    def fit(self):
        np.random.seed(42)
        q1 = self.wholeSet[np.random.choice(self.wholeSet.shape[0], 1, replace=False)]
        self.coreset = np.append(self.coreset, q1).reshape(-1,2)
        self.Root = Node(self.wholeSet, q1)     
        for i in range(self.m):
            self.MaxNode = None
            self.MaxNodeCost = 0
            self.visit(self.Root)
            self.addChild(self.Root)
            self.propagateUp(self.Root)
    
    def addChild(self, node):
        if node:
            self.addChild(node.leftChild)
            self.addChild(node.rightChild)
            if node.leftChild is None and node.rightChild is None:
                if node.cost == self.MaxNodeCost:
                    childs = self.MaxNode.subsetFinder()
                    self.coreset = np.vstack((self.coreset, childs[1][1]))
                    node.leftChild = Node(childs[0][0], 
                                          childs[0][1])
                    node.rightChild = Node(childs[1][0],
                                           childs[1][1])
    
    def visit(self, node):
        if node:
            self.visit(node.leftChild)
            self.visit(node.rightChild)
            if node.leftChild is None and node.rightChild is None:
                if node.cost >= self.MaxNodeCost:
                    self.MaxNode = node
                    self.MaxNodeCost = node.cost
    
    def propagateUp(self, node):
        if node:
            self.propagateUp(node.leftChild)
            self.propagateUp(node.rightChild)
            if node.leftChild and node.rightChild:
                node.cost = node.leftChild.cost + node.rightChild.cost
    

In [None]:
if __name__ == "__main__":
    # Example usage
    np.random.seed(42)
    N = 1000
    K = 3
    means = np.array([[0, 0], [5, 5], [10, 0]])
    cov = np.eye(2)
    # this is a 3x1000 array of random multivariate normal distributions
    X = np.vstack([np.random.multivariate_normal(mean, cov, int(N / K)) for mean in means])
    gg = CoreSetTree(X, 230)
    gg.fit()
    
    plt.scatter(X[:,0], X[:,1])
    plt.scatter(gg.coreset[:,0], gg.coreset[:,1], marker='x', s=200)
    plt.show()
    