In [None]:
# Probabilistic NN with one hidden neuron (Gaussian output)
# ---------------------------------------------------------
# - predicts mean μ(x) and std σ(x)
# - trained with Gaussian NLL
# - synthetic data has heteroscedastic noise (σ varies with x)

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import norm
import matplotlib.pyplot as plt

torch.manual_seed(0)
device = "cpu"

# 1) Make simple synthetic data
n = 50
# Simulate a x from uniform
x = -3 + 6 * torch.rand(n, 1).sort(dim=0).values  # shape (n, 1)
#x = torch.linspace(-3, 3, n).unsqueeze(1)
true_sigma = 0.1 + 0.5*torch.sigmoid(2*x)         # noise increases with x
y = torch.sin(x) + true_sigma * torch.randn_like(x)

x, y = x.to(device), y.to(device)

# 2) Tiny model: 1 hidden neuron (ReLU) -> 2 outputs (mu, raw_sigma)
class TinyProbNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.h = nn.Linear(1, 5)      # 1 input -> 5 hidden neurons
        self.out = nn.Linear(5, 2)    # outputs: [mu, raw_sigma]
    def forward(self, x):
        h = F.relu(self.h(x))
        mu, raw_sigma = self.out(h).chunk(2, dim=1)
        # ensure positivity; softplus is numerically stable
        sigma = F.softplus(raw_sigma) + 1e-6
        return mu, sigma

model = TinyProbNN().to(device)
opt = torch.optim.Adam(model.parameters(), lr=5e-2)

# 3) Training loop with Gaussian NLL
def gaussian_nll(y, mu, sigma):
    # NLL = 0.5*log(2πσ^2) + (y-μ)^2 / (2σ^2)
    return 0.5*torch.log(2*math.pi*sigma**2) + 0.5*((y - mu)**2)/(sigma**2)

for epoch in range(1000):
    mu, sigma = model(x)
    loss = gaussian_nll(y, mu, sigma).mean()
    opt.zero_grad()
    loss.backward()
    opt.step()
    if (epoch+1) % 200 == 0:
        print(f"epoch {epoch+1:4d} | loss {loss.item():.4f} | "
              f"μ(0)={model(x[[n//2]])[0].item():.3f}, σ(0)={model(x[[n//2]])[1].item():.3f}")

# 4) Quick sanity check at a few x locations
with torch.no_grad():
    probe = torch.tensor([[-2.0], [0.0], [2.0]])
    mu_hat, sigma_hat = model(probe)
    print("\nProbes:")
    for xi, mui, si in zip(probe.flatten(), mu_hat.flatten(), sigma_hat.flatten()):
        print(f"x={xi.item():>4.1f} -> μ={mui.item():+.3f}, σ={si.item():.3f}")


plt.figure(figsize=(8, 5))
# Scatter plot of raw data
plt.scatter(x.cpu().numpy(), y.cpu().numpy(), s=10, color='gray', alpha=0.6, label='Raw data')

# Predicted mean
plt.plot(x.cpu().numpy(), mu.cpu().detach().numpy(), color='blue', label='Predicted μ(x)')

# Candlestick: μ(x) ± 2σ(x)
mu_np = mu.cpu().detach().numpy().flatten()
sigma_np = sigma.cpu().detach().numpy().flatten()

level = norm.ppf(0.75) # for a 50% credible interval

upper = mu_np + level * sigma_np
lower = mu_np - level * sigma_np

plt.vlines(x.cpu().numpy().flatten(), lower, upper, color='red', alpha=0.7, linewidth=1, label='μ(x) ± 2σ(x)')

plt.xlabel('x')
plt.ylabel('y')
plt.title('Probabilistic NN Output')
plt.legend()
plt.tight_layout()
plt.show()



RuntimeError: mat1 and mat2 shapes cannot be multiplied (50x5 and 2x2)