In [None]:
import numpy as np
import pandas as pd
import torch
from skyfield.api import load, wgs84
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.animation as animation


# === Skyfield Setup ===
ts = load.timescale()
eph = load('de421.bsp')
earth, moon, sun = eph['earth'], eph['moon'], eph['sun']

def get_subsolar_lonlat(dt):
    t = ts.utc(dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second)
    subsolar = wgs84.subpoint(earth.at(t).observe(sun))
    lon = (subsolar.longitude.degrees + 180) % 360 - 180
    lat = subsolar.latitude.degrees
    return lon, lat

def get_sublunar_lonlat(dt):
    t = ts.utc(dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second)
    sublunar = wgs84.subpoint(earth.at(t).observe(moon))
    lon = (sublunar.longitude.degrees + 180) % 360 - 180
    lat = sublunar.latitude.degrees
    return lon, lat

# Compute solar & lunar coordinates for each epoch
solar_lonlat = np.array([get_subsolar_lonlat(dt) for dt in all_epochs])
lunar_lonlat = np.array([get_sublunar_lonlat(dt) for dt in all_epochs])

# === Create ML-ready dataset ===
time_steps = len(all_epochs)
lat_len, lon_len = len(lats), len(lons)
ml_data = np.zeros((time_steps, 2, lat_len, lon_len), dtype=np.float32)

lon_grid, lat_grid = np.meshgrid(lons, lats)

# Compute distances on lat/lon grid to subsolar & sublunar points
for idx in range(time_steps):
    ml_data[idx, 0] = np.sqrt(
        (lon_grid - solar_lonlat[idx, 0])**2 + 
        (lat_grid - solar_lonlat[idx, 1])**2
    )
    ml_data[idx, 1] = np.sqrt(
        (lon_grid - lunar_lonlat[idx, 0])**2 + 
        (lat_grid - lunar_lonlat[idx, 1])**2
    )

# Save tensor for ML training
torch.save(torch.tensor(ml_data), 'solar_lunar_positions_tensor.pt')
print("Saved solar_lunar_positions_tensor.pt")

# === Visualization for Sanity Check ===
fig, ax = plt.subplots(1, 2, figsize=(14, 5), subplot_kw={'projection': ccrs.PlateCarree()})

def update(frame_idx):
    for a in ax:
        a.clear()
        a.coastlines()
        a.gridlines(draw_labels=True)
    
    epoch = all_epochs[frame_idx]

    ax[0].pcolormesh(lons, lats, ml_data[frame_idx, 0], cmap='plasma_r')
    ax[0].plot(solar_lonlat[frame_idx,0], solar_lonlat[frame_idx,1], 'o', color='red', markersize=10, transform=ccrs.PlateCarree())
    ax[0].set_title(f"Subsolar Distance\n{epoch}")

    ax[1].pcolormesh(lons, lats, ml_data[frame_idx, 1], cmap='magma_r')
    ax[1].plot(lunar_lonlat[frame_idx,0], lunar_lonlat[frame_idx,1], 'o', color='yellow', markersize=10, transform=ccrs.PlateCarree())
    ax[1].set_title("Sublunar Distance")

ani = animation.FuncAnimation(fig, update, frames=time_steps, interval=250)
ani.save('solar_lunar_positions_check.gif', writer='pillow', fps=4)
print("Saved solar_lunar_positions_check.gif")
