In [None]:
#Plotting Functions
#I used separate functions for each dataset initially. This was fixed later- but both are still included.
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import os
import numpy as np
import matplotlib.colors as mcolors


def to_rgba(color_str, alpha=0.2):
    rgba = mcolors.to_rgba(color_str, alpha)
    return f'rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, {rgba[3]:.2f})'

def plot_metrics(results_tuple, output_folder, per_level_titles=None, x_axis_labels=None, context_length=3, segment_length=9, time_resolution=30):
    per_level_results, aggregated_results, csi_far_results = results_tuple
    os.makedirs(output_folder, exist_ok=True)

    model_names = list(per_level_results.keys())
    palette = px.colors.qualitative.Plotly
    color_map = {model: palette[i % len(palette)] for i, model in enumerate(model_names)}

    time_steps = list(range(context_length, segment_length))
    n_steps = len(time_steps)
    if x_axis_labels is None:
        x_axis_labels = [f"{(i+1)*time_resolution}" for i in range(n_steps)]

    metrics = ["MSE", "MAE", "PCC"]
    level_keys = list(per_level_results[model_names[0]].keys())
    if per_level_titles is None:
        per_level_titles = level_keys
    title_map = dict(zip(level_keys, per_level_titles))

    # ==== per-level MSE/MAE/PCC ====
    for lvl in level_keys:
        fig = make_subplots(rows=1, cols=3, shared_yaxes=False, subplot_titles=metrics)
        for col, metric in enumerate(metrics, start=1):
            for model in model_names:
                mean_vals = per_level_results[model][lvl].get(metric + "_mean")
                std_vals  = per_level_results[model][lvl].get(metric + "_std")

                if mean_vals is not None and std_vals is not None:
                    y_mean = [mean_vals[t] for t in time_steps]
                    y_std = [std_vals[t] for t in time_steps]
                    y_upper = [m + s for m, s in zip(y_mean, y_std)]
                    y_lower = [m - s for m, s in zip(y_mean, y_std)]

                    fig.add_trace(
                        go.Scatter(
                            x=x_axis_labels + x_axis_labels[::-1],
                            y=y_upper + y_lower[::-1],
                            fill='toself',
                            fillcolor=to_rgba(color_map[model], alpha=0.2),
                            line=dict(color='rgba(255,255,255,0)'),
                            hoverinfo="skip",
                            showlegend=False
                        ),
                        row=1, col=col
                    )

                    fig.add_trace(
                        go.Scatter(
                            x=x_axis_labels,
                            y=y_mean,
                            mode="lines+markers",
                            name=model,
                            line=dict(color=color_map[model]),
                            marker=dict(color=color_map[model]),
                            showlegend=(col == 1)
                        ),
                        row=1, col=col
                    )
            fig.update_xaxes(title_text="Time (min)", row=1, col=col)
            fig.update_yaxes(title_text=metric, row=1, col=col)

        fig.update_layout(
            title_text=f"Metrics for Level {title_map[lvl]}",
            title_x=0.5,
            legend=dict(orientation="h", y=-0.2, x=0.5, xanchor="center"),
            height=400, width=1200
        )
        fig.write_image(os.path.join(output_folder, f"metrics_level_{lvl}.png"))

    # ==== aggregated MSE/MAE/PCC ====
    fig = make_subplots(rows=1, cols=3, shared_yaxes=False, subplot_titles=metrics)
    for col, metric in enumerate(metrics, start=1):
        for model in model_names:
            mean_vals = aggregated_results[model].get(metric + "_mean")
            std_vals = aggregated_results[model].get(metric + "_std")

            if mean_vals is not None and std_vals is not None:
                y_mean = [mean_vals[t] for t in time_steps]
                y_std = [std_vals[t] for t in time_steps]
                y_upper = [m + s for m, s in zip(y_mean, y_std)]
                y_lower = [m - s for m, s in zip(y_mean, y_std)]

                fig.add_trace(
                    go.Scatter(
                        x=x_axis_labels + x_axis_labels[::-1],
                        y=y_upper + y_lower[::-1],
                        fill='toself',
                        fillcolor=to_rgba(color_map[model], alpha=0.2),
                        line=dict(color='rgba(255,255,255,0)'),
                        hoverinfo="skip",
                        showlegend=False
                    ),
                    row=1, col=col
                )

                fig.add_trace(
                    go.Scatter(
                        x=x_axis_labels,
                        y=y_mean,
                        mode="lines+markers",
                        name=model,
                        line=dict(color=color_map[model]),
                        marker=dict(color=color_map[model]),
                        showlegend=(col == 1)
                    ),
                    row=1, col=col
                )
        fig.update_xaxes(title_text="Time (min)", row=1, col=col)
        fig.update_yaxes(title_text=metric, row=1, col=col)

    fig.update_layout(
        title_text="Aggregated Metrics",
        title_x=0.5,
        legend=dict(orientation="h", y=-0.2, x=0.5, xanchor="center"),
        height=400, width=1200
    )
    fig.write_image(os.path.join(output_folder, "metrics_aggregated.png"))

    print("Saved all metric plots with standard deviation shading.")

import os
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import numpy as np
import matplotlib.colors as mcolors

def to_rgba(color_str, alpha=0.2):
    rgba = mcolors.to_rgba(color_str, alpha)
    return f'rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, {rgba[3]:.2f})'

