### Plotting K-means Clustering with Animation

In [1]:
#import libaries
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go



In [2]:
#read data
df = pd.read_csv(r'C:\Users\VeerenTaylor\Downloads\python\jupyter notebook\github\customer data\cdata\Mall_Customers.csv')

In [3]:
#use faceted plot to visualise the data

fig = px.scatter(df, x="Annual Income (k$)", y="Spending Score (1-100)", color='Age', facet_col="Gender", height=600, width=1000, color_continuous_scale="haline")
fig.show()

##### we can see around 5 clusters in the data

In [4]:
#using pandas lets define a function initialise centroids from the data
def get_centroid(data, k):

    centroids = pd.DataFrame()
    for i in range(k):
        centroids = pd.concat([data.sample(),centroids])

    centroids = centroids.reset_index(drop=True)
    return centroids

In [5]:
# lets take a subset of the data

data = df.iloc[:,3:5]
data

Unnamed: 0,Annual Income (k$),Spending Score (1-100)
0,15,39
1,15,81
2,16,6
3,16,77
4,17,40
...,...,...
195,120,79
196,126,28
197,126,74
198,137,18


In [6]:
centroids = get_centroid(data, 5)
centroids

Unnamed: 0,Annual Income (k$),Spending Score (1-100)
0,78,76
1,76,87
2,81,5
3,71,75
4,75,93


In [7]:
# create function to compute the distances from each centroids and returning the index of each

def get_distances(centroids, data):
    distances = centroids.T.apply(lambda x: np.sqrt((((data - x)**2).sum(axis=1))))
    return distances.idxmin(axis=1)


In [8]:
#calcualte new mean
def get_new_centroids(data,labels):
    return data.groupby(labels).mean()

In [9]:
#create function for plotting

def plot_animation(static_anime, centroid_anime):


    #define figure for scatter in plotly
    static_fig= px.scatter(static_anime.sort_values(by='iteration'), x='Annual Income (k$)', y='Spending Score (1-100)', 
                            animation_frame='iteration', 
                            animation_group='static index', 
                            height=800, width=800,
                            color='labels',
                            color_discrete_sequence=px.colors.qualitative.G10)
    

        #define figure for centroids in plotly
    centroids_fig = px.scatter(centroid_anime.sort_values(by='iteration'), x='Annual Income (k$)', y='Spending Score (1-100)',
                                animation_frame='iteration',
                                animation_group='centroid number',
                                height=800,
                                width=800)
    
    #update the markers and style in the centroids
    for f in range(len(centroids_fig.frames)):
        centroids_fig.frames[f]['data'][0]['marker']['symbol']='x'
        centroids_fig.frames[f]['data'][0]['marker']['color']='black'
        centroids_fig.frames[f]['data'][0]['marker']['size'] = 10
    
    
    frames = [go.Frame(data=f.data + centroids_fig.frames[i].data, name=f.name) for i, f in enumerate(static_fig.frames)]

    updmenus = [{"args": [None, {"frame": {"duration": 500}}],"label": "&#9654;","method": "animate",},
            {'args': [[None], {'frame': {'duration': 0}, 'mode': 'immediate', 'fromcurrent': False, }],
                  'label': '&#9724;', 'method': 'animate'} ]

# now can animate...
    fig = go.Figure(data=frames[0].data, frames=frames, layout=static_fig.layout).update_layout(
        updatemenus=[{"buttons":updmenus}],showlegend=False
    ).update_coloraxes(showscale=False)

    fig.show()


In [10]:
#putting the above two together, we can gether iteration data
z_iter = 1
centroids = None
new_centroids = None
labels= None
k=5
centroid_anime =pd.DataFrame()
static_anime = pd.DataFrame()


centroids = get_centroid(data, k)
centroids['iteration'] = 'iteration 0'
centroids['centroid number'] = centroids.index

centroid_anime = centroids


while z_iter < 10:

    labels = get_distances(centroids, data)
    new_centroids = get_new_centroids(data,labels)
    new_centroids['iteration'] = f'iteration {z_iter}'

    #creates data for the animaiton frame for centroids
    centroid_anime = pd.concat([new_centroids, centroid_anime])

    #creates data for the animation frame,which shows overall scatter plot
    static_df = df.copy()
    static_df['labels'] = labels
    static_df['iteration'] = f'iteration {z_iter}'

    static_anime = pd.concat([static_df, static_anime])



    #plot function



    centroids=new_centroids
    z_iter+=1

centroid_anime['centroid number']=centroid_anime.index
static_anime['static index'] = static_anime.index
static_anime['labels'] = static_anime['labels'].astype(str)

plot_animation(static_anime, centroid_anime)

