# Train FuXi-S2S Bias Correction (PAGASA)

This notebook **trains a statistical bias correction model** using the station observations in `data/pagasa` and FuXi-S2S forecast outputs in `output/`.

Important: The repository ships an **ONNX inference model** (`model/fuxi_s2s.onnx`). Without the original training code/weights, we cannot retrain the deep network itself here. Instead, we train a lightweight correction layer (linear calibration) that improves station-level forecasts.

Outputs:
- `train_fuxi/artifacts/bias_correction_params.pkl` (pickled correction parameters)

In [None]:
# 1) Setup
import os
import sys
from pathlib import Path
import numpy as np
import pandas as pd

REPO_ROOT = Path(r"C:\Machine Learning\FuXi-S2S")
os.chdir(REPO_ROOT)
sys.path.insert(0, str(REPO_ROOT))

print('Repo root:', REPO_ROOT)
print('Python:', sys.version)

In [None]:
# 2) Load PAGASA observations
pagasa_file = REPO_ROOT / 'data' / 'pagasa' / 'CBSUA Pili, Camarines Sur Daily data.xlsx'
assert pagasa_file.exists(), f'Missing PAGASA file: {pagasa_file}'

obs = pd.read_excel(pagasa_file)
print('Raw shape:', obs.shape)
print('Columns:', list(obs.columns))

# Build a Date column from YEAR/MONTH/DAY
required = ['YEAR', 'MONTH', 'DAY']
for col in required:
    if col not in obs.columns:
        raise ValueError(f'Missing required column {col} in PAGASA file. Found: {list(obs.columns)}')

obs['Date'] = pd.to_datetime(obs[['YEAR', 'MONTH', 'DAY']])
obs['date_only'] = obs['Date'].dt.date

# Normalize likely observation column names (keep original if present)
# Expecting at least RAINFALL and some temperature column(s)
print('Date range:', obs['Date'].min(), 'to', obs['Date'].max())

display(obs.head())

In [None]:
# 3) Discover available model runs in output/
from glob import glob

output_root = REPO_ROOT / 'output'
assert output_root.exists(), 'Missing output/ directory. Run inference first.'

# Find init dates by looking for output/YYYY/YYYYMMDD/member/ directories
member_dirs = sorted(Path(output_root).glob('*/*/member'))
member_dirs = [d for d in member_dirs if d.is_dir()]

# Keep only those that actually contain NetCDF files
valid_member_dirs = []
for d in member_dirs:
    has_flat = any(d.glob('member??_lead??.nc'))
    has_sub = any(d.glob('[0-9][0-9]/*.nc'))
    if has_flat or has_sub:
        valid_member_dirs.append(d)

if not valid_member_dirs:
    raise RuntimeError('No forecast NetCDF files found in output/. Run inference first.')

# Ensure init_date is defined and converted to an integer for comparison
init_date = '20000101'  # Default value or replace with a valid date string
init_date_int = int(init_date)

init_dates = sorted({d.parent.name for d in valid_member_dirs if int(d.parent.name) >= init_date_int})  # d=.../YYYY/YYYYMMDD/member
print('Found init dates:', init_dates[:10], ('...' if len(init_dates) > 10 else ''))
print('Total init dates:', len(init_dates))

In [None]:
# 4) Build training pairs (forecast vs observed) at station location
from fuxis2s_model.core.compare import load_fuxi_output, extract_station_forecast, STATIONS

station_name = 'CBSUA Pili'
station = STATIONS[station_name]
station_lat = station['lat']
station_lon = station['lon']

def list_members_for_init(init_date: str) -> list[int]:
    year = init_date[:4]
    member_dir = output_root / year / init_date / 'member'
    if not member_dir.exists():
        return []

    files = list(member_dir.glob('member??_lead??.nc'))
    members = set()
    for p in files:
        name = p.name
        try:
            members.add(int(name.split('_lead')[0].replace('member', '')))
        except Exception:
            pass

    # fallback: subdir structure member/00/*.nc
    if not members:
        for sub in member_dir.glob('[0-9][0-9]'):
            try:
                members.add(int(sub.name))
            except Exception:
                pass

    return sorted(members)