def plot_metrics_sevir(results_tuple, output_folder, per_level_titles=None, x_axis_labels=None,context_length=3,segment_length=9,time_resolution=30):
    per_level_results, aggregated_results, csi_far_results = results_tuple
    os.makedirs(output_folder, exist_ok=True)

    time_steps = list(range(context_length, segment_length))
    n_steps = len(time_steps)
    if x_axis_labels is None:
        x_axis_labels = [f"{(i+1)*time_resolution}" for i in range(n_steps)]

    model_names = list(per_level_results.keys())
    palette = px.colors.qualitative.Plotly
    color_map = {m: palette[i % len(palette)] for i,m in enumerate(model_names)}
    legend_layout = dict(orientation="h", y=-0.2, x=0.5, xanchor="center")

    levels = list(per_level_results[model_names[0]].keys())
    if per_level_titles is None:
        per_level_titles = levels
    title_map = dict(zip(levels, per_level_titles))
    metrics = ["MSE","MAE","PCC"]

    # Per-level MSE/MAE/PCC
    for lvl in levels:
        fig = make_subplots(rows=1, cols=3, shared_yaxes=False, subplot_titles=metrics)
        for col, metric in enumerate(metrics, start=1):
            for model in model_names:
                mean_vals = per_level_results[model][lvl].get(metric + "_mean")
                std_vals = per_level_results[model][lvl].get(metric + "_std")
                if mean_vals is not None and std_vals is not None:
                    y_mean = [mean_vals[t] for t in time_steps]
                    y_std = [std_vals[t] for t in time_steps]
                    y_upper = [m + s for m, s in zip(y_mean, y_std)]
                    y_lower = [m - s for m, s in zip(y_mean, y_std)]
                    fig.add_trace(go.Scatter(x=x_axis_labels + x_axis_labels[::-1], y=y_upper + y_lower[::-1], fill='toself', fillcolor=to_rgba(color_map[model], 0.2), line=dict(color='rgba(255,255,255,0)'), showlegend=False), row=1, col=col)
                    fig.add_trace(go.Scatter(x=x_axis_labels, y=y_mean, mode="lines+markers", name=model, line=dict(color=color_map[model]), marker=dict(color=color_map[model]), showlegend=(col==1)), row=1, col=col)
            fig.update_xaxes(title_text="Time (min)", row=1, col=col)
            fig.update_yaxes(title_text=metric, row=1, col=col)

        fig.update_layout(title_text=f"Metrics for Level {title_map[lvl]}", title_x=0.5, legend=legend_layout, height=400, width=1200)
        fig.write_image(os.path.join(output_folder, f"metrics_level_{lvl}.png"))

    # Aggregated MSE/MAE/PCC
    fig = make_subplots(rows=1, cols=3, shared_yaxes=False, subplot_titles=metrics)
    for col, metric in enumerate(metrics, start=1):
        for model in model_names:
            mean_vals = aggregated_results[model].get(metric + "_mean")
            std_vals = aggregated_results[model].get(metric + "_std")
            if mean_vals is not None and std_vals is not None:
                y_mean = [mean_vals[t] for t in time_steps]
                y_std = [std_vals[t] for t in time_steps]
                y_upper = [m + s for m, s in zip(y_mean, y_std)]
                y_lower = [m - s for m, s in zip(y_mean, y_std)]
                fig.add_trace(go.Scatter(x=x_axis_labels + x_axis_labels[::-1], y=y_upper + y_lower[::-1], fill='toself', fillcolor=to_rgba(color_map[model], 0.2), line=dict(color='rgba(255,255,255,0)'), showlegend=False), row=1, col=col)
                fig.add_trace(go.Scatter(x=x_axis_labels, y=y_mean, mode="lines+markers", name=model, line=dict(color=color_map[model]), marker=dict(color=color_map[model]), showlegend=(col==1)), row=1, col=col)
        fig.update_xaxes(title_text="Time (min)", row=1, col=col)
        fig.update_yaxes(title_text=metric, row=1, col=col)

    fig.update_layout(title_text="Aggregated Metrics Across All Levels", title_x=0.5, legend=legend_layout, height=400, width=1200)
    fig.write_image(os.path.join(output_folder, "metrics_aggregated.png"))

    # CSI and FAR: 2x3 grid
    thresholds = list(csi_far_results[model_names[0]].keys())
    assert len(thresholds) == 6, "Expected 6 thresholds for 2×3 layout"
    for metric in ["CSI", "FAR"]:
        fig = make_subplots(rows=2, cols=3, shared_yaxes=False, subplot_titles=[f"{thr}" for thr in thresholds])
        for idx, thr in enumerate(thresholds):
            r, c = divmod(idx, 3)
            for model in model_names:
                mean_vals = csi_far_results[model][thr].get(f"{metric}_mean")
                std_vals = csi_far_results[model][thr].get(f"{metric}_std")
                if mean_vals is not None and std_vals is not None:
                    y_mean = [mean_vals[t] for t in time_steps]
                    y_std = [std_vals[t] for t in time_steps]
                    y_upper = [m + s for m, s in zip(y_mean, y_std)]
                    y_lower = [m - s for m, s in zip(y_mean, y_std)]
                    fig.add_trace(go.Scatter(x=x_axis_labels + x_axis_labels[::-1], y=y_upper + y_lower[::-1], fill='toself', fillcolor=to_rgba(color_map[model], 0.2), line=dict(color='rgba(255,255,255,0)'), showlegend=False), row=r+1, col=c+1)
                    fig.add_trace(go.Scatter(x=x_axis_labels, y=y_mean, mode="lines+markers", name=model, line=dict(color=color_map[model]), marker=dict(color=color_map[model]), showlegend=(idx==0)), row=r+1, col=c+1)
            fig.update_xaxes(title_text="Time (min)", row=r+1, col=c+1)
            fig.update_yaxes(title_text=metric, row=r+1, col=c+1)

        fig.update_layout(title_text=f"{metric} at Various Thresholds", title_x=0.5, legend=legend_layout, height=800, width=1200)
        fig.write_image(os.path.join(output_folder, f"{metric.lower()}_2x3.png"))

        import os
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

