# CausalKit Data Module Examples

This notebook demonstrates how to use the functions and classes in the `causalkit.data` module. The data module provides tools for:

1. **Generating synthetic data** for causal inference tasks
2. **Managing causal data** with the `causaldata` class

We'll explore both of these capabilities with practical examples.


## 1. Data Generation Functions

The `causalkit.data` module provides functions to generate synthetic data for causal inference tasks. These functions are useful for:

- Testing causal inference methods
- Demonstrating causal inference concepts
- Benchmarking different approaches

Let's explore the available data generation functions.


In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set plotting style
sns.set_style('whitegrid')

# Import data generation functions
from causalkit.data import generate_rct_data, generate_obs_data


### 1.1 Generating Randomized Controlled Trial (RCT) Data

The `generate_rct_data` function creates synthetic data that mimics a randomized controlled trial. In an RCT, treatment assignment is random and independent of any covariates.


In [None]:
# Generate RCT data with default parameters
rct_df = generate_rct_data(random_state=42)

# Display the first few rows
print(f"Generated RCT data with {len(rct_df)} rows and {len(rct_df.columns)} columns")
rct_df.head()


Let's examine the structure of the generated data:


In [None]:
# Display summary statistics
rct_df.describe()


In [None]:
# Check the distribution of treatment assignment
treatment_counts = rct_df['treatment'].value_counts(normalize=True)
print("Treatment distribution:")
print(treatment_counts)

# Visualize the treatment distribution
plt.figure(figsize=(8, 5))
sns.countplot(x='treatment', data=rct_df)
plt.title('Distribution of Treatment Assignment')
plt.xlabel('Treatment')
plt.ylabel('Count')
plt.show()


Let's also look at the relationship between treatment and outcome:


In [None]:
# Calculate average outcome by treatment group
avg_outcome = rct_df.groupby('treatment')['target'].mean()
print("Average outcome by treatment group:")
print(avg_outcome)

# Visualize the relationship
plt.figure(figsize=(8, 5))
sns.barplot(x='treatment', y='target', data=rct_df)
plt.title('Average Outcome by Treatment Group')
plt.xlabel('Treatment')
plt.ylabel('Average Outcome')
plt.show()


### 1.2 Customizing RCT Data Generation

The `generate_rct_data` function allows customization of various parameters:


In [None]:
# Generate RCT data with custom parameters
custom_rct_df = generate_rct_data(
    n_users=10000,           # Number of users
    split=0.7,               # 70% in control, 30% in treatment
    target_type="continuous", # Continuous outcome variable
    random_state=42          # For reproducibility
)

# Display the first few rows
print(f"Generated custom RCT data with {len(custom_rct_df)} rows")
custom_rct_df.head()


In [None]:
# Check the distribution of treatment assignment
custom_treatment_counts = custom_rct_df['treatment'].value_counts(normalize=True)
print("Treatment distribution:")
print(custom_treatment_counts)

# Visualize the treatment distribution
plt.figure(figsize=(8, 5))
sns.countplot(x='treatment', data=custom_rct_df)
plt.title('Distribution of Treatment Assignment (Custom Parameters)')
plt.xlabel('Treatment')
plt.ylabel('Count')
plt.show()


### 1.3 Generating Observational Data

The `generate_obs_data` function creates synthetic data that mimics observational studies. In observational data, treatment assignment is not random and may depend on covariates.


In [None]:
# Generate observational data
obs_df = generate_obs_data(random_state=42)

# Display the first few rows
print(f"Generated observational data with {len(obs_df)} rows and {len(obs_df.columns)} columns")
obs_df.head()


In [None]:
# Check the distribution of treatment assignment
obs_treatment_counts = obs_df['treatment'].value_counts(normalize=True)
print("Treatment distribution:")
print(obs_treatment_counts)

# Visualize the treatment distribution
plt.figure(figsize=(8, 5))
sns.countplot(x='treatment', data=obs_df)
plt.title('Distribution of Treatment Assignment (Observational Data)')
plt.xlabel('Treatment')
plt.ylabel('Count')
plt.show()


In observational data, treatment assignment often depends on covariates. Let's examine this relationship:


In [None]:
# Examine relationship between age and treatment
plt.figure(figsize=(10, 6))
sns.boxplot(x='treatment', y='age', data=obs_df)
plt.title('Relationship Between Age and Treatment Assignment')
plt.xlabel('Treatment')
plt.ylabel('Age')
plt.show()


## 2. The causaldata Class

The `causaldata` class provides a convenient way to manage data for causal inference tasks. It wraps a pandas DataFrame and stores metadata about columns for causal inference analysis.


In [None]:
# Import the causaldata class
from causalkit.data import causaldata


### 2.1 Creating a causaldata Object

Let's create a causaldata object using the RCT data we generated earlier:


In [None]:
# Create a causaldata object
ck_data = causaldata(
    df=rct_df,
    target='target',
    cofounders=['age', 'invited_friend'],
    treatment='treatment'
)

