# CFRNet
- Author's used imb function [here](https://github.com/clinicalml/cfrnet/blob/master/cfr_net_train.py#L48)
- Author's loss function [here](https://github.com/clinicalml/cfrnet/blob/master/cfr/cfr_net.py#L201)
- Author's CFR imb implementation [here](https://github.com/clinicalml/cfrnet/blob/master/cfr/util.py#L119-L140)

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from math import sqrt

# CFRNet Model
class CFRNet(nn.Module):
    def __init__(self, input_dim, representation_dim=100, hidden_dim=100):
        super(CFRNet, self).__init__()
        self.representation = nn.Sequential(
            nn.Linear(input_dim, representation_dim),
            nn.ReLU(),
            nn.Linear(representation_dim, representation_dim),
            nn.ReLU()
        )
        self.head_control = nn.Sequential(
            nn.Linear(representation_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.head_treated = nn.Sequential(
            nn.Linear(representation_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, t):
        phi = self.representation(x)
        y0 = self.head_control(phi)
        y1 = self.head_treated(phi)
        y_pred = t * y1 + (1 - t) * y0
        return y_pred, y0, y1, phi

# MMD Loss (identical to author's mmd2_rbf with treatment proportion p)
def compute_mmd(phi, t, p, bandwidth=1.0):
    treated = phi[t.squeeze() == 1]
    control = phi[t.squeeze() == 0]
    m = treated.size(0)
    n = control.size(0)

    def rbf_kernel(x, y, sigma):
        x_norm = (x ** 2).sum(1).view(-1, 1)
        y_norm = (y ** 2).sum(1).view(1, -1)
        cross_term = torch.mm(x, y.t())
        dist = x_norm + y_norm - 2 * cross_term
        return torch.exp(-dist / (2 * sigma**2))

    K_tt = rbf_kernel(treated, treated, bandwidth)
    K_cc = rbf_kernel(control, control, bandwidth)
    K_tc = rbf_kernel(treated, control, bandwidth)

    term_tt = (K_tt.sum() - torch.diagonal(K_tt).sum()) / (m * (m - 1))
    term_cc = (K_cc.sum() - torch.diagonal(K_cc).sum()) / (n * (n - 1))
    term_tc = K_tc.mean()

    mmd = p**2 * term_tt + (1 - p)**2 * term_cc - 2 * p * (1 - p) * term_tc
    return mmd

# Synthetic Data Simulation
def simulate_data(n=1000, p=25):
    x = np.random.normal(0, 1, size=(n, p))
    t = np.random.binomial(1, 0.5, size=(n, 1))
    y0 = np.sum(x[:, :5], axis=1, keepdims=True)
    y1 = y0 + 2 + 0.5 * np.sin(np.sum(x[:, 5:10], axis=1, keepdims=True))
    y = t * y1 + (1 - t) * y0
    return x.astype(np.float32), t.astype(np.float32), y.astype(np.float32), y0.astype(np.float32), y1.astype(np.float32)




In [16]:
# Prepare data
x, t, y, y0_true, y1_true = simulate_data()
splits = train_test_split(x, t, y, y0_true, y1_true, test_size=0.2)
x_train, x_val, t_train, t_val, y_train, y_val, y0_train, y0_val, y1_train, y1_val = splits

# Torch tensors
x_train = torch.tensor(x_train)
t_train = torch.tensor(t_train)
y_train = torch.tensor(y_train)
x_val = torch.tensor(x_val)
t_val = torch.tensor(t_val)
y_val = torch.tensor(y_val)
y0_val = torch.tensor(y0_val)
y1_val = torch.tensor(y1_val)

# Model setup
model = CFRNet(input_dim=x.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
alpha = 1.0

# Training loop
for epoch in range(20):
    model.train()
    optimizer.zero_grad()
    y_pred, _, _, phi = model(x_train, t_train)

    pred_loss = F.mse_loss(y_pred, y_train)
    p = t_train.mean().item()
    mmd_loss = 4.0 * compute_mmd(phi, t_train, p, bandwidth=1.0)

    total_loss = pred_loss + alpha * mmd_loss
    total_loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {total_loss.item():.4f}, MMD: {mmd_loss.item():.4f}")

# Evaluation: PEHE on validation
model.eval()
with torch.no_grad():
    _, y0_pred, y1_pred, _ = model(x_val, t_val)
    ite_pred = (y1_pred - y0_pred).squeeze().numpy()
    ite_true = (y1_val - y0_val).squeeze().numpy()
    pehe = sqrt(mean_squared_error(ite_true, ite_pred))
    print(f"Validation PEHE: {pehe:.4f}")

Epoch 1, Loss: 6.9879, MMD: 0.0001
Epoch 2, Loss: 6.8680, MMD: 0.0000
Epoch 3, Loss: 6.7535, MMD: -0.0000
Epoch 4, Loss: 6.6413, MMD: -0.0001
Epoch 5, Loss: 6.5278, MMD: -0.0001
Epoch 6, Loss: 6.4102, MMD: -0.0000
Epoch 7, Loss: 6.2861, MMD: -0.0000
Epoch 8, Loss: 6.1534, MMD: -0.0000
Epoch 9, Loss: 6.0099, MMD: 0.0000
Epoch 10, Loss: 5.8544, MMD: 0.0001
Epoch 11, Loss: 5.6865, MMD: 0.0001
Epoch 12, Loss: 5.5069, MMD: 0.0002
Epoch 13, Loss: 5.3173, MMD: 0.0002
Epoch 14, Loss: 5.1200, MMD: 0.0002
Epoch 15, Loss: 4.9191, MMD: 0.0003
Epoch 16, Loss: 4.7192, MMD: 0.0003
Epoch 17, Loss: 4.5251, MMD: 0.0004
Epoch 18, Loss: 4.3414, MMD: 0.0004
Epoch 19, Loss: 4.1712, MMD: 0.0005
Epoch 20, Loss: 4.0129, MMD: 0.0005
Validation PEHE: 0.6381


In [7]:
len(splits)

10

In [13]:
# TMP
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from math import sqrt

# CFRNet Model
class CFRNet(nn.Module):
    def __init__(self, input_dim, representation_dim=100, hidden_dim=100):
        super(CFRNet, self).__init__()
        self.representation = nn.Sequential(
            nn.Linear(input_dim, representation_dim),
            nn.ReLU(),
            nn.Linear(representation_dim, representation_dim),
            nn.ReLU()
        )
        self.head_control = nn.Sequential(
            nn.Linear(representation_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.head_treated = nn.Sequential(
            nn.Linear(representation_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, t):
        phi = self.representation(x)
        y0 = self.head_control(phi)
        y1 = self.head_treated(phi)
        y_pred = t * y1 + (1 - t) * y0
        return y_pred, y0, y1, phi

# MMD Loss (identical to author's mmd2_rbf with treatment proportion p)
def compute_mmd(phi, t, p, bandwidth=1.0):
    treated = phi[t.squeeze() == 1]
    control = phi[t.squeeze() == 0]
    m = treated.size(0)
    n = control.size(0)

    def rbf_kernel(x, y, sigma):
        x_norm = (x ** 2).sum(1).view(-1, 1)
        y_norm = (y ** 2).sum(1).view(1, -1)
        cross_term = torch.mm(x, y.t())
        dist = x_norm + y_norm - 2 * cross_term
        return torch.exp(-dist / (2 * sigma**2))

    K_tt = rbf_kernel(treated, treated, bandwidth)
    K_cc = rbf_kernel(control, control, bandwidth)
    K_tc = rbf_kernel(treated, control, bandwidth)

    term_tt = (K_tt.sum() - torch.diagonal(K_tt).sum()) / (m * (m - 1))
    term_cc = (K_cc.sum() - torch.diagonal(K_cc).sum()) / (n * (n - 1))
    term_tc = K_tc.mean()

    mmd = p**2 * term_tt + (1 - p)**2 * term_cc - 2 * p * (1 - p) * term_tc
    return mmd

# Synthetic Data Simulation
def simulate_data(n=1000, p=25):
    x = np.random.normal(0, 1, size=(n, p))
    t = np.random.binomial(1, 0.5, size=(n, 1))
    y0 = np.sum(x[:, :5], axis=1, keepdims=True)
    y1 = y0 + 2 + 0.5 * np.sin(np.sum(x[:, 5:10], axis=1, keepdims=True))
    y = t * y1 + (1 - t) * y0
    return x.astype(np.float32), t.astype(np.float32), y.astype(np.float32), y0.astype(np.float32), y1.astype(np.float32)

# Prepare data
x, t, y, y0_true, y1_true = simulate_data()
splits = train_test_split(x, t, y, y0_true, y1_true, test_size=0.2)
x_train, x_val, t_train, t_val, y_train, y_val, y0_train, y0_val, y1_train, y1_val = splits

# Torch tensors
x_train = torch.tensor(x_train)
t_train = torch.tensor(t_train)
y_train = torch.tensor(y_train)
x_val = torch.tensor(x_val)
t_val = torch.tensor(t_val)
y_val = torch.tensor(y_val)
y0_val = torch.tensor(y0_val)
y1_val = torch.tensor(y1_val)

# Model setup
model = CFRNet(input_dim=x.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
alpha = 1.0

# Training loop
for epoch in range(20):
    model.train()
    optimizer.zero_grad()
    y_pred, _, _, phi = model(x_train, t_train)

    pred_loss = F.mse_loss(y_pred, y_train)
    p = t_train.mean().item()
    mmd_loss = 4.0 * compute_mmd(phi, t_train, p, bandwidth=1.0)

    total_loss = pred_loss + alpha * mmd_loss
    total_loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {total_loss.item():.4f}, MMD: {mmd_loss.item():.4f}")

# Evaluation: PEHE on validation
model.eval()
with torch.no_grad():
    _, y0_pred, y1_pred, _ = model(x_val, t_val)
    ite_pred = (y1_pred - y0_pred).squeeze().numpy()
    ite_true = (y1_val - y0_val).squeeze().numpy()
    pehe = sqrt(mean_squared_error(ite_true, ite_pred))
    print(f"Validation PEHE: {pehe:.4f}")


Epoch 1, Loss: 7.3874, MMD: 0.0042
Epoch 2, Loss: 7.2541, MMD: 0.0043
Epoch 3, Loss: 7.1266, MMD: 0.0043
Epoch 4, Loss: 7.0000, MMD: 0.0042
Epoch 5, Loss: 6.8711, MMD: 0.0040
Epoch 6, Loss: 6.7366, MMD: 0.0037
Epoch 7, Loss: 6.5933, MMD: 0.0034
Epoch 8, Loss: 6.4390, MMD: 0.0031
Epoch 9, Loss: 6.2721, MMD: 0.0028
Epoch 10, Loss: 6.0911, MMD: 0.0025
Epoch 11, Loss: 5.8963, MMD: 0.0022
Epoch 12, Loss: 5.6880, MMD: 0.0019
Epoch 13, Loss: 5.4673, MMD: 0.0017
Epoch 14, Loss: 5.2366, MMD: 0.0014
Epoch 15, Loss: 5.0002, MMD: 0.0013
Epoch 16, Loss: 4.7636, MMD: 0.0011
Epoch 17, Loss: 4.5330, MMD: 0.0010
Epoch 18, Loss: 4.3166, MMD: 0.0009
Epoch 19, Loss: 4.1196, MMD: 0.0008
Epoch 20, Loss: 3.9444, MMD: 0.0007
Validation PEHE: 0.6714