# Global parameters (reuse from earlier)


def plot_metrics_relative_sevir(results_tuple, ref_model, output_folder,
                          per_level_titles=None, x_axis_labels=None,context_length=3,segment_length=9,time_resolution=30):
    """
    Plots percentage difference of each model relative to a reference model
    for MSE, MAE, PCC, CSI, and FAR. Uses the same layout and color scheme
    as plot_metrics_sevir, but with dashed lines and zero-baseline.

    Args:
        results_tuple (tuple):
            - per_level_results: dict[model][level]["MSE"/"MAE"/"PCC"] -> list over time
            - aggregated_results: dict[model]["MSE"/"MAE"/"PCC"] -> list over time
            - csi_far_results: dict[model][threshold]["CSI"/"FAR"] -> list over time
        ref_model (str): Name of the reference model to compare against.
        output_folder (str): Directory to save figures.
        per_level_titles (list): Display names for each precipitation level.
        x_axis_labels (list): Labels (minutes) for the predicted time steps.
    """
    per_level_results, aggregated_results, csi_far_results = results_tuple
    os.makedirs(output_folder, exist_ok=True)

    # build color map identical to before
    model_names = list(per_level_results.keys())
    palette = px.colors.qualitative.Plotly
    color_map = {m: palette[i % len(palette)] for i, m in enumerate(model_names)}

    # time steps and labels
    time_steps = list(range(context_length, segment_length))
    n_steps = len(time_steps)
    if x_axis_labels is None:
        x_axis_labels = [f"{(i+1)*time_resolution}" for i in range(n_steps)]

    # horizontal zero baseline trace
    zero_line = dict(mode="lines", line=dict(color="black", width=1, dash="dot"), showlegend=False)

    # common legend layout
    legend_layout = dict(orientation="h", y=-0.2, x=0.5, xanchor="center")

    metrics = ["MSE", "MAE", "PCC"]
    levels = list(per_level_results[ref_model].keys())
    if per_level_titles is None:
        per_level_titles = levels
    title_map = dict(zip(levels, per_level_titles))

    # 1) Per-level MSE/MAE/PCC relative plots
    for lvl in levels:
        fig = make_subplots(rows=1, cols=3, shared_yaxes=False, subplot_titles=metrics)
        for col, metric in enumerate(metrics, start=1):
            ref_vals = per_level_results[ref_model][lvl][metric]
            for model in model_names:
                if model == ref_model:
                    continue
                vals = per_level_results[model][lvl][metric]
                y = [ (vals[t] - ref_vals[t]) / ref_vals[t] * 100 for t in time_steps ]
                fig.add_trace(
                    go.Scatter(x=x_axis_labels, y=y,
                               mode="lines+markers",
                               name=model,
                               line=dict(color=color_map[model], dash="dash"),
                               marker=dict(color=color_map[model]),
                               showlegend=(col==1)
                              ),
                    row=1, col=col
                )
            # add zero baseline
            fig.add_trace(go.Scatter(x=x_axis_labels, y=[0]*n_steps, **zero_line), row=1, col=col)
            fig.update_xaxes(title_text="Time (min)", row=1, col=col)
            fig.update_yaxes(title_text=f"% Δ {metric}", row=1, col=col)

        fig.update_layout(
            title_text=f"% Difference vs {ref_model} | Level {title_map[lvl]}",
            title_x=0.5, legend=legend_layout,
            height=400, width=1200
        )
        fig.write_image(os.path.join(output_folder, f"rel_metrics_level_{lvl}.png"))
      #  fig.close()

    # 2) Aggregated MSE/MAE/PCC relative
    fig = make_subplots(rows=1, cols=3, shared_yaxes=False, subplot_titles=metrics)
    for col, metric in enumerate(metrics, start=1):
        ref_vals = aggregated_results[ref_model][metric]
        for model in model_names:
            if model == ref_model:
                continue
            vals = aggregated_results[model][metric]
            y = [ (vals[t] - ref_vals[t]) / ref_vals[t] * 100 for t in time_steps ]
            fig.add_trace(
                go.Scatter(x=x_axis_labels, y=y,
                           mode="lines+markers",
                           name=model,
                           line=dict(color=color_map[model], dash="dash"),
                           marker=dict(color=color_map[model]),
                           showlegend=(col==1)
                          ),
                row=1, col=col
            )
        fig.add_trace(go.Scatter(x=x_axis_labels, y=[0]*n_steps, **zero_line), row=1, col=col)
        fig.update_xaxes(title_text="Time (min)", row=1, col=col)
        fig.update_yaxes(title_text=f"% Δ {metric}", row=1, col=col)

    fig.update_layout(
        title_text=f"% Difference vs {ref_model} | Aggregated",
        title_x=0.5, legend=legend_layout,
        height=400, width=1200
    )
    fig.write_image(os.path.join(output_folder, "rel_metrics_aggregated.png"))
   # fig.close()

    # 3) CSI: 2×3 grid relative
    thresholds = list(csi_far_results[ref_model].keys())
    assert len(thresholds) == 6, "Expected 6 thresholds for 2×3 layout"
    fig = make_subplots(rows=2, cols=3, shared_yaxes=False,
                        subplot_titles=[f"{thr}" for thr in thresholds])
    for idx, thr in enumerate(thresholds):
        r, c = divmod(idx, 3)
        ref_vals = csi_far_results[ref_model][thr]["CSI"]
        for model in model_names:
            if model == ref_model:
                continue
            vals = csi_far_results[model][thr]["CSI"]
            y = [ (vals[t] - ref_vals[t]) / ref_vals[t] * 100 for t in time_steps ]
            fig.add_trace(
                go.Scatter(x=x_axis_labels, y=y,
                           mode="lines+markers",
                           name=model,
                           line=dict(color=color_map[model], dash="dash"),
                           marker=dict(color=color_map[model]),
                           showlegend=(idx==0)
                          ),
                row=r+1, col=c+1
            )
        fig.add_trace(go.Scatter(x=x_axis_labels, y=[0]*n_steps, **zero_line), row=r+1, col=c+1)
        fig.update_xaxes(title_text="Time (min)", row=r+1, col=c+1)
        fig.update_yaxes(title_text="% Δ CSI", row=r+1, col=c+1)

    fig.update_layout(
        title_text=f"% Difference CSI vs {ref_model}",
        title_x=0.5, legend=legend_layout,
        height=800, width=1200
    )
    fig.write_image(os.path.join(output_folder, "rel_csi_2x3.png"))
   # fig.close()

    # 4) FAR: 2×3 grid relative
    fig = make_subplots(rows=2, cols=3, shared_yaxes=False,
                        subplot_titles=[f"{thr}" for thr in thresholds])
    for idx, thr in enumerate(thresholds):
        r, c = divmod(idx, 3)
        ref_vals = csi_far_results[ref_model][thr]["FAR"]
        for model in model_names:
            if model == ref_model:
                continue
            vals = csi_far_results[model][thr]["FAR"]
            y = [ (vals[t] - ref_vals[t]) / ref_vals[t] * 100 for t in time_steps ]
            fig.add_trace(
                go.Scatter(x=x_axis_labels, y=y,
                           mode="lines+markers",
                           name=model,
                           line=dict(color=color_map[model], dash="dash"),
                           marker=dict(color=color_map[model]),
                           showlegend=(idx==0)
                          ),
                row=r+1, col=c+1
            )
        fig.add_trace(go.Scatter(x=x_axis_labels, y=[0]*n_steps, **zero_line), row=r+1, col=c+1)
        fig.update_xaxes(title_text="Time (min)", row=r+1, col=c+1)
        fig.update_yaxes(title_text="% Δ FAR", row=r+1, col=c+1)

    fig.update_layout(
        title_text=f"% Difference FAR vs {ref_model}",
        title_x=0.5, legend=legend_layout,
        height=800, width=1200
    )
    fig.write_image(os.path.join(output_folder, "rel_far_2x3.png"))
  #  fig.close()