# Display the object
print(ck_data)


### 2.2 Accessing Data by Role

The causaldata class provides properties to access columns by their role in causal inference:


In [None]:
# Access the target variable
print("Target variable:")
print(ck_data.target.head())

# Access the treatment variable
print("\nTreatment variable:")
print(ck_data.treatment.head())

# Access the cofounders
print("\nCofounders:")
print(ck_data.cofounders.head())


### 2.3 Using the get_df Method

The `get_df` method allows flexible retrieval of data from the causaldata object:


In [None]:
# Get the entire DataFrame
full_df = ck_data.get_df()
print("Full DataFrame shape:", full_df.shape)
print("Full DataFrame columns:", list(full_df.columns))
print("\nFirst few rows:")
full_df.head()


In [None]:
# Get specific columns
specific_cols_df = ck_data.get_df(columns=['user_id', 'gender'])
print("Specific columns DataFrame shape:", specific_cols_df.shape)
print("Specific columns DataFrame columns:", list(specific_cols_df.columns))
print("\nFirst few rows:")
specific_cols_df.head()


In [None]:
# Get target and treatment columns
target_treatment_df = ck_data.get_df(include_target=True, include_treatment=True)
print("Target and treatment DataFrame shape:", target_treatment_df.shape)
print("Target and treatment DataFrame columns:", list(target_treatment_df.columns))
print("\nFirst few rows:")
target_treatment_df.head()


In [None]:
# Get cofounders and specific columns
mixed_df = ck_data.get_df(columns=['user_id'], include_cofounders=True)
print("Mixed DataFrame shape:", mixed_df.shape)
print("Mixed DataFrame columns:", list(mixed_df.columns))
print("\nFirst few rows:")
mixed_df.head()


### 2.4 Error Handling

The `get_df` method includes error handling for non-existent columns:


In [None]:
# Try to get a non-existent column
try:
    error_df = ck_data.get_df(columns=['non_existent_column'])
except ValueError as e:
    print(f"Error: {e}")


## 3. Practical Example: Analyzing Treatment Effects

Let's put everything together in a practical example where we analyze treatment effects using the causaldata class:


In [None]:
# Generate new RCT data
analysis_df = generate_rct_data(n_users=5000, random_state=123)

# Create a causaldata object
analysis_data = causaldata(
    df=analysis_df,
    target='target',
    cofounders=['age', 'invited_friend', 'gender'],
    treatment='treatment'
)

# Get the data we need for analysis
analysis_subset = analysis_data.get_df(include_target=True, include_treatment=True, include_cofounders=True)
print("Analysis subset shape:", analysis_subset.shape)
print("Analysis subset columns:", list(analysis_subset.columns))
analysis_subset.head()


In [None]:
# Calculate average treatment effect
treatment_effect = analysis_subset.groupby('treatment')['target'].mean()
ate = treatment_effect[1] - treatment_effect[0]

print("Average outcome by treatment group:")
print(treatment_effect)
print(f"\nAverage Treatment Effect (ATE): {ate:.4f}")

# Visualize the treatment effect
plt.figure(figsize=(10, 6))
sns.barplot(x='treatment', y='target', data=analysis_subset)
plt.title('Average Outcome by Treatment Group')
plt.xlabel('Treatment')
plt.ylabel('Average Outcome')
plt.annotate(f"ATE = {ate:.4f}", xy=(0.5, max(treatment_effect) - 0.05), 
             xytext=(0.5, max(treatment_effect) + 0.05),
             ha='center', va='center',
             arrowprops=dict(arrowstyle='->', lw=1.5))
plt.show()


### 3.1 Heterogeneous Treatment Effects

Let's examine if the treatment effect varies across different subgroups:


In [None]:
# Calculate treatment effect by gender
gender_effects = analysis_subset.groupby(['gender', 'treatment'])['target'].mean().unstack()
gender_ate = gender_effects[1] - gender_effects[0]

print("Average outcome by gender and treatment group:")
print(gender_effects)
print("\nAverage Treatment Effect by gender:")
print(gender_ate)

# Visualize heterogeneous treatment effects
plt.figure(figsize=(12, 6))
sns.barplot(x='gender', y='target', hue='treatment', data=analysis_subset)
plt.title('Heterogeneous Treatment Effects by Gender')
plt.xlabel('Gender')
plt.ylabel('Average Outcome')
plt.legend(title='Treatment')
plt.show()


## 4. Conclusion

In this notebook, we've explored the `causalkit.data` module, which provides tools for generating synthetic data and managing causal data. We've seen how to:

1. Generate synthetic data for randomized controlled trials and observational studies
2. Create and use causaldata objects to manage data for causal inference
3. Access data by role (target, treatment, cofounders)
4. Use the get_df method to flexibly retrieve data
5. Analyze treatment effects using the tools provided

These tools make it easier to work with causal inference data and focus on the analysis rather than data management.