# Animation of the testing results with TLS as reference

In [None]:
# Imports and path setup
from __future__ import annotations
from pathlib import Path
import pandas as pd
import json
from datetime import datetime, time
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import seaborn as sns
from datetime import datetime
import multiprocessing as mp

import matplotlib.image as mpimg
import cv2
from tqdm import tqdm
from PIL import Image
import numpy as np
import imageio

# Ensure project root is on sys.path for imports if needed
import sys

PROJECT_ROOT = Path.cwd().parent.parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

# Paths
TESTING_CSV = PROJECT_ROOT / "data" / "01_Training_Validation_Data" / "splits" / "testing.csv"
IMAGES_DIR = PROJECT_ROOT / "data" / "01_Training_Validation_Data" / "image_data"
TLS_DIR = PROJECT_ROOT / "data" / "02_Test_Data" / "TLS" / "analysis"
#TESTING_DIR_CALATHEA = PROJECT_ROOT / "data/02_Test_Data/images/calathea_ornata"
#TESTING_DIR_MARANTA = PROJECT_ROOT / "data/02_Test_Data/images/maranta_leuconeura"
TESTING_DIR_CALATHEA = PROJECT_ROOT / "data/other/images/testing/calathea_ornata_all_images"
TESTING_DIR_MARANTA = PROJECT_ROOT / "data/other/images/testing/maranta_leuconeura_all_images"

TESTING_RESULTS_DIR = PROJECT_ROOT / "data" / "03_Model_Outputs" / "predictions" / "testing"
FIGURE_OUTPUT_DIR = PROJECT_ROOT / "data" / "other" / "figures" / "results"
ANIMATION_OUTPUT_DIR = PROJECT_ROOT / "data" / "04_Supplementary_Material"

# Colors
custom_colors = ["#C2B2B4", "#6B4E71", "#3A4454", "#53687E", "#F5DDDD"]
custom_colors2 = ["#F1F1FE", "#9492B9", "#AFBFCD", "#3A739D", "#ADC0A8"]

# Create color palette
custom_palette = sns.color_palette(custom_colors)
custom_palette2 = sns.color_palette(custom_colors2)

# Basic checks
assert TESTING_CSV.exists(), f"Testing CSV not found: {TESTING_CSV}"
assert IMAGES_DIR.exists(), f"Images directory not found: {IMAGES_DIR}"
assert TLS_DIR.exists(), f"TLS directory not found: {TLS_DIR}"
assert TESTING_DIR_CALATHEA.exists(), f"Calathea testing directory not found: {TESTING_DIR_CALATHEA}"
assert TESTING_DIR_MARANTA.exists(), f"Maranta testing directory not found: {TESTING_DIR_MARANTA}"
assert TESTING_RESULTS_DIR.exists(), f"Testing results directory not found: {TESTING_RESULTS_DIR}"
assert FIGURE_OUTPUT_DIR.exists(), f"Figure output directory not found: {FIGURE_OUTPUT_DIR}"
assert ANIMATION_OUTPUT_DIR.exists(), f"Animation output directory not found: {ANIMATION_OUTPUT_DIR}"

In [None]:
with open(TESTING_RESULTS_DIR / "calathea_ornata_predictions.json", "r") as f:
    calathea_results = json.load(f)

with open(TESTING_RESULTS_DIR / "maranta_leuconeura_predictions.json", "r") as f:
    maranta_results = json.load(f)

