# Lightning Flash Dataset: EDA with K-Means Clustering

## Import Libraries

In [None]:
import warnings

import duckdb as db
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import seaborn as sns; sns.set() # plot styling

# from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

%matplotlib inline
warnings.filterwarnings('ignore')

## Data Load

In [None]:
# 2.Load dataset - db connections
# test dataset from GOES-West noaa-goes18
conn = db.connect("glmFlash.db.test") # Path to db file
lat_df = conn.execute("SELECT * FROM tbl_flash_lat;").df() # latitude co-ordinates
lon_df = conn.execute("SELECT * FROM tbl_flash_lon;").df() # longitude co-ordinates
print(lat_df.describe(), "\n")
print(lon_df.describe())

## Preprocessing

In [None]:
# 3.Preprocess - Cleaning up the data
# assuming each flash point corresponds to a lightning discharge?
lat_df.drop(lat_df.columns[[2]],axis=1,inplace=True)  # drop file name
lon_df.drop(lon_df.columns[[2]],axis=1,inplace=True)  # drop file name
print(lat_df.info(), "\n")
print(lon_df.info())

In [None]:
# remove duplicates
lat_df.drop_duplicates(subset=['ts_date'], inplace=True)
lon_df.drop_duplicates(subset=['ts_date'], inplace=True)
print(lat_df.ts_date.value_counts(), "\n")
print(lon_df.ts_date.value_counts())

In [None]:
# join dataframes
geo_df = lon_df.join(lat_df.set_index('ts_date'), on='ts_date')
geo_df.size

In [None]:
# add hour category
geo_df['hour'] = geo_df['ts_date'].dt.hour
# add day category
geo_df['day'] = geo_df['ts_date'].dt.strftime('%m-%d-%y')
# create initial dataset
geo_df_sm = geo_df[(geo_df['hour'] == 00)]
# dataset time window
start = geo_df.ts_date.min()
end = geo_df.ts_date.max()
print(f"Start: {start}; End: {end}")
geo_df_sm

In [None]:
# 3.Raw visualization
plt.figure(figsize=(6, 6))
plt.scatter(x=geo_df_sm['lon'], y=geo_df_sm['lat'])
plt.ylabel("Latitude")
plt.xlabel("Longitude")
plt.title("Raw Flash Events")

## Clustering

In [None]:
# 4. Implement K-Means clustering algorithm
kmeans_kwargs = {
    "init": "k-means++",
    "n_init": 10,
    "max_iter": 100,
    "random_state": 42,
}
geo_df_sm = geo_df_sm.loc[:, ["lon", "lat"]]
kmeans = KMeans(n_clusters=k, **kmeans_kwargs)
geo_df_sm["cluster"] = kmeans.fit_predict(geo_df_sm)
geo_df_sm["cluster"] = geo_df_sm["cluster"].astype("category")
geo_df_sm.head()

In [None]:
# Plot clusters
g = sns.relplot(
    x="lon", y="lat", hue="cluster", data=geo_df_sm, height=6,  palette="tab20", sizes=(10, 100)
).set(title='Clustered Flash Events', xlabel="Longitude", ylabel="Latitude")

## Evaluations

In [None]:
# 5a. Evaluate results, silhouette score
kmeans_kwargs = {
    "init": "k-means++",
    "n_init": 10,
    "max_iter": 100,
    "random_state": 42,
}
geo_df_sil = geo_df_sm.loc[:, ["lon", "lat"]]
# A list holds the silhouette coefficients for each k
silhouette_coefficients = []

# Start at 2 clusters for silhouette coefficient
for k in range(2, 24):
    kmeans = KMeans(n_clusters=k, **kmeans_kwargs)
    kmeans.fit(geo_df_sil)
    score = silhouette_score(geo_df_sil, kmeans.labels_)
    silhouette_coefficients.append(score)

In [None]:
plt.style.use("fivethirtyeight")
plt.plot(range(2, 24), silhouette_coefficients)
plt.xticks(range(2, 24))
plt.xlabel("Number of Clusters")
plt.ylabel("Silhouette Coefficient")
plt.show()

In [None]:
# 5b. Evaluate results, elbow method
kmeans_kwargs = {
    "init": "k-means++",
    "n_init": 10,
    "max_iter": 100,
    "random_state": 42,
}

# holds the sum of the squared distances for each k
ssd = []

