In [42]:
!pip install duckdb

Defaulting to user installation because normal site-packages is not writeable


In [43]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import duckdb

In [44]:
con = duckdb.connect()
con.execute("SET s3_region='us-east-1'")
S3_BASE = "s3://jax-envision-public-data/study_1001/2025v3.3/tabular"

In [45]:
DOSE_MAPPING_REP1 = {
    4917: '5 mg/kg',  4918: 'Vehicle',  4919: '25 mg/kg',
    4920: '25 mg/kg', 4921: '5 mg/kg',  4922: 'Vehicle',
    4923: 'Vehicle',  4924: '25 mg/kg', 4925: '5 mg/kg'
}
DOSE_MAPPING_REP2 = {
    4926: '25 mg/kg', 4927: '5 mg/kg',  4928: 'Vehicle',
    4929: 'Vehicle',  4930: '25 mg/kg', 4931: '5 mg/kg',
    4932: '5 mg/kg',  4933: '25 mg/kg', 4934: 'Vehicle'
}


In [46]:
# Injection events - extended dates to capture full analysis window
INJECTION_EVENTS = [
    {
        'name': 'Replicate 1, Dose 1',
        'short_name': 'Rep1_Dose1',
        'injection_time_utc': pd.Timestamp('2025-01-14 11:00:00'),
        'dates_to_load': ['2025-01-13', '2025-01-14', '2025-01-15'],
        'cages': list(DOSE_MAPPING_REP1.keys()),
        'dose_mapping': DOSE_MAPPING_REP1,
    },
    {
        'name': 'Replicate 1, Dose 2',
        'short_name': 'Rep1_Dose2',
        'injection_time_utc': pd.Timestamp('2025-01-17 22:00:00'),
        'dates_to_load': ['2025-01-17', '2025-01-18', '2025-01-19'],
        'cages': list(DOSE_MAPPING_REP1.keys()),
        'dose_mapping': DOSE_MAPPING_REP1,
    },
    {
        'name': 'Replicate 2, Dose 1',
        'short_name': 'Rep2_Dose1',
        'injection_time_utc': pd.Timestamp('2025-01-28 22:00:00'),
        'dates_to_load': ['2025-01-28', '2025-01-29', '2025-01-30'],
        'cages': list(DOSE_MAPPING_REP2.keys()),
        'dose_mapping': DOSE_MAPPING_REP2,
    },
    {
        'name': 'Replicate 2, Dose 2',
        'short_name': 'Rep2_Dose2',
        'injection_time_utc': pd.Timestamp('2025-01-31 11:00:00'),
        'dates_to_load': ['2025-01-30', '2025-01-31', '2025-02-01'],
        'cages': list(DOSE_MAPPING_REP2.keys()),
        'dose_mapping': DOSE_MAPPING_REP2,
    }
]

In [47]:
TIME_WINDOWS = {
    'baseline': (-180, -60),
    'immediate': (0, 30),
    'peak_early': (30, 90),
    'peak_sustained': (90, 180),
    'decline_early': (180, 300),
    'decline_late': (300, 420),
    'post_6hr': (360, 540),
    'post_12hr': (720, 900),
    'next_day': (1380, 1560)
}

In [48]:
WINDOW_ORDER = ['immediate', 'peak_early', 'peak_sustained', 'decline_early', 
                'decline_late', 'post_6hr', 'post_12hr', 'next_day']


In [49]:
def load_cage_data(cage_id, date_str, dose_mapping, file_type='animal_tsdb_mvp'):
    """Load social/distance features from animal_tsdb_mvp."""
    path = f"{S3_BASE}/cage_id={cage_id}/date={date_str}/{file_type}.parquet"
    
    # Social and distance features only
    query = f"""
    SELECT * FROM read_parquet('{path}')
    WHERE resolution = 60
    AND (name LIKE '%social%' OR name LIKE '%distance%')
    """
    
    try:
        df = con.execute(query).fetchdf()
        df['cage_id'] = cage_id
        df['dose_group'] = dose_mapping[cage_id]
        return df
    except Exception as e:
        print(f"    ✗ Cage {cage_id}, Date {date_str}: {e}")
        return pd.DataFrame()

