In [None]:
from skimage import io, color, util, filters
import numpy as np
import networkx as nx
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
# load and binarize the image
# Use any sample image !!!!!!!
image = io.imread("/content/mnistsample.png")
if len(image.shape) == 2:
    # Duplicate the single channel to create a three-channel image
    image = np.stack((image,) * 3, axis=-1)
image = color.rgb2gray(image)
image = util.invert(image)
binary_image = image > filters.threshold_otsu(image)

G = nx.Graph()


for (x, y), value in np.ndenumerate(binary_image):
    G.add_node((x, y), binary=value)


white_pixels = np.argwhere(binary_image == 1)


neigh = NearestNeighbors(n_neighbors=5)


neigh.fit(white_pixels)


for pixel in white_pixels:
    distances, indices = neigh.kneighbors([pixel])
    for idx, neighbor_idx in enumerate(indices[0]):
        neighbor = white_pixels[neighbor_idx]
        # use distance as weight (smaller distance -> larger weight)
        weight = 1 / distances[0][idx] if distances[0][idx] != 0 else 1
        G.add_edge(tuple(pixel), tuple(neighbor), weight=weight)



In [None]:


fig, ax = plt.subplots()
ax.imshow(binary_image, cmap=plt.cm.gray)

# Iterate over edges and plot lines
for edge in G.edges(data=True):
    x1, y1 = edge[0]
    x2, y2 = edge[1]
    ax.plot([y1, y2], [x1, x2], 'r-', linewidth=0.5)

plt.show()