# Traffic Splitting Example

This notebook demonstrates how to use the `split_traffic` function from CausalKit for A/B testing and experimentation scenarios.

The `split_traffic` function provides a flexible way to split traffic (users, sessions, etc.) for A/B testing and experimentation:

- Simple random splits with customizable ratios
- Support for multiple variants (A/B/C/...)
- Stratified splitting to maintain balanced distributions of important variables
- Reproducible results with random state control


In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Import the split_traffic function
from causalkit.design.traffic_splitter import split_traffic


## Creating Sample Data

First, let's create a sample dataset representing user traffic. This could be users visiting a website, customers in a store, or any other scenario where you want to split traffic for experimentation.


In [None]:
# Set random seed for reproducibility
np.random.seed(42)
n_users = 1000

# Generate synthetic user data
user_data = {
    'user_id': range(1, n_users + 1),
    'age_group': np.random.choice(['18-24', '25-34', '35-44', '45+'], size=n_users),
    'country': np.random.choice(['US', 'UK', 'CA', 'AU', 'DE'], size=n_users, 
                               p=[0.4, 0.2, 0.2, 0.1, 0.1]),
    'device': np.random.choice(['mobile', 'desktop', 'tablet'], size=n_users,
                              p=[0.6, 0.3, 0.1]),
    'past_purchases': np.random.poisson(2, size=n_users)
}

df = pd.DataFrame(user_data)

# Display the first few rows of the dataset
print(f"Total users: {len(df)}")
df.head()


Let's examine the distribution of key variables in our dataset:


In [None]:
# Create a subplot with 3 columns
fig = make_subplots(rows=1, cols=3, subplot_titles=['Country Distribution', 'Device Distribution', 'Age Group Distribution'])

# Plot country distribution
country_counts = df['country'].value_counts(normalize=True).reset_index()
country_counts.columns = ['country', 'proportion']
fig.add_trace(
    go.Bar(x=country_counts['country'], y=country_counts['proportion'], name='Country'),
    row=1, col=1
)

# Plot device distribution
device_counts = df['device'].value_counts(normalize=True).reset_index()
device_counts.columns = ['device', 'proportion']
fig.add_trace(
    go.Bar(x=device_counts['device'], y=device_counts['proportion'], name='Device'),
    row=1, col=2
)

# Plot age group distribution
age_counts = df['age_group'].value_counts(normalize=True).reset_index()
age_counts.columns = ['age_group', 'proportion']
fig.add_trace(
    go.Bar(x=age_counts['age_group'], y=age_counts['proportion'], name='Age Group'),
    row=1, col=3
)

# Update layout
fig.update_layout(
    height=500, 
    width=1000,
    showlegend=False,
    yaxis_title='Proportion',
    yaxis2_title='Proportion',
    yaxis3_title='Proportion'
)

fig.show()


## Example 1: Simple Random Split (50/50)

The most basic use case is to split traffic into two groups: control and treatment, with an equal 50/50 split.


In [None]:
# Split the data into control and treatment groups
control_df, treatment_df = split_traffic(df, random_state=123)

print(f"Control group size: {len(control_df)}")
print(f"Treatment group size: {len(treatment_df)}")

# Verify that all users are accounted for
print(f"Total users after split: {len(control_df) + len(treatment_df)}")
print(f"Original total users: {len(df)}")


Let's visualize the distribution of key variables in both groups to ensure they're balanced:


