In [None]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sunpy.map
from sunpy.visualization.colormaps import color_tables as ct
from astropy.visualization import ImageNormalize, SinhStretch
from astropy.time import Time
import astropy.units as u
from matplotlib.dates import DateFormatter
import csv
from datetime import timedelta
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
import pickle
from functools import partial

# ---- Configuration ----
aia_folder = '/mnt/data2/AIA_processed_data/94'  # AIA FITS folder (512x512)
goes_csv = '/mnt/data/goes_combined/combined_g18_avg1m_20230701_20230815.csv'  # GOES CSV path
output_file = '../visualizations/aia_goes_synced.mp4'
frame_skip = 100  # Skip frames to speed up animation, adjust as needed
n_processes = min(cpu_count() - 1, 50)  # Use available CPUs but leave one free, max 8

window_length_minutes = 30  # Sliding window length on GOES plot (minutes)

# ---- Helper Functions for Multiprocessing ----
def load_aia_file(filepath):
    """Load and validate a single AIA file"""
    try:
        aia_map = sunpy.map.Map(filepath, lazy=True)
        return {
            'filepath': filepath,
            'time': aia_map.date.datetime,
            'valid': True
        }
    except Exception as e:
        return {
            'filepath': filepath,
            'time': None,
            'valid': False,
            'error': str(e)
        }

def preprocess_aia_data(args):
    """Preprocess AIA data for animation frame"""
    filepath, goes_times, goes_flux, window_length_minutes = args

    try:
        aia_map = sunpy.map.Map(filepath)
        aia_time = aia_map.date.datetime

        # Interpolate GOES flux at current AIA time
        if len(goes_times) > 0:
            interp_flux = np.interp(aia_time.timestamp(),
                                   [t.timestamp() for t in goes_times],
                                   goes_flux)
        else:
            interp_flux = 0

        # Sliding window limits on GOES plot
        window_start = aia_time - timedelta(minutes=window_length_minutes/2)
        window_end = aia_time + timedelta(minutes=window_length_minutes/2)

        # Filter GOES data within window
        windowed_times = []
        windowed_flux = []
        if len(goes_times) > 0:
            mask = (goes_times >= window_start) & (goes_times <= window_end)
            windowed_times = goes_times[mask]
            windowed_flux = goes_flux[mask]

        # Process AIA data
        valid_data = aia_map.data[np.isfinite(aia_map.data)]
        if len(valid_data) > 0:
            vmax = np.percentile(valid_data, 99)
            if vmax <= 0:
                vmax = np.max(valid_data)
        else:
            vmax = 1

        return {
            'aia_data': aia_map.data,
            'aia_time': aia_time,
            'aia_wcs': aia_map.wcs,
            'interp_flux': interp_flux,
            'windowed_times': windowed_times,
            'windowed_flux': windowed_flux,
            'window_start': window_start,
            'window_end': window_end,
            'vmax': vmax,
            'valid': True
        }
    except Exception as e:
        return {
            'filepath': filepath,
            'valid': False,
            'error': str(e)
        }

# ---- Load GOES Data ----
print("Loading GOES data...")
goes_times = []
goes_flux = []
with open(goes_csv, 'r') as f:
    reader = csv.DictReader(f)
    for row in tqdm(reader, desc="Reading GOES CSV"):
        try:
            t = Time(row['time']).to_datetime()
            fval = float(row['xrsb_flux'])
            goes_times.append(t)
            goes_flux.append(fval)
        except Exception:
            continue

goes_times = np.array(goes_times)
goes_flux = np.array(goes_flux)
print(f'Loaded {len(goes_times)} GOES data points')

# ---- Gather and Validate AIA Files ----
print("Gathering AIA files...")
all_aia_files = sorted(glob.glob(os.path.join(aia_folder, '*.fits')))
print(f"Found {len(all_aia_files)} AIA files")

if len(all_aia_files) == 0:
    print("No AIA files found. Please check the folder path.")
    exit()

# Validate files using multiprocessing
print("Validating AIA files using multiprocessing...")
with Pool(n_processes) as pool:
    validation_results = list(tqdm(
        pool.imap(load_aia_file, all_aia_files),
        total=len(all_aia_files),
        desc="Validating files"
    ))

# Filter valid files
valid_files = []
aia_times = []
for result in validation_results:
    if result['valid']:
        valid_files.append(result['filepath'])
        aia_times.append(result['time'])
    else:
        print(f"Skipping invalid file: {result['filepath']}")

print(f'{len(valid_files)} valid AIA files found.')

# Check if we have valid files
if len(valid_files) == 0:
    print("No valid AIA files found. Please check the folder path.")
    exit()

# Select files for animation (apply frame_skip)
animation_files = valid_files[::frame_skip]
animation_times = aia_times[::frame_skip]
frame_count = len(animation_files)

print(f"Will create animation with {frame_count} frames")

# ---- Preprocess Animation Data ----
print("Preprocessing animation data...")
preprocess_args = [
    (filepath, goes_times, goes_flux, window_length_minutes)
    for filepath in animation_files
]

