# Basic Bathtub Scenario - Data Exploration

Quick look at the generated data from the basic_bathtub scenario.

In [None]:
import sys
from pathlib import Path

# Add project root to path (for running from notebooks/)
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from src.scenarios import (
    BasicBathtubScenario,
    generate_bathtub_data,
    summarise_cohort_costs,
)

## Generate Data

In [None]:
# Create scenario with adjusted parameters for interesting data
# (defaults have aggressive mortality and punishing economics)
scenario = BasicBathtubScenario(
    scale1=100.0,         # Milder infant mortality (default: 10)
    scale2=200.0,         # Slower wear-out (default: 100)
    service_cost=20.0,    # Cheaper service (default: 50)
    revenue_per_time=2.0  # Higher revenue (default: 1)
)

# Generate training data with baseline policy
df, results, costs = generate_bathtub_data(
    n_subjects=500,
    max_time=150.0,
    baseline_a=15.0,   # First service around t=15-25
    baseline_b=10.0,   # Durability coefficient
    scenario=scenario,
    seed=42
)

print(f"Generated {len(results)} subject journeys")
print(f"Total events: {len(df)}")

## Event DataFrame

In [None]:
df.head(20)

In [None]:
# Event counts
df['event'].value_counts()

## Subject Outcomes

In [None]:
# Build per-subject summary
subject_summary = []
for r in results:
    failed = r.terminated and not r.censored  # Terminal event that's not censoring
    service_count = sum(1 for e in r.history if e.event_name == 'service')
    subject_summary.append({
        'subject_id': r.subject.id,
        'durability': r.subject.features['durability'],
        'lifetime': r.final_time,
        'failed': failed,
        'truncated': r.truncated,
        'service_count': service_count,
    })

subject_df = pd.DataFrame(subject_summary)
subject_df.head(10)

In [None]:
# Outcome breakdown
terminated = sum(1 for r in results if r.terminated)
truncated = sum(1 for r in results if r.truncated)

print(f"Terminated (failure): {terminated} ({100*terminated/len(results):.0f}%)")
print(f"Truncated (reached max_time): {truncated} ({100*truncated/len(results):.0f}%)")

## Durability vs Outcomes

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Durability distribution
axes[0].hist(subject_df['durability'], bins=30, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Durability')
axes[0].set_ylabel('Count')
axes[0].set_title('Durability Distribution')

# Durability vs Lifetime (red=failed/terminated, green=truncated/survived)
colors = ['red' if f else 'green' for f in subject_df['failed']]
axes[1].scatter(subject_df['durability'], subject_df['lifetime'], c=colors, alpha=0.5, s=20)
axes[1].set_xlabel('Durability')
axes[1].set_ylabel('Lifetime')
axes[1].set_title('Durability vs Lifetime (red=failed, green=truncated)')

# Durability vs Service Count
axes[2].scatter(subject_df['durability'], subject_df['service_count'], alpha=0.5, s=20)
axes[2].set_xlabel('Durability')
axes[2].set_ylabel('Service Count')
axes[2].set_title('Durability vs Services')

plt.tight_layout()
plt.show()

## Cost Analysis

In [None]:
# Cost summary
summary = summarise_cohort_costs(costs)

for key in ['lifetime', 'net_value', 'service_count', 'revenue']:
    if key in summary:
        s = summary[key]
        print(f"{key:15s}: mean={s['mean']:8.1f}, std={s['std']:7.1f}, min={s['min']:7.1f}, max={s['max']:7.1f}")

In [None]:
# Net value distribution
net_values = [c['net_value'] for c in costs]

plt.figure(figsize=(10, 4))
plt.hist(net_values, bins=40, edgecolor='black', alpha=0.7)
plt.axvline(np.mean(net_values), color='red', linestyle='--', label=f'Mean: {np.mean(net_values):.1f}')
plt.xlabel('Net Value')
plt.ylabel('Count')
plt.title('Distribution of Net Value per Subject')
plt.legend()
plt.show()

## Example Journey

In [None]:
# Pick a subject with some services and a failure
example_id = subject_df[subject_df['failed'] & (subject_df['service_count'] > 2)]['subject_id'].iloc[0]
example_df = df[df['subject_id'] == example_id].copy()

print(f"Subject {example_id}:")
print(example_df[['event', 'time', 'durability']].to_string(index=False))

In [None]:
# Visualise the journey
plt.figure(figsize=(12, 3))

services = example_df[example_df['event'] == 'service']['time']
failures = example_df[example_df['event'] == 'failure']['time']

plt.hlines(1, 0, example_df['time'].max(), colors='gray', linewidth=2)

for t in services:
    plt.axvline(t, color='blue', linestyle='--', alpha=0.7)
    plt.scatter([t], [1], color='blue', s=100, zorder=5, label='Service' if t == services.iloc[0] else '')

for t in failures:
    plt.axvline(t, color='red', linewidth=2)
    plt.scatter([t], [1], color='red', s=150, marker='X', zorder=5, label='Failure')

plt.xlim(0, example_df['time'].max() * 1.05)
plt.ylim(0.5, 1.5)
plt.xlabel('Time')
plt.title(f'Journey for Subject {example_id} (durability={example_df["durability"].iloc[0]:.2f})')
plt.legend()
plt.yticks([])
plt.show()

## Baseline Policy Effect

The baseline policy uses `interval = a + b * durability`. Let's verify this creates data diversity.

In [None]:
# Expected interval vs actual (approximate from first service time)
first_service = df[df['event'] == 'service'].groupby('subject_id')['time'].min()

merged = subject_df.set_index('subject_id').join(first_service.rename('first_service_time'))
merged['expected_interval'] = 15.0 + 10.0 * merged['durability']  # baseline_a=15, baseline_b=10

plt.figure(figsize=(8, 5))
plt.scatter(merged['expected_interval'], merged['first_service_time'], alpha=0.5)
max_val = max(merged['expected_interval'].max(), merged['first_service_time'].max()) + 5
plt.plot([10, max_val], [10, max_val], 'r--', label='Perfect match')
plt.xlabel('Expected Interval (15 + 10*durability)')
plt.ylabel('Actual First Service Time')
plt.title('Baseline Policy Creates Feature-Dependent Intervals')
plt.legend()
plt.show()