# Return ssd for each k
for k in range(1, 24):
    kmeans = KMeans(n_clusters=k, **kmeans_kwargs)
    kmeans.fit(geo_df_sm)
    ssd.append(kmeans.inertia_)

In [None]:
plt.style.use("fivethirtyeight")
plt.plot(range(1, 24), ssd)
plt.xticks(range(1, 24))
plt.xlabel("Number of Clusters")
plt.ylabel("Sum of Squared Distances")
plt.show()

## Plotting

In [None]:
# 6. Overlay cluster on "world" map
# load US regions shape file
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
print(world.crs)

In [None]:
# Convert DataFrame to GeoDataFrame
geo = gpd.GeoDataFrame(geo_df_sm, geometry=gpd.points_from_xy(geo_df_sm.lon, geo_df_sm.lat))

# coodinate reference system (CRS) to EPSG 4326
geo.crs = {'init': 'epsg:4326'}
geo.head()

In [None]:
# plot map
fig, ax = plt.subplots(figsize=(10,10))
ax.set_aspect('equal')
world.plot(ax=ax, alpha=0.4, color='whitesmoke', linestyle=':', edgecolor='black', zorder=1)
geo.plot(ax=ax, column="cluster", alpha=0.7, cmap='viridis', linewidth=0.1, zorder=2)
plt.title(f"Lightning Clusters: approx. {start.date()} Hour: 00") # i.e. Start: 2023-04-19 23:59:59.601740; End: 2023-04-20 12:59:58.608381
plt.xlabel("longitude")
plt.ylabel("latitude")
plt.show()

In [None]:
# 7. Create simple time lapse
num_clusters = 7
kmeans_kwargs = {
    "init": "k-means++",
    "n_init": 10,
    "max_iter": 100,
    "random_state": 42,
}
DayList = geo_df['day'].drop_duplicates().to_list()
hourList = geo_df['hour'].drop_duplicates().to_list()
DayList.sort()
hourList.sort()
for j in DayList:
    for i in hourList:
        k = num_clusters
        geo_df_n = geo_df[(geo_df['hour'] == i)]
        geo_df_n = geo_df_n.loc[:, ["lon", "lat"]]
        num_samples = geo_df_n.shape[0]

        if num_clusters >= num_samples:
            k = num_samples
        else:
            pass

        print(f"Generating clusters for {j} on hour: {i}; k={num_clusters}...")

        kmeans = KMeans(n_clusters=k, **kmeans_kwargs)
        geo_df_n["cluster"] = kmeans.fit_predict(geo_df_n)
        geo_df_n["cluster"] = geo_df_n["cluster"].astype("category")
                    
        # Convert DataFrame to GeoDataFrame
        geo = gpd.GeoDataFrame(geo_df_n, geometry=gpd.points_from_xy(geo_df_n.lon, geo_df_n.lat))

        # coodinate reference system (CRS) to EPSG 4326
        geo.crs = {'init': 'epsg:4326'}

        # plot map
        fig, ax = plt.subplots(figsize=(10,10))
        ax.set_aspect('equal')
        world.plot(ax=ax, alpha=0.4, color='whitesmoke', linestyle=':', edgecolor='black', zorder=1)
        geo.plot(ax=ax, column="cluster", alpha=0.7, cmap='viridis', linewidth=0.1, zorder=2)
        plt.title(f"Lightning Clusters on {j} Hour: {i}; k={k}")
        plt.xlabel("longitude")
        plt.ylabel("latitude")
        # plt.show()
        filename = f"maps/lightning_clusters_{j}_{i}.png"
        fig.savefig(f"{filename}", bbox_inches="tight", dpi=600)

In [None]:
# Combine images to gif
import glob
from PIL import Image
from IPython.display import HTML

filePath = f"maps/sample_lightning_clusters.gif"

def png_to_gif(images_path, file_path, duration=500):
    frames = []
    images = glob.glob(images_path)

    for i in sorted(images):
        im = Image.open(i)
        im = im.resize((1200,800), Image.Resampling.LANCZOS)
        frames.append(im.copy())

    frames[0].save(f"{file_path}", format="GIF", append_images=frames[1:], save_all=True, duration=duration, loop=0, quality=100, optimize=True)

png_to_gif(images_path="maps/*.png", file_path=filePath, duration=500)



In [None]:
HTML(f'<img src="{filePath}", width="800", align="center">')