#averaging function across seeds. Also calcs std dev.
import copy
import numpy as np
from collections import defaultdict

def average_nested_lists_across_seeds_std(seed_results):
    """
    Averages nested metric lists across multiple seed result tuples and computes standard deviation.

    Parameters
    ----------
    seed_results : list of tuples
        Each tuple has the structure: (level_metrics, overall_metrics, threshold_metrics)

    Returns
    -------
    tuple
        Averaged version of (level_metrics, overall_metrics, threshold_metrics)
        Each metric now includes both _mean and _std
    """
    def average_leaf_lists(dicts):
        """Averages leaf lists in a list of identically structured dicts and computes std."""
        result = {}
        for key in dicts[0]:
            if isinstance(dicts[0][key], dict):
                result[key] = average_leaf_lists([d[key] for d in dicts])
            elif isinstance(dicts[0][key], list):
                stacked = np.stack([d[key] for d in dicts])  # shape (num_seeds, ...)
                result[key + "_mean"] = stacked.mean(axis=0).tolist()
                result[key + "_std"] = stacked.std(axis=0).tolist()
            else:
                raise ValueError(f"Unsupported type at key {key}: {type(dicts[0][key])}")
        return result

    # Transpose seed_results into 3 lists: per_level_all, overall_all, threshold_all
    per_level_all = [seed[0] for seed in seed_results]
    overall_all   = [seed[1] for seed in seed_results]
    threshold_all = [seed[2] for seed in seed_results]

    # Averaging per-level metrics
    per_level_avg = {}
    for model in per_level_all[0]:
        per_level_avg[model] = {}
        for level in per_level_all[0][model]:
            per_level_dicts = [seed[model][level] for seed in per_level_all]
            per_level_avg[model][level] = average_leaf_lists(per_level_dicts)

    # Averaging overall metrics
    overall_avg = {}
    for model in overall_all[0]:
        model_dicts = [seed[model] for seed in overall_all]
        overall_avg[model] = average_leaf_lists(model_dicts)

    # Averaging threshold-based metrics
    threshold_avg = {}
    for model in threshold_all[0]:
        threshold_avg[model] = {}
        for thresh in threshold_all[0][model]:
            thresh_dicts = [seed[model][thresh] for seed in threshold_all]
            threshold_avg[model][thresh] = average_leaf_lists(thresh_dicts)

    return (per_level_avg, overall_avg, threshold_avg)



In [None]:
#KNMI 30 results

In [1]:
import pickle 
#BlockGPT
with open ('Results/FinalPickledResults/results_blockGPT_KNMI30_seed1.pkl','rb') as file:
    vqbatched = pickle.load(file)



with open ('Results/FinalPickledResults/results_blockGPT_KNMI30_seed2seed3.pkl','rb') as file:
    vqbatched_s2s3 = pickle.load(file)

# Diffcast
    
with open('Results/FinalPickledResults/results_diffcast_phydnet_KNMI30_seed1.pkl','rb') as file:
    diffcast_s1 = pickle.load(file)
with open('Results/FinalPickledResults/results_diffcast_phydnet_KNMI30_seed2seed3.pkl','rb') as file:
    diffcast_s2s3 = pickle.load(file)

#NowcastingGPT
    
with open ('Results/FinalPickledResults/results_nowcastingGPT_KNMI30_seed1.pkl','rb') as file:
    nowcasting_s1 = pickle.load(file)
with open('Results/FinalPickledResults/results_nowcastingGPT_KNMI30_seed2seed3.pkl','rb') as file:
    nowcasting_s2s3 = pickle.load(file)


