<a href="https://colab.research.google.com/github/Sylvia232/COMPSCI-675D-Final-Project/blob/main/CS_675_TARNET%2BSynthetic_Data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


Synthetic Data generation

In [34]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, image_shape):
        super().__init__()
        self.image_shape = image_shape
        C, H, W = image_shape
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, C * H * W),
            nn.Tanh()  # map to range[-1, 1]
        )

    def forward(self, z):
        x = self.fc(z)
        return x.view(-1, *self.image_shape)

In [35]:
class Out_func(nn.Module): #define f_t(x, u)
    def __init__(self, d_X, d_U):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_X + d_U, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x, u):
        x_flatten = x.view(x.size(0), -1) if x.dim() > 2 else x
        input = torch.cat([x_flatten, u], dim=1)
        return self.net(input).squeeze(-1)

In [42]:
def synthetic_generation(num_obs=5000, image_shape=(1, 32, 32), dz1=10, dz2=10, sig_x= 0.5, sig_u= 0.5, beta_mag=2):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    C, H, W = image_shape
    d_X = C * H * W
    d_U = dz2

    # Define G_X, G_U the mapping functions from latent dim to image

    # eval mode and no_grad modes are used to prevent gradient issues
    G_X = Decoder(dz1, image_shape).to(device)
    G_X.eval()

    class G_U(nn.Module):
        def __init__(self, dz2, d_U):
            super().__init__()
            self.linear = nn.Linear(dz2, d_U)

        def forward(self, z):
            return F.relu(self.linear(z))

    G_U = G_U(dz2, d_U).to(device)
    G_U.eval()

    # Define outcome functions for each treatment
    f_0 = Out_func(d_X, d_U).to(device)
    f_1 = Out_func(d_X, d_U).to(device)
    f_0.eval()
    f_1.eval()

    # Generate data without tracking gradients
    with torch.no_grad():

        # Define latent variables
        Z_1 = torch.randn(num_obs, dz1).to(device)
        Z_2 = torch.randn(num_obs, dz2).to(device)

        # Find X, U
        X = G_X(Z_1)
        X = X + sig_x * torch.randn_like(X)
        U = G_U(Z_2)
        U = U + sig_u * torch.randn_like(U)

        # Treatment t
        alpha = 0.01 * torch.randn(d_X).to(device)
        beta = beta_mag * torch.randn(d_U).to(device)

        probs = torch.sigmoid(X.view(num_obs, -1) @ alpha + U @ beta)
        treatments = torch.bernoulli(probs).to(device)

        # Outcomes Y
        x_flat = X.view(num_obs, -1)
        Y_0 = f_0(x_flat, U) + 0.01 * torch.randn(num_obs).to(device)
        Y_1 = f_1(x_flat, U) + 0.01 * torch.randn(num_obs).to(device)
        Y = torch.where(treatments == 1, Y_1, Y_0)

    # Detach from the computation graph to prevent gradient issues
    return X.detach(), treatments.detach(), Y.detach(), Y_0.detach(), Y_1.detach(), U.detach()

TARNET Architecture

In [43]:
# CNN Feature extractor for phi(x)
class CNNFeatureExtractor(nn.Module):
    def __init__(self, in_channels, output_dim=256):
        # input shape: (B, C, H, W)
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))  # gives dim (B, C, 1, 1)
        )
        self.fc = nn.Linear(128, output_dim)

    def forward(self, x):
        rep = self.features(x)
        rep = rep.view(rep.size(0), -1)
        return self.fc(rep)


