In [87]:
import torch
import torch.nn as nn

class CovariateEmbedding(nn.Module):
    def __init__(self, num_covariates, embedding_dim):
        super(CovariateEmbedding, self).__init__()
        self.embedding_layers = nn.ModuleList([
            nn.Linear(1, embedding_dim) for _ in range(num_covariates)
        ])

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, num_covariates)
        Returns: Tensor of shape (batch_size, num_covariates, embedding_dim)
        """
        embedded = [layer(x[:, i:i+1]) for i, layer in enumerate(self.embedding_layers)]
        return torch.stack(embedded, dim=1)  # Shape: (batch, num_covariates, embedding_dim)


class TreatmentEmbedding(nn.Module):
    def __init__(self, num_treatments, embedding_dim):
        super(TreatmentEmbedding, self).__init__()
        self.embedding = nn.Embedding(num_treatments, embedding_dim)

    def forward(self, t):
        """
        t: Tensor of shape (batch_size,)
        Returns: Tensor of shape (batch_size, embedding_dim)
        """
        return self.embedding(t)


class TransformerCovariateEncoder(nn.Module):
    def __init__(self, num_covariates, embedding_dim, num_heads=4, num_layers=2):
        super(TransformerCovariateEncoder, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, num_covariates, embedding_dim)
        Returns: Tensor of shape (batch_size, num_covariates, embedding_dim)
        """
        x = x.permute(1, 0, 2)  # Transformers expect (seq_len, batch, dim)
        x = self.transformer(x)
        return x.permute(1, 0, 2)  # Convert back to (batch_size, num_covariates, embedding_dim)


class TreatmentCovariateCrossAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads=4, num_layers=1):
        super(TreatmentCovariateCrossAttention, self).__init__()
        decoder_layer = nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=num_heads)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

    def forward(self, covariate_embeddings, treatment_embeddings):
        """
        covariate_embeddings: (batch_size, num_covariates, embedding_dim) -> Acts as "memory" (key & value)
        treatment_embeddings: (batch_size, embedding_dim) -> Acts as "query"
        
        Returns: (batch_size, num_covariates, embedding_dim) - Updated covariate representation
        """
        # Expand treatment embeddings to match covariates
        treatment_embeddings = treatment_embeddings.unsqueeze(1)  # Shape (batch_size, 1, embedding_dim)

        # TransformerDecoder requires (seq_len, batch, dim) format
        memory = covariate_embeddings.permute(1, 0, 2)  # (num_covariates, batch, embedding_dim)
        query = treatment_embeddings.permute(1, 0, 2)  # (1, batch, embedding_dim)

        # Apply Transformer Decoder (cross-attention)
        updated_covariates = self.transformer_decoder(query, memory)  # Shape: (1, batch, embedding_dim)

        return updated_covariates.permute(1, 0, 2)  # Convert back to (batch_size, 1, embedding_dim)


