In [None]:
"""
Created by Alanna Wedum
    Last updated on 02/29/2024
    Reads in Mattingly AR ID file and plots ARs based on whether AR_data flag > 0
    Plots a circle within 200km of location of choice
    Change start and end time to plot over single or multiple time steps.
    Allows the user to hit play and watch the evolution of the AR during that time frame.
"""

import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from ipywidgets import interactive
from ipywidgets import Play, HBox, link
import numpy as np
from matplotlib.animation import FuncAnimation
from ipywidgets import IntSlider
from ipywidgets import fixed
import glob
from matplotlib.colors import ListedColormap
from datetime import datetime, timedelta

#Input the latitude and longitude you're interested in looking at
summit_lat, summit_lon = 72.5796, -38.4592

# Reading Data
fp = "/data/Mattingly_ARs/ERA5_ARs_NH_3hr_updated/ERA5_ARs_NH_3hr/" # Enter your filepath or work from current directory (".")
ds = xr.open_dataset(fp + "ARs_ERA5_NH_3hr_202210010000_202210312100.nc")

# Convert all times to pandas datetime objects for easier indexing
all_times = pd.to_datetime(ds['time'].values)

# Find indices for the range 2016-08-03 15:00:00 to 2016-08-05 18:00:00
start_time = pd.Timestamp('2022-10-16 00:00:00')
end_time = pd.Timestamp('2022-10-19 21:00:00')
relevant_indices = [i for i, time in enumerate(all_times) if start_time <= time <= end_time]


def haversine(lon1, lat1, lon2, lat2):
    """
    Calculate the great circle distance in meters between two points 
    on the earth (specified in decimal degrees)
    """
    # convert decimal degrees to radians 
    lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])
    # haversine formula 
    dlon = lon2 - lon1 
    dlat = lat2 - lat1 
    a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
    c = 2 * np.arcsin(np.sqrt(a)) 
    # Radius of earth in kilometers is 6371
    m = 6371000 * c
    return m

# Setting up the Plot
def plot_func(frame):
    
    fig, ax = plt.subplots(figsize = (8, 4), subplot_kw={'projection': ccrs.LambertConformal(central_longitude = summit_lon, central_latitude = summit_lat)})
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle='-')
    
    # Set the geospatial extents
    extents = [-90, 40, 40, 80]
    ax.set_extent(extents, crs=ccrs.PlateCarree())
    
    data = ds['AR_labels'][relevant_indices[frame], :, :]
    
    # Create a custom colormap: fully transparent for values < 0.5, red for values >= 0.5
    colors = [(0, 0, 0, 0) if val < 0.5 else (1, 0, 0) for val in np.unique(data)]
    cmap = ListedColormap(colors)

    c = ax.pcolormesh(ds['lon'], ds['lat'], data, transform=ccrs.PlateCarree(), cmap=cmap)
    
    plt.xlabel('Longitude')
    plt.ylabel('Latitude') 
             
    # Function to create circle points given a central lat, lon, and radius in meters
    def great_circle(lon, lat, radius_m, points=200):
        lats = np.zeros(points)
        lons = np.zeros(points)
        for i in range(points):
            angle = np.deg2rad(i * 360.0 / points)
            lats[i] = lat + np.rad2deg(radius_m / 6371000.0 * np.sin(angle))
            lons[i] = lon + np.rad2deg(radius_m / 6371000.0 / np.cos(np.deg2rad(lat)) * np.cos(angle))
        return lons, lats

    # Create circle of radius 200,000 meters (200 km) around point
    circle_lons, circle_lats = great_circle(summit_lon, summit_lat, 200000, points=200)
    ax.plot(circle_lons, circle_lats, transform=ccrs.Geodetic(), color='blue')
    
    # Add a star at specific latitude and longitude
    ax.scatter(summit_lon, summit_lat, marker='*', color='black', s=50, transform=ccrs.PlateCarree())
    
    time_str = str(all_times[relevant_indices[frame]])
    ax.text(0.02, 0.02, f'Time: {time_str}', transform=ax.transAxes, fontweight = 'bold')

print(np.unique(ds['AR_labels'][relevant_indices[0], :, :].values))

lon_index = (np.abs(ds['lon'].values - summit_lon)).argmin()
lat_index = (np.abs(ds['lat'].values - summit_lat)).argmin()

   
# Create the Play widget
play = Play(
    value=0,
    min=0,
    max=len(relevant_indices)-1,
    step=1,
    interval=2000,  # interval is the time delay between frames in milliseconds
    description="Press play",
    disabled=False
)
    
# Create the interactive plot
interactive_plot = interactive(plot_func, frame=(0, len(relevant_indices)-1))

# Link the play widget to the frame slider
slider = interactive_plot.children[0]
link((play, 'value'), (slider, 'value'))

# Adjust the height of the output display
output = interactive_plot.children[-1]
output.layout.height = '650px'

# Display the play widget and interactive plot together
display(HBox([play, interactive_plot]))