In [44]:
class TARNet(nn.Module):
    def __init__(self, cnn_feat, rep_dim=256, hidden_dim=200):
        super().__init__()

        # Shared representation layers for phi(x)
        self.cnn_feat = cnn_feat

        # Head for predicting Y_0 the outcome under treatment t_0
        self.head_0 = nn.Sequential(
            nn.Linear(rep_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        # Head for predicting Y_1 the outcome under treatment t_1
        self.head_1 = nn.Sequential(
            nn.Linear(rep_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, t=None):
        rep = self.cnn_feat(x)
        y0 = self.head_0(rep).squeeze(-1)
        y1 = self.head_1(rep).squeeze(-1)

        if t is None:
            return y0, y1, rep
        else:
            return torch.where(t == 1, y1, y0)


Training the TARNet with synthetic data

In [45]:
def train_tarnet(num_epochs=10, batch_size=64, lr=1e-3):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on device: {device}")
    # Generate synthetic data
    print("Generating synthetic data...")
    X, treatments, Y, Y_0, Y_1, U = synthetic_generation(num_obs=5000)

    # Create model
    print("Creating model...")
    cnn_feat = CNNFeatureExtractor(in_channels=1, output_dim=256)
    model = TARNet(cnn_feat, rep_dim=256, hidden_dim=200)
    model = model.to(device)
    # Optimizer and loss
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    criterion = nn.MSELoss()

    # Dataset
    dataset = torch.utils.data.TensorDataset(X, treatments, Y)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Training loop
    model.train()
    print("Started training")
    for epoch in range(num_epochs):
        epoch_loss = 0
        for batch_idx, (batch_X, batch_t, batch_Y) in enumerate(dataloader):
            optimizer.zero_grad()
            batch_X = batch_X.to(device)
            batch_t = batch_t.to(device)
            batch_Y = batch_Y.to(device)
            # Forward pass
            predictions = model(batch_X, batch_t)
            loss = criterion(predictions, batch_Y)

            # Backward pass
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()


        avg_loss = epoch_loss / len(dataloader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

    return model

In [46]:
def evaluate_model(model, test_size=1000):
    print("Evaluating model")
    # Generate test data
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on device: {device}")
    X_test, t_test, Y_test, Y_0_test, Y_1_test, U_test = synthetic_generation(num_obs=test_size)

    X_test = X_test.to(device)
    t_test = t_test.to(device)
    Y_test = Y_test.to(device)
    Y_0_test = Y_0_test.to(device)
    Y_1_test = Y_1_test.to(device)

    model.eval()
    with torch.no_grad():
        # Get predictions for both potential outcomes
        y_0_pred, y_1_pred, _ = model(X_test, t=None)

        # Calculate MSE
        pred = model(X_test, t_test)
        mse = F.mse_loss(pred, Y_test).item()

        # Calculate PEHE (Precision in Estimation of Heterogeneous Effects)
        ite_true = Y_1_test - Y_0_test
        ite_pred = y_1_pred - y_0_pred
        pehe = torch.sqrt(F.mse_loss(ite_pred, ite_true)).item()

        print(f"MSE: {mse:.4f}")
        print(f"PEHE: {pehe:.4f}")


    return mse, pehe


In [48]:
# Train the model
trained_model = train_tarnet(num_epochs= 50, batch_size=64, lr=0.01)

Training on device: cuda
Generating synthetic data...
Using device: cuda
Creating model...
Started training
Epoch [1/50], Loss: 0.3774
Epoch [2/50], Loss: 0.0011
Epoch [3/50], Loss: 0.0011
Epoch [4/50], Loss: 0.0011
Epoch [5/50], Loss: 0.0012
Epoch [6/50], Loss: 0.0011
Epoch [7/50], Loss: 0.0010
Epoch [8/50], Loss: 0.0011
Epoch [9/50], Loss: 0.0010
Epoch [10/50], Loss: 0.0010
Epoch [11/50], Loss: 0.0011
Epoch [12/50], Loss: 0.0010
Epoch [13/50], Loss: 0.0011
Epoch [14/50], Loss: 0.0010
Epoch [15/50], Loss: 0.0010
Epoch [16/50], Loss: 0.0011
Epoch [17/50], Loss: 0.0043
Epoch [18/50], Loss: 0.0026
Epoch [19/50], Loss: 0.0012
Epoch [20/50], Loss: 0.0012
Epoch [21/50], Loss: 0.0011
Epoch [22/50], Loss: 0.0019
Epoch [23/50], Loss: 0.0011
Epoch [24/50], Loss: 0.0011
Epoch [25/50], Loss: 0.0010
Epoch [26/50], Loss: 0.0010
Epoch [27/50], Loss: 0.0011
Epoch [28/50], Loss: 0.0010
Epoch [29/50], Loss: 0.0011
Epoch [30/50], Loss: 0.0010
Epoch [31/50], Loss: 0.0010
Epoch [32/50], Loss: 0.0010
Epoch

In [49]:
#evaluate the model
mse_model, pehe = evaluate_model(trained_model)

Evaluating model
Training on device: cuda
Using device: cuda
Factual MSE: 0.0021
PEHE: 0.0733
