In [None]:
import pickle
from src.utils.helper_som import get_minimal_distance_factors

import numpy as np
import pandas as pd

from matplotlib import cm, colorbar
from bokeh.models import ColumnDataSource, Plot, LinearAxis, Grid
from bokeh.models.glyphs import HexTile
from bokeh.io import curdoc, show, output_notebook

output_notebook()

In [None]:
# import data, normed document vectors and som
df_sample = pd.read_pickle(filepath_or_buffer='data/df_sample.pkl')
normed_array_doc_vec = np.load(file='data/doc_vec_norm_sample.npy')
with open('data/som.p', 'rb') as infile:
    som = pickle.load(infile)

# compute parameters
len_vector = normed_array_doc_vec.shape[1]
x, y = get_minimal_distance_factors(n=5 * np.sqrt(normed_array_doc_vec.shape[0]))

# will consider all the sample mapped into a specific neuron as a cluster.
# to identify each cluster more easily, will translate the bi-dimensional indices
# of the neurons on the SOM into mono-dimensional indices.
# each neuron represents a cluster
winner_coordinates = np.array([som.winner(x) for x in normed_array_doc_vec]).T
# with np.ravel_multi_index, we convert the bi-dimensional coordinates to a mono-dimensional index
cluster_index = np.ravel_multi_index(multi_index=winner_coordinates, dims=(x, y))

In [None]:
xx, yy = som.get_euclidean_coordinates() # return position of neurons on a euclidean plane that reflects chosen topology in meshgrids xx, yy e.g. (1,4) -> xx[1,4], yy[1,4]
umatrix = som.distance_map() # returns distance map of the weights
weights = som.get_weights() # returns weights of neural network

In [None]:
tile_centres_column = []
tile_centres_row = []
colours = []
for i in range(weights.shape[0]):
    for j in range(weights.shape[1]):
        wy = yy[(i, j)] * 2 / np.sqrt(3) * 3 / 4
        tile_centres_column.append(xx[(i, j)])
        tile_centres_row.append(wy)
        colours.append(cm.Blues(umatrix[i, j]))
        
weight_x = []
weight_y = []
for cnt, x in enumerate(normed_array_doc_vec):
    w = som.winner(x)
    wx, wy = som.convert_map_to_euclidean(xy=w)
    wy = wy * 2 / np.sqrt(3) * 3/4
    weight_x.append(wx)
    weight_y.append(wy)

source = ColumnDataSource(dict(q=tile_centres_column, r=tile_centres_row, c=colours))

## Attempt 1
Using `HexTile()` and `add_glyph()`.

In [None]:
plot = Plot(title=None, plot_width=800, plot_height=800)

hex = HexTile(q='q', r='r',
              size=.95 / np.sqrt(3),
              fill_color='c',
              fill_alpha=.4,
              line_color='gray')
plot.add_glyph(source_or_glyph=source, glyph=hex)

xaxis = LinearAxis()
plot.add_layout(obj=xaxis, place='below')
yaxis = LinearAxis()
plot.add_layout(obj=yaxis, place='left')

plot.add_layout(obj=Grid(dimension=0, ticker=xaxis.ticker))
plot.add_layout(obj=Grid(dimension=1, ticker=yaxis.ticker))
curdoc().add_root(plot)
show(plot)

# Attempt 2

In [None]:
from bokeh.plotting import figure, output_file, show

# match_aspect=True ensures have regular hexagons
# https://docs.bokeh.org/en/latest/docs/reference/plotting.html?highlight=figure#bokeh.plotting.figure
plot = figure(plot_width=800, plot_height=800, match_aspect=True) 
plot.hex_tile(q=tile_centres_column, r=tile_centres_row, 
              size=(.95 / np.sqrt(3)),
              color=colours,
              fill_alpha=.4,
              line_color='black')

show(plot)

It's like the hexes need to be flipped vertically to match the dots...

The hexes are also squished into each other and overlapping one another.

In [None]:
# match_aspect=True ensures have regular hexagons (matches x-axis with y-axis)
# https://docs.bokeh.org/en/latest/docs/reference/plotting.html?highlight=figure#bokeh.plotting.figure
plot = figure(plot_width=800, plot_height=800,
              #x_range=(0, weights.shape[0]), y_range=(0, weights.shape[1]),
              match_aspect=True) 
plot.hex_tile(q=tile_centres_column, r=tile_centres_row, 
              size=.95 / np.sqrt(3),
              color=colours,
              fill_alpha=.4,
              line_color='black')
plot.dot(x=weight_x, y=weight_y,
         fill_color='black',
         size=12)

show(plot)

# Attempt within Loops

In [None]:
plot = figure(plot_width=800, plot_height=800,
              #x_range=(0, weights.shape[0]), y_range=(0, weights.shape[1]),
              match_aspect=True) 
for i in range(weights.shape[0]):
    for j in range(weights.shape[1]):
        wy = yy[(i, j)] * 2 / np.sqrt(3) * 3 / 4
        plot.hex_tile(q=xx[(i, j)], r=wy,
                      size=.95 / np.sqrt(3),
                      color=cm.Blues(umatrix[i, j]),
                      fill_alpha=.4,
                      line_color='black')
show(plot)