In [1]:
import numpy as np


In [2]:

class SelfOrganizingMap:
    def __init__(self, x, y, input_len, learning_rate=0.5, sigma=None, num_iterations=1000):
        self.x = x
        self.y = y
        self.input_len = input_len
        self.learning_rate = learning_rate
        self.sigma = sigma if sigma else max(x, y) / 2
        self.num_iterations = num_iterations
        self.weights = np.random.rand(x, y, input_len)
        self.locations = np.array([[i, j] for i in range(x) for j in range(y)])

    def _find_bmu(self, input_vec):
        distances = np.linalg.norm(self.weights - input_vec, axis=2)
        bmu_idx = np.unravel_index(np.argmin(distances), (self.x, self.y))
        return bmu_idx

    def _neighborhood_function(self, bmu_idx, iteration):
        sigma = self.sigma * np.exp(-iteration / self.num_iterations)
        learning_rate = self.learning_rate * np.exp(-iteration / self.num_iterations)
        bmu_location = np.array(bmu_idx)

        for i in range(self.x):
            for j in range(self.y):
                node_location = np.array([i, j])
                distance_to_bmu = np.linalg.norm(node_location - bmu_location)
                influence = np.exp(-(distance_to_bmu ** 2) / (2 * sigma ** 2))
                self.weights[i, j] += influence * learning_rate * (self.input_vec - self.weights[i, j])

    def train(self, data):
        for iteration in range(self.num_iterations):
            self.input_vec = data[np.random.randint(0, data.shape[0])]
            bmu_idx = self._find_bmu(self.input_vec)
            self._neighborhood_function(bmu_idx, iteration)

    def map_vects(self, data):
        mapped = []
        for vec in data:
            bmu_idx = self._find_bmu(vec)
            mapped.append(bmu_idx)
        return mapped


In [3]:
from sklearn.datasets import load_iris
data = load_iris().data

som = SelfOrganizingMap(x=10, y=10, input_len=4, learning_rate=0.5, num_iterations=500)
som.train(data)

mapped = som.map_vects(data)
print(mapped[:10])


[(2, 0), (0, 1), (0, 0), (0, 1), (2, 0), (5, 1), (0, 0), (1, 1), (0, 1), (0, 1)]
