In [90]:
import pandas as pd
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
import os
from ipywidgets import interact, interactive
import ipywidgets as widgets
print(torch.__version__)

2.5.1+cu121


In [2]:
# cheating in some variables here from some other data processing I did in a different project
MEAN_LAT = 38.28132872306216
MEAN_LNG = -91.65855796075903

In [35]:
# copying in the draw counties method for our visualizations
drawing_data = 'https://www.tyro.work/cont-us-counties-latlng-normalized.csv'
if not os.path.exists('/content/cont-us-counties-latlng-normalized.csv'):
  !wget $drawing_data
else:
  print("County Location Data has already been downloaded.")

drawing_df = pd.read_csv('/content/cont-us-counties-latlng-normalized.csv')
draw_lat_tensor = torch.tensor(drawing_df['lat'].values)
draw_lat_tensor = draw_lat_tensor.reshape(-1, 1)
draw_lng_tensor = torch.tensor(drawing_df['lng'].values)
draw_lng_tensor = draw_lng_tensor.reshape(-1, 1)
# i thought we needed to vstack these, but we actually need hstack
draw_lnglat_tensor = torch.hstack((draw_lng_tensor, draw_lat_tensor))
# print("county data shape: ", draw_lnglat_tensor.shape)
# print('county data: ', draw_lnglat_tensor)
# these constants will be useful for randomizing our centroids
lat_max = draw_lat_tensor.max().item()
lat_min = draw_lat_tensor.min().item()
lng_max = draw_lng_tensor.max().item()
lng_min = draw_lng_tensor.min().item()

def draw_counties():
  plt.figure(figsize = (8,5))
  plt.scatter(draw_lnglat_tensor[:, 0], draw_lnglat_tensor[:, 1], c='#79d2a4', marker='.', s=4, label='U.S. Counties')
  plt.xlabel('Longitude')
  plt.ylabel('Latitude')
  # Instead of converting all the data back to normal latitudes and longitudes, I'm gonna just hack the axis x_labels
  x_locs, x_labels = plt.xticks()
  x_labels = [int(item + MEAN_LNG) for item in x_locs]
  plt.xticks(x_locs, x_labels)
  y_locs, y_labels = plt.yticks()
  y_labels = [int(item + MEAN_LAT) for item in y_locs]
  plt.yticks(y_locs, y_labels)

# now we have a method to quickly draw the county locations
# draw_counties()

County Location Data has already been downloaded.


In [15]:
model_pytorch = 'https://www.tyro.work/optimized_model.pth'
if not os.path.exists('/content/optimized_model.pth'):
  !wget $model_pytorch
else:
  print("Model has already been downloaded.")

k_means_pytorch = 'https://www.tyro.work/k_means_model_data.pth'
if not os.path.exists('/content/k_means_model_data.pth'):
  !wget $k_means_pytorch
else:
  print("K Means Data has already been downloaded.")

Model has already been downloaded.
K Means Data has already been downloaded.


In [16]:
optimized_model = torch.load('/content/optimized_model.pth')

# [k_loss, k_init_centroids, k_final_centroids, k_end_centroid_count] is the consturction of the optimized model
k_loss = optimized_model[0]
k_init_centroids = optimized_model[1]
k_final_centroids = optimized_model[2]
k_end_centroid_count = optimized_model[3]

# print(k_end_centroid_count)
print('loaded model successfully')

# now need to load the test set as a model
k_means_data = torch.load('/content/k_means_model_data.pth')

# [training_set, testing_set, LAT_RANGE, LNG_RANGE, LAT_SHIFT, LNG_SHIFT]
training_set = k_means_data[0]
testing_set = k_means_data[1]
LAT_RANGE = k_means_data[2]
LNG_RANGE = k_means_data[3]
LAT_SHIFT = k_means_data[4]
LNG_SHIFT = k_means_data[5]
# print("lat range")
# print(LAT_RANGE)
# print("lng range")
# print(LNG_RANGE)
# print("lat shift")
# print(LAT_SHIFT)
# print("lng shift")
# print(LNG_SHIFT)
# print('training set')
# print(training_set)
# print('testing set')
# print(testing_set)
# print(' ')
print('loaded k means data successfully')

loaded model successfully
loaded k means data successfully


  optimized_model = torch.load('/content/optimized_model.pth')
  k_means_data = torch.load('/content/k_means_model_data.pth')


In [107]:
# optimized_model = [k_loss, k_init_centroids, k_final_centroids, k_end_centroid_count]
ks = [i for i in range(350)]
# print(len(ks), len(k_loss))

def draw_k(K = 225, Show_Initial_Centroids=False, Loss_Start=1):
  # Initialise the subplot function using number of rows and columns
  figure, axis = plt.subplots(1, 2, figsize=(20, 5))
  # draw_counties(), recreated for subplots:
  axis[0].scatter(draw_lnglat_tensor[:, 0], draw_lnglat_tensor[:, 1], c='#79d2a4', marker='.', s=4, label='U.S. Counties')
  axis[0].set_xlabel('Longitude')
  axis[0].set_ylabel('Latitude')
  axis[0].scatter(k_final_centroids[K][:, 0], k_final_centroids[K][:, 1], c='b', s=50, marker='2', label='Optimized Centroids')
  axis[0].set_title(f'K-Means Optimized Centroids For K={K}')
  axis[0].legend()

  left_bound = 1 if not Loss_Start else Loss_Start
  axis[1].plot(ks[left_bound:], k_loss[left_bound:], c='#56a0d3', marker='o', label='All Losses')
  axis[1].set_title(f'Best Loss K={K} is {k_loss[K]}')
  axis[1].plot([K], k_loss[K], c='r', marker=7, markersize=25, label='K Loss')

  # Instead of converting all the data back to normal latitudes and longitudes, I'm gonna just hack the axis x_labels
  # but for subplots need to finesse this with the .sca command
  plt.sca(axis[0])
  x_locs, x_labels = plt.xticks()
  x_labels = [int(item + MEAN_LNG) for item in x_locs]
  plt.xticks(x_locs, x_labels)
  y_locs, y_labels = plt.yticks()
  y_labels = [int(item + MEAN_LAT) for item in y_locs]
  plt.yticks(y_locs, y_labels)

  plt.show()
  return

# grid = widgets.Grid(2, 2)
slider = widgets.IntRangeSlider(
    value=[5, 7],
    min=0,
    max=10,
    step=1,
    description='Test:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
)
plot = interact(draw_k, K=(1, 349), Loss_Start=(0, 345, 25))

interactive(children=(IntSlider(value=225, description='K', max=349, min=1), Checkbox(value=False, description…