In [61]:
import torch
import torch.nn as nn

class CovariateEmbedding(nn.Module):
    def __init__(self, num_covariates, embedding_dim):
        super(CovariateEmbedding, self).__init__()
        self.embedding_layers = nn.ModuleList([
            nn.Linear(1, embedding_dim) for _ in range(num_covariates)
        ])

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, num_covariates)
        Returns: Tensor of shape (batch_size, num_covariates, embedding_dim)
        """
        embedded = [layer(x[:, i:i+1]) for i, layer in enumerate(self.embedding_layers)]
        return torch.stack(embedded, dim=1)  # Shape: (batch, num_covariates, embedding_dim)


class TreatmentEmbedding(nn.Module):
    def __init__(self, num_treatments, embedding_dim):
        super(TreatmentEmbedding, self).__init__()
        self.embedding = nn.Embedding(num_treatments, embedding_dim)

    def forward(self, t):
        """
        t: Tensor of shape (batch_size,)
        Returns: Tensor of shape (batch_size, embedding_dim)
        """
        return self.embedding(t)


class TransformerCovariateEncoder(nn.Module):
    def __init__(self, num_covariates, embedding_dim, num_heads=4, num_layers=2):
        super(TransformerCovariateEncoder, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, num_covariates, embedding_dim)
        Returns: Tensor of shape (batch_size, num_covariates, embedding_dim)
        """
        x = x.permute(1, 0, 2)  # Transformers expect (seq_len, batch, dim)
        x = self.transformer(x)
        return x.permute(1, 0, 2)  # Convert back to (batch_size, num_covariates, embedding_dim)


class TreatmentCovariateCrossAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads=4, num_layers=1):
        super(TreatmentCovariateCrossAttention, self).__init__()
        decoder_layer = nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=num_heads)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

    def forward(self, covariate_embeddings, treatment_embeddings):
        """
        covariate_embeddings: (batch_size, num_covariates, embedding_dim) -> Acts as "memory" (key & value)
        treatment_embeddings: (batch_size, embedding_dim) -> Acts as "query"
        
        Returns: (batch_size, num_covariates, embedding_dim) - Updated covariate representation
        """
        # Expand treatment embeddings to match covariates
        treatment_embeddings = treatment_embeddings.unsqueeze(1)  # Shape (batch_size, 1, embedding_dim)

        # TransformerDecoder requires (seq_len, batch, dim) format
        memory = covariate_embeddings.permute(1, 0, 2)  # (num_covariates, batch, embedding_dim)
        query = treatment_embeddings.permute(1, 0, 2)  # (1, batch, embedding_dim)

        # Apply Transformer Decoder (cross-attention)
        updated_covariates = self.transformer_decoder(query, memory)  # Shape: (1, batch, embedding_dim)

        return updated_covariates.permute(1, 0, 2)  # Convert back to (batch_size, 1, embedding_dim)


