In [None]:
import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegWriter
from astropy.io import fits

def generate_spades_movie(movie_save_path, fits_thmaps, fits_composites=None, composite_plot=False, composite_overlay=False, channel=None, themes=None, fps=5, dpi=100):
    # Inputs: thematic maps (should already be sorted by time) and corresponding composite images if composites are to be plotted alongside
    # Composite plot true means that composite plot will be plotted alongside thematic map, composite overlay means specified themes will be overlayed on a composite image
    # Outputs: none, movie will be saved to the defined path
    
    # Define the movie metadata
    metadata = dict(title='Thematic Map Movie')
    # Create ffmpeg writer object with specified metadata and fps
    writer = FFMpegWriter(fps=fps, metadata=metadata)
    
    # Define parameters for the figure used to plot
    fig = plt.figure(figsize=(15,15))
    fig.tight_layout()
    plt.axis('off')
    
    # Define a list of theme indices to map from data values to color map
    theme_index = {'unlabeled': 0,
                   'outer_space': 1,
                   'bright_region': 3,
                   'filament': 4,
                   'prominence': 5,
                   'coronal_hole': 6,
                   'quiet_sun': 7,
                   'limb': 8,
                   'flare': 9}
    # Define solar classes, same as theme index but easier data type for colortable
    SOLAR_CLASSES = [('unlabeled', 0),
                             ('outer_space', 1),
                             ('bright_region', 3),
                             ('filament', 4),
                             ('prominence', 5),
                             ('coronal_hole', 6),
                             ('quiet_sun', 7),
                             ('limb', 8),
                             ('flare', 9)]
    SOLAR_CLASS_NAME = {number: theme for theme, number in SOLAR_CLASSES}
    # Define the plot colors which correspond to the themes
    SOLAR_COLORS = {"unlabeled": "white",
                    "outer_space": "black",
                    "bright_region": "#F0E442",
                    "filament": "#D55E00",
                    "prominence": "#E69F00",
                    "coronal_hole": "#009E73",
                    "quiet_sun": "#0072B2",
                    "limb": "#56B4E9",
                    "flare": "#CC79A7"}
    # Use the above definitions to create a colortable and color map
    colortable = [SOLAR_COLORS[SOLAR_CLASS_NAME[i]] if i in SOLAR_CLASS_NAME else 'white'
                  for i in range(max(list(SOLAR_CLASS_NAME.keys())) + 1)]
    cmap = matplotlib.colors.ListedColormap(colortable)
    
    # Make a list of channels and their corresponding image percentiles for plotting
    # TO DO
    
    #Iterate through all fits files specified
    with writer.saving(fig, movie_save_path, dpi):
        for i, fits_thmap in enumerate(fits_thmaps):
            # Extract data and date from fits_thmap file
            with fits.open(fits_thmap) as hdul:
                thmap_temp = hdul[0].data # CHANGE?
                date_temp = hdul[0].header['DATE_OBS'] # CHANGE?
            
            # Option 1 - plot composite image (specified by channel) and thematic map side by side
            if(composite_plot):
                # Get corresponding image set (image set should be specified)
                with fits.open(fits_composites[i]) as hdul:
                    image_set = hdul[0].data
                # Plot the data in several subplots
                ax1 = fig.add_subplot(121)
                ax1.imshow(image_set[channel], origin='lower', cmap='Greys_r', vmin=0, vmax=np.nanpercentile(image_set[channel], 99))
                ax1.axis('off')
                ax2 = fig.add_subplot(122)
                ax2.imshow(thmap_temp, origin='lower', cmap=cmap, vmin=-1, vmax=len(colortable))
                ax2.axis('off')
                fig.tight_layout()
                fig.suptitle(str(date_temp), fontsize=20)
                writer.grab_frame()
            
            # Option 2 - plot composite image (specified by channel) with all themes in input parameter themes overlayed
            elif(composite_overlay):
                # Get corresponding image set (image set should be specified)
                with fits.open(fits_composites[i]) as hdul:
                    image_set = hdul[0].data
                # Create a mask for where each theme exists in the image
                masked = np.zeros((np.shape(thmap_temp)[0], np.shape(thmap_temp)[1]))
                masked[:] = np.nan
                for theme in themes:
                    masked[np.where(thmap_temp==theme_index[theme])] = theme_index[theme]
                # Plot the composite image
                plt.imshow(image_set[channel], origin='lower', cmap='Greys_r', vmin=0, vmax=np.nanpercentile(image_set[channel], 99))
                # Plot the masked theme data
                plt.imshow(masked, cmap=cmap, vmin=-1, vmax=len(colortable), alpha=0.6)
                plt.title(str(date_temp), fontsize=15)
                plt.axis('off')
                writer.grab_frame()
            
            # Option 3 - plot only the thematic map
            else:
                plt.imshow(thmap_temp, origin='lower', cmap=cmap, vmin=-1, vmax=len(colortable))
                plt.title(str(date_temp), fontsize=15)
                plt.axis('off')
                #Grab current image for movie
                writer.grab_frame()
            
            #Clear the figure
            fig.clear()