In [1]:
import torch
from scipy.stats import norm

from models.dcf import DistributionalConditionalForecast
from models.pft import ProbabilisticForecastTransformer
from models.cqv import ConditionalQuantileVAE

from loader import test_loader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
dcf = DistributionalConditionalForecast(
    window_size=30, 
	num_series=2, 
	static_dim=18,
    latent_dim=32, 
	hidden_dim=128,
    dropout=0.1, 
	output_dim=2
)

dcf.to(device)
checkpoint = torch.load("checkpoints/DCF.pth", map_location=device)
dcf.load_state_dict(checkpoint["model_state_dict"])
dcf.eval()

pft = ProbabilisticForecastTransformer(
	window_size=30, 
	num_series=2, 
	static_dim=18,
    latent_dim=32, 
	d_model=64, 
	nhead=4, 
	num_layers=2,
    hidden_dim=128, 
	dropout=0.1, 
	output_dim=2
)

pft.to(device)
checkpoint = torch.load("checkpoints/PFT.pth", map_location=device)
pft.load_state_dict(checkpoint["model_state_dict"])
pft.eval()

cqv = ConditionalQuantileVAE(
	window_size=30, 
	num_series=2, 
	static_dim=18,
    latent_dim=32, 
	hidden_dim=128,
    dropout=0.1, 
	output_dim=2, 
	num_quantiles=3
)

cqv.to(device)
checkpoint = torch.load("checkpoints/CQV.pth", map_location=device)
cqv.load_state_dict(checkpoint["model_state_dict"])
cqv.eval()

print("Models loaded successfully!")

Models loaded successfully!


In [4]:
def evaluate(model, confidence=0.95):
	total = 0
	count = 0
	z = norm.ppf((1 + confidence) / 2)
	with torch.no_grad():
		for x_seq, x_cal, y in test_loader:
			x_seq, x_cal, y = x_seq.to(device), x_cal.to(device), y.to(device)
			out, _, _ = model(x_seq, x_cal)
			mu = out[:, :2]
			logvar = out[:, 2:4]
			sigma = torch.exp(0.5 * logvar)
			lower = mu - z * sigma
			upper = mu + z * sigma
			inside = ((y >= lower) & (y <= upper)).float()
			count += inside.sum().item()
			total += y.numel()

	coverage = count / total
	print(f"Coverage of {confidence * 100}% confidence interval: {coverage * 100:.2f}%")

In [5]:
evaluate(dcf, confidence=0.95)
evaluate(dcf, confidence=0.90)
evaluate(dcf, confidence=0.80)

Coverage of 95.0% confidence interval: 86.73%
Coverage of 90.0% confidence interval: 81.49%
Coverage of 80.0% confidence interval: 72.77%


In [6]:
evaluate(pft, confidence=0.95)
evaluate(pft, confidence=0.90)
evaluate(pft, confidence=0.80)

Coverage of 95.0% confidence interval: 95.45%
Coverage of 90.0% confidence interval: 91.19%
Coverage of 80.0% confidence interval: 83.07%


In [7]:
evaluate(cqv, confidence=0.95)
evaluate(cqv, confidence=0.90)
evaluate(cqv, confidence=0.80)

Coverage of 95.0% confidence interval: 100.00%
Coverage of 90.0% confidence interval: 100.00%
Coverage of 80.0% confidence interval: 100.00%


In [8]:
# Count the number of parameters in each model
print(f"DCF has {sum(p.numel() for p in dcf.parameters())} parameters")
print(f"PFT has {sum(p.numel() for p in pft.parameters())} parameters")
print(f"CQV has {sum(p.numel() for p in cqv.parameters())} parameters")

DCF has 99140 parameters
PFT has 648196 parameters
CQV has 99398 parameters
