-
Notifications
You must be signed in to change notification settings - Fork 20
/
torch_metrics.py
28 lines (24 loc) · 922 Bytes
/
torch_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# third party
import torch
def sqrt_PEHE(po: torch.Tensor, hat_te: torch.Tensor) -> torch.Tensor:
"""
Precision in Estimation of Heterogeneous Effect(PyTorch version).
PEHE reflects the ability to capture individual variation in treatment effects.
Args:
po: expected outcome.
hat_te: estimated outcome.
"""
po = torch.Tensor(po)
hat_te = torch.Tensor(hat_te)
return torch.sqrt(torch.mean(((po[:, 1] - po[:, 0]) - hat_te) ** 2))
def abs_error_ATE(po: torch.Tensor, hat_te: torch.Tensor) -> torch.Tensor:
"""
Average Treatment Effect.
ATE measures what is the expected causal effect of the treatment across all individuals in the population.
Args:
po: expected outcome.
hat_te: estimated outcome.
"""
po = torch.Tensor(po)
hat_te = torch.Tensor(hat_te)
return torch.abs(torch.mean(po[:, 1] - po[:, 0]) - torch.mean(hat_te))