def ensemble_mean_forecast_df(init_date: str, members: list[int]) -> pd.DataFrame:
    per_member = []
    for m in members:
        try:
            da = load_fuxi_output(str(output_root), init_date=init_date, member=m)
        except OSError as e:
            print(f"Skipping invalid NetCDF file for member {m} on init_date {init_date}: {e}")
            continue
        df = extract_station_forecast(da, lat=station_lat, lon=station_lon, init_date=init_date)
        df['member'] = m
        per_member.append(df)

    all_df = pd.concat(per_member, ignore_index=True)

    # Mean over members for each lead time
    numeric_cols = all_df.select_dtypes(include=[np.number]).columns.tolist()
    group_cols = ['lead_time_days']
    mean_df = all_df.groupby(group_cols, as_index=False)[numeric_cols].mean()

    # Restore non-numeric time columns from first member (they're identical per lead time)
    time_cols = ['init_time', 'valid_time']
    first = all_df.sort_values('member').drop_duplicates('lead_time_days')[['lead_time_days'] + time_cols]
    mean_df = mean_df.merge(first, on='lead_time_days', how='left')

    mean_df['init_date'] = init_date
    return mean_df

pairs = []
for init_date in init_dates:
    members = list_members_for_init(init_date)
    if not members:
        continue

    f = ensemble_mean_forecast_df(init_date, members)
    f['date_only'] = pd.to_datetime(f['valid_time']).dt.date

    merged = f.merge(obs, on='date_only', how='inner', suffixes=('', '_obs'))
    if len(merged) == 0:
        continue

    pairs.append(merged)

if not pairs:
    raise RuntimeError('No overlapping forecast/observation dates found. Try running inference for dates covered by PAGASA data.')

training_df = pd.concat(pairs, ignore_index=True)
print('Training pairs:', len(training_df))
print('Columns:', list(training_df.columns)[:25], '...')

display(training_df.head())

In [None]:
# 5) Train and save bias corrector
from train_fuxi.bias_correction import BiasCorrector

# Choose observation columns available in your Excel file.
# Common expected names: TMAX, TMIN, RAINFALL, WINDSPEED
available = set(training_df.columns)
print('Available obs-like columns:', [c for c in ['TMAX','TMIN','RAINFALL','WINDSPEED'] if c in available])

mapping = {}
if 'TMAX' in available and 't2m_celsius' in available:
    mapping['t2m_celsius'] = 'TMAX'
if 'RAINFALL' in available and 'tp' in available:
    mapping['tp'] = 'RAINFALL'
if 'WINDSPEED' in available and 'wind_speed' in available:
    mapping['wind_speed'] = 'WINDSPEED'

if not mapping:
    raise RuntimeError('No mapping could be built. Check your PAGASA column names and forecast columns.')

corrector = BiasCorrector().fit(training_df, mapping=mapping)
save_path = REPO_ROOT / 'train_fuxi' / 'artifacts' / 'bias_correction_params.pkl'
saved = corrector.save(save_path)

print('Trained models for:', list(corrector.models.keys()))
print('Saved to:', saved)

In [None]:
# 6) Quick evaluation (RMSE before vs after)
def rmse(a, b):
    a = np.asarray(a, dtype=float)
    b = np.asarray(b, dtype=float)
    m = np.isfinite(a) & np.isfinite(b)
    if m.sum() == 0:
        return np.nan
    return float(np.sqrt(np.mean((a[m] - b[m])**2)))

corrected = corrector.transform(training_df)

for forecast_col, obs_col in mapping.items():
    raw = rmse(training_df[forecast_col], training_df[obs_col])
    cor = rmse(corrected[f'{forecast_col}_corrected'], training_df[obs_col])
    print(f'{forecast_col} -> {obs_col}: RMSE raw={raw:.3f}, corrected={cor:.3f}')