class OutcomePrediction(nn.Module):
    def __init__(self, embedding_dim):
        super(OutcomePrediction, self).__init__()
        self.fc = nn.Linear(embedding_dim, 1)  # Final regression layer

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, 1, embedding_dim)
        Returns: (batch_size, 1)
        """
        
        x = x.reshape(x.shape[0], -1)  # Flatten before prediction
        return self.fc(x)


class TransTEE(nn.Module):
    def __init__(self, num_covariates, embedding_dim, num_treatments):
        super(TransTEE, self).__init__()
        self.covariate_embedding = CovariateEmbedding(num_covariates, embedding_dim)
        self.treatment_embedding = TreatmentEmbedding(num_treatments, embedding_dim)
        self.covariate_encoder = TransformerCovariateEncoder(num_covariates, embedding_dim)
        self.cross_attention = TreatmentCovariateCrossAttention(embedding_dim)
        self.outcome_predictor = OutcomePrediction(embedding_dim)

    def forward(self, x, t):
        """
        x: Covariates (batch_size, num_covariates)
        t: Treatments (batch_size,)

        Returns: Estimated outcome (batch_size, 1)
        """
        x = self.covariate_embedding(x)  # Encode covariates
        t = self.treatment_embedding(t)  # Encode treatment
        x = self.covariate_encoder(x)  # Self-attention on covariates
        x = self.cross_attention(x, t)  # Treatment-covariate interactions
        y_pred = self.outcome_predictor(x)  # Final outcome prediction
        return y_pred


In [62]:
import pandas as pd

# Example dataset
data = pd.DataFrame({
    'age': [55, 40, 60],
    'blood_pressure': [140, 130, 150],
    'cholesterol': [200, 180, 220],
    'treatment': [1, 2, 3],  # Treatment as categorical variable
    'outcome': [120, 125, 145]  # Observed outcome (only needed for training)
})

print(data)

import torch

def dataframe_to_tensors(df):
    """
    Convert a Pandas DataFrame into PyTorch tensors.

    Args:
    df (pd.DataFrame): Input DataFrame with covariates, treatments, and optionally outcomes.

    Returns:
    covariates_tensor (torch.Tensor): Shape (batch_size, num_covariates)
    treatment_tensor (torch.Tensor): Shape (batch_size,)
    outcome_tensor (torch.Tensor or None): Shape (batch_size, 1) if available, else None
    """
    # Convert continuous covariates to float tensor
    covariates = torch.tensor(df.iloc[:, :-2].values, dtype=torch.float32)  # All except last 2 cols
    # Convert treatment to integer tensor
    treatment = torch.tensor(df['treatment'].values, dtype=torch.long)  # Long tensor for embedding lookup
    # Convert outcome if available
    outcome = torch.tensor(df['outcome'].values, dtype=torch.float32).unsqueeze(1) if 'outcome' in df else None
    
    return covariates, treatment, outcome

# Convert DataFrame
covariates_tensor, treatment_tensor, outcome_tensor = dataframe_to_tensors(data)

# Print shapes
print("Covariates shape:", covariates_tensor.shape)  # Expected: (batch_size, num_covariates)
print("Treatment shape:", treatment_tensor.shape)  # Expected: (batch_size,)
print("Outcome shape:", outcome_tensor.shape)  # Expected: (batch_size, 1)


   age  blood_pressure  cholesterol  treatment  outcome
0   55             140          200          1      120
1   40             130          180          2      125
2   60             150          220          3      145
Covariates shape: torch.Size([3, 3])
Treatment shape: torch.Size([3])
Outcome shape: torch.Size([3, 1])


In [63]:
# Define model parameters
num_covariates = 3  # Age, BP, Cholesterol
embedding_dim = 8   # Embedding dimension for both covariates & treatments
num_treatments = 5  # Assume 5 possible treatments

# Initialize the model
model = TransTEE(num_covariates, embedding_dim, num_treatments)

# Perform forward pass (inference)
predicted_outcome = model(covariates_tensor, treatment_tensor)

print("Predicted outcome:", predicted_outcome)


Predicted outcome: tensor([[-1.6538],
        [-0.3132],
        [ 1.1677]], grad_fn=<AddmmBackward0>)


In [64]:
# Define loss function (Mean Squared Error for regression)
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop (one epoch for example)
num_epochs = 100
for epoch in range(num_epochs):
    optimizer.zero_grad()
    y_pred = model(covariates_tensor, treatment_tensor)  # Forward pass
    loss = loss_function(y_pred, outcome_tensor)  # Compute loss
    loss.backward()  # Backpropagation
    optimizer.step()  # Update weights

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


Epoch 0, Loss: 17089.8926
Epoch 10, Loss: 16571.9609
Epoch 20, Loss: 16542.5371
Epoch 30, Loss: 16476.1855
Epoch 40, Loss: 16453.0371
Epoch 50, Loss: 16405.7246
Epoch 60, Loss: 16397.1191
Epoch 70, Loss: 16333.0781
Epoch 80, Loss: 16326.8389
Epoch 90, Loss: 16271.6143


In [26]:
# Perform forward pass (inference)
predicted_outcome = model(covariates_tensor, treatment_tensor)

print("Predicted outcome:", predicted_outcome)

Predicted outcome: tensor([[5.7516],
        [5.7539],
        [5.6316]], grad_fn=<AddmmBackward0>)


In [65]:
# Use Dragon Net input data frame
from causalml.dataset import synthetic_data
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Load synthetic dataset using updated API
y, X, w, tau, b, e = synthetic_data(mode=1, n=1000, p=5, sigma=1.0, adj=0.0)

# Split into train and test sets
X_train, X_test, y_train, y_test, w_train, w_test, tau_train, tau_test = train_test_split(
    X, y, w, tau, test_size=0.2, random_state=42
)

# Normalize features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Convert to PyTorch tensors
X_train, X_test = torch.tensor(X_train, dtype=torch.float32), torch.tensor(X_test, dtype=torch.float32)
y_train, y_test = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1), torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)
w_train, w_test = torch.tensor(w_train, dtype=torch.float32).unsqueeze(1), torch.tensor(w_test, dtype=torch.float32).unsqueeze(1)
tau_train, tau_test = torch.tensor(tau_train, dtype=torch.float32), torch.tensor(tau_test, dtype=torch.float32)

# Print dataset shapes to verify
print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}, w_train shape: {w_train.shape}")
print(f"tau_train shape: {tau_train.shape}")

X_train shape: torch.Size([800, 5]), y_train shape: torch.Size([800, 1]), w_train shape: torch.Size([800, 1])
tau_train shape: torch.Size([800])


In [66]:
# Define model parameters
num_covariates = 5  # Age, BP, Cholesterol
embedding_dim = 20   # Embedding dimension for both covariates & treatments
num_treatments = 2  # Assume 2 possible treatments

# Initialize the model
model = TransTEE(num_covariates, embedding_dim, num_treatments)

# Perform forward pass (inference)
predicted_outcome = model(X_train, w_train.int().squeeze())

print("Predicted outcome:", predicted_outcome)


Predicted outcome: tensor([[ 1.0359e+00],
        [ 2.9412e-01],
        [ 2.3322e-01],
        [ 4.3096e-01],
        [ 7.5649e-01],
        [ 1.0054e+00],
        [ 9.5990e-01],
        [ 8.9026e-01],
        [ 5.0006e-01],
        [ 2.8021e-01],
        [ 1.3730e-01],
        [ 9.9433e-02],
        [ 1.0567e+00],
        [ 4.2498e-01],
        [ 3.2436e-01],
        [ 2.7496e-01],
        [ 1.5074e-01],
        [ 8.6410e-01],
        [ 3.0293e-01],
        [ 5.2863e-01],
        [ 1.1586e+00],
        [ 6.8384e-01],
        [ 8.1149e-01],
        [ 7.5807e-01],
        [ 8.1780e-01],
        [ 7.3987e-01],
        [ 2.3601e-01],
        [ 6.0317e-01],
        [ 2.1394e-01],
        [ 6.2874e-01],
        [ 2.7720e-01],
        [ 3.5176e-01],
        [ 1.5916e-01],
        [ 8.6406e-01],
        [ 2.4870e-01],
        [ 4.0038e-02],
        [ 4.1144e-01],
        [ 1.0755e+00],
        [ 4.7650e-02],
        [ 1.9686e-01],
        [ 1.4573e-01],
        [ 5.9275e-02],
        [ 6.162

In [67]:
# Training loop

loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    y_pred = model(X_train, w_train.int().squeeze())  # Forward pass
    loss = loss_function(y_pred, y_train)  # Compute loss
    loss.backward()  # Backpropagation
    optimizer.step()  # Update weights

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


Epoch 0, Loss: 2.5690
Epoch 10, Loss: 1.1526
Epoch 20, Loss: 1.1516
Epoch 30, Loss: 1.1349
Epoch 40, Loss: 1.1396
Epoch 50, Loss: 1.1229
Epoch 60, Loss: 1.0771
Epoch 70, Loss: 1.0431
Epoch 80, Loss: 1.0479
Epoch 90, Loss: 1.0320


In [68]:
import numpy as np
model.eval()
with torch.no_grad():
    y0_pred_test = model(X_test, torch.zeros(X_test.shape[0]).int())
    y1_pred_test = model(X_test, torch.ones(X_test.shape[0]).int())

    # Estimate Individual Treatment Effects (ITE)
    tau_hat = (y1_pred_test - y0_pred_test).squeeze().numpy()

    # Compute Mean Absolute Error
    mae = np.mean(np.abs(tau_hat - tau_test.numpy()))
    print(f"Mean Absolute Error in Treatment Effect Estimation: {mae:.4f}")


Mean Absolute Error in Treatment Effect Estimation: 0.2883


# Training transTEE on larger size of dataset leverage torch.dataset

In [69]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from causalml.dataset import synthetic_data
from sklearn.metrics import mean_absolute_error

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load synthetic dataset with larger sample size
y, X, w, tau, b, e = synthetic_data(mode=1, n=50000, p=10, sigma=1.0, adj=0.0)  # Increased dataset size

# Split into train and test sets
X_train, X_test, y_train, y_test, w_train, w_test, tau_train, tau_test = train_test_split(
    X, y, w, tau, test_size=0.2, random_state=42
)

# Normalize features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Convert to PyTorch tensors
X_train, X_test = torch.tensor(X_train, dtype=torch.float32), torch.tensor(X_test, dtype=torch.float32)
y_train, y_test = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1), torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)
w_train, w_test = torch.tensor(w_train, dtype=torch.float32).unsqueeze(1), torch.tensor(w_test, dtype=torch.float32).unsqueeze(1)
tau_train, tau_test = torch.tensor(tau_train, dtype=torch.float32), torch.tensor(tau_test, dtype=torch.float32)

# --- PyTorch Dataset & DataLoader ---
class CausalDataset(Dataset):
    def __init__(self, X, y, w):
        self.X = X
        self.y = y
        self.w = w

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.w[idx]

# Create data loaders for mini-batch training
batch_size = 512
train_dataset = CausalDataset(X_train, y_train, w_train)
test_dataset = CausalDataset(X_test, y_test, w_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)



# Initialize model
input_dim = X_train.shape[1]
model = TransTEE(num_covariates, embedding_dim, num_treatments).to(device)

# Optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=0.001)

# --- Train DragonNet with Mini-batch Training ---
num_epochs = 10
step = 0
for epoch in range(num_epochs):
    model.train()

    for batch_X, batch_y, batch_w in train_loader:
        batch_X, batch_y, batch_w = batch_X.to(device), batch_y.to(device), batch_w.to(device)

        optimizer.zero_grad()
        # Forward pass
        y_pred = model(batch_X, batch_w.int().squeeze())  # Forward pass
        loss = loss_function(y_pred, batch_y)  # Compute loss
    
        loss.backward()
        optimizer.step()
        step += 1

        if step % 50 == 0:
            print(f"Step {step}: Loss = {loss.item():.4f}")

# --- Evaluate transTEE ---
model.eval()
tau_hat = []

with torch.no_grad():
    for batch_X, _, _ in test_loader:
        batch_X = batch_X.to(device)
        y0_pred_test = model(batch_X, torch.zeros(batch_X.shape[0]).int())
        y1_pred_test = model(batch_X, torch.ones(batch_X.shape[0]).int())
        tau_hat.extend((y1_pred_test - y0_pred_test).cpu().numpy())

tau_hat = np.array(tau_hat).flatten()

# --- Compare with True Treatment Effects ---
mae = mean_absolute_error(tau_test, tau_hat)
print(f"\nMean Absolute Error for transTEE: {mae:.4f}")


Using device: cpu
Step 50: Loss = 1.0689
Step 100: Loss = 1.0806
Step 150: Loss = 1.1676
Step 200: Loss = 1.0960
Step 250: Loss = 1.0676
Step 300: Loss = 1.0490
Step 350: Loss = 1.1014
Step 400: Loss = 1.0645
Step 450: Loss = 1.0006
Step 500: Loss = 1.0752
Step 550: Loss = 1.0046
Step 600: Loss = 1.0622
Step 650: Loss = 0.9824
Step 700: Loss = 0.9412
Step 750: Loss = 0.9430

Mean Absolute Error for transTEE: 0.0773