In [50]:
def load_injection_event_data(event, resolution=60):
    """Load all behavioral features for a single injection event."""
    print(f"\n{'='*60}")
    print(f"Loading: {event['name']}")
    print(f"  Injection time (UTC): {event['injection_time_utc']}")
    print(f"  Dates to load: {event['dates_to_load']}")
    print(f"{'='*60}")
    
    dfs = []
    
    for date_str in event['dates_to_load']:
        print(f"\n  Loading date: {date_str}")
        for cage_id in event['cages']:
            df = load_cage_data(cage_id, date_str, event['dose_mapping'])
            if not df.empty:
                print(f"    ✓ Cage {cage_id} ({event['dose_mapping'][cage_id]}): {len(df):,} rows")
                dfs.append(df)
    
    if not dfs:
        print("  ⚠ No data loaded!")
        return pd.DataFrame()
    
    df_all = pd.concat(dfs, ignore_index=True)
    print(f"\n  Total rows loaded: {len(df_all):,}")
    
    # Filter to desired resolution
    df_filtered = df_all[df_all['resolution'] == resolution].copy()
    
    # Convert time to datetime
    df_filtered['time'] = pd.to_datetime(df_filtered['time'])
    
    # Compute minutes from injection
    df_filtered['minutes_from_injection'] = (
        df_filtered['time'] - event['injection_time_utc']
    ).dt.total_seconds() / 60
    
    # Filter to analysis window
    df_filtered = df_filtered[
        (df_filtered['minutes_from_injection'] >= -180) &
        (df_filtered['minutes_from_injection'] <= 1560)
    ]
    
    df_filtered['event'] = event['short_name']
    
    print(f"  Filtered data: {len(df_filtered):,} rows")
    print(f"  Features: {sorted(df_filtered['name'].unique())}")
    
    return df_filtered


In [51]:
def load_all_events(resolution=60):
    """Load data from all injection events."""
    all_data = []
    
    for event in INJECTION_EVENTS:
        df = load_injection_event_data(event, resolution=resolution)
        if not df.empty:
            all_data.append(df)
    
    if not all_data:
        return pd.DataFrame()
    
    df_combined = pd.concat(all_data, ignore_index=True)
    
    print(f"\n{'='*60}")
    print(f"TOTAL: {len(df_combined):,} rows, {df_combined['name'].nunique()} features, {df_combined['animal_id'].nunique()} animals")
    print(f"{'='*60}")
    
    return df_combined

## Percent Change from Baseline Calculation

### Goal
For each animal, compare their behavior *after* injection to their own *baseline* behavior before injection, expressed as a percent change.

### Step-by-Step Process

**Step 1: Filter to one feature**

We isolate just one behavioral metric at a time (e.g., `animal_bouts.locomotion`). This gets repeated for each feature in the dataset.

**Step 2: Define baseline window**

The baseline window is -180 to -60 minutes before injection (2 hours of pre-injection behavior). We skip -60 to 0 minutes because handling the animal for injection might affect behavior.

**Step 3: Loop through each animal × event combination**

Each animal gets their own baseline. This is critical—we're comparing each mouse to itself, not to group averages. This controls for individual differences in baseline activity levels.

**Step 4: Calculate that animal's baseline mean**

For this specific animal during this specific injection event, what was their average value for the feature during the 2-hour baseline window?

**Step 5: Skip animals with zero or missing baseline**

If an animal has no data or zero baseline activity, we can't compute percent change (division by zero), so we skip them.

**Step 6: Calculate percent change for each post-injection window**

For each time window (immediate, peak_early, peak_sustained, etc.), we calculate how much this animal's behavior changed compared to their own baseline using the formula:

$$\text{Percent Change} = \frac{\text{window mean} - \text{baseline mean}}{\text{baseline mean}} \times 100$$

**Example:** If baseline locomotion was 0.10 and post-injection was 0.25:
$$\frac{0.25 - 0.10}{0.10} \times 100 = 150\% \text{ increase}$$

**Step 7: Store and aggregate results**

Results are stored per animal, then averaged across all animals within each dose group for visualization in the heatmaps.

### Why This Approach Matters

By normalizing each animal to their own baseline, we control for individual differences. A naturally hyperactive mouse and a naturally sedentary mouse can both show "50% increase" even though their absolute activity levels differ. This makes comparisons across animals and dose groups statistically valid.

