In [18]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import glob
import ipywidgets as widgets
from IPython.display import display

plt.rcParams['figure.figsize'] = (20, 16)

def process_and_plot(poss_file, times_file, ax1, ax2, ax3, ax4, label, use_boundary_layer, show_bl_height):
    # Load data
    try:
        poss = np.load(poss_file, allow_pickle=True)
        times = np.load(times_file, allow_pickle=True)
    except Exception as e:
        print(f"Error loading files: {str(e)}")
        return None, None

    heights = poss[:, 2, :]
    if use_boundary_layer:
        bl_heights = poss[:, 3, :]
    
    # Convert times to datetime objects if they're not already
    times = pd.to_datetime(times)
    
    # Calculate the most recent date
    most_recent_date = times.max()
    
    # Calculate time difference from the most recent date in hours
    time_diff_hours = (most_recent_date - times).total_seconds() / 3600
    
    # Check for NaNs
    if use_boundary_layer:
        valid_mask = ~np.isnan(heights) & ~np.isnan(bl_heights)
    else:
        valid_mask = ~np.isnan(heights)
    print(f"File: {label}")
    print(f"Total data points: {heights.size}")
    print(f"Valid data points: {np.sum(valid_mask)}")
    print(f"NaN or invalid data points: {heights.size - np.sum(valid_mask)}")
    
    # Calculate average heights for each time step
    avg_particle_height = np.nanmean(heights, axis=0)
    if use_boundary_layer:
        avg_bl_height = np.nanmean(bl_heights, axis=0)
    
    # Std height
    std_height = np.nanstd(heights, axis=0)
    
    # Calculate the difference if using boundary layer
    if use_boundary_layer:
        height_difference = avg_particle_height - avg_bl_height
    
    # Plot 1: Mean Height with Standard Deviation
    ax1.plot(time_diff_hours, avg_particle_height, label=f'{label} Mean Height')
    ax1.fill_between(time_diff_hours, avg_particle_height - std_height, avg_particle_height + std_height, alpha=0.3)
    if use_boundary_layer and show_bl_height:
        ax1.plot(time_diff_hours, avg_bl_height, linestyle='--', label=f'{label} Mean BL Height')
    
    # Plot 2: Average Particle Height (and Boundary Layer Height if used)
    ax2.plot(time_diff_hours, avg_particle_height, label=f'{label} Particle Height')
    if use_boundary_layer and show_bl_height:
        ax2.plot(time_diff_hours, avg_bl_height, linestyle='--', label=f'{label} BL Height')
    
    # Plot 3: Height Difference (only if using boundary layer)
    if use_boundary_layer:
        ax3.plot(time_diff_hours, height_difference, label=label)
    
    # Plot 4: Particle Distribution over Time (only if using boundary layer)
    if use_boundary_layer:
        above_bl = np.sum(heights > bl_heights, axis=0)
        below_bl = np.sum(heights <= bl_heights, axis=0)
        ax4.plot(time_diff_hours, above_bl, label=f'{label} Above BL')
        ax4.plot(time_diff_hours, below_bl, linestyle='--', label=f'{label} Below BL')
    
    return most_recent_date, time_diff_hours

def set_x_axis(ax, most_recent_date, max_time_diff):
    # This function remains the same as before
    ...

def plot_data(num_files, use_boundary_layer, show_bl_height):
    # Create a figure with four subplots in a 2x2 grid
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(20, 16))
    
    # Get all .pkl files in the 'pkl_files' directory
    pkl_dir = 'FLEXPART/pkl_files'
    poss_files = glob.glob(os.path.join(pkl_dir, '*_part_poss.pkl'))
    times_files = glob.glob(os.path.join(pkl_dir, '*_part_times.pkl'))
    
    # Sort the files to ensure matching pairs
    poss_files.sort()
    times_files.sort()
    
    print(f"Found {len(poss_files)} poss files and {len(times_files)} times files")
    
    # Process each pair of files
    max_time_diff = 0
    most_recent_date = None
    if len(poss_files) == len(times_files) and len(poss_files) > 0:
        for poss_file, times_file in zip(poss_files[:num_files], times_files[:num_files]):
            label = os.path.basename(poss_file)[:6]  # Use the first 6 letters of the file name as label
            print(f"Processing files: {poss_file} and {times_file}")
            result = process_and_plot(poss_file, times_file, ax1, ax2, ax3, ax4, label, use_boundary_layer, show_bl_height)
            if result[0] is not None:
                most_recent_date, time_diff_hours = result
                max_time_diff = max(max_time_diff, np.max(time_diff_hours))
    else:
        print("No matching pairs of files found.")
    
    # Set titles and labels for each subplot
    ax1.set_ylabel('Height')
    ax1.set_title('Mean Height with Standard Deviation')
    ax1.legend()
    ax1.grid(True)
    set_x_axis(ax1, most_recent_date, max_time_diff)
    
    ax2.set_ylabel('Height')
    ax2.set_title('Average Particle Height' + (' and Boundary Layer Height' if (use_boundary_layer and show_bl_height) else ''))
    ax2.legend()
    ax2.grid(True)
    set_x_axis(ax2, most_recent_date, max_time_diff)
    
    if use_boundary_layer:
        ax3.axhline(y=0, color='green', linestyle='--', label='y=0')
        ax3.set_ylabel('Height Difference')
        ax3.set_title('Average Particle Height minus Average Boundary Layer Height')
        ax3.legend()
        ax3.grid(True)
        set_x_axis(ax3, most_recent_date, max_time_diff)
        
        ax4.set_ylabel('Number of Particles')
        ax4.set_title('Distribution of Particles Relative to Boundary Layer')
        ax4.legend()
        ax4.grid(True)
        set_x_axis(ax4, most_recent_date, max_time_diff)
    else:
        fig.delaxes(ax3)
        fig.delaxes(ax4)
    
    # Adjust layout and show plot
    plt.tight_layout()
    plt.show()
    
    if most_recent_date is None:
        print("No data was processed. Check if the 'pkl_files' directory exists and contains the correct files.")

def get_file_count():
    pkl_dir = 'FLEXPART/pkl_files'
    poss_files = glob.glob(os.path.join(pkl_dir, '*_part_poss.pkl'))
    times_files = glob.glob(os.path.join(pkl_dir, '*_part_times.pkl'))
    return min(len(poss_files), len(times_files))

In [19]:
# Get the number of files in the directory
file_count = get_file_count()

# Create widgets for user input
num_files_widget = widgets.IntSlider(min=1, max=file_count, step=1, value=1, description='Number of files:')
use_bl_widget = widgets.Checkbox(value=True, description='Use Boundary Layer Data')
show_bl_height_widget = widgets.Checkbox(value=True, description='Show BL Height in Plots 1 & 2')

# Display widgets
display(num_files_widget, use_bl_widget, show_bl_height_widget)

# Create a button to generate the plot
plot_button = widgets.Button(description="Generate Plot")
display(plot_button)

# Define what happens when the button is clicked
def on_button_clicked(b):
    plot_data(num_files_widget.value, use_bl_widget.value, show_bl_height_widget.value)

plot_button.on_click(on_button_clicked)

IntSlider(value=1, description='Number of files:', max=14, min=1)

Checkbox(value=True, description='Use Boundary Layer Data')

Checkbox(value=True, description='Show BL Height in Plots 1 & 2')

Button(description='Generate Plot', style=ButtonStyle())