In [None]:
# Animation parameters - centralized configuration
ANIMATION_CONFIG = {
    'target_duration_seconds': 20,  # Changed to 20 seconds
    'dpi': 120,  # Reduced for faster processing
    'figure_size': (18, 7),  # Slightly smaller for speed
    'width_ratios': [1.5, 1],  # Time series : Image ratio
    'subplot_spacing': 0.02,
    
    # Visual parameters
    'rolling_window': 3,
    'marker_size': 60,  # Slightly smaller
    'line_width': 4,
    'alpha_scatter': 0.7,
    'alpha_grid': 0.2,
    
    # Font sizes
    'xlabel_fontsize': 14,
    'ylabel_fontsize': 14,
    'legend_fontsize': 12,
    'tick_labelsize': 12,
    'title_fontsize': 16,
    
    # Colors
    'anglecam_color': custom_palette2[3],
    'tls_color': custom_palette[1],
    'night_color': '#f0f0f0',
    'night_alpha': 1.0,
    
    # Day/night timing
    'day_start_hour': 8,
    'day_end_hour': 18,
    
    # GIF-specific parameters
    'gif_duration_ms': 1,  # Duration between frames in milliseconds
    'gif_optimize': True,  # Optimize GIF for smaller file size
    'gif_loop': 0,  # Loop forever (0 = infinite loop)
    'gif_max_colors': 256,  # Maximum colors in GIF palette
}

def extract_datetime_from_filename(filename):
    """Extract datetime from image filename."""
    # G5Bullet_55_2025-01-17_17_32_00_corrected.jpg
    parts = filename.split('_')
    date_part = parts[2]  # 2025-01-17
    hour_part = parts[3]  # 17
    minute_part = parts[4]  # 32
    
    # Parse datetime
    year, month, day = date_part.split('-')
    datetime_str = f"{day}-{month}-{year} {hour_part}:{minute_part}"
    return datetime.strptime(datetime_str, "%d-%m-%Y %H:%M")

def get_all_image_files(image_dir):
    """Get all image files sorted by datetime."""
    image_files = []
    for img_path in Path(image_dir).glob("*.jpg"):
        try:
            dt = extract_datetime_from_filename(img_path.name)
            image_files.append((dt, img_path))
        except:
            continue
    
    # Sort by datetime
    image_files.sort(key=lambda x: x[0])
    return image_files

def prepare_prediction_data(results):
    """Prepare prediction data with timestamps and pre-calculate smoothed lines."""
    data = []
    for result in results['predictions']: 
        try:
            dt = datetime.strptime(result['datetime'], "%d-%m-%Y %H:%M")
            data.append({
                'filename': result['filename'],
                'predicted_mean_angle': result['predicted_mean_angle'],
                'reference_mean_angle': result['reference_mean_angle'],
                'timestamp': dt
            })
        except:
            continue
    
    df = pd.DataFrame(data).sort_values('timestamp')
    
    # Pre-calculate smoothed lines for ALL data
    if len(df) > 0:
        df['reference_smooth'] = (
            df['reference_mean_angle']
            .rolling(window=ANIMATION_CONFIG['rolling_window'], center=True, min_periods=1)
            .mean()
        )
        df['predicted_smooth'] = (
            df['predicted_mean_angle']
            .rolling(window=ANIMATION_CONFIG['rolling_window'], center=True, min_periods=1)
            .mean()
        )
    
    return df