In [None]:
def compare_distributions(control_df, treatment_df, column):
    """Helper function to compare distributions between control and treatment groups."""
    control_counts = control_df[column].value_counts(normalize=True)
    treatment_counts = treatment_df[column].value_counts(normalize=True)

    # Combine into a DataFrame for easier plotting
    comparison_df = pd.DataFrame({
        'Control': control_counts,
        'Treatment': treatment_counts
    }).reset_index().rename(columns={'index': column})

    # Reshape for plotly
    comparison_df_melted = pd.melt(
        comparison_df, 
        id_vars=[column], 
        value_vars=['Control', 'Treatment'],
        var_name='Group', 
        value_name='Proportion'
    )

    # Create a grouped bar chart with plotly
    fig = px.bar(
        comparison_df_melted, 
        x=column, 
        y='Proportion', 
        color='Group',
        barmode='group',
        title=f'{column} Distribution: Control vs Treatment',
        labels={'Proportion': 'Proportion', column: column},
        height=500
    )

    # Rotate x-axis labels if needed
    if column == 'country':
        fig.update_layout(xaxis_tickangle=-45)

    fig.show()

# Compare distributions for key variables
for column in ['country', 'device', 'age_group']:
    compare_distributions(control_df, treatment_df, column)


## Example 2: Uneven Split (80/20)

Sometimes you might want to allocate more traffic to one group than the other. For example, you might want to expose only 20% of your users to a new feature.


In [None]:
# Split with 80% in control group and 20% in treatment group
control_df, treatment_df = split_traffic(df, split_ratio=0.8, random_state=123)

print(f"Control group size: {len(control_df)}")
print(f"Treatment group size: {len(treatment_df)}")
#
# Visualize the split with Plotly
fig = px.pie(
    values=[len(control_df), len(treatment_df)],
    names=['Control (80%)', 'Treatment (20%)'],
    title='80/20 Traffic Split',
    color_discrete_sequence=['#66b3ff', '#ff9999']
)

# Update layout for better appearance
fig.update_traces(textinfo='percent+label', hole=0.3)
fig.update_layout(height=500, width=700)

fig.show()


## Example 3: Multiple Variants (A/B/C Test)

You can also split traffic into more than two groups, which is useful for testing multiple variants.


In [None]:
# Split into three groups: control (40%), variant B (30%), variant C (30%)
control_df, variant_b_df, variant_c_df = split_traffic(
    df, split_ratio=[0.4, 0.3], random_state=123
)

print(f"Control group size: {len(control_df)}")
print(f"Variant B group size: {len(variant_b_df)}")
print(f"Variant C group size: {len(variant_c_df)}")

# Visualize the split with Plotly
fig = px.pie(
    values=[len(control_df), len(variant_b_df), len(variant_c_df)],
    names=['Control (40%)', 'Variant B (30%)', 'Variant C (30%)'],
    title='Multiple Variants Split',
    color_discrete_sequence=['#66b3ff', '#ff9999', '#99ff99']
)

# Update layout for better appearance
fig.update_traces(textinfo='percent+label', hole=0.3)
fig.update_layout(height=500, width=700)

fig.show()


## Example 4: Stratified Split by Country

When certain variables are important for your analysis, you might want to ensure that they have the same distribution in all groups. This is where stratified splitting comes in.


In [None]:
# Stratified split by country
control_df, treatment_df = split_traffic(
    df, split_ratio=0.5, stratify_column='country', random_state=123
)

# Compare country distributions
compare_distributions(control_df, treatment_df, 'country')

# Let's also check if other variables remain roughly balanced
compare_distributions(control_df, treatment_df, 'device')
compare_distributions(control_df, treatment_df, 'age_group')


## Example 5: Stratified Split with Multiple Variables

Sometimes you might want to stratify by multiple variables. One approach is to create a combined column that represents the combination of those variables.


In [None]:
# Create a combined column for country and device
df['strat_combined'] = df['country'] + '_' + df['device']

# Stratified split by the combined column
control_df, treatment_df = split_traffic(
    df, split_ratio=0.5, stratify_column='strat_combined', random_state=123
)

# Compare distributions
compare_distributions(control_df, treatment_df, 'country')
compare_distributions(control_df, treatment_df, 'device')

# Let's look at some specific combinations
print("\nControl group - Country-Device combinations:")
print(control_df['strat_combined'].value_counts(normalize=True).head(5))

print("\nTreatment group - Country-Device combinations:")
print(treatment_df['strat_combined'].value_counts(normalize=True).head(5))
