In [132]:
import h3
import folium
import numpy as np
import pandas as pd
from colour import Color
from sklearn.cluster import AgglomerativeClustering

db_uri = 'postgresql://postgres:reconnect@host.docker.internal:5432/postgres'

In [None]:
def visualize_hexagons(hexagons, colors, folium_map=None):
    """
    hexagons is a list of hexcluster. Each hexcluster is a list of hexagons. 
    eg. [[hex1, hex2], [hex3, hex4]]
    """
    polylines = []
    lat = []
    lng = []
    for hex in hexagons:
        polygons = h3.h3_set_to_multi_polygon([hex], geo_json=False)
        # flatten polygons into loops.
        outlines = [loop for polygon in polygons for loop in polygon]
        polyline = [outline + [outline[0]] for outline in outlines][0]
        lat.extend(map(lambda v:v[0],polyline))
        lng.extend(map(lambda v:v[1],polyline))
        polylines.append(polyline)
    
    if folium_map is None:
        m = folium.Map(location=[sum(lat)/len(lat), sum(lng)/len(lng)], zoom_start=8, tiles='cartodbpositron')
    else:
        m = folium_map
    for polyline, color in zip(polylines, colors):
        my_PolyLine=folium.Polygon(locations=polyline,color=color, fill=True)
        m.add_child(my_PolyLine)
    return m

In [36]:
data = pd.read_sql_query(
    '''
    select 
        kingdom,
        phylum,
        class,
        _order,
        family,
        genus,
        species,
        decimallongitude as lon,
        decimallatitude as lat
    from 
        stage.ma_winter
    where
        _order = 'Pinales'
    ''', 
    db_uri
)
data.shape

(1659, 9)

In [37]:
data.head()

Unnamed: 0,kingdom,phylum,class,_order,family,genus,species,lon,lat
0,Plantae,Tracheophyta,Pinopsida,Pinales,Pinaceae,Tsuga,Tsuga canadensis,-73.451506,41.853281
1,Plantae,Tracheophyta,Pinopsida,Pinales,Pinaceae,Pinus,Pinus strobus,-72.698434,41.899412
2,Plantae,Tracheophyta,Pinopsida,Pinales,Cupressaceae,Juniperus,Juniperus virginiana,-72.450091,42.302807
3,Plantae,Tracheophyta,Pinopsida,Pinales,Pinaceae,Pinus,Pinus rigida,-72.258005,42.216925
4,Plantae,Tracheophyta,Pinopsida,Pinales,Pinaceae,Pinus,Pinus strobus,-71.128597,42.608101


In [38]:
columns = ['kingdom', 'phylum', 'class', '_order', 'family', 'genus']
taxa_hierarchy = {}
for _, row in data.iterrows():
    level = taxa_hierarchy
    for col in columns:
        if row[col] not in level:
            level[row[col]] = {} if col != 'genus' else set()
        level = level[row[col]]
    level.add(row['species'])

In [39]:
def assign_scores(level, current_score=1.):
    scores = {}
    score = current_score / len(level)
    if type(level) == set:
        for el in level:
            scores[el] = score
    else:
        for key, sub_level in level.items():
            scores.update(assign_scores(sub_level, score))
    return scores

scores = assign_scores(taxa_hierarchy)

In [57]:
ZOOM = 4
data['h3_index'] = data.apply(lambda r: h3.geo_to_h3(r['lat'], r['lon'], ZOOM), axis=1)

In [58]:
species_indices = {}
current_index = 0
for _, row in data.iterrows():
    if row['species'] not in species_indices:
        species_indices[row['species']] = current_index
        current_index += 1
scores_by_index = {
    index: scores[species]
    for species, index in species_indices.items()
}

In [112]:
hexes = sorted(data['h3_index'].unique())
hex_indices = {
    h: i
    for i, h in enumerate(hexes)
}

vectors = [
    [0] * len(species_indices)
    for _ in hexes
]
for _, row in data.iterrows():
    vectors[hex_indices[row['h3_index']]][species_indices[row['species']]] = 1
    
linkage = [
    [0] * len(vectors)
    for _ in range(len(vectors))
]

def get_difference(v1, v2, scores_by_index):
    total_score = 0
    for i, score in scores_by_index.items():
        if v1[i] == v2[i]:
            total_score += score
    return 1 - total_score

for i, v1 in enumerate(vectors):
    for j, v2 in enumerate(vectors):
        if j <= i:
            difference = get_difference(v1, v2, scores_by_index)
            linkage[i][j] = difference
            linkage[j][i] = difference

array([[0.        , 0.47708333, 0.42708333, ..., 0.45625   , 0.40625   ,
        0.48958333],
       [0.47708333, 0.        , 0.09166667, ..., 0.12083333, 0.07083333,
        0.1125    ],
       [0.42708333, 0.09166667, 0.        , ..., 0.02916667, 0.02083333,
        0.0625    ],
       ...,
       [0.45625   , 0.12083333, 0.02916667, ..., 0.        , 0.05      ,
        0.09166667],
       [0.40625   , 0.07083333, 0.02083333, ..., 0.05      , 0.        ,
        0.08333333],
       [0.48958333, 0.1125    , 0.0625    , ..., 0.09166667, 0.08333333,
        0.        ]])

In [153]:
N = 10
clustering = AgglomerativeClustering(n_clusters=N, affinity='precomputed', linkage='average')
clusters = clustering.fit_predict(linkage)
clusters

array([5, 9, 0, 1, 0, 8, 0, 0, 0, 2, 2, 0, 1, 1, 1, 1, 0, 2, 2, 0, 2, 7,
       1, 0, 1, 1, 6, 3, 0, 6, 0, 0, 4, 4, 0, 2, 1, 2, 0, 0, 0])

In [154]:
start = 'blue'
end = 'orange'
color_pallette = list(Color(start).range_to(Color(end), N))
visualize_hexagons(hexes, [color_pallette[clusters[i]].hex for i in range(len(hexes))])