In [1]:
import pandas  as  pd
import  numpy  as  np
import  matplotlib.pyplot  as  plt
import  seaborn  as  sns
import  plotly.express  as  px
import  plotly.graph_objects  as  go
from  plotly.subplots  import  make_subplots
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
import  warnings

In [2]:
data =  pd.read_csv("drinks.csv")
data.head()

Unnamed: 0,country,beer_servings,spirit_servings,wine_servings,total_litres_of_pure_alcohol
0,Afghanistan,0,0,0,0.0
1,Albania,89,132,54,4.9
2,Algeria,25,0,14,0.7
3,Andorra,245,138,312,12.4
4,Angola,217,57,45,5.9


In [3]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 193 entries, 0 to 192
Data columns (total 5 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   country                       193 non-null    object 
 1   beer_servings                 193 non-null    int64  
 2   spirit_servings               193 non-null    int64  
 3   wine_servings                 193 non-null    int64  
 4   total_litres_of_pure_alcohol  193 non-null    float64
dtypes: float64(1), int64(3), object(1)
memory usage: 7.7+ KB


In [4]:
X = data.iloc[:, 1:5].values

kmeans = KMeans(n_clusters=3, init='k-means++', random_state=42)
y_kmeans = kmeans.fit_predict(X)

data['country_index'] = range(len(data))
data['hover_text'] = data.apply(lambda row: row['country'] + '<br>' + 'Cluster: ' + str(y_kmeans[row.name]), axis=1)

cluster_colors = {0: 'red', 1: 'green', 2: 'blue'}
data['cluster_color'] = data.apply(lambda row: cluster_colors[y_kmeans[row.name]], axis=1)

# Create a list of traces for each country
traces = []
for index, row in data.iterrows():
    trace = go.Scatter3d(
        x=[row['beer_servings']],
        y=[row['spirit_servings']],
        z=[row['wine_servings']],
        text=[row['hover_text']],
        name=row['country'],
        mode='markers',
        marker=dict(size=5, color=row['cluster_color']),
        visible=True,
        hovertemplate=(
            "Country: %{text}<br>"
            "Beer: %{x}<br>"
            "Spirit: %{y}<br>"
            "Wine: %{z}<br>"
        )
    )
    traces.append(trace)

highlight_color = 'yellow'

# Create a dropdown menu with a list of countries
buttons = [
    dict(
        args=[{
            'marker.color': [row['cluster_color'] for _, row in data.iterrows()]
        }],
        label='All',
        method='restyle'
    )
]

for i, country in enumerate(data['country']):
    buttons.append(
        dict(
            args=[{
                'marker.color': [highlight_color if j == i else row['cluster_color'] for j, row in data.iterrows()]
            }],
            label=country,
            method='restyle'
        )
    )

# Add dropdown menu to layout
layout = go.Layout(
    title='3D Scatter Plot of Alcohol Consumption by Country',
    scene=dict(xaxis_title='Beer', yaxis_title='Spirit', zaxis_title='Wine'),
    updatemenus=[
        go.layout.Updatemenu(
            buttons=buttons,
            direction='down',
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0,
            xanchor="left",
            y=1.2,
            yanchor="top"
        )
    ]
)

fig = go.Figure(data=traces, layout=layout)
fig.show()



In [5]:
import plotly.io as pio
# Save the plot as an HTML file
pio.write_html(fig, file='3d_scatter_plot1.html', auto_open=True)