In [39]:
import numpy as np
import matplotlib.pyplot as plt
from math import ceil

In [6]:
data = np.loadtxt('data/animals.dat', delimiter = ',', dtype = int)
animalNames = np.loadtxt('data/animalnames.txt', dtype = str)

In [19]:
data = np.reshape(data, (32,84))
data.shape

(32, 84)

In [16]:
np.shape(SOM().weights)

(100, 84)

In [133]:
class SOM():
    def __init__(self, data):
        self.weights = np.random.rand(100,84)
        self.data = data

    def train(self, eta, iterations, data = None, starting_fraction = None):
        if data == None:
            data = self.data
        if starting_fraction == None:
            starting_fraction = 4
        
        area_size = ceil(self.weights.shape[0]/starting_fraction)
        for epoch in range(iterations):
            self.iterate(data, area_size, eta)
            area_size = max(area_size - 1, 1)
            if area_size < 0:
                return
    
    
    def iterate(self, data, area_size, eta):
        for dataindex in range(data.shape[0]):
            datapoint = data[dataindex,:].copy()
            winner = self.find_winner(datapoint)
            neighbourhood = self.find_neighbourhood(winner, area_size)
            self.update_weights(neighbourhood, eta, datapoint)
            
    def update_weights(self, neighbourhood, eta, datapoint):
        for weight_index in range(neighbourhood[0],neighbourhood[1]):
            weight = self.weights[weight_index,:].copy()
            new_weight = np.add(weight, eta*(np.subtract(datapoint,weight)))
            self.weights[weight_index,:] = new_weight
            
    
    def distance(self, x, y):
        sub_vec = np.subtract(x, y)
        dist = np.dot(sub_vec.T, sub_vec)
        return dist
            
    def find_winner(self, datapoint):
        min_dist = self.distance(self.weights[0,:], datapoint)
        min_index = 0
        for weightindex in range(self.weights.shape[0]):
            dist = self.distance(self.weights[weightindex,:], datapoint)
            if dist < min_dist:
                min_dist = dist
                min_index = weightindex
        return min_index
    
    def find_neighbourhood(self, winner, area_size):
        diff = int(area_size/2)
        max_ind = self.weights.shape[0]
        if area_size == 1:
            if winner == max_ind:
                lower = winner -1
                upper = winner
            else:
                lower = winner
                upper = min(winner+1, max_ind)
        elif area_size == 2:
            if np.random.rand() > 1:
                lower = winner
                upper = min(winner + 2, max_ind)
            else:
                lower = max(winner-2, 0)
                upper = winner
        else:
            lower = max(winner - diff, 0)
            upper = min(winner + diff, max_ind)

        return (lower, upper)
    
    def predict(self, data):
        predictions = []
        for dataindex in range(data.shape[0]):
            datapoint = data[dataindex,:].copy()
            winner = self.find_winner(datapoint)
            predictions.append([winner, animalNames[dataindex]])
        predictions = np.array(predictions,dtype=object)
        predictions = predictions[predictions[:,0].argsort()]
  
        return predictions
        
        

In [134]:
x = SOM(data)

In [135]:
x.train(0.1, 100)

In [136]:
pred = x.predict(data)

In [137]:
pred

array([[0, "'moskito'"],
       [1, "'dragonfly'"],
       [4, "'grasshopper'"],
       [5, "'beetle'"],
       [7, "'butterfly'"],
       [11, "'housefly'"],
       [15, "'spider'"],
       [19, "'pelican'"],
       [19, "'duck'"],
       [22, "'penguin'"],
       [23, "'ostrich'"],
       [26, "'frog'"],
       [28, "'seaturtle'"],
       [30, "'crocodile'"],
       [33, "'walrus'"],
       [36, "'bear'"],
       [37, "'dog'"],
       [39, "'hyena'"],
       [42, "'skunk'"],
       [44, "'ape'"],
       [47, "'lion'"],
       [48, "'cat'"],
       [51, "'elephant'"],
       [52, "'bat'"],
       [55, "'rat'"],
       [59, "'horse'"],
       [62, "'pig'"],
       [64, "'camel'"],
       [65, "'giraffe'"],
       [70, "'kangaroo'"],
       [73, "'antelop'"],
       [77, "'rabbit'"]], dtype=object)