def create_animation_frame_optimized(args):
    """Optimized frame creation function for multiprocessing."""
    frame_idx, image_files, prediction_data_full, output_path, config, plant_name, time_range, y_range = args
    
    # Get current image
    current_dt, current_image_path = image_files[frame_idx]
    
    # Get predictions up to current time (using pre-calculated smoothed data)
    current_predictions = prediction_data_full[prediction_data_full['timestamp'] <= current_dt].copy()
    
    # Create figure with optimized settings
    plt.ioff()  # Turn off interactive mode for speed
    fig, (ax1, ax2) = plt.subplots(
        1, 2, 
        figsize=config['figure_size'],
        gridspec_kw={"width_ratios": config['width_ratios'], "wspace": config['subplot_spacing']},
        dpi=config['dpi']
    )
    
    # LEFT PLOT: Time series (evolving)
    time_min, time_max = time_range
    y_min, y_max = y_range
    
    # Add day/night background (pre-calculated)
    dates = pd.date_range(time_min.date(), time_max.date(), freq='D').date
    
    for date in dates:
        # Night periods
        night_start_early = pd.Timestamp.combine(date, time(0, 0))
        night_end_early = pd.Timestamp.combine(date, time(config['day_start_hour'], 0))
        night_start_late = pd.Timestamp.combine(date, time(config['day_end_hour'], 0))
        night_end_late = pd.Timestamp.combine(date, time(23, 59, 59))
        
        # Shade night periods
        if night_end_early >= time_min and night_start_early <= time_max:
            ax1.axvspan(
                max(night_start_early, time_min), 
                min(night_end_early, time_max),
                color=config['night_color'], 
                alpha=config['night_alpha'], 
                zorder=0
            )
        
        if night_end_late >= time_min and night_start_late <= time_max:
            ax1.axvspan(
                max(night_start_late, time_min), 
                min(night_end_late, time_max),
                color=config['night_color'], 
                alpha=config['night_alpha'], 
                zorder=0
            )
    
    # Plot data if available
    if len(current_predictions) > 0:
        # Plot smoothed lines (using pre-calculated values)
        if len(current_predictions) > 1:
            ax1.plot(
                current_predictions['timestamp'], 
                current_predictions['reference_smooth'],
                color=config['tls_color'], 
                linewidth=config['line_width'], 
                alpha=0.9, 
                label='TLS',
                solid_capstyle='round'
            )
            ax1.plot(
                current_predictions['timestamp'], 
                current_predictions['predicted_smooth'],
                color=config['anglecam_color'], 
                linewidth=config['line_width'], 
                alpha=0.9, 
                label='AngleCam',
                solid_capstyle='round'
            )
        
        # Plot scatter points
        ax1.scatter(
            current_predictions['timestamp'], 
            current_predictions['reference_mean_angle'],
            color=config['tls_color'], 
            s=config['marker_size'], 
            alpha=config['alpha_scatter'], 
            zorder=3,
            edgecolors='white',
            linewidth=0.5
        )
        ax1.scatter(
            current_predictions['timestamp'], 
            current_predictions['predicted_mean_angle'],
            color=config['anglecam_color'], 
            s=config['marker_size'], 
            alpha=config['alpha_scatter'], 
            zorder=3,
            edgecolors='white',
            linewidth=0.5
        )
        
        # Add legend only if we have data
        if len(current_predictions) >= 1:
            ax1.legend(fontsize=config['legend_fontsize'], frameon=False, loc='upper left')
    
    # Set fixed axis limits (pre-calculated for consistency)
    ax1.set_ylim(y_min, y_max)
    ax1.set_xlim(time_min, time_max)
    
    # Add current time indicator
    ax1.axvline(current_dt, color='red', linestyle='--', alpha=0.8, linewidth=2, zorder=4)
    
    # Formatting
    ax1.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))
    ax1.xaxis.set_major_locator(mdates.HourLocator(byhour=[0, 6, 12, 18]))
    ax1.grid(True, alpha=config['alpha_grid'])
    ax1.spines["top"].set_visible(False)
    ax1.spines["right"].set_visible(False)
    ax1.set_xlabel("Time of Day", fontsize=config['xlabel_fontsize'])
    ax1.set_ylabel("Leaf Angle (Â°)", fontsize=config['ylabel_fontsize'])
    ax1.tick_params(axis="both", which="major", labelsize=config['tick_labelsize'])
    
    # RIGHT PLOT: Current image
    if current_image_path.exists():
        img = mpimg.imread(current_image_path)
        ax2.imshow(img)
        ax2.axis('off')
        ax2.set_title(f'{current_dt.strftime("%d-%m-%Y %H:%M")}', 
                     fontsize=config['title_fontsize'], pad=20)
    else:
        ax2.text(0.5, 0.5, f'Image not found', ha='center', va='center', 
                transform=ax2.transAxes, fontsize=12)
        ax2.set_xlim(0, 1)
        ax2.set_ylim(0, 1)
    
    # Save optimized
    plt.subplots_adjust(left=0.06, right=0.98, top=0.92, bottom=0.12, wspace=config['subplot_spacing'])
    plt.savefig(output_path, dpi=config['dpi'], facecolor='white', edgecolor='none')
    plt.close()
    plt.ion()  # Turn interactive mode back on

