In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)
device = "cpu"

# 1) Synthetic data with heteroscedastic noise
n = 200
x = torch.linspace(-3, 3, n).unsqueeze(1)
true_sigma = 0.1 + 0.5*torch.sigmoid(2*x)
y = torch.sin(x) + true_sigma * torch.randn_like(x)
x, y = x.to(device), y.to(device)

# 2) Tiny model: 1 hidden neuron -> (mu, sigma)
class TinyProbNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.h = nn.Linear(1, 1)
        self.out = nn.Linear(1, 2)  # -> [mu, raw_sigma]
    def forward(self, x):
        h = F.relu(self.h(x))
        mu, raw_sigma = self.out(h).chunk(2, dim=1)
        sigma = F.softplus(raw_sigma) + 1e-6  # ensure positivity
        return mu, sigma

model = TinyProbNN().to(device)
opt = torch.optim.Adam(model.parameters(), lr=5e-2)

# 3) CRPS for Gaussian outputs (closed form)
def std_norm_pdf(z):
    return torch.exp(-0.5*z*z) / math.sqrt(2*math.pi)

def std_norm_cdf(z):
    # CDF via erf for numerical stability
    return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))

def crps_gaussian(y, mu, sigma):
    z = (y - mu) / sigma
    Phi = std_norm_cdf(z)
    phi = std_norm_pdf(z)
    return sigma * (z*(2*Phi - 1) + 2*phi - 1/math.sqrt(math.pi))

# 4) Train with CRPS
for epoch in range(800):
    mu, sigma = model(x)
    loss = crps_gaussian(y, mu, sigma).mean()
    opt.zero_grad()
    loss.backward()
    opt.step()
    if (epoch+1) % 200 == 0:
        mid = n//2
        mu0, s0 = model(x[[mid]])
        print(f"epoch {epoch+1:4d} | CRPS {loss.item():.4f} | μ(0)={mu0.item():+.3f}, σ(0)={s0.item():.3f}")

# 5) Quick probes
with torch.no_grad():
    probe = torch.tensor([[-2.0], [0.0], [2.0]])
    mu_hat, sigma_hat = model(probe)
    print("\nProbes (CRPS-trained):")
    for xi, mui, si in zip(probe.flatten(), mu_hat.flatten(), sigma_hat.flatten()):
        print(f"x={xi.item():>4.1f} -> μ={mui.item():+.3f}, σ={si.item():.3f}")
