In [2]:
import numpy as np
from keras.datasets import mnist
import matplotlib.pyplot as plt

def normalize(data):
    return data / 255.0

def initialize_weights(input_dim, output_dim):
    return np.random.rand(input_dim, output_dim)

def euclidean_distance(x, y):
    return np.linalg.norm(x - y)

def find_best_matching_unit(data_point, weights):
    distances = [euclidean_distance(data_point, weight) for weight in weights]
    return np.argmin(distances)

def update_weights(data_point, weights, bmu_index, learning_rate, sigma):
    for i, weight in enumerate(weights):
        distance_to_bmu = euclidean_distance(weights[bmu_index], weight)
        influence = np.exp(-(distance_to_bmu ** 2) / (2 * (sigma ** 2)))
        weights[i] += learning_rate * influence * (data_point - weight)

def train_som(data, input_dim, output_dim, epochs, learning_rate_initial, sigma_initial):
    weights = initialize_weights(input_dim, output_dim)
    for epoch in range(epochs):
        learning_rate = learning_rate_initial * np.exp(-epoch / epochs)
        sigma = sigma_initial * np.exp(-epoch / epochs)
        for data_point in data:
            bmu_index = find_best_matching_unit(data_point, weights)
            update_weights(data_point, weights, bmu_index, learning_rate, sigma)
    return weights

def visualize(weights):
    plt.figure(figsize=(10, 10))
    for i, weight in enumerate(weights):
        plt.subplot(10, 10, i+1)
        plt.imshow(weight.reshape(28, 28), cmap='gray')
        plt.axis('off')
    plt.show()

def main():
    # Load and preprocess MNIST dataset
    (x_train, _), (_, _) = mnist.load_data()
    x_train = normalize(x_train.reshape(-1, 784))
    
    # Parameters
    input_dim = 784
    output_dim = 100
    epochs = 100
    learning_rate_initial = 0.1
    sigma_initial = output_dim / 2

    # Train the SOM
    weights = train_som(x_train, input_dim, output_dim, epochs, learning_rate_initial, sigma_initial)

    # Visualize the learned weights
    visualize(weights)

if __name__ == "__main__":
    main()


ValueError: operands could not be broadcast together with shapes (784,) (100,) 