def create_timelapse_animation(plant_name, image_dir, prediction_results, output_dir):
    """Create complete time-lapse animation for a plant with multiprocessing."""
    
    print(f"Creating animation for {plant_name}...")
    
    # Setup directories
    frames_dir = Path(output_dir) / "frames"
    video_dir = Path(output_dir) / "video"
    frames_dir.mkdir(parents=True, exist_ok=True)
    video_dir.mkdir(parents=True, exist_ok=True)
    
    # Get all images and predictions
    image_files = get_all_image_files(image_dir)
    prediction_data = prepare_prediction_data(prediction_results)
    
    print(f"Found {len(image_files)} images and {len(prediction_data)} predictions")
    
    # Debug: Show time range of images and predictions
    if image_files:
        img_start = image_files[0][0]
        img_end = image_files[-1][0]
        print(f"Images time range: {img_start.strftime('%d-%m-%Y %H:%M')} to {img_end.strftime('%d-%m-%Y %H:%M')}")
    
    if len(prediction_data) > 0:
        pred_start = prediction_data['timestamp'].min()
        pred_end = prediction_data['timestamp'].max()
        print(f"Predictions time range: {pred_start.strftime('%d-%m-%Y %H:%M')} to {pred_end.strftime('%d-%m-%Y %H:%M')}")
    
    # Calculate FPS to achieve target duration
    fps = len(image_files) / ANIMATION_CONFIG['target_duration_seconds']
    
    print(f"Using {fps:.2f} FPS for {ANIMATION_CONFIG['target_duration_seconds']}s video")
    print(f"This will create a {len(image_files) / fps:.1f}s duration video")
    
    # Pre-calculate axis ranges for consistency
    time_min = prediction_data['timestamp'].min()
    time_max = prediction_data['timestamp'].max()
    y_min = min(prediction_data['predicted_mean_angle'].min(), 
               prediction_data['reference_mean_angle'].min())
    y_max = max(prediction_data['predicted_mean_angle'].max(), 
               prediction_data['reference_mean_angle'].max())
    y_margin = 0.1 * (y_max - y_min)
    time_range = (time_min, time_max)
    y_range = (y_min - y_margin, y_max + y_margin)
    
    # Prepare arguments for multiprocessing
    frame_args = []
    for i, _ in enumerate(image_files):
        frame_path = frames_dir / f"frame_{i:06d}.png"
        args = (i, image_files, prediction_data, frame_path, ANIMATION_CONFIG, 
               plant_name, time_range, y_range)
        frame_args.append(args)
    
    # Generate frames using multiprocessing
    print("Generating frames with multiprocessing...")
    num_processes = min(mp.cpu_count() - 1, 30)  # Leave one core free, max 8 processes
    
    with mp.Pool(processes=num_processes) as pool:
        list(tqdm(
            pool.imap(create_animation_frame_optimized, frame_args),
            total=len(frame_args),
            desc=f"Generating frames ({num_processes} processes)"
        ))
    
    # Create video
    video_path = video_dir / f"{plant_name}_timelapse.mp4"
    print(f"Creating video: {video_path}")
    
    # Get frame paths
    frame_paths = [frames_dir / f"frame_{i:06d}.png" for i in range(len(image_files))]
    
    # Read first frame to get dimensions
    first_frame = cv2.imread(str(frame_paths[0]))
    height, width, layers = first_frame.shape
    
    # Create video writer with optimized codec
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(
        str(video_path), fourcc, fps, (width, height)
    )
    
    # Write frames to video
    for frame_path in tqdm(frame_paths, desc="Writing video"):
        frame = cv2.imread(str(frame_path))
        video_writer.write(frame)
    
    video_writer.release()
    
    # Try to close OpenCV windows (fails in headless environments, but that's OK)
    try:
        cv2.destroyAllWindows()
    except:
        pass  # Ignore errors in headless environments
    
    actual_duration = len(image_files) / fps
    print(f"Video saved: {video_path}")
    print(f"Actual duration: {actual_duration:.1f}s at {fps} FPS")
    
    # Create GIF
    gif_path = video_dir / f"{plant_name}_timelapse.gif"
    print(f"Creating GIF: {gif_path}")
    
    # Calculate GIF frame duration to match target duration
    gif_frame_duration_ms = int((ANIMATION_CONFIG['target_duration_seconds'] * 1000) / len(image_files))
    print(f"GIF frame duration: {gif_frame_duration_ms}ms (target: {ANIMATION_CONFIG['target_duration_seconds']}s)")
    
    # Test if consecutive frames are different
    if len(frame_paths) >= 2:
        img1 = imageio.imread(frame_paths[0])
        img2 = imageio.imread(frame_paths[1])
        print(f"Frame 0 shape: {img1.shape}, Frame 1 shape: {img2.shape}")
        diff = np.mean(np.abs(img1.astype(float) - img2.astype(float)))
        print(f"Average pixel difference between first two frames: {diff:.2f}")
    
    # Create GIF with imageio (more reliable than PIL)
    print("Creating GIF with imageio...")
    
    # Take every nth frame to reduce GIF size if needed
    step = max(1, len(frame_paths) // 100)  # Limit to ~100 frames for reasonable file size
    selected_paths = frame_paths[::step]
    print(f"Using every {step} frame(s), total: {len(selected_paths)} frames for GIF")
    
    # Calculate duration for imageio (in seconds per frame)
    gif_duration_per_frame = ANIMATION_CONFIG['target_duration_seconds'] / len(selected_paths)
    print(f"GIF duration per frame: {gif_duration_per_frame:.3f}s")
    
    with imageio.get_writer(gif_path, mode='I', duration=gif_duration_per_frame) as writer:
        for frame_path in tqdm(selected_paths, desc="Writing GIF frames"):
            if frame_path.exists():
                image = imageio.imread(frame_path)
                writer.append_data(image)
    
    # Verify GIF was created and get file size
    if gif_path.exists():
        gif_size_mb = gif_path.stat().st_size / (1024 * 1024)
        print(f"GIF saved: {gif_path}")
        print(f"GIF size: {gif_size_mb:.1f} MB")
        print(f"Total frames in GIF: {len(selected_paths)}")
        print(f"Total animation duration: {gif_duration_per_frame * len(selected_paths):.1f}s")
    else:
        print("ERROR: GIF file was not created!")
    
    print(f"Frames saved in: {frames_dir}")
    
    return video_path, gif_path, frames_dir

# Create animations for both plants
CALATHEA_ANIMATION_OUTPUT_DIR = ANIMATION_OUTPUT_DIR / "calathea"
MARANTA_ANIMATION_OUTPUT_DIR = ANIMATION_OUTPUT_DIR / "maranta"

# Create Calathea animation
calathea_video_path, calathea_gif_path, calathea_frames_dir = create_timelapse_animation(
    "calathea", 
    TESTING_DIR_CALATHEA, 
    calathea_results, 
    CALATHEA_ANIMATION_OUTPUT_DIR
)

# Create Maranta animation
maranta_video_path, maranta_gif_path, maranta_frames_dir = create_timelapse_animation(
    "maranta", 
    TESTING_DIR_MARANTA, 
    maranta_results, 
    MARANTA_ANIMATION_OUTPUT_DIR
)

print("\nAnimation creation complete!")
print(f"Maranta MP4: {maranta_video_path}")
print(f"Maranta GIF: {maranta_gif_path}")