---
title: "ML Methods"
author:
  - name: Mahira Ayub
    affiliations:
      - id: bu
        name: Boston University
        city: Boston
        state: MA
  - name: Ava Godsy
    affiliations:
      - ref: bu
  - name: Joshua Lawrence
    affiliations:
      - ref: bu
date: today
format: 
  html:
    theme: minty
    bibliography: references.bib
    csl: csl/econometrica.csl
    toc: true
---

In [None]:
import plotly.graph_objects as go
from sklearn.cluster import KMeans
import numpy as np

# Prepare the data
X = df_clean[['STATE_INDEX', 'SALARY']].dropna()
soc_labels = df_clean.loc[X.index, 'SOC']

# Perform KMeans clustering
kmeans = KMeans(n_clusters=3, random_state=42, n_init=10)
clusters = kmeans.fit_predict(X)
centroids = kmeans.cluster_centers_

# Define colors
colors = ['#3D7C6A', '#B14E53', '#297C8A']

# Create figure
fig = go.Figure()

# Add scatter plots for each cluster
for i in range(3):
    mask = clusters == i
    fig.add_trace(go.Scatter(
        x=X.loc[mask, 'STATE_INDEX'],
        y=X.loc[mask, 'SALARY'],
        mode='markers',
        name=f'Cluster {i + 1}',
        marker=dict(
            color=colors[i],
            size=8,
            opacity=0.6
        ),
        text=[f'SOC: {soc}<br>State Index: {si:.1f}<br>Salary: ${sal:,.0f}' 
              for soc, si, sal in zip(soc_labels[mask], 
                                      X.loc[mask, 'STATE_INDEX'], 
                                      X.loc[mask, 'SALARY'])],
        hovertemplate='%{text}<extra></extra>'
    ))

# Add centroids
fig.add_trace(go.Scatter(
    x=centroids[:, 0],
    y=centroids[:, 1],
    mode='markers',
    name='Centroids',
    marker=dict(
        color='black',
        size=15,
        symbol='x',
        line=dict(width=2)
    ),
    hovertemplate='Centroid<br>State Index: %{x:.1f}<br>Salary: $%{y:,.0f}<extra></extra>'
))

# Update layout
fig.update_layout(
    title=dict(
        text='KMeans Clustering: State Index vs Salary by SOC',
        font=dict(family='Verdana', size=18)
    ),
    xaxis=dict(
        title=dict(text='State Index', font=dict(family='Verdana', size=14)),
        tickfont=dict(family='Verdana', size=14)
    ),
    yaxis=dict(
        title=dict(text='Salary ($)', font=dict(family='Verdana', size=14)),
        tickfont=dict(family='Verdana', size=14)
    ),
    font=dict(family='Verdana', size=14),
    legend=dict(font=dict(family='Verdana', size=14)),
    hovermode='closest',
    plot_bgcolor='#f8f9fa',
    paper_bgcolor='white'
)

# Show the plot
fig.show()