In [18]:
import torch
from scipy.stats import norm

from models.dcf import DistributionalConditionalForecast
from models.pft import ProbabilisticForecastTransformer
from models.cqv import ConditionalQuantileVAE

from loader import test_loader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:

dcf = DistributionalConditionalForecast(
    window_size=30, 
	num_series=2, 
	static_dim=18,
    latent_dim=32, 
	hidden_dim=128,
    dropout=0.1, 
	output_dim=2
)

dcf.to(device)
checkpoint = torch.load("checkpoints/DCF.pth", map_location=device)
dcf.load_state_dict(checkpoint["model_state_dict"])
dcf.eval()

pft = ProbabilisticForecastTransformer(
	window_size=30, 
	num_series=2, 
	static_dim=18,
    latent_dim=32, 
	d_model=64, 
	nhead=4, 
	num_layers=2,
    hidden_dim=128, 
	dropout=0.1, 
	output_dim=2
)

pft.to(device)
checkpoint = torch.load("checkpoints/PFT.pth", map_location=device)
pft.load_state_dict(checkpoint["model_state_dict"])
pft.eval()

cqv = ConditionalQuantileVAE(
	window_size=30, 
	num_series=2, 
	static_dim=18,
    latent_dim=32, 
	hidden_dim=128,
    dropout=0.1, 
	output_dim=2, 
	num_quantiles=3
)

cqv.to(device)
checkpoint = torch.load("checkpoints/CQV.pth", map_location=device)
cqv.load_state_dict(checkpoint["model_state_dict"])
cqv.eval()

print("Models loaded successfully!")

Models loaded successfully!


In [None]:
# Kiểm tra xem bao nhiêu % dữ liệu nằm trong khoảng tin cậy 95% của phân phối trả về
def evaluate(model, confidence=0.95):
	total = 0
	count = 0
	z = norm.ppf((1 + confidence) / 2)
	with torch.no_grad():
		for x_seq, x_cal, y in test_loader:
			x_seq, x_cal, y = x_seq.to(device), x_cal.to(device), y.to(device)
			out, _, _ = model(x_seq, x_cal)
			mu = out[:, :2]
			logvar = out[:, 2:4]
			sigma = torch.exp(0.5 * logvar)
			lower = mu - z * sigma
			upper = mu + z * sigma
			inside = ((y >= lower) & (y <= upper)).float()
			count += inside.sum().item()
			total += y.numel()

	coverage = count / total
	print(f"Coverage: {coverage:.4f}")
	return coverage

In [24]:
print(f"With 95% confidence:")
print(f"DCF: {evaluate(dcf):.2f}")

With 95% confidence:
Coverage: 0.87


TypeError: unsupported format string passed to NoneType.__format__

In [21]:
evaluate(pft)

Coverage: 0.95


In [22]:
evaluate(cqv)

Coverage: 1.00