class OutcomePrediction(nn.Module):
    def __init__(self, embedding_dim):
        super(OutcomePrediction, self).__init__()
        self.fc = nn.Linear(embedding_dim, 1)  # Final regression layer

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, 1, embedding_dim)
        Returns: (batch_size, 1)
        """
        
        x = x.reshape(x.shape[0], -1)  # Flatten before prediction
        return self.fc(x)


class TransTEE(nn.Module):
    def __init__(self, num_covariates, embedding_dim, num_treatments):
        super(TransTEE, self).__init__()
        self.covariate_embedding = CovariateEmbedding(num_covariates, embedding_dim)
        self.treatment_embedding = TreatmentEmbedding(num_treatments, embedding_dim)
        self.covariate_encoder = TransformerCovariateEncoder(num_covariates, embedding_dim)
        self.cross_attention = TreatmentCovariateCrossAttention(embedding_dim)
        self.outcome_predictor = OutcomePrediction(embedding_dim)

        # Treatment prediction head (Propensity Score)
        self.propensity_head = nn.Sequential(
            nn.Linear(embedding_dim, 1),
            nn.Sigmoid()
        )

        # Learnable epsilon (initialized small)
        self.epsilon = nn.Parameter(torch.tensor(1e-6))

    def forward(self, x, t):
        """
        x: Covariates (batch_size, num_covariates)
        t: Treatments (batch_size,)

        Returns: Estimated outcome (batch_size, 1)
        """
        x = self.covariate_embedding(x)  # Encode covariates
        e_x = self.propensity_head(torch.mean(x, dim=1))  # Propensity scores

        
        t = self.treatment_embedding(t)  # Encode treatment
        x = self.covariate_encoder(x)  # Self-attention on covariates
        x = self.cross_attention(x, t)  # Treatment-covariate interactions
        y_pred = self.outcome_predictor(x)  # Final outcome prediction

        
        return y_pred, e_x

def make_regression_loss(y_0_pred, y_1_pred, y_true, t_true):
    
    loss0 = (1 - t_true) * torch.square(y_0_pred - y_true)
    loss1 = t_true * torch.square(y_1_pred - y_true)
    loss = loss0 + loss1
    return torch.mean(loss)

def make_binary_classification_loss(t_pred, t_true):
    return nn.BCELoss()(t_pred, t_true)

def make_targeted_regularization_loss(e_x, y0_pred, y1_pred, Y, T, epsilon):
    """ Computes the doubly robust loss """
    
    # Compute predicted outcome based on treatment
    y_pred = T * y1_pred + (1 - T) * y0_pred

    
    # Compute inverse probability weights
    e_x = torch.clamp(e_x, 1e-6, 1 - 1e-6)  # Avoid division by zero
    weight = (T - e_x) / (e_x * (1 - e_x))


    # Compute y_pred_tilde (corrected y_pred with propsensity scores)
    y_pred_tilde = y_pred + epsilon * weight
    
    # Targeted regularization loss
    t_loss = torch.mean((Y-y_pred_tilde) ** 2)
    
    return t_loss



In [72]:
import pandas as pd

# Example dataset
data = pd.DataFrame({
    'age': [55, 40, 60],
    'blood_pressure': [140, 130, 150],
    'cholesterol': [200, 180, 220],
    'treatment': [1, 2, 3],  # Treatment as categorical variable
    'outcome': [120, 125, 145]  # Observed outcome (only needed for training)
})

print(data)

import torch

def dataframe_to_tensors(df):
    """
    Convert a Pandas DataFrame into PyTorch tensors.

    Args:
    df (pd.DataFrame): Input DataFrame with covariates, treatments, and optionally outcomes.

    Returns:
    covariates_tensor (torch.Tensor): Shape (batch_size, num_covariates)
    treatment_tensor (torch.Tensor): Shape (batch_size,)
    outcome_tensor (torch.Tensor or None): Shape (batch_size, 1) if available, else None
    """
    # Convert continuous covariates to float tensor
    covariates = torch.tensor(df.iloc[:, :-2].values, dtype=torch.float32)  # All except last 2 cols
    # Convert treatment to integer tensor
    treatment = torch.tensor(df['treatment'].values, dtype=torch.long)  # Long tensor for embedding lookup
    # Convert outcome if available
    outcome = torch.tensor(df['outcome'].values, dtype=torch.float32).unsqueeze(1) if 'outcome' in df else None
    
    return covariates, treatment, outcome

# Convert DataFrame
covariates_tensor, treatment_tensor, outcome_tensor = dataframe_to_tensors(data)

# Print shapes
print("Covariates shape:", covariates_tensor.shape)  # Expected: (batch_size, num_covariates)
print("Treatment shape:", treatment_tensor.shape)  # Expected: (batch_size,)
print("Outcome shape:", outcome_tensor.shape)  # Expected: (batch_size, 1)


   age  blood_pressure  cholesterol  treatment  outcome
0   55             140          200          1      120
1   40             130          180          2      125
2   60             150          220          3      145
Covariates shape: torch.Size([3, 3])
Treatment shape: torch.Size([3])
Outcome shape: torch.Size([3, 1])


In [73]:
# Define model parameters
num_covariates = 3  # Age, BP, Cholesterol
embedding_dim = 8   # Embedding dimension for both covariates & treatments
num_treatments = 5  # Assume 5 possible treatments

# Initialize the model
model = TransTEE(num_covariates, embedding_dim, num_treatments)

# Perform forward pass (inference)
predicted_outcome, e_x = model(covariates_tensor, treatment_tensor)

print("Predicted outcome:", predicted_outcome)
print("Predicted e_x:", e_x)


Predicted outcome: tensor([[ 0.2193],
        [ 0.2970],
        [-0.4642]], grad_fn=<AddmmBackward0>)
Predicted e_x: tensor([[0.0504],
        [0.0319],
        [0.0689]], grad_fn=<SigmoidBackward0>)


In [74]:
# Define loss function (Mean Squared Error for regression)
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop (one epoch for example)
num_epochs = 100
for epoch in range(num_epochs):
    optimizer.zero_grad()
    y_pred, e_x = model(covariates_tensor, treatment_tensor)  # Forward pass
    regressionloss = loss_function(y_pred, outcome_tensor)  # Compute loss
    loss.backward()  # Backpropagation
    optimizer.step()  # Update weights

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


Epoch 0, Loss: 17047.5039
Epoch 10, Loss: 16614.7246
Epoch 20, Loss: 16596.6367
Epoch 30, Loss: 16568.8438
Epoch 40, Loss: 16508.1270
Epoch 50, Loss: 16507.9863
Epoch 60, Loss: 16463.8730
Epoch 70, Loss: 16421.8418
Epoch 80, Loss: 16422.6426
Epoch 90, Loss: 16358.5273


In [75]:
# Perform forward pass (inference)
predicted_outcome = model(covariates_tensor, treatment_tensor)

print("Predicted outcome:", predicted_outcome)

Predicted outcome: (tensor([[2.5095],
        [2.8087],
        [2.6630]], grad_fn=<AddmmBackward0>), tensor([[0.1285],
        [0.0761],
        [0.1836]], grad_fn=<SigmoidBackward0>))


In [88]:
# Use Dragon Net input data frame
from causalml.dataset import synthetic_data
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Load synthetic dataset using updated API
y, X, w, tau, b, e = synthetic_data(mode=1, n=1000, p=5, sigma=1.0, adj=0.0)

# Split into train and test sets
X_train, X_test, y_train, y_test, w_train, w_test, tau_train, tau_test = train_test_split(
    X, y, w, tau, test_size=0.2, random_state=42
)

# Normalize features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Convert to PyTorch tensors
X_train, X_test = torch.tensor(X_train, dtype=torch.float32), torch.tensor(X_test, dtype=torch.float32)
y_train, y_test = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1), torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)
w_train, w_test = torch.tensor(w_train, dtype=torch.float32).unsqueeze(1), torch.tensor(w_test, dtype=torch.float32).unsqueeze(1)
tau_train, tau_test = torch.tensor(tau_train, dtype=torch.float32), torch.tensor(tau_test, dtype=torch.float32)

# Print dataset shapes to verify
print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}, w_train shape: {w_train.shape}")
print(f"tau_train shape: {tau_train.shape}")

X_train shape: torch.Size([800, 5]), y_train shape: torch.Size([800, 1]), w_train shape: torch.Size([800, 1])
tau_train shape: torch.Size([800])


In [94]:
# Define model parameters
num_covariates = 5  # Age, BP, Cholesterol
embedding_dim = 40   # Embedding dimension for both covariates & treatments
num_treatments = 2  # Assume 2 possible treatments

# Initialize the model
model = TransTEE(num_covariates, embedding_dim, num_treatments)

# Perform forward pass (inference)
predicted_outcome, e_x = model(X_train, w_train.int().squeeze())

print("Predicted outcome:", predicted_outcome)
print("Predicted e_x:", e_x)


Predicted outcome: tensor([[ 4.1619e-02],
        [-4.7174e-02],
        [ 3.7158e-02],
        [ 2.6077e-01],
        [-2.1781e-01],
        [-3.0013e-02],
        [ 1.4985e-01],
        [ 5.4474e-01],
        [ 1.3593e-01],
        [ 2.0516e-01],
        [ 2.9578e-02],
        [ 1.1050e-01],
        [-8.9443e-02],
        [-1.0838e-01],
        [-2.7468e-02],
        [ 3.1280e-01],
        [ 3.7207e-01],
        [ 3.6446e-01],
        [-2.7050e-02],
        [ 8.9254e-02],
        [ 4.6575e-01],
        [-1.3428e-01],
        [-1.8785e-01],
        [ 1.5875e-01],
        [ 2.0064e-02],
        [ 6.3650e-01],
        [-3.4367e-01],
        [-7.2917e-02],
        [ 1.6352e-01],
        [ 4.2504e-01],
        [ 2.7069e-01],
        [ 1.5387e-02],
        [ 4.9464e-01],
        [ 7.5986e-02],
        [ 1.3178e-02],
        [ 2.5009e-01],
        [-7.4835e-02],
        [-1.4526e-01],
        [ 8.5950e-02],
        [ 2.1924e-02],
        [-2.4607e-02],
        [ 9.6194e-02],
        [ 3.845

In [95]:
# Training loop

# hyper-praameter
alpha = 0.1
beta = 0.1

loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    
    y_pred, e_x = model(X_train, w_train.int().squeeze())  # Forward pass

    y_0_pred, e_x = model(X_train, torch.zeros(X_train.shape[0]).int())
    y_1_pred, e_x = model(X_train, torch.ones(X_train.shape[0]).int())
    
    regression_loss = loss_function(y_pred, y_train)  # Compute loss
    bce_loss = make_binary_classification_loss(e_x, w_train)
    vanila_loss = regression_loss + alpha * bce_loss
    
    t_loss = make_targeted_regularization_loss(e_x, y_0_pred, y_1_pred, y_train, w_train, model.epsilon)
    
    loss = vanila_loss + beta * t_loss
    loss.backward()  # Backpropagation
    optimizer.step()  # Update weights

    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss = {loss.item():.4f},  \
        regression loss: {regression_loss.item():.4f}, \
        bce loss: {bce_loss.item():.4f}, \
        t_loss: {t_loss.item():.4f}, \
        epsilon: {model.epsilon.item():.4f}, \
        ")


Epoch 0: Loss = 3.9088,          regression loss: 3.4922,         bce loss: 0.6812,         t_loss: 3.4845,         epsilon: 0.0010,         
Epoch 10: Loss = 1.4173,          regression loss: 1.2264,         bce loss: 0.6696,         t_loss: 1.2395,         epsilon: 0.0043,         
Epoch 20: Loss = 1.3869,          regression loss: 1.2010,         bce loss: 0.6582,         t_loss: 1.2009,         epsilon: 0.0034,         
Epoch 30: Loss = 1.3797,          regression loss: 1.1949,         bce loss: 0.6475,         t_loss: 1.2004,         epsilon: 0.0025,         
Epoch 40: Loss = 1.3806,          regression loss: 1.1981,         bce loss: 0.6375,         t_loss: 1.1880,         epsilon: 0.0003,         
Epoch 50: Loss = 1.3708,          regression loss: 1.1897,         bce loss: 0.6282,         t_loss: 1.1828,         epsilon: -0.0025,         
Epoch 60: Loss = 1.3357,          regression loss: 1.1568,         bce loss: 0.6197,         t_loss: 1.1691,         epsilon: -0.0053,        

In [97]:
import numpy as np
model.eval()
with torch.no_grad():
    y0_pred_test, e_x = model(X_test, torch.zeros(X_test.shape[0]).int())
    y1_pred_test, e_x = model(X_test, torch.ones(X_test.shape[0]).int())

    # Estimate Individual Treatment Effects (ITE)
    tau_hat = (y1_pred_test - y0_pred_test).squeeze().numpy()

    # Compute Mean Absolute Error
    mae = np.mean(np.abs(tau_hat - tau_test.numpy()))
    print(f"Mean Absolute Error in Treatment Effect Estimation: {mae:.4f}")


Mean Absolute Error in Treatment Effect Estimation: 0.1844


# Training transTEE on larger size of dataset leverage torch.dataset

In [101]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from causalml.dataset import synthetic_data
from sklearn.metrics import mean_absolute_error

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load synthetic dataset with larger sample size
y, X, w, tau, b, e = synthetic_data(mode=1, n=50000, p=10, sigma=1.0, adj=0.0)  # Increased dataset size

# Split into train and test sets
X_train, X_test, y_train, y_test, w_train, w_test, tau_train, tau_test = train_test_split(
    X, y, w, tau, test_size=0.2, random_state=42
)

# Normalize features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Convert to PyTorch tensors
X_train, X_test = torch.tensor(X_train, dtype=torch.float32), torch.tensor(X_test, dtype=torch.float32)
y_train, y_test = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1), torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)
w_train, w_test = torch.tensor(w_train, dtype=torch.float32).unsqueeze(1), torch.tensor(w_test, dtype=torch.float32).unsqueeze(1)
tau_train, tau_test = torch.tensor(tau_train, dtype=torch.float32), torch.tensor(tau_test, dtype=torch.float32)

# --- PyTorch Dataset & DataLoader ---
class CausalDataset(Dataset):
    def __init__(self, X, y, w):
        self.X = X
        self.y = y
        self.w = w

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.w[idx]

# Create data loaders for mini-batch training
batch_size = 512
train_dataset = CausalDataset(X_train, y_train, w_train)
test_dataset = CausalDataset(X_test, y_test, w_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


# Initialize model
input_dim = X_train.shape[1]
model = TransTEE(num_covariates, embedding_dim, num_treatments).to(device)

# Optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=0.001)

# --- Train DragonNet with Mini-batch Training ---
num_epochs = 10
step = 0
for epoch in range(num_epochs):
    model.train()

    for batch_X, batch_y, batch_w in train_loader:
        batch_X, batch_y, batch_w = batch_X.to(device), batch_y.to(device), batch_w.to(device)

        optimizer.zero_grad()
        # Forward pass
        y_pred, e_x = model(batch_X, batch_w.int().squeeze())  # Forward pass
        y_0_pred, e_x = model(batch_X, torch.zeros(batch_X.shape[0]).int())
        y_1_pred, e_x = model(batch_X, torch.ones(batch_X.shape[0]).int())
        
        regression_loss = loss_function(y_pred, batch_y)  # Compute loss
        bce_loss = make_binary_classification_loss(e_x, batch_w)
        vanila_loss = regression_loss + alpha * bce_loss

        t_loss = make_targeted_regularization_loss(e_x, y_0_pred, y_1_pred, batch_y, batch_w, model.epsilon)
    
        loss = vanila_loss + beta * t_loss
    
        loss.backward()
        optimizer.step()
        step += 1

        if step % 50 == 0:
            print(f"Epoch {epoch}: Loss = {loss.item():.4f},  \
            regression loss: {regression_loss.item():.4f}, \
            bce loss: {bce_loss.item():.4f}, \
            t_loss: {t_loss.item():.4f}, \
            epsilon: {model.epsilon.item():.4f}, \
            ")


# --- Evaluate transTEE ---
model.eval()
tau_hat = []

with torch.no_grad():
    for batch_X, _, _ in test_loader:
        batch_X = batch_X.to(device)
        y0_pred_test, e_x = model(batch_X, torch.zeros(batch_X.shape[0]).int())
        y1_pred_test, e_x = model(batch_X, torch.ones(batch_X.shape[0]).int())
        tau_hat.extend((y1_pred_test - y0_pred_test).cpu().numpy())

tau_hat = np.array(tau_hat).flatten()

# --- Compare with True Treatment Effects ---
mae = mean_absolute_error(tau_test, tau_hat)
print(f"\nMean Absolute Error for transTEE: {mae:.4f}")


Using device: cpu
Epoch 0: Loss = 1.3980,              regression loss: 1.2113,             bce loss: 0.6565,             t_loss: 1.2113,             epsilon: 0.0107,             
Epoch 1: Loss = 1.1742,              regression loss: 1.0114,             bce loss: 0.6237,             t_loss: 1.0050,             epsilon: 0.0058,             
Epoch 1: Loss = 1.1413,              regression loss: 0.9819,             bce loss: 0.6151,             t_loss: 0.9787,             epsilon: 0.0046,             
Epoch 2: Loss = 1.0814,              regression loss: 0.9330,             bce loss: 0.5637,             t_loss: 0.9204,             epsilon: 0.0035,             
Epoch 3: Loss = 1.0034,              regression loss: 0.8621,             bce loss: 0.5449,             t_loss: 0.8678,             epsilon: 0.0062,             
Epoch 3: Loss = 1.3048,              regression loss: 1.1374,             bce loss: 0.5457,             t_loss: 1.1275,             epsilon: 0.0050,             
Epoch 4: L

In [102]:
print(sum(p.numel() for p in model.parameters()))

525147
