# Paleo AI Viewer – Holocene Temperature Prediction

This notebook loads Holocene paleotemperature data, applies an ensemble of trained neural networks, and generates an interactive scatter plot comparing predicted and true temperature values.

The plot includes a dropdown menu to color-code the points by season, proxy, or archive type.


In [None]:
#Install dependencies

### Interactive Visualization

In [None]:
!pip install -q torch==2.* plotly pandas scikit-learn joblib

#Load from GitHub
!git clone https://github.com/ArturStachnik/Paleo_AI_viewer.git repo

#Imports
import os, glob, torch, joblib
import pandas as pd
import numpy as np
import torch.nn as nn
from sklearn.model_selection import train_test_split

#File paths
MODELS_DIR = "repo/models"
CSV_PATH   = "repo/Holocene_T_clean.csv"

#Define Model and Utilities
class ProbNet(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128), nn.ReLU(), nn.Dropout(0.1),
            nn.Linear(128, 64), nn.ReLU(),
        )
        self.loc     = nn.Linear(64, 1)
        self.log_var = nn.Linear(64, 1)
    def forward(self, x):
        h = self.net(x)
        return self.loc(h), self.log_var(h)

def load_ensemble(model_dir, device="cpu"):
    scaler = joblib.load(os.path.join(model_dir, "scaler.pkl"))
    input_dim = scaler.scale_.shape[0] + 3
    models = []
    for path in sorted(glob.glob(os.path.join(model_dir, "model_*.pt"))):
        m = ProbNet(input_dim).to(device)
        m.load_state_dict(torch.load(path, map_location=device))
        m.eval()
        models.append(m)
    return models, scaler

def predict_mean(models, X):
    X_t = torch.tensor(X, dtype=torch.float32)
    with torch.no_grad():
        mus = [m(X_t)[0].cpu().numpy().flatten() for m in models]
    return np.mean(mus, axis=0)

#Load and prepare data
RAW_DF = pd.read_csv(CSV_PATH)
seasons    = ["annual", "warm", "cold"]
season_ohe = pd.get_dummies(RAW_DF["season"])[seasons].values
cont       = RAW_DF[["lon", "lat", "elev", "age", "resolution"]].values
y          = RAW_DF["temperature"].values

Xc_tr, Xc_te, y_tr, y_te, s_tr, s_te, idx_tr, idx_te = train_test_split(
    cont, y, season_ohe, RAW_DF.index,
    test_size=0.2, random_state=0, shuffle=True
)

device     = "cuda" if torch.cuda.is_available() else "cpu"
models, sc = load_ensemble(MODELS_DIR, device)
X_test     = np.hstack([sc.transform(Xc_te), s_te])
mu_pred    = predict_mean(models, X_test)

#Create DataFrame
plot_df = pd.DataFrame({
    "True":       y_te,
    "Predicted":  mu_pred,
    "season":     RAW_DF.loc[idx_te, "season"].values,
    "proxy":      RAW_DF.loc[idx_te, "proxy"].values,
    "archive":    RAW_DF.loc[idx_te, "archive"].values,
    "dataSetName":RAW_DF.loc[idx_te, "dataSetName"].values,
})