In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd

In [None]:
class Node:
    def __init__(self, num_weights, x, y):
        self.weights = np.random.rand(num_weights)
        self.x = x
        self.y = y

    def calculate_distance(self, input_vector):
        distance = np.sum((self.weights - input_vector) ** 2)
        return distance

    def adjust_weights(self, target, learning_rate, influence):
        self.weights += learning_rate * influence * (target - self.weights)


In [None]:
class SOM:
    def __init__(self, x, y, num_iterations, size, learning_rate):
        self.x = x
        self.y = y
        self.num_iterations = num_iterations
        self.size_of_input_vector = size
        self.initial_learning_rate = learning_rate

        self.som = []
        for i in range(self.x):
            for j in range(self.y):
                self.som.append(Node(self.size_of_input_vector, i, j))

        self.map_radius = max(self.x, self.y) / 2
        self.time_constant = num_iterations / np.log(self.map_radius)
        self.iteration_count = 0
        self.done = False

    def epoch(self, data):
        if len(data[0]) != self.size_of_input_vector:
            return False

        if self.done:
            return True

        if self.iteration_count < self.num_iterations:
            current_data = np.random.randint(0, len(data))

            # Finding BMU
            bmu = self.find_best_matching_node(data[current_data])

            # Exponential decay of radius and learning rate
            self.neighbourhood_radius = self.map_radius * np.exp(-self.iteration_count / self.time_constant)
            self.learning_rate = self.initial_learning_rate * np.exp(-self.iteration_count / self.num_iterations)

            # Adjusting weights
            for node in self.som:
                dist_to_node_sq = (node.x - bmu.x)**2 + (node.y - bmu.y)**2
                if dist_to_node_sq < self.neighbourhood_radius ** 2:
                    influence = np.exp(-dist_to_node_sq / (2 * (self.neighbourhood_radius ** 2)))
                    node.adjust_weights(data[current_data], self.learning_rate, influence)

            self.iteration_count += 1
        else:
            self.done = True

        return True

    def find_best_matching_node(self, vec):
        lowest_distance = float('inf')
        winner = None

        for node in self.som:
            dist = node.calculate_distance(vec)
            if dist < lowest_distance:
                lowest_distance = dist
                winner = node

        return winner

    # Plotting SOM grid to show clustering
    def plot_grid(self, data, uniquesIDs):
      grid = np.zeros((self.x, self.y, 3))
      self.colorMapping = []

      # Assigning weights to r, g, b for more spread out cluster colors
      red_weight, green_weight, blue_weight = 1.0, 0.9, 0.6
      for i in range(self.x):
          for j in range(self.y):
              rgb = np.zeros(3)
              weights = self.som[i * self.y + j].weights
              for k in range(len(weights)):
                  if k % 3 == 0:
                      rgb[0] += weights[k] * red_weight
                  elif k % 3 == 1:
                      rgb[1] += weights[k] * green_weight
                  elif k % 3 == 2:
                      rgb[2] += weights[k] * blue_weight
              rgb /= sum(rgb)

              grid[i,j] = rgb
              coordinates = (i, j)

      # Creating mapping between colors, cities and cluster coordinates
      xIndex = 0
      for x in data:
          bestClusterCoordinates = self.find_best_matching_node(x).x, self.find_best_matching_node(x).y
          iD = uniquesIDs[xIndex]
          self.colorMapping.append([iD, bestClusterCoordinates])
          xIndex += 1

      for lst in self.colorMapping:
          coordinates = lst[1]
          lst.append(grid[coordinates[0], coordinates[1]])

      # Plot the grid
      plt.imshow(grid)
      plt.show()


    # Visulaing dataset on world map
    def plot_map(self):
        worldmap = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
        _, ax = plt.subplots()
        worldmap.plot(ax=ax, facecolor='white', edgecolor='black')
        countries = worldmap["name"].tolist()

        for i in self.colorMapping:
            if i[0] in countries:
                worldmap[worldmap.name == i[0]].plot(color=i[2], ax=ax)

        plt.show()




In [None]:
data = pd.read_csv("/content/Q1_countrydata.csv")

# Replacing name of country in data with name corresponding to geopanda worldmap file for plotting
data.replace('United States', 'United States of America', inplace = True)
data.replace( 'Russian Federation', 'Russia',inplace = True)
data.replace('Congo (Kinshasa)', 'Dem. Rep. Congo', inplace = True)
data.replace('Congo (Brazzaville)', 'Congo', inplace = True)
data.replace('Central African Republic', 'Central African Rep.', inplace = True)
data.replace('South Sudan', 'S. Sudan', inplace = True)
data.replace('Taiwan*', 'Taiwan', inplace = True)
data.replace('Korea, South', 'South Korea', inplace = True)
data.replace('Eswatini', 'eSwatini', inplace = True)
data.replace('Equatorial Guinea', 'Eq. Guinea', inplace = True)
aggregated_data = data.groupby('Country_Region').agg({
    'Confirmed': 'sum',
    'Deaths': 'sum',
    'Recovered': 'sum'
}).reset_index()
normalized_data = aggregated_data.copy()
cols_to_normalize = ['Confirmed', 'Deaths', 'Recovered']
for col in cols_to_normalize:
    normalized_data[col] = (aggregated_data[col] - aggregated_data[col].min()) / (aggregated_data[col].max() - aggregated_data[col].min())

# Extract numeric columns and convert to NumPy array
numeric_data = normalized_data[['Confirmed', 'Deaths', 'Recovered']].values

uniquesIDs = data['Country_Region'].unique()

som = SOM(10, 10, 5000, 3, 0.05)
som.plot_grid(numeric_data, uniquesIDs)
while not som.done:
    som.epoch(numeric_data)
    if som.iteration_count % 100 == 0:
        som.plot_grid(numeric_data, uniquesIDs)

som.plot_grid(numeric_data, uniquesIDs)
som.plot_map()

Reference: http://www.ai-junkie.com/ann/som/som4.html