In [None]:
# 7) Visualization and summary of all variables
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

# Summary statistics for all variables
summary = training_df.describe(include='all')
display(summary)

# Visualize distributions of forecast and observed variables
forecast_vars = list(mapping.keys())
obs_vars = list(mapping.values())
all_vars = forecast_vars + obs_vars

for var in all_vars:
    if var in training_df.columns:
        plt.figure(figsize=(8, 4))
        label_type = 'Forecasted' if var in forecast_vars else 'Observed' if var in obs_vars else ''
        sns.histplot(training_df[var].dropna(), kde=True, bins=30)
        plt.title(f'Distribution of {var} ({label_type})')
        plt.xlabel(f'{var} value')
        plt.ylabel('Frequency')
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.show()

# Correlation heatmap
plt.figure(figsize=(10, 8))
corr = training_df[all_vars].corr()
sns.heatmap(corr, annot=True, cmap='coolwarm', fmt='.2f',
            cbar_kws={'label': 'Correlation coefficient'})
plt.title('Correlation Heatmap of Forecast and Observed Variables')
plt.xlabel('Variables')
plt.ylabel('Variables')
plt.tight_layout()
plt.show()

In [None]:
# 8) Key variable visualizations: Humidity, Rainfall, Min/Max Temp, Wind (Observed vs Forecasted on same graph, with correct observed columns)
key_pairs = [
    # (forecast_col, observed_col, label, y_label)
    ('tp', 'RAINFALL', 'Rainfall', 'Rainfall (mm)'),
    ('t2m_celsius', 'TMAX', 'Max Temperature', 'Temperature (°C)'),
    ('wind_speed', 'WINDSPEED', 'Wind Speed', 'Wind Speed (m/s)'),
    ('wind_direction', 'WIND DIRECTION', 'Wind Direction', 'Wind Direction (deg)'),
    ('t2m_celsius', 'TMIN', 'Min Temperature', 'Temperature (°C)'),
]

# Use valid_time as x-axis if available and is datetime
if 'valid_time' in training_df.columns:
    training_df['valid_time'] = pd.to_datetime(training_df['valid_time'])

for forecast_col, observed_col, label, y_label in key_pairs:
    if forecast_col in training_df.columns and observed_col in training_df.columns:
        plt.figure(figsize=(12, 5))
        plt.plot(training_df['valid_time'], training_df[forecast_col], marker='o', linestyle='-', label=f'{label} (Forecasted)', color='tab:blue')
        plt.plot(training_df['valid_time'], training_df[observed_col], marker='s', linestyle='--', label=f'{label} (Observed)', color='tab:orange')
        plt.xlabel('Valid Time')
        plt.ylabel(y_label)
        plt.title(f'{label}: Observed vs Forecasted Over Time')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.tight_layout()
        plt.show()

# For variables with only one source (e.g., humidity), plot as before
single_vars = [
    ('specific_humidity', 'Specific Humidity (Forecasted)'),
]
for var, label in single_vars:
    if var in training_df.columns:
        plt.figure(figsize=(12, 5))
        if 'valid_time' in training_df.columns:
            plt.plot(training_df['valid_time'], training_df[var], marker='o', linestyle='-', label=label)
            plt.xlabel('Valid Time')
            plt.xticks(rotation=45)
        else:
            sns.histplot(training_df[var].dropna(), kde=True, bins=30)
            plt.xlabel(label)
        plt.title(f'{label} Over Time' if 'valid_time' in training_df.columns else f'{label} Distribution')
        plt.ylabel(label)
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.legend()
        plt.tight_layout()
        plt.show()

## Next steps
- If you want this correction applied during storage to MongoDB, we can wire `BiasCorrector.load(...)` into the store pipeline.
- For stronger corrections, we can extend the model to include lead-time buckets or quantile mapping.