In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from osdr_validation.model_inference import cell_sampler, tissue_regression, parameter_checker
from osdr_validation.visualization import plot_inferred_portrait

## Load Simulation Data

In [None]:
# Load post-proliferation dataset from notebook 2
post_df = pd.read_csv('../data/simulated_tissues_post_alt.csv')

print(f"Loaded {len(post_df)} cells")
print(f"Time steps available: {sorted(post_df['Time_Step'].unique())}")

## Perform Model Inference

Sample cells and fit logistic regression models at t=1000 (final time point).

In [None]:
# Perform inference on data at t=1000 with seed=0
pplus_f, pplus_m, pminus_f, pminus_m = tissue_regression(post_df, t=1000, seed=0)

# Package parameters for plotting
params = {"pplus_f": pplus_f, "pplus_m": pplus_m, "pminus_f": pminus_f, "pminus_m": pminus_m}

print("\nInferred parameters:")
print(f"Sample sizes: 1k, 5k, 10k, 25k cells")
for i in range(4):
    sample_size = [1, 5, 10, 25][i]
    print(f"\n{sample_size}k sample:")
    print(f"  F: intercept={params['pplus_f'][i][0]:.4f}, coef={params['pplus_f'][i][1]:.6f}")
    print(f"  M: intercept={params['pplus_m'][i][0]:.4f}, coef={params['pplus_m'][i][1]:.6f}")

## Phase Portraits of Inferred Models

Plot phase portraits for each sample size (1k, 5k, 10k, 25k cells).

In [None]:
# Plot phase portraits for all 4 sample sizes
for i in range(4):
    sample_size = [1, 5, 10, 25][i]
    print(f"\n{'='*50}")
    print(f"Phase Portrait: {sample_size}k cell sample")
    print(f"{'='*50}")
    plot_inferred_portrait(params, t_id=i)

## Inference Across Multiple Time Steps

Examine how inference quality changes when using data from different time points (earlier vs. closer to steady state).

In [None]:
# Perform inference at multiple time steps
time_steps = [100, 200, 500, 700, 1000]

print("Inference across time steps:")
print("="*60)

for t in time_steps:
    if t not in post_df["Time_Step"].unique():
        continue
        
    print(f"\n{'='*60}")
    print(f"Time step: t={t}")
    print(f"{'='*60}")
    
    pplus_f_t, pplus_m_t, pminus_f_t, pminus_m_t = tissue_regression(post_df, t=t, seed=0)
    params_t = {"pplus_f": pplus_f_t, "pplus_m": pplus_m_t, 
                "pminus_f": pminus_f_t, "pminus_m": pminus_m_t}
    
    # Show phase portrait for 10k sample only
    print(f"\nPhase portrait for 10k sample at t={t}:")
    plot_inferred_portrait(params_t, t_id=2)

## Robustness Analysis: Multiple RNG Seeds

Test inference stability by running with different random seeds at t=1000.

In [None]:
# Test 10 different random seeds
seed_range = list(range(0, 10))

print("Testing inference with multiple seeds:")
print(f"Seeds: {seed_range}")
print("="*60)

for seed in seed_range:
    print(f"\n{'='*60}")
    print(f"Seed: {seed}")
    print(f"{'='*60}")
    
    try:
        pplus_f_s, pplus_m_s, pminus_f_s, pminus_m_s = tissue_regression(
            post_df, t=1000, seed=seed
        )
        
        # Check for negative coefficients (correct dynamics)
        for i in range(4):
            sample_size = [1, 5, 10, 25][i]
            f_sign = "✓" if pplus_f_s[i][1] < 0 else "✗"
            m_sign = "✓" if pplus_m_s[i][1] < 0 else "✗"
            print(f"{sample_size}k: F coef={pplus_f_s[i][1]:.6f} {f_sign}, "
                  f"M coef={pplus_m_s[i][1]:.6f} {m_sign}")
    except Exception as e:
        print(f"Failed with error: {e}")

## Quantitative Assessment: Parameter Sign Frequency

Check how often inferred models have correct (negative) coefficients across seeds.

In [None]:
# Check parameter signs across 20 seeds
test_seeds = list(range(0, 20))
total, freq_good = parameter_checker(post_df, test_seeds, t=1000)

print(f"\nParameter Sign Analysis ({total} seeds tested):")
print("="*60)
for i, freq in enumerate(freq_good):
    sample_size = [1, 5, 10, 25][i]
    print(f"{sample_size:>2}k sample: {freq*100:.0f}% with correct (negative) coefficients")

print("\nInterpretation:")
print("  Negative coefficients → correct density-division relationship")
print("  Positive coefficients → reversed dynamics (incorrect inference)")

## Overlapping Phase Portraits

Plot phase portraits with fixed points from multiple seeds overlaid on single streamlines.

In [None]:
# Plot overlapping phase portraits for 10k sample
seed_range_plot = list(range(0, 10))

print(f"Overlapping phase portraits: 10 seeds")
print("="*60)
plot_inferred_portrait(post_df, t_id=2, srange=seed_range_plot)

## Summary

**Key Findings:**

1. **Sample Size Matters**: 10k+ cell samples show substantially better inference quality
   - Smaller samples (1k) often fail or produce incorrect parameter signs
   - 10k and 25k samples converge more reliably to correct dynamics

2. **Time Point Effects**: 
   - Data closer to steady state (t=1000) works well for larger samples
   - Earlier time points can produce more variable results

3. **Inference Stability**:
   - Multiple RNG seeds show consistency for 10k+ samples
   - Smaller samples are highly sensitive to sampling variability

4. **Phase Portrait Quality**:
   - Good inferences show convergent streamlines toward central fixed point (16, 16)
   - Poor inferences show divergent or reversed flow patterns

**Validation Success**: The OSDR method successfully recovers the ground truth dynamics when:
- Sample size is adequate (≥10k cells)
- Data includes sufficient variation around steady state
- Multiple sampling replicates used to assess reliability

**Next step:** Visualize model fits and logistic regression curves in notebook 4.