In [None]:

vqbatched_s1 = tuple({ 'BlockGPT': entry['vqgan-GPT_batched'] } for entry in vqbatched)
vqbatched_s2 = tuple({ 'BlockGPT': entry['vqgan-s1_GPT_batched'] } for entry in vqbatched_s2s3)
vqbatched_s3 = tuple({ 'BlockGPT': entry['vqgan-s2_GPT_batched'] } for entry in vqbatched_s2s3)
avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([vqbatched_s1,vqbatched_s2, vqbatched_s3])
seed_results_vqbatched = (avg_per_level, avg_overall, avg_threshold)


diffcast_s1 = tuple({ 'Diffcast+Phydnet': entry['diffcast_phydnet'] } for entry in diffcast_s1)
diffcast_s2 = tuple({ 'Diffcast+Phydnet': entry['diffcast_s1_phydnet'] } for entry in diffcast_s2s3)
diffcast_s3 = tuple({ 'Diffcast+Phydnet': entry['diffcast_s2_phydnet'] } for entry in diffcast_s2s3)
avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([diffcast_s1,diffcast_s2, diffcast_s3])
seed_results_diffcast = (avg_per_level, avg_overall, avg_threshold)


nowcasting_s1= tuple({ 'NowcastingGPT': entry['vqgan_s1-GPT'] } for entry in nowcasting_s1)
nowcasting_s2 = tuple({ 'NowcastingGPT': entry['vqgan_s2-GPT'] } for entry in nowcasting_s2s3)
nowcasting_s3 = tuple({ 'NowcastingGPT': entry['vqgan_s3-GPT'] } for entry in nowcasting_s2s3)
avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([nowcasting_s1,nowcasting_s2, nowcasting_s3])
seed_results_nowcasting = (avg_per_level, avg_overall, avg_threshold)

In [None]:
combined_results_seeds= tuple(
    {**d1, **d2,**d3}  # Merge dictionaries at the same index
    for d1, d2,d3 in zip(seed_results_vqbatched,seed_results_diffcast,seed_results_nowcasting)
)


plot_metrics(combined_results_seeds , "Results/MetricsPlots/KNMI30",["0-20","20-40","40-60","60-80","80-95","95-100"])

In [None]:
#ablations on KNMI 30 
#16H indicates 16 heads
#block 8 represents a block size of 8- generates a row at a time
#tt is token-by-token, rr is row  by row and ff is frame by frame

In [2]:

import pickle
with open ('Results/FinalPickledResults/results_blockGPT_KNMI30_seed1.pkl','rb') as file:
    vqbatched = pickle.load(file)
with open('Results/FinalPickledResults/results_blockGPT_KNMI30_ablatations_50M_200M.pkl','rb') as file:
    vqbatched_4H_16H = pickle.load(file)

with open('Results/FinalPickledResults/results_blockGPT_KNMI30_ablatations_50M_200M_seed2seed3.pkl','rb') as file:
    vqbatched_4H_16H_s2s3 = pickle.load(file)
with open('Results/FinalPickledResults/results_blockGPT_KNMI30_ablatations_rowbyrow_seed1.pkl','rb') as file:
    vqbatched_block8 = pickle.load(file)
with open('Results/FinalPickledResults/results_blockGPT_KNMI30_ablatations_rowbyrow_seed2seed3.pkl','rb') as file:
    vqbatched_block8_s2s3 = pickle.load(file)

In [None]:

vqbatched_s1 = tuple({ 'BlockGPT_8H': entry['vqgan-GPT_batched'] } for entry in vqbatched)
vqbatched_s2 = tuple({ 'BlockGPT_8H': entry['vqgan-s1_GPT_batched'] } for entry in vqbatched_s2s3)
vqbatched_s3 = tuple({ 'BlockGPT_8H': entry['vqgan-s2_GPT_batched'] } for entry in vqbatched_s2s3)
results_seeds = [vqbatched_s1,vqbatched_s2, vqbatched_s3]

avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([vqbatched_s1,vqbatched_s2, vqbatched_s3])
seed_results_vqbatched = (avg_per_level, avg_overall, avg_threshold)

In [None]:
#rename 
vqbatched_4H_s1 = tuple({ 'BlockGPT 4H': entry['vqgan_4H-GPT_batched'] } for entry in vqbatched_4H_16H)
vqbatched_4H_s2 = tuple({ 'BlockGPT 4H': entry['vqgan_4H-s2_GPT_batched'] } for entry in vqbatched_4H_16H_s2s3)
vqbatched_4H_s3 = tuple({ 'BlockGPT 4H': entry['vqgan_4H-s3_GPT_batched'] } for entry in vqbatched_4H_16H_s2s3)
vqbatched_16H_s1 = tuple({ 'BlockGPT 16H': entry['vqgan_16H-GPT_batched'] } for entry in vqbatched_4H_16H)
vqbatched_16H_s2 = tuple({ 'BlockGPT 16H': entry['vqgan_16H-s2_GPT_batched'] } for entry in vqbatched_4H_16H_s2s3)
vqbatched_16H_s3 = tuple({ 'BlockGPT 16H': entry['vqgan_16H-s3_GPT_batched'] } for entry in vqbatched_4H_16H_s2s3)



In [None]:
#rename NowcastingGPT to token-by-token in seed_results_nowcasting
tt = tuple({ 'token-by-token (NowcastingGPT)': entry['NowcastingGPT'] } for entry in seed_results_nowcasting)
rr = tuple({ 'row-by-row': entry['vqgan_block8-GPT_batched'] } for entry in vqbatched_block8)
ff = tuple({ 'frame-by-frame (BlockGPT)': entry['BlockGPT_8H'] } for entry in seed_results_vqbatched)



