# CFRNet
- Author's used imb function [here](https://github.com/clinicalml/cfrnet/blob/master/cfr_net_train.py#L48)
- Author's loss function [here](https://github.com/clinicalml/cfrnet/blob/master/cfr/cfr_net.py#L201)
- Author's CFR imb implementation [here](https://github.com/clinicalml/cfrnet/blob/master/cfr/util.py#L119-L140)

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from math import sqrt

# CFRNet Model
class CFRNet(nn.Module):
    def __init__(self, input_dim, representation_dim=100, hidden_dim=100):
        super(CFRNet, self).__init__()
        self.representation = nn.Sequential(
            nn.Linear(input_dim, representation_dim),
            nn.ReLU(),
            nn.Linear(representation_dim, representation_dim),
            nn.ReLU()
        )
        self.head_control = nn.Sequential(
            nn.Linear(representation_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.head_treated = nn.Sequential(
            nn.Linear(representation_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, t):
        phi = self.representation(x)
        y0 = self.head_control(phi)
        y1 = self.head_treated(phi)
        y_pred = t * y1 + (1 - t) * y0
        return y_pred, y0, y1, phi

# MMD Loss (identical to author's mmd2_rbf with treatment proportion p)
def compute_mmd(phi, t, p, bandwidth=1.0):
    treated = phi[t.squeeze() == 1]
    control = phi[t.squeeze() == 0]
    m = treated.size(0)
    n = control.size(0)

    def rbf_kernel(x, y, sigma):
        x_norm = (x ** 2).sum(1).view(-1, 1)
        y_norm = (y ** 2).sum(1).view(1, -1)
        cross_term = torch.mm(x, y.t())
        dist = x_norm + y_norm - 2 * cross_term
        return torch.exp(-dist / (2 * sigma**2))

    K_tt = rbf_kernel(treated, treated, bandwidth)
    K_cc = rbf_kernel(control, control, bandwidth)
    K_tc = rbf_kernel(treated, control, bandwidth)

    term_tt = (K_tt.sum() - torch.diagonal(K_tt).sum()) / (m * (m - 1))
    term_cc = (K_cc.sum() - torch.diagonal(K_cc).sum()) / (n * (n - 1))
    term_tc = K_tc.mean()

    mmd = p**2 * term_tt + (1 - p)**2 * term_cc - 2 * p * (1 - p) * term_tc
    return mmd

# Synthetic Data Simulation
def simulate_data(n=1000, p=25):
    x = np.random.normal(0, 1, size=(n, p))
    t = np.random.binomial(1, 0.5, size=(n, 1))
    y0 = np.sum(x[:, :5], axis=1, keepdims=True)
    y1 = y0 + 2 + 0.5 * np.sin(np.sum(x[:, 5:10], axis=1, keepdims=True))
    y = t * y1 + (1 - t) * y0
    return x.astype(np.float32), t.astype(np.float32), y.astype(np.float32), y0.astype(np.float32), y1.astype(np.float32)

In [26]:
# Prepare data
x, t, y, y0_true, y1_true = simulate_data()
splits = train_test_split(x, t, y, y0_true, y1_true, test_size=0.2)
x_train, x_val, t_train, t_val, y_train, y_val, y0_train, y0_val, y1_train, y1_val = splits

# Torch tensors
x_train = torch.tensor(x_train)
t_train = torch.tensor(t_train)
y_train = torch.tensor(y_train)
x_val = torch.tensor(x_val)
t_val = torch.tensor(t_val)
y_val = torch.tensor(y_val)
y0_val = torch.tensor(y0_val)
y1_val = torch.tensor(y1_val)

# Model setup
model = CFRNet(input_dim=x.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
alpha = 1.0

# Training loop
for epoch in range(20):
    model.train()
    optimizer.zero_grad()
    y_pred, _, _, phi = model(x_train, t_train)

    pred_loss = F.mse_loss(y_pred, y_train)
    p = t_train.mean().item()
    mmd_loss = 4.0 * compute_mmd(phi, t_train, p, bandwidth=1.0)

    total_loss = pred_loss + alpha * mmd_loss
    total_loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {total_loss.item():.4f}, MMD: {mmd_loss.item():.4f}")

# Evaluation: PEHE on validation
model.eval()
with torch.no_grad():
    _, y0_pred, y1_pred, _ = model(x_val, t_val)
    ite_pred = (y1_pred - y0_pred).squeeze().numpy()
    ite_true = (y1_val - y0_val).squeeze().numpy()
    pehe = sqrt(mean_squared_error(ite_true, ite_pred))
    print(f"Validation PEHE: {pehe:.4f}")

Epoch 1, Loss: 7.2561, MMD: 0.0065
Epoch 2, Loss: 7.1113, MMD: 0.0065
Epoch 3, Loss: 6.9736, MMD: 0.0063
Epoch 4, Loss: 6.8384, MMD: 0.0060
Epoch 5, Loss: 6.7011, MMD: 0.0056
Epoch 6, Loss: 6.5594, MMD: 0.0052
Epoch 7, Loss: 6.4089, MMD: 0.0047
Epoch 8, Loss: 6.2464, MMD: 0.0041
Epoch 9, Loss: 6.0697, MMD: 0.0036
Epoch 10, Loss: 5.8769, MMD: 0.0031
Epoch 11, Loss: 5.6673, MMD: 0.0027
Epoch 12, Loss: 5.4420, MMD: 0.0023
Epoch 13, Loss: 5.2049, MMD: 0.0019
Epoch 14, Loss: 4.9601, MMD: 0.0016
Epoch 15, Loss: 4.7149, MMD: 0.0014
Epoch 16, Loss: 4.4781, MMD: 0.0012
Epoch 17, Loss: 4.2600, MMD: 0.0010
Epoch 18, Loss: 4.0690, MMD: 0.0009
Epoch 19, Loss: 3.9110, MMD: 0.0008
Epoch 20, Loss: 3.7843, MMD: 0.0007
Validation PEHE: 0.5527
