In [None]:
from scipy.sparse import csc_matrix
from scipy.sparse.csgraph import csgraph_from_masked


class AffinityMatrix:

    def __init__(self, n_neighbor, max_neighbor, dim_neighbor):
        self.distances = None
        self.neighbors = None
        self.points = None
        self.dims = None
        self.components_ = None
        self.min_dim = None
        self.n_neighbor = n_neighbor
        self.max_neighbor = max_neighbor
        self.dim_neighbor = dim_neighbor

    def get_real_neighbors(self, x):
        res = [x]
        for ind in self.neighbors[x]:
            res.append(ind)
        return np.asarray(res, dtype=np.int32)

    def get_real_neighbors_points(self, x):
        res = []
        for u in self.get_real_neighbors(x):
            res.append(self.points[u])
        return np.asarray(res)

    def fit(self, X):
        nbrs = NearestNeighbors(n_neighbors=self.n_neighbor).fit(X)
        self.distances, self.neighbors = nbrs.kneighbors(X)
        self.points = X.to_numpy()
        self.dims = np.zeros(shape=(len(X)))
        for v in range(len(X)):
            self.dims[v] = get_dimension(self.get_real_neighbors_points(v), 0.000001, self.dim_neighbor)
        self.min_dim = min(self.dims)

    def is_connected(self, i, j):
        if min(self.dims[i], self.dims[j]) != self.min_dim:
            return False
        return True

    def get_connection_weight(self, i, j):
        if not self.is_connected(i, j):
            return np.inf
        return np.linalg.norm(self.points[i] - self.points[j])

    def predict(self):
        row = []
        col = []
        data = []
        nn = len(self.points)
        for i in range(len(self.points)):
            if self.dims[i] != self.min_dim:
                continue
            cnt = 0
            r_n = self.get_real_neighbors(i)
            for jj in range(self.max_neighbor, 0, -1):
                j = r_n[jj]
                if cnt == 2 * self.dims[i] + 2:
                    break
                if self.is_connected(i, j):
                    cnt += 1
                    row.append(i)
                    col.append(j)
                    data.append(self.get_connection_weight(j, i))
                    row.append(j)
                    col.append(i)
                    data.append(self.get_connection_weight(i, j))
        return csc_matrix((np.array(data), (np.array(row), np.array(col))), shape=(nn, nn))
