In [1]:
import numpy as np
import math

In [2]:
input_dim=2
map_size=(5,5)
learning_rate=0.5
sigma=max(map_size)/2
num_iteration=1000

In [3]:
def initialize_weights(map_size,input_dim):
    return np.random.rand(map_size[0],map_size[1],input_dim)

In [4]:
def euclidean_distance(vec1,vec2):
    return np.linalg.norm(vec1-vec2)

In [5]:
def find_bmu(weights,input_vector):
    min_dist=float('inf')
    bmu_idx=(0,0)
    for i in range(weights.shape[0]):
        for j in range(weights.shape[1]):
            dist=euclidean_distance(weights[i,j],input_vector)
            if dist<min_dist:
                min_dist=dist
                bmu_idx=(i,j)
    return bmu_idx

In [6]:
def update_weights(weights,input_vector,bmu_idx,iteration,num_iterations,learning_rate,sigma):
    lr=learning_rate*(1-iteration/num_iterations)
    sig=sigma*(1-iteration/num_iterations)
    for i in range(weights.shape[0]):
        for j in range(weights.shape[1]):
            dist_to_bmu=euclidean_distance(np.array([i,j]),np.array(bmu_idx))
            influence=math.exp(-dist_to_bmu**2/(2*sig**2)) if sig > 0 else 0
            weights[i,j]+=lr*influence*(input_vector-weights[i,j])

In [8]:
np.random.seed(42)
data=np.random.rand(100,input_dim)

In [9]:
weights=initialize_weights(map_size,input_dim)

In [11]:
for iteration in range(num_iteration):
    input_vector=data[np.random.randint(0,len(data))]
    bmu_idx=find_bmu(weights,input_vector)
    update_weights(weights,input_vector,bmu_idx,iteration,num_iteration,learning_rate,sigma)

In [12]:
print("Training Complete. Final weight matrix:")
print(weights)

Training Complete. Final weight matrix:
[[[0.06911067 0.8891252 ]
  [0.08891785 0.6973514 ]
  [0.15167536 0.511875  ]
  [0.13601851 0.30787052]
  [0.15263778 0.1766668 ]]

 [[0.26397046 0.89202459]
  [0.26662598 0.6748733 ]
  [0.27536604 0.49428127]
  [0.29989858 0.29867383]
  [0.30364158 0.12295931]]

 [[0.51165396 0.80459907]
  [0.54676685 0.69427739]
  [0.49072331 0.47795606]
  [0.46850006 0.28046495]
  [0.44519142 0.18170196]]

 [[0.78847623 0.80778239]
  [0.72897406 0.70099813]
  [0.69969875 0.53688319]
  [0.66259183 0.3214137 ]
  [0.64563024 0.1303303 ]]

 [[0.91881107 0.81169544]
  [0.86326146 0.66024797]
  [0.836292   0.53552777]
  [0.84927399 0.33253933]
  [0.86946388 0.19965146]]]


In [16]:
import  as plt
print(plt.__version__)


3.8.4
