This notebook produces the main data from the paper:
# ForecastPFN: Synthetically-Trained Zero-Shot Forecasting
By: Samuel Dooley, Gurnoor Singh Khurana, Chirag Mohapatra, Siddartha Naidu, Colin White

The CSVs which can be used to reproduce the tables can be downloaded [here](https://drive.google.com/file/d/1oa9FlY6WQojlN4nx8ZGmo8Zbfc5QG-bd/view?usp=sharing).

In [None]:
import os
import numpy as np
import pandas as pd
import glob

In [None]:
import warnings
warnings.filterwarnings("ignore")

# Produce Table 1

In [None]:
df = pd.read_csv('benchmark.csv')
fifty = df[df['Train Budget'] == 50].groupby('Model').apply(lambda x: x.groupby('Dataset').mean()['mse']).round(3).style.apply(lambda col: ['font-weight:bold' if x==col.min() else '' for x in col])
# print(fifty.to_latex().replace('\\font-weightbold ', '\\textbf{'))
fifty

In [None]:
fhund = df[df['Train Budget'] == 500].groupby('Model').apply(lambda x: x.groupby('Dataset').mean()['mse']).round(3).style.apply(lambda col: ['font-weight:bold' if x==col.min() else '' for x in col])
# print(fhund.to_latex().replace('\\font-weightbold ', '\\textbf{'))
fhund

# Produce Figures 3 and 4
and associated figures for other metrics

In [None]:
import plotnine
from plotnine import *
fpfn_theme = plotnine.themes.theme(
    legend_position="bottom",
    legend_box_spacing=.55,
    axis_text_x=element_text(rotation=45, hjust=1)) + theme_minimal()

%matplotlib inline

colors = {'Informer': '#7fc97f',
          'Arima': '#beaed4',
          'Meta-N-BEATS': '#bb2299',
          'ForecastPFN': 'Black',
          'ForecastPFN_20230502-140223': 'Pink',
          'Autoformer': '#386cb0',
          'FEDformer': '#f0270f',
          'Prophet': '#bf5b16',
          'Transformer': '#fdc086',
          'Mean': '#287068',
          'Last': '#C12FDD',
          'SeasonalNaive': '#008000'
         }


In [None]:
# Plots for Wins by Prediction Length
metric_wins_pred_agg = pd.read_csv('metric_wins_pred_agg.csv', index_col=0)
metric_wins_pred_agg['Prediction Length'] = metric_wins_pred_agg['Prediction Length'].astype('int')
metric_wins_pred_agg['Train Budget'] = metric_wins_pred_agg['Train Budget'].astype('int')


for metric, group in metric_wins_pred_agg.groupby('metric'):
    p = (ggplot(group)
         + aes(x='Prediction Length', y='Wins', color='Model')
         + geom_line()
         + geom_ribbon(aes(ymin=f'Wins_low', ymax=f'Wins_upper',
                           fill='Model'), alpha=.15, outline_type='upper', show_legend=False)
         + facet_grid('~Train Budget')
         + scale_color_manual(values=colors)
         + scale_fill_manual(values=colors)
         + fpfn_theme
         + labs(title=f'{metric} Wins per Prediction Length By Data Budget (50 to 500)')
         + ylab(f'{metric} Wins')
         + geom_point(aes(shape='Model'))
         + theme(figure_size=(15, 5))
         + theme(
               legend_direction='horizontal',
               legend_position=(.5,.04),
               legend_title=element_blank(),
               legend_box_spacing=.4,
               plot_title = element_text(hjust = 0.5, size=20),
               axis_text_x=element_text(size=14),
               axis_text_y=element_text(size=14),
               axis_title_x=element_text(size=16),
               axis_title_y=element_text(size=16),
               strip_text_x=element_text(size=14),
               legend_text=element_text(size=14),
          )
         + guides(fill=guide_legend(nrow=1), color=guide_legend(nrow=1))
    )
    print(p)
#     p.save(f'figures/{metric}_wins_predlen_legend_error.pdf')


In [None]:
# Plots for Mean Rank vs Prediction Length
rank_pred_agg = pd.read_csv('rank_pred_agg.csv', index_col=0)
rank_pred_agg['Prediction Length'] = rank_pred_agg['Prediction Length'].astype(
    'int')
rank_pred_agg['Train Budget'] = rank_pred_agg['Train Budget'].astype(
    'int')


for metric, group in rank_pred_agg.groupby('metric'):
    p = (ggplot(group)
         + aes(x='Prediction Length', y='Rank', color='Model')
         + geom_line()
         + geom_ribbon(aes(ymin=f'Rank_low', ymax=f'Rank_upper',
                           fill='Model'), alpha=.15, outline_type='upper', show_legend=False)
         + facet_grid('~Train Budget')
         + scale_color_manual(values=colors)
         + scale_fill_manual(values=colors)
         + fpfn_theme
         + labs(title=f'Mean {metric} Rank per Prediction Length By Data Budget (50 to 500)')
         + ylab(f'Mean {metric} Rank')
         + geom_point(aes(shape='Model'))
         + theme(figure_size=(15, 5))
         + theme(
               legend_direction='horizontal',
               legend_position=(.5,.04),
               legend_title=element_blank(),
               legend_box_spacing=.4,
               plot_title = element_text(hjust = 0.5, size=20),
               axis_text_x=element_text(size=14),
               axis_text_y=element_text(size=14),
               axis_title_x=element_text(size=16),
               axis_title_y=element_text(size=16),
               strip_text_x=element_text(size=14),
               legend_text=element_text(size=14),
    )
        + guides(fill=guide_legend(nrow=1), color=guide_legend(nrow=1))
    )
    print(p)
#     p.save(f'figures/mean_{metric}_rank_predlen_legend_error.pdf')


In [None]:
# Plots for Wins vs. Data Budgets
metric_wins_train_agg = pd.read_csv('metric_wins_train_agg.csv', index_col=0)

for metric, group in metric_wins_train_agg.groupby('metric'):
    p = (ggplot(group)
        + aes(x='Train Budget', y='Wins', color='Model')
        + geom_line()
        + geom_ribbon(aes(ymin=f'Wins_low', ymax=f'Wins_upper',
                    fill='Model'), alpha=.15, outline_type='upper', show_legend=False)
        + scale_color_manual(values=colors)
        + scale_fill_manual(values=colors)
        + fpfn_theme
        + labs(
            title=f'Number of total {metric} Wins per Data Budget',
            x="Data Budget"
              )
        + ylab(f'{metric} Wins')
        + geom_point(aes(shape='Model'))
        # + theme(legend_position="none")
        + scale_x_continuous(trans='log10')
        )
    print(p)
#     p.save(f'figures/{metric}_wins_trainbudget_error.pdf')


In [None]:
# Plots for Wins vs. Time Budgets
metric_wins_time_agg = pd.read_csv('metric_wins_time_agg.csv', index_col=0)

for metric, group in metric_wins_time_agg.groupby('metric'):
    p = (ggplot(group)
         + aes(x='Time Budget', y='Wins', color='Model')
         + geom_line()
         + geom_ribbon(aes(ymin=f'Wins_low', ymax=f'Wins_upper',
                           fill='Model'), alpha=.15, outline_type='upper', show_legend=False)
         + scale_color_manual(values=colors)
         + scale_fill_manual(values=colors)
         + fpfn_theme
         + labs(title=f'Number of total {metric} Wins per Time Budget')
         + ylab(f'{metric} Wins')
         + geom_point(aes(shape='Model'))
         + theme(
#                  legend_position="none",
                 plot_title = element_text(hjust = 0.5)
                )
         + scale_x_continuous(trans='log10')
#          + theme(figure_size=(4.8, 4.8))
         )
    print(p)
#     p.save(f'figures/{metric}_wins_timebudget_error.pdf')


In [None]:
# Plots for Mean Rank vs. Data Budgets
rank_train_agg = pd.read_csv('rank_train_agg.csv', index_col=0)

for metric, group in rank_train_agg.groupby('metric'):
    p = (ggplot(group)
         + aes(x='Train Budget', y='Rank', color='Model')
         + geom_line()
         + geom_ribbon(aes(ymin=f'Rank_low', ymax=f'Rank_upper',
                           fill='Model'), alpha=.15, outline_type='upper', show_legend=False)
         + scale_color_manual(values=colors)
         + scale_fill_manual(values=colors)
         + fpfn_theme
         + labs(title=f'Mean {metric} Rank per Data Budget',
                x="Data Budget"
               )
         + ylab(f'Mean {metric} Rank')
         + geom_point(aes(shape='Model'))
         + theme(legend_position="none",
                plot_title = element_text(hjust = 0.5),
                )
         + scale_x_continuous(trans='log10')
         + theme(figure_size=(4.8, 4.8))
         )
    print(p)
#     p.save(f'figures/mean_{metric}_rank_trainbudget_error.pdf')


In [None]:
# Plots Mean Rank vs. Time Budgets
rank_time_agg = pd.read_csv('rank_time_agg.csv', index_col=0)

for metric, group in rank_time_agg.groupby('metric'):
    p = (ggplot(group)
         + aes(x='Time Budget', y='Rank', color='Model')
         + geom_line()
         + geom_ribbon(aes(ymin=f'Rank_low', ymax=f'Rank_upper',
                           fill='Model'), alpha=.15, outline_type='upper', show_legend=False)
         + scale_color_manual(values=colors)
         + scale_fill_manual(values=colors)
         + fpfn_theme
         + labs(title=f'Mean {metric} Rank per Time Budget')
         + ylab(f'Mean {metric} Rank')
         + geom_point(aes(shape='Model'))
         + theme(legend_position="none",
                plot_title = element_text(hjust = 0.5),               
                )
         + scale_x_continuous(trans='log10')
         + theme(figure_size=(4.8, 4.8))
         )
    print(p)
#     p.save(f'figures/mean_{metric}_rank_timebudget_error.pdf')
