[Reference](https://medium.com/@yocissms/lets-replicate-a-paper-lightning-flash-clustering-and-analysis-2304e82feafd)

In [1]:
import numpy as np
from typing import List, Dict, Tuple
import gzip
import datetime

array_type = np.ndarray[tuple[float], np.dtype[np.float64]]

def load_dat(gz_file_path: str, min_stations: int=7, max_chi_squared: float=1.0, max_altitude: float=20e3) -> Tuple[array_type, datetime.datetime]:

    with gzip.open(gz_file_path, 'rt') as f:
        lines = f.readlines()

    start_time_str = ""
    for l in lines:
        if "Data start time:" in l:
            start_time_str = l.replace("Data start time:", "").strip()

    start_time = datetime.datetime.strptime(start_time_str, "%m/%d/%y %H:%M:%S")
    idx = 1 + lines.index("*** data ***\n")
    lines = lines[idx:]
    data = np.zeros((len(lines), 7))

    for line_idx, l in enumerate(lines):
        splt = l.strip().split()
        for j in range(6):
            data[line_idx,j] = float(splt[j])
        data[line_idx,6] = float(int(splt[6], 0).bit_count())

    # from https://github.com/deeplycloudy/lmatools/blob/8d55e11dfbbe040f58f9a393f83e33e2a4b84b4c/examples/flashsort/clustertests/lma.py#L144
    return data[(data[:,6] >= min_stations) & (data[:,4] <= max_chi_squared) & (data[:,4] < max_altitude)], start_time

In [2]:
def to_ecef(spatial_data: array_type) -> array_type:
    rad_lat = spatial_data[:,0] * np.pi / 180
    rad_lon = spatial_data[:,1] * np.pi / 180
    altitudes = spatial_data[:,2]
    a=6378137.0
    b=6356752.314245
    e2=1 - b**2/a**2

    cot = 1 / np.tan(rad_lat)
    n_phi = a / np.sqrt(1 - (e2/ (1+ np.square(cot))))

    transformed_data = np.zeros(spatial_data.shape)
    transformed_data[:,0] = (n_phi + altitudes) * np.cos(rad_lat) * np.cos(rad_lon)
    transformed_data[:,1] = (n_phi + altitudes) * np.cos(rad_lat) * np.sin(rad_lon)
    transformed_data[:,2] = ((1-e2)*n_phi + altitudes) * np.sin(rad_lat)

    return transformed_data

In [6]:
import numpy as np
import os

center_geo = np.array([40.4463980, -104.6368130, 1000.00]) # COLMA center
lma_center = to_ecef(np.expand_dims(center_geo, axis=0))

n_grid_points, grid_spacing = 18, 0.2559
grid_start_lat = center_geo[0] - (n_grid_points / 2) * grid_spacing
grid_start_lon = center_geo[1] - (n_grid_points / 2) * grid_spacing

data, start_time = load_dat(os.path.join(data_dir, fname), min_stations=min_stations, max_chi_squared=max_chi_squared, max_altitude=max_altitude)

data_grid_lat = np.floor( (data[:,1] - grid_start_lat) / grid_spacing )
data_grid_lon = np.floor( (data[:,2] - grid_start_lon) / grid_spacing )
spatial_data = to_ecef(data[:,1:4]) - center

sources = np.zeros((data.shape[0], 10))
sources[:,0] = data[:,0]
sources[:,1:4] = spatial_data
sources[:,4] = data[:,5]
sources[:,5] = data_grid_lat
sources[:,6] = data_grid_lon
sources[:,7:] = data[:,1:4] # retain initial geodetic coordinates

In [7]:
from sklearn.cluster import DBSCAN

# sources[time, x, y, z, power, grid_lat, grid_lon, init_lat, init_lon, init_alt]
def cluster(sources: array_type, xyz_scale: float, t_scale: float, grid_max: int, min_samples: int=10, epsilon: float=1.0, max_duration: float=3.0) -> Tuple[List[array_type], int]:
    min_t = np.min(sources[:,0])
    max_t = np.max(sources[:,0])
    time_start = min_t

    all_flashes: List[array_type]=[]
    total_removed = 0

    algorithm = DBSCAN(eps=epsilon, min_samples=min_samples)

    while time_start <= max_t:
        indexes = (sources[:,0] >= time_start) & (sources[:,0] < time_start + max_duration*2)
        first_half_indexes = (sources[:,0] >= time_start) & (sources[:,0] < time_start + max_duration)

        dbscan_data = np.zeros((np.sum(indexes), 4))
        dbscan_data[:,:3] = sources[indexes, 1:4] / xyz_scale
        dbscan_data[:,3] = sources[indexes, 0] / t_scale

        if dbscan_data.shape[0] > 0:

            clustering = algorithm.fit(dbscan_data)
            first_half_unique_labels = np.unique(clustering.labels_[:np.sum(first_half_indexes)])
            first_half_cluster_indexes = np.squeeze(np.argwhere(np.isin(clustering.labels_, first_half_unique_labels)))
            cluster_labels = clustering.labels_[first_half_cluster_indexes]
            cluster_info, n_removed = get_cluster_list(sources[first_half_cluster_indexes,:], cluster_labels, grid_max)

            all_flashes += cluster_info
            total_removed += n_removed

            mask = np.ones(len(sources), bool)
            mask[first_half_cluster_indexes] = 0
            sources = sources[mask]

        time_start += max_duration

    return all_flashes, total_removed

def get_cluster_list(sources: array_type, cluster_ids: np.ndarray[int, np.dtype[np.int64]], grid_max: int) -> Tuple[List[array_type], int]:
    unique_cluster_ids = np.unique(cluster_ids)
    all_cluster_sources: List[array_type] = []
    n_removed = 0
    min_sources = 5

    for cluster_id in unique_cluster_ids:
        if cluster_id == -1:
            continue

        cluster_sources = sources[cluster_ids == cluster_id]

        # remove flashes with out-of-bounds sources
        if not np.all((cluster_sources[:,5:7] > -1) & (cluster_sources[:,5:7] < grid_max)):
            n_removed += 1
            continue

        if cluster_sources.shape[0] >= min_sources:
            all_cluster_sources.append(cluster_sources)

    return all_cluster_sources, n_removed

In [8]:
def merge_flashes(flashes: List[array_type], t_threshold: float=0.15, xyz_threshold: float=3000.0) -> List[array_type]:
    flashes_sorted = sorted(flashes, key=lambda x: x[0,0])

    flash_merges = np.zeros((len(flashes), 2)) - 1.0

    for j in range(1, len(flashes_sorted)):
        branch_flash = flashes_sorted[j]

        for i in range(j):
            base_flash = flashes_sorted[i]

            if branch_flash[0,0] - base_flash[-1,0] > t_threshold:
                continue

            dists: np.ndarray[float, np.dtype[np.float64]] = np.linalg.norm(base_flash[:,1:4] - branch_flash[0, 1:4], axis=1)

            if np.any(dists <= xyz_threshold):
                min_dist = np.min(dists)

                if flash_merges[j,1] == -1.0 or flash_merges[j,1] > min_dist:
                    merge_idx = int(flash_merges[i,0]) if flash_merges[i,0] > -1.0 else i
                    flash_merges[merge_idx,:] = np.array([merge_idx, min_dist])

    flashes_merged: List[array_type | None] = [elem for elem in flashes_sorted]

    for idx, flash in enumerate(flashes_sorted):
        merge_idx = int(flash_merges[idx, 0])

        if merge_idx > -1:
            merged_sources = np.concatenate((flash, flashes_sorted[merge_idx]))
            flashes_merged[merge_idx] = np.sort(merged_sources, axis=0)
            flashes_merged[idx] = None

    flashes_merged_remove_none: List[array_type] = [elem for elem in flashes_merged if elem is not None]

    return flashes_merged_remove_none

In [9]:
from scipy.spatial import ConvexHull

def get_flash_params(flashes: array_type, start_time: datetime.datetime) -> List[Dict]:
    flash_params: List[Dict] = []

    north_pole = np.array([90.0,0.0,0.0])

    for flash in flashes:
        grid_points = np.unique(flash[:,5:7], axis=0).astype('int64')
        n_sources = len(flash)
        duration = flash[-1,0] - flash[0,0]
        mean_power = np.mean(flash[:,4])

        # 2d hull area: see https://github.com/deeplycloudy/lmatools/blob/8d55e11dfbbe040f58f9a393f83e33e2a4b84b4c/lmatools/flashsort/flash_stats.py#L112
        init_coords = flash[:,7:]
        mean_point = np.mean(init_coords, axis=0)
        init_coords -= (mean_point + north_pole)
        coords_ecef = to_ecef(init_coords)
        hull = ConvexHull(coords_ecef)

        dist_from_center = np.linalg.norm(flash[0,1:3]) / 1000.0 # km

        init_time = flash[0,0]
        seconds_from_start = init_time - (start_time.hour * 3600 + start_time.minute * 60 + start_time.second)
        init_datetime = start_time + datetime.timedelta(seconds=seconds_from_start)

        flash_params.append({'n_sources': n_sources, 'duration': duration, 'mean_power': mean_power, 'grid_points': grid_points.tolist(),
            'init_alt': flash[0,9], 'init_time': init_datetime.strftime("%m/%d/%y %H:%M:%S"), 'hull_area': hull.volume*1e-6,
            'dist_from_center': dist_from_center})

    return flash_params


In [12]:
import numpy as np
import os
from datetime import date, timedelta
import json
from pathlib import Path

all_files = os.listdir(data_dir)

for f_idx, fname in enumerate(all_files):
  file_stem = str(Path(fname)).rstrip(''.join(Path(fname).suffixes))#Path(fname).stem
  print(f"processing {file_stem} ({str(f_idx + 1)} / {str(len(all_files))})")

  # load data
  # ...

  flashes, n_removed = cluster(sources, xyz_scale, t_scale, n_grid_points, min_samples=min_samples, epsilon=epsilon, max_duration=max_duration)

  if len(flashes) == 0:
      continue

  flashes_merged = merge_flashes(flashes, t_threshold=merge_t_threshold, xyz_threshold=merge_xyz_threshold)
  merge_count = len(flashes) - len(flashes_merged)

  if save_flash_sources:
      flashes_merged_npy = np.zeros((sum([elem.shape[0] for elem in flashes_merged]), 1+flashes_merged[0].shape[1]))
      counter = 0
      for flash_idx, flash in enumerate(flashes_merged):
          flashes_merged_npy[counter:counter+flash.shape[0],:-1] = flash
          flashes_merged_npy[counter:counter+flash.shape[0],-1] = flash_idx
          counter += flash.shape[0]

      with open(os.path.join(out_dir, f'{file_stem}_sources.npy'), 'wb') as f:
          np.save(f, flashes_merged_npy)

  flash_params = get_flash_params(flashes_merged, start_time)

  with open(os.path.join(out_dir, f'{file_stem}.json'), 'w') as f:
      json.dump({'data_start_time': start_time.strftime("%m/%d/%y %H:%M:%S"), 'n_removed': n_removed, 'merge_count': merge_count, 'flash_params': flash_params}, f)


In [13]:
import matplotlib.pyplot as plt

n_buckets = int(max_dist // bucket_size)
param_names = ('duration', 'hull_area', 'mean_power', 'n_sources')
param_units = ('(ms)', '(square km)', '(dBW)', '')
stats_buckets = [{param: [] for param in param_names} for i in range(n_buckets)]

for fname in os.listdir(json_dir):
    with open(os.path.join(json_dir, fname), 'r') as f:
        data = json.load(f)

    flashes = data['flash_params']
    total_flash_count += len(flashes)

    for flash in flashes:
        bucket_idx = int(flash['dist_from_center'] // bucket_size)
        grid_x_vals, grid_y_vals = [pt[0] for pt in flash['grid_points']], [pt[1] for pt in flash['grid_points']]

        if not ( (grid_min_x <= min(grid_x_vals) and grid_max_x > max(grid_x_vals)) and (grid_min_y <= min(grid_y_vals) and grid_max_y > max(grid_y_vals)) ):
            continue

        if bucket_idx > len(stats_buckets):
            continue
        for param in param_names:
            stats_buckets[bucket_idx][param].append(flash[param])

percentiles = (5, 50, 95)
all_flash_durations = []
for elem in stats_buckets:
    all_flash_durations += elem['duration']
all_flash_durations = np.array(all_flash_durations)

for p in percentiles:
    d = np.percentile(all_flash_durations, p) * 1000
    print(f'percentile {p} of flash durations (milliseconds): {round(d)}')

plt.rc('xtick', labelsize=8)
fig, axs = plt.subplots(2, 2)

for param_idx, param in enumerate(param_names):
    scale_val = 1e3 if param == 'duration' else 1.0 # convert duration to ms
    param_vals = [np.array(elem[param]) * scale_val for elem in stats_buckets]

    if param != 'mean_power':
        axs[param_idx//2, param_idx % 2].set_yscale('log')

    axs[param_idx//2, param_idx % 2].boxplot(param_vals, tick_labels=[str(i) for i in range(bucket_size, bucket_size*(len(param_vals)+1), bucket_size)], showfliers=False)
    axs[param_idx//2, param_idx % 2].set(xlabel='distance from LMA center (km)', ylabel=f'{param} {param_units[param_idx]}')
    axs[param_idx//2, param_idx % 2].set_title(param)

plt.tight_layout()
plt.savefig(os.path.join(plot_dir, 'flash_params.png'))
plt.close()

In [14]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
import datetime

all_flash_areas = {}
all_flash_hours = {}

for fname in os.listdir(json_dir):
    with open(os.path.join(json_dir, fname), 'r') as f:
        data = json.load(f)

    flashes = data['flash_params']

    for flash in flashes:

        for grid_point in flash['grid_points']:
            if tuple(grid_point) not in all_flash_areas:
                all_flash_areas[tuple(grid_point)] = []
            all_flash_areas[tuple(grid_point)].append(flash['hull_area'])

            if tuple(grid_point) not in all_flash_hours:
                all_flash_hours[tuple(grid_point)] = []

            dt = datetime.datetime.strptime(flash['init_time'], "%m/%d/%y %H:%M:%S")
            all_flash_hours[tuple(grid_point)].append(dt.strftime("%m/%d/%y %H"))

area_grid = np.zeros((grid_size, grid_size))
lh_grid = np.zeros((grid_size, grid_size))
fph_grid = np.zeros((grid_size, grid_size))

for k, v in all_flash_areas.items():
    flashes_sorted = sorted(v)
    median = flashes_sorted[len(flashes_sorted) // 2]
    area_grid[k[0],k[1]] = median

for k, v in all_flash_hours.items():
    n_hours = len(list(set(v)))
    lh_grid[k[0],k[1]] = n_hours
    cntr = Counter(v)
    fph_grid[k[0],k[1]] = np.mean(np.array(list(cntr.values())))

grids = (area_grid, lh_grid, fph_grid)

fig, axs = plt.subplots(1, 3, figsize=(9, 3), subplot_kw={'projection': ccrs.PlateCarree()})

for param_idx, param in enumerate(('flash_area', 'lightning_hours', 'flashes_per_lightning_hour')):
    grid = grids[param_idx]
    vmax = 100.0 if param == 'flash_area' else np.max(grid)

    axs[param_idx].add_feature(cfeature.STATES, linewidth=1.4, edgecolor='white')
    plot_extent = [*grid_longitude_range, *grid_latitude_range]
    axs[param_idx].set_extent(plot_extent,crs=ccrs.PlateCarree())
    axs[param_idx].imshow(grid, origin='upper', cmap='jet', extent=plot_extent, vmax=vmax, transform=ccrs.PlateCarree())

    for landmark_idx in range(len(landmark_names)):
        landmark_coords = landmark_coordinates[landmark_idx]

        axs[param_idx].plot(landmark_coords[1], landmark_coords[0], 'wo', markersize=7, transform=ccrs.PlateCarree())
        axs[param_idx].text(landmark_coords[1]-0.001, landmark_coords[0]+0.2, landmark_names[landmark_idx], color='black', size=14, path_effects=[pe.withStroke(linewidth=2, foreground="white")], transform=ccrs.PlateCarree())
    axs[param_idx].set_title(param)

plt.savefig(os.path.join(plot_dir, 'grid_stats.png'))