In [52]:
def compute_percent_change_single_feature(df, feature_name):
    """
    Compute percent change from baseline for a single feature.
    Returns DataFrame with: animal_id, dose_group, event, window, pct_change
    """
    df_feat = df[df['name'] == feature_name].copy()
    
    if df_feat.empty:
        return pd.DataFrame()
    
    baseline_start, baseline_end = TIME_WINDOWS['baseline']
    results = []
    
    # Group by animal, dose, and event
    for (animal_id, dose_group, event), animal_df in df_feat.groupby(['animal_id', 'dose_group', 'event']):
        
        # Calculate baseline mean
        baseline_data = animal_df[
            (animal_df['minutes_from_injection'] >= baseline_start) &
            (animal_df['minutes_from_injection'] < baseline_end)
        ]['value']
        
        if baseline_data.empty or baseline_data.mean() == 0:
            continue
        
        baseline_mean = baseline_data.mean()
        
        # Calculate percent change for each post-injection window
        for window_name, (win_start, win_end) in TIME_WINDOWS.items():
            if window_name == 'baseline':
                continue
            
            window_data = animal_df[
                (animal_df['minutes_from_injection'] >= win_start) &
                (animal_df['minutes_from_injection'] < win_end)
            ]['value']
            
            if window_data.empty:
                continue
            
            window_mean = window_data.mean()
            pct_change = ((window_mean - baseline_mean) / baseline_mean) * 100
            
            results.append({
                'animal_id': animal_id,
                'dose_group': dose_group,
                'event': event,
                'feature': feature_name,
                'window': window_name,
                'baseline_mean': baseline_mean,
                'window_mean': window_mean,
                'pct_change': pct_change
            })
    
    return pd.DataFrame(results)

In [53]:
def compute_percent_change_all_features(df):
    """Compute percent change for all features in the dataset."""
    features = df['name'].unique()
    all_results = []
    
    print(f"\nComputing percent change for {len(features)} features...")
    
    for i, feature in enumerate(features):
        print(f"  [{i+1}/{len(features)}] {feature}")
        result = compute_percent_change_single_feature(df, feature)
        if not result.empty:
            all_results.append(result)
    
    if not all_results:
        return pd.DataFrame()
    
    results_df = pd.concat(all_results, ignore_index=True)
    print(f"\nResults: {len(results_df):,} rows")
    
    return results_df

## Visualization: Creating Heatmaps

### Goal
Visualize the percent change from baseline for all features across all time windows, with one heatmap per dose group (Vehicle, 5 mg/kg, 25 mg/kg).

### Step-by-Step Process

**Step 1: Filter to one dose group**
```python
dose_df = results_df[results_df['dose_group'] == dose_group]
```

We create separate heatmaps for each dose group so we can visually compare how different doses affect behavior.

**Step 2: Average across animals and events**
```python
pivot_df = dose_df.groupby(['feature', 'window'])['pct_change'].mean().reset_index()
```

For each feature × time window combination, we calculate the mean percent change across all animals in that dose group. This gives us a single value per cell in the heatmap.

**Step 3: Pivot to matrix format**
```python
heatmap_data = pivot_df.pivot(index='feature', columns='window', values='pct_change')
```

We reshape the data so that:
- **Rows** = behavioral features (locomotion, drinking, inactive, etc.)
- **Columns** = time windows (immediate, peak_early, peak_sustained, etc.)
- **Cell values** = mean percent change from baseline

**Step 4: Reorder columns chronologically**
```python
heatmap_data = heatmap_data[[w for w in WINDOW_ORDER if w in heatmap_data.columns]]
```

Time windows are ordered from earliest (immediate: 0-30 min) to latest (next_day: 1380-1560 min) so the temporal progression reads left-to-right.

**Step 5: Clean feature names for readability**
```python
heatmap_data.index = (heatmap_data.index
                      .str.replace('animal_bouts.', '')
                      .str.replace('distance_travelled.animal.cm_s', 'distance')
                      ...)
```

We remove verbose prefixes like `animal_bouts.` so the y-axis labels are cleaner (e.g., "locomotion" instead of "animal_bouts.locomotion").

**Step 6: Create the heatmap with seaborn**
```python
sns.heatmap(heatmap_data, ax=ax, cmap='RdYlGn_r', center=0,
            vmin=vmin, vmax=vmax, annot=True, fmt='.1f',
            cbar_kws={'label': '% Change from Baseline'})
```

Key parameters:
- `cmap='RdYlGn_r'`: Red-Yellow-Green colormap (reversed), so red = increase, green = decrease
- `center=0`: Zero percent change is the neutral color (yellow/white)
- `vmin, vmax`: Color scale bounds (wider for 25 mg/kg due to larger effects)
- `annot=True, fmt='.1f'`: Display numeric values in each cell with 1 decimal place

**Step 7: Create all three heatmaps side-by-side**
```python
fig, axes = plt.subplots(1, 3, figsize=(20, 10))
doses = ['Vehicle', '5 mg/kg', '25 mg/kg']
vmaxes = [150, 150, 400]  # Wider scale for 25 mg/kg
```

