In [79]:
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.cluster import AgglomerativeClustering
import numpy as np
import pandas as pd
import plotly.graph_objects as go

In [220]:
N = 3
MIN_DIST = 1.5
BATCH_SIZE = 100
MIN_CLUSTER = 1/5

In [73]:
def get_centers(df, n=N):
    centers = np.array([[np.mean(df[df['class'] == i]['x']),
                         np.mean(df[df['class'] == i]['y'])] for i in range(n)])

    return centers

In [74]:
def is_enough(df, n=N):
    lens = np.array([len(df[df['class'] == i]) for i in range(n)])
    return min(lens) / max(lens) > MIN_CLUSTER

In [75]:
def dist(a, b):
    return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** .5

In [179]:
def is_close(centers, n=N):
    for i in range(n):
        for j in range(i + 1, n):
            length = dist(centers[i], centers[j])
            if length < MIN_DIST:
                return True
    return False

In [221]:
X = pd.read_csv('test6.csv')
x =  np.array([np.array(X['x']), np.array(X['y'])]).T
np.random.shuffle(x)
x_batch = x[:BATCH_SIZE]

batch_aggl = AgglomerativeClustering(n_clusters=N, linkage="single").fit(x_batch)
df = pd.DataFrame({'x': x_batch[:, 0], 'y': x_batch[:, 1], 'class': batch_aggl.labels_})
centers = get_centers(df, N)

if not is_close(centers) and is_enough(df, N):
    print('yo')
    kmeans = KMeans(n_clusters=N, init=centers, n_init=1).fit(x)
    labels = kmeans.labels_
else:
    print('yoyo')
    aggl = AgglomerativeClustering(n_clusters=N, linkage="single").fit(x)
    labels = aggl.labels_

df_pred = pd.DataFrame({'x': x[:, 0], 'y': x[:, 1], 'class': labels})

yo


In [222]:
kmeans = KMeans(n_clusters=N).fit(x)
labels = kmeans.labels_
df_pred = pd.DataFrame({'x': x[:, 0], 'y': x[:, 1], 'class': labels})

In [225]:
plt = go.Figure()
mn = min(min(X['y']), min(X['x']))
mx = max(max(X['y']), max(X['x']))
plt.update_yaxes(range=[mn - .2, mx + .2])
plt.update_xaxes(range=[mn - .2, mx + .2])
for cls in range(N):
    df_pred_i = df_pred[df_pred['class'] == cls]
    plt.add_trace(
        go.Scatter(
            x=df_pred_i['x'],
            y=df_pred_i['y'],
            mode='markers',
        )
    )
plt.add_trace(
    go.Scatter(
        x=centers[:, 0],
        y=centers[:, 1],
        mode='markers',
        marker=dict(
            color='White',
            size=10,
            line=dict(
                color='Black',
                width=1
            )
        )
    )
)
plt.update_layout(
    showlegend=False,
    autosize=False,
    width=600,
    height=600
)
plt.write_image(f'test6_pred_good.png', width=1024, height=1024, scale=4)
plt.show()