with Pool(n_processes) as pool:
    frame_data = list(tqdm(
        pool.imap(preprocess_aia_data, preprocess_args),
        total=len(preprocess_args),
        desc="Preprocessing frames"
    ))

# Filter out invalid frames
valid_frame_data = [frame for frame in frame_data if frame.get('valid', False)]
print(f"Successfully preprocessed {len(valid_frame_data)} frames")

if len(valid_frame_data) == 0:
    print("No valid frames to animate.")
    exit()

# ---- Set up Plot ----
print("Setting up animation...")
cmap = ct.aia_color_table(94 * u.angstrom)

fig = plt.figure(figsize=(12, 6))
ax1 = fig.add_subplot(1, 2, 1)

# Initialize with empty data - will be updated in animation
line_goes, = ax1.plot([], [], color='orange', label='GOES SXR')
goes_marker, = ax1.plot([], [], 'ro', markersize=8, label='Current Time')
ax1.set_title('GOES SXR Light Curve')
ax1.set_ylabel('Flux (W/m²)')
ax1.set_xlabel('Time')
ax1.legend()
ax1.grid(True)
ax1.xaxis.set_major_formatter(DateFormatter('%H:%M'))

# Initialize ax2 with the first AIA map to get proper WCS projection
try:
    first_frame = valid_frame_data[0]
    ax2 = fig.add_subplot(1, 2, 2, projection=first_frame['aia_wcs'])
except Exception as e:
    print(f"Error setting up WCS projection: {e}")
    ax2 = fig.add_subplot(1, 2, 2)

def update(i):
    if i >= len(valid_frame_data):
        return line_goes, goes_marker

    frame = valid_frame_data[i]

    # Update GOES plot
    goes_marker.set_data([frame['aia_time']], [frame['interp_flux']])

    if len(frame['windowed_times']) > 0:
        line_goes.set_data(frame['windowed_times'], frame['windowed_flux'])
        ax1.set_xlim(frame['window_start'], frame['window_end'])
        ax1.set_ylim(np.min(frame['windowed_flux']) * 0.9,
                    np.max(frame['windowed_flux']) * 1.1)
    else:
        line_goes.set_data([], [])
        ax1.set_xlim(frame['window_start'], frame['window_end'])

    # Clear and update AIA plot
    ax2.cla()

    # Plot AIA data
    if frame['aia_data'].size > 0:
        norm = ImageNormalize(vmin=0, vmax=frame['vmax'], stretch=SinhStretch())

        try:
            # Create a temporary map for plotting
            temp_map = sunpy.map.Map(frame['aia_data'], frame['aia_wcs'])
            temp_map.plot(axes=ax2, cmap=cmap, norm=norm)
        except Exception as e:
            print(f"Error plotting AIA map: {e}")
            # Fallback to simple imshow
            ax2.imshow(frame['aia_data'], cmap=cmap, norm=norm, origin='lower')

    ax2.set_title(f'AIA 94 Å\n{frame["aia_time"].strftime("%Y-%m-%d %H:%M:%S")}')
    ax2.grid(False)

    return line_goes, goes_marker

# ---- Create Animation ----
print(f"Creating animation with {len(valid_frame_data)} frames...")

# Create progress bar for animation creation
pbar = tqdm(total=len(valid_frame_data), desc="Creating animation")

def animate_with_progress(frame_num):
    pbar.update(1)
    return update(frame_num)

# Use interval parameter to control animation speed
ani = animation.FuncAnimation(fig, animate_with_progress, frames=len(valid_frame_data),
                            interval=100, blit=False, repeat=False)

# ---- Save animation ----
print(f"Saving animation to {output_file}...")
try:
    # Use different writer options for better compatibility
    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=10, metadata=dict(artist='SunPy'), bitrate=1800)

    # Create progress bar for saving
    save_pbar = tqdm(total=len(valid_frame_data), desc="Saving animation")

    def progress_callback(current_frame, total_frames):
        save_pbar.update(1)

    ani.save(output_file, writer=writer, dpi=150, progress_callback=progress_callback)
    save_pbar.close()
    print(f'Animation saved to {output_file}')
except Exception as e:
    print(f"Error saving animation: {e}")
    print("Trying alternative save method...")
    try:
        ani.save(output_file, writer='pillow', fps=10)
        print(f'Animation saved to {output_file} using pillow writer')
    except Exception as e2:
        print(f"Failed to save animation: {e2}")

pbar.close()
print("Animation creation complete!")
plt.show()  # Display the final frame

Loading GOES data...


Reading GOES CSV: 66240it [00:16, 3998.36it/s]


Loaded 66240 GOES data points
Gathering AIA files...
Found 61995 AIA files
Validating AIA files using multiprocessing...


Validating files: 100%|██████████| 61995/61995 [00:52<00:00, 1172.70it/s]


61995 valid AIA files found.
Will create animation with 620 frames
Preprocessing animation data...


Preprocessing frames:  14%|█▍        | 87/620 [00:12<01:19,  6.74it/s]