In [None]:
#avg across seeds
avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([vqbatched_4H_s1,vqbatched_4H_s2, vqbatched_4H_s3])
seed_results_vqbatched_4H = (avg_per_level, avg_overall, avg_threshold)

avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([vqbatched_16H_s1,vqbatched_16H_s2, vqbatched_16H_s3])
seed_results_vqbatched_16H = (avg_per_level, avg_overall, avg_threshold)

In [None]:

rr_s2 = tuple({ 'row-by-row': entry['vqgan_block8-s2-GPT_batched'] } for entry in vqbatched_block8_s2s3)
rr_s3 = tuple({ 'row-by-row': entry['vqgan_block8-s3-GPT_batched'] } for entry in vqbatched_block8_s2s3)

avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([rr ,rr_s2, rr_s3])
seed_results_vqbatched_rr = (avg_per_level, avg_overall, avg_threshold)

combined_results = tuple(
    {**d1, **d2,**d3}  # Merge dictionaries at the same index
    for d1, d2,d3 in zip(seed_results_vqbatched_4H,seed_results_vqbatched,seed_results_vqbatched_16H)
)
plot_metrics(combined_results , "Results/FinalMetricPlots/head_size",["0-20","20-40","40-60","60-80","80-95","95-100"])

combined_results = tuple(
    {**d1, **d2,**d3}  # Merge dictionaries at the same index
    for d1, d2,d3 in zip(tt,seed_results_vqbatched_rr,ff)
)


plot_metrics(combined_results , "Results/FinalMetricPlots/KNMI30/Ablations/block_size",["0-20","20-40","40-60","60-80","80-95","95-100"])

In [None]:
#seasonal analysis

In [None]:
with open ('Results/FinalPickledResults/results_blockGPT_KNMI30_ablatations_winter_seeds123.pkl','rb') as file:
    vqbatched_winter = pickle.load(file)
#rename 
vqbatched_winter_s1 = tuple({ 'BlockGPT in winter': entry['vqgan-s1_GPT_batched'] } for entry in vqbatched_winter)
vqbatched_winter_s2 = tuple({ 'BlockGPT in winter': entry['vqgan-s2_GPT_batched'] } for entry in vqbatched_winter)
vqbatched_winter_s3 = tuple({ 'BlockGPT in winter': entry['vqgan-s3_GPT_batched'] } for entry in vqbatched_winter)

In [None]:
with open ('Results/FinalPickledResults/results_blockGPT_KNMI30_ablatations_spring_seeds123.pkl','rb') as file:
    vqbatched_spring = pickle.load(file)
#rename
vqbatched_spring_s1 = tuple({ 'BlockGPT in spring': entry['vqgan-s1_GPT_batched'] } for entry in vqbatched_spring)
vqbatched_spring_s2 = tuple({ 'BlockGPT in spring': entry['vqgan-s2_GPT_batched'] } for entry in vqbatched_spring)
vqbatched_spring_s3 = tuple({ 'BlockGPT in spring': entry['vqgan-s3_GPT_batched'] } for entry in vqbatched_spring)

In [None]:
#summer
with open ('Results/FinalPickledResults/results_blockGPT_KNMI30_ablatations_summer_seeds123.pkl','rb') as file:
    vqbatched_summer = pickle.load(file)
#rename
vqbatched_summer_s1 = tuple({ 'BlockGPT in summer': entry['vqgan-s1_GPT_batched'] } for entry in vqbatched_summer)
vqbatched_summer_s2 = tuple({ 'BlockGPT in summer': entry['vqgan-s2_GPT_batched'] } for entry in vqbatched_summer)
vqbatched_summer_s3 = tuple({ 'BlockGPT in summer': entry['vqgan-s3_GPT_batched'] } for entry in vqbatched_summer)

In [None]:
#fall
with open ('Results/FinalPickledResults/results_blockGPT_KNMI30_ablatations_fall_seeds123.pkl','rb') as file:
    vqbatched_fall = pickle.load(file)
#rename
vqbatched_fall_s1 = tuple({ 'BlockGPT in fall': entry['vqgan-s1_GPT_batched'] } for entry in vqbatched_fall)
vqbatched_fall_s2 = tuple({ 'BlockGPT in fall': entry['vqgan-s2_GPT_batched'] } for entry in vqbatched_fall)
vqbatched_fall_s3 = tuple({ 'BlockGPT in fall': entry['vqgan-s3_GPT_batched'] } for entry in vqbatched_fall)

In [None]:
#avergae across seeds
vq_summmer = average_nested_lists_across_seeds_std([vqbatched_summer_s1,vqbatched_summer_s2,vqbatched_summer_s3])
vq_winter = average_nested_lists_across_seeds_std([vqbatched_winter_s1,vqbatched_winter_s2,vqbatched_winter_s3])
vq_spring = average_nested_lists_across_seeds_std([vqbatched_spring_s1,vqbatched_spring_s2,vqbatched_spring_s3])
vq_fall = average_nested_lists_across_seeds_std([vqbatched_fall_s1,vqbatched_fall_s2,vqbatched_fall_s3])

combined_results = tuple(
    {**d1, **d2,**d3,**d4}  # Merge dictionaries at the same index
    for d1, d2,d3,d4 in zip(vq_summmer,vq_winter,vq_spring,vq_fall)
)

#plot
plot_metrics(combined_results, "Results/FinalMetricPlots/KNMI30/Ablations/seasons",["0-20","20-40","40-60","60-80","80-95","95-100"])

In [None]:
#SEVIR 30 minutes

In [None]:
import pickle

with open('Results/FinalPickledResults/results_blockGPT_SEVIR30_seed2seed3.pkl','rb') as file:
    vqbatched_s2s3 = pickle.load(file)