We use different color scales because 25 mg/kg produces much larger effects (up to 400-500% change) than Vehicle or 5 mg/kg. Using the same scale would wash out the smaller effects in the lower dose groups.

### How to Interpret the Heatmaps

| Color | Meaning |
|-------|---------|
| Dark Red | Large increase from baseline |
| Light Red/Orange | Moderate increase |
| Yellow/White | No change from baseline |
| Light Green | Moderate decrease |
| Dark Green | Large decrease from baseline |

### What to Look For

1. **Dose-response**: Does the effect get stronger from Vehicle → 5 mg/kg → 25 mg/kg?
2. **Time-dependence**: Does the effect peak at a certain window and then decline?
3. **Feature specificity**: Which features show the strongest response?
4. **Biphasic patterns**: Does the direction of effect reverse over time (e.g., early suppression followed by late activation)?

In [54]:
def create_heatmap(results_df, dose_group, ax=None, vmin=-100, vmax=100):
    """Create heatmap of percent change for a single dose group."""
    dose_df = results_df[results_df['dose_group'] == dose_group]
    
    # Average across animals and events
    pivot_df = dose_df.groupby(['feature', 'window'])['pct_change'].mean().reset_index()
    heatmap_data = pivot_df.pivot(index='feature', columns='window', values='pct_change')
    
    # Reorder columns
    heatmap_data = heatmap_data[[w for w in WINDOW_ORDER if w in heatmap_data.columns]]
    
    # Clean feature names for display
    heatmap_data.index = (heatmap_data.index
                          .str.replace('animal_bouts.', '')
                          .str.replace('distance_travelled.animal.cm_s', 'distance')
                          .str.replace('animal.respiration_rate_lucas_kanade_psd', 'respiration')
                          .str.replace('.animal.cm', '')
                          .str.replace('social.', 'social_'))
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 8))
    
    sns.heatmap(heatmap_data, ax=ax, cmap='RdYlGn_r', center=0,
                vmin=vmin, vmax=vmax, annot=True, fmt='.1f',
                cbar_kws={'label': '% Change from Baseline'})
    
    ax.set_title(f'{dose_group} - Temporal Response Profile', fontsize=14, fontweight='bold')
    ax.set_xlabel('Time Window After Injection')
    ax.set_ylabel('Behavioral Feature')
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    
    return heatmap_data

In [55]:
def create_all_heatmaps(results_df, save_path=None):
    """Create heatmaps for all three dose groups."""
    fig, axes = plt.subplots(1, 3, figsize=(20, 10))
    
    doses = ['Vehicle', '5 mg/kg', '25 mg/kg']
    vmaxes = [150, 150, 400]  # Wider scale for 25 mg/kg
    
    for ax, dose, vmax in zip(axes, doses, vmaxes):
        create_heatmap(results_df, dose, ax=ax, vmin=-vmax, vmax=vmax)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"\nHeatmaps saved to {save_path}")
    
    plt.show()
    return fig

## Analysis

In [56]:
event = INJECTION_EVENTS[0]  # Rep1_Dose1

In [None]:
df = load_injection_event_data(event, resolution=60)



Loading: Replicate 1, Dose 1
  Injection time (UTC): 2025-01-14 11:00:00
  Dates to load: ['2025-01-13', '2025-01-14', '2025-01-15']

  Loading date: 2025-01-13


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

    ✓ Cage 4917 (5 mg/kg): 39,675 rows


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

    ✓ Cage 4918 (Vehicle): 39,083 rows


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

    ✓ Cage 4919 (25 mg/kg): 39,508 rows


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

    ✓ Cage 4920 (25 mg/kg): 39,277 rows


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

    ✓ Cage 4921 (5 mg/kg): 39,406 rows


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

    ✓ Cage 4922 (Vehicle): 38,395 rows


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

    ✓ Cage 4923 (Vehicle): 39,182 rows


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

In [None]:
print(f"Shape: {df.shape}")
print(f"\nFeatures available:")
print(df['name'].value_counts())
print(f"\nDose groups: {df['dose_group'].unique()}")
print(f"Events: {df['event'].unique()}")

In [None]:
df.head()

In [None]:
results = compute_percent_change_all_features(df)


In [None]:
print(f"Results shape: {results.shape}")
print(f"\nSample of results:")
results.head(10)

In [None]:
create_all_heatmaps(results)


In [None]:
results.to_csv('percent_change_results.csv', index=False)
print("Results saved to percent_change_results.csv")