with open('Results/FinalPickledResults/results_blockGPT_diffcast_phydnet_seed1.pkl','rb') as file:
    diffcast_vqbatched_s1 = pickle.load(file)
with open('Results/FinalPickledResults/results_diffcast_phydnet_SEVIR30_seed2seed3.pkl','rb') as file:
    diffcast_s2s3 = pickle.load(file)


with open('Results/FinalPickledResults/results_nowcastingGPT_SEVIR30_seed1.pkl','rb') as file:
    nowcasting_s1 = pickle.load(file)
with open('Results/FinalPickledResults/results_nowcastingGPT_SEVIR30_seed2seed3.pkl','rb') as file:
    nowcasting_s2s3 = pickle.load(file)


with open('Results/FinalPickledResults/results_diffcast_blockGPT_SEVIR30_seed1.pkl','rb') as file:
    diffcastBlockGPT = pickle.load(file)
diffcastBlockGPT = tuple({ 'diffcast+BlockGPT': entry['diffcast_BlockGPT'] } for entry in diffcastBlockGPT)

with open('Results/FinalPickledResults/results_diffcast_blockGPT_SEVIR30_seed2.pkl','rb') as file:
    diffcastBlockGPT_seed2 = pickle.load(file)

diffcastBlockGPT_seed2 = tuple({ 'diffcast+BlockGPT': entry['diffcast_BlockGPT'] } for entry in diffcastBlockGPT_seed2)

with open('Results/FinalPickledResults/results_diffcast_blockGPT_SEVIR30_seed3.pkl','rb') as file:
    diffcastBlockGPT_seed3 = pickle.load(file)
diffcastBlockGPT_seed3 = tuple({ 'diffcast+BlockGPT': entry['diffcast_BlockGPT'] } for entry in diffcastBlockGPT_seed3)

In [None]:
diffcast_s1 = tuple({ 'Diffcast+Phydnet': entry['diffcast_phydnet'] } for entry in diffcast_vqbatched_s1)
diffcast_s2 = tuple({ 'Diffcast+Phydnet': entry['diffcast_s1_phydnet'] } for entry in diffcast_s2s3)
diffcast_s3 = tuple({ 'Diffcast+Phydnet': entry['diffcast_s2_phydnet'] } for entry in diffcast_s2s3)

vqbatched_s1 = tuple({ 'BlockGPT': entry['vqgan-GPT_batched'] } for entry in diffcast_vqbatched_s1)
vqbatched_s2 = tuple({ 'BlockGPT': entry['vqgan_s1_GPT_batched'] } for entry in vqbatched_s2s3)
vqbatched_s3 = tuple({ 'BlockGPT': entry['vqgan_s2_GPT_batched'] } for entry in vqbatched_s2s3)

nowcasting_s1 = tuple({ 'NowcastingGPT': entry['vqgan_s1-GPT'] } for entry in nowcasting_s1)
nowcasting_s2 = tuple({ 'NowcastingGPT': entry['vqgan_s2-GPT'] } for entry in nowcasting_s2s3)
nowcasting_s3 = tuple({ 'NowcastingGPT': entry['vqgan_s3-GPT'] } for entry in nowcasting_s2s3)

In [None]:
avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([diffcast_s1,diffcast_s2, diffcast_s3])
seed_results_diffcast = (avg_per_level, avg_overall, avg_threshold)

avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([vqbatched_s1,vqbatched_s2, vqbatched_s3])
seed_results_vqbatched = (avg_per_level, avg_overall, avg_threshold)

avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([nowcasting_s1,nowcasting_s2, nowcasting_s3])
seed_results_nowcasting = (avg_per_level, avg_overall, avg_threshold)

avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([diffcastBlockGPT, diffcastBlockGPT_seed2, diffcastBlockGPT_seed3])
seed_results_diffblockgpt= (avg_per_level, avg_overall, avg_threshold)

In [None]:
combined_results_sevir30 = tuple(
    {**d1, **d2,**d3}  # Merge dictionaries at the same index
    for d1, d2,d3 in zip(seed_results_vqbatched,seed_results_diffcast,seed_results_nowcasting)
)

plot_metrics(combined_results_sevir30 , "Results/MetricsPlots/SEVIR30",["0-20","20-40","40-60","60-80","80-95","95-100"])

In [None]:
#Sevir 5 minutes

In [None]:
import pickle 

with open('Results/FinalPickledResults/results_blockGPT_SEVIR5_seed2seed3.pkl','rb') as file:
    vqbatched_s2s3= pickle.load(file)#

with open ('Results/FinalPickledResults/results_blockGPT_SEVIR5_seed1.pkl','rb') as file:
    vqbatched_s1= pickle.load(file)
with open('Results/FinalPickledResults/results_diffcast_phydnet_SEVIR5_seed2seed3.pkl','rb') as file:
    diffcast_s2s3 = pickle.load(file)
with open('Results/FinalPickledResults/results_diffcast_phydnet_SEVIR5_seed1.pkl','rb') as file:
    diffcast_s1 = pickle.load(file)
diffcast_s1 = tuple({ 'Diffcast+Phydnet': entry['diffcast_phydnet'] } for entry in diffcast_s1)



with open('Results/FinalPickledResults/results_nowcastingGPT_SEVIR5_seed1.pkl','rb') as file:
    nowcasting_sevir_big_s1 = pickle.load(file)

with open('Results/FinalPickledResults/results_nowcastingGPT_SEVIR5_seed2.pkl','rb') as file:
    nowcasting_sevir_big_s2 = pickle.load(file)

with open('Results/FinalPickledResults/results_nowcastingGPT_SEVIR5_seed3.pkl','rb') as file:
    nowcasting_sevir_big_s3 = pickle.load(file)
    


In [None]:
vqbatched_s1 = tuple({ 'BlockGPT': entry['vqgan-GPT_batched'] } for entry in vqbatched_s1)
vqbatched_s2 = tuple({ 'BlockGPT': entry['vqgan_s1_GPT_batched'] } for entry in vqbatched_s2s3)
vqbatched_s3 = tuple({ 'BlockGPT': entry['vqgan_s2_GPT_batched'] } for entry in vqbatched_s2s3)
avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([vqbatched_s1,vqbatched_s2, vqbatched_s3])
seed_results_vqbatched = (avg_per_level, avg_overall, avg_threshold)

diffcast_s1 = tuple({ 'Diffcast+Phydnet': entry['diffcast_phydnet'] } for entry in diffcast_s1)
diffcast_s2 = tuple({ 'Diffcast+Phydnet': entry['diffcast_s1_phydnet'] } for entry in diffcast_s2s3)
diffcast_s3 = tuple({ 'Diffcast+Phydnet': entry['diffcast_s2_phydnet'] } for entry in diffcast_s2s3)
avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([diffcast_s1,diffcast_s2, diffcast_s3])
seed_results_diffcast= (avg_per_level, avg_overall, avg_threshold)

nowcasting_s1 = tuple({ 'NowcastingGPT': entry['vqgan-s1_GPT'] } for entry in nowcasting_sevir_big_s1)
nowcasting_s2 = tuple({ 'NowcastingGPT': entry['vqgan-s2_GPT'] } for entry in nowcasting_sevir_big_s2)
nowcasting_s3 = tuple({ 'NowcastingGPT': entry['vqgan-s3_GPT'] } for entry in nowcasting_sevir_big_s3)
avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([nowcasting_s1,nowcasting_s2, nowcasting_s3])
seed_results_nowcasting = (avg_per_level, avg_overall, avg_threshold)

combined_metrics_sevir= tuple(
    {**d1, **d2,**d3}  # Merge dictionaries at the same index
    for d1, d2,d3 in zip(seed_results_vqbatched,seed_results_diffcast,seed_results_nowcasting)
)
plot_metrics(combined_metrics_sevir , "Results/MetricsPlots/SEVIR5",["0-20","20-40","40-60","60-80","80-95","95-100"])

In [None]:
#KNMI 5 minutes

In [None]:
import pickle

with open ('Results/FinalPickledResults/results_blockGPT_KNMI5_seeds123.pkl','rb') as file:
    vqbatched_s1s2s3 = pickle.load(file)
    
with open('Results/FinalPickledResults/results_diffcast_phydnet_KNMI5_seed1.pkl','rb') as file:
     diffcast_s1 = pickle.load(file)
with open('Results/FinalPickledResults/results_diffcast_phydnet_KNMI5_seed2seed3.pkl','rb') as file:
    diffcast_s2s3 = pickle.load(file)



with open ('Results/FinalPickledResults/results_nowcastingGPT_KNMI5_seed1.pkl','rb') as file:
    nowcasting_s1 = pickle.load(file)
with open('Results/FinalPickledResults/results_nowcastingGPT_KNMI5_seed2.pkl','rb') as file:
    nowcasting_s2 = pickle.load(file)
with open('Results/FinalPickledResults/results_nowcastingGPT_KNMI5_seed3.pkl','rb') as file:
    nowcasting_s3 = pickle.load(file)




In [None]:

vqbatched_s1 = tuple({ 'BlockGPT': entry['vqgan_GPT_batched'] } for entry in     vqbatched_s1s2s3)
vqbatched_s2 = tuple({ 'BlockGPT': entry['vqgan_s2_GPT_batched'] } for entry in     vqbatched_s1s2s3)
vqbatched_s3 = tuple({ 'BlockGPT': entry['vqgan_s3_GPT_batched'] } for entry in     vqbatched_s1s2s3)
#avg vqbatched
avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([vqbatched_s1,vqbatched_s2, vqbatched_s3])
seed_results_vqbatched = (avg_per_level, avg_overall, avg_threshold)

In [None]:
diffcast_s1 = tuple({ 'Diffcast+Phydnet': entry['diffcast_phydnet'] } for entry in diffcast_s1)
diffcast_s2 = tuple({ 'Diffcast+Phydnet': entry['diffcast_phydnet'] } for entry in diffcast_s2s3)
diffcast_s3 = tuple({ 'Diffcast+Phydnet': entry['diffcast_s2_phydnet'] } for entry in diffcast_s2s3)

In [None]:

nowcasting_s1 = tuple({ 'NowcastingGPT': entry['vqgan-GPT'] } for entry in nowcasting_s1)
nowcasting_s2 = tuple({ 'NowcastingGPT': entry['vqgan-s2_GPT'] } for entry in nowcasting_s2)
nowcasting_s3 = tuple({ 'NowcastingGPT': entry['vqgan-s3_GPT'] } for entry in nowcasting_s3)

In [None]:
#avg diffcast results 
avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([diffcast_s1,diffcast_s2, diffcast_s3])
seed_results_diffcast = (avg_per_level, avg_overall, avg_threshold)

#avg nowcasting results 
avg_per_level, avg_overall, avg_threshold = average_nested_lists_across_seeds_std([nowcasting_s1,nowcasting_s2, nowcasting_s3])
seed_results_nowcasting = (avg_per_level, avg_overall, avg_threshold)

In [None]:

combined_results = tuple(
    {**d1, **d2,**d3}  # Merge dictionaries at the same index
    for d1, d2,d3 in zip(seed_results_vqbatched,seed_results_diffcast,seed_results_nowcasting)
)
plot_metrics(combined_results , "Results/MetricsPlots/SEVIR5",["0-20","20-40","40-60","60-80","80-95","95-100"])