# MorphServe Extension — Interactive Notebook

Demonstrates that **contiguous middle-block swapping** matches or beats scattered LIS-based selection for runtime FP16↔INT4 layer morphing.

This notebook imports from the `morphserve/` package. Make sure you've installed dependencies:
```bash
pip install -r requirements.txt
```

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

print(f"GPU available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Load Models and Calibration Data

In [None]:
from morphserve.models import load_fp16_model, load_int4_model, load_calibration_data

model_fp16, tokenizer, num_layers = load_fp16_model()
model_int4 = load_int4_model()
inputs_list = load_calibration_data(tokenizer, n_texts=20, max_length=512)
inputs = inputs_list[0]
print(f"Sequence length: {inputs['input_ids'].shape[1]} tokens")

## 2. Layer Sensitivity Analysis

In [None]:
from morphserve.sensitivity import (
    compute_lts_scores, compute_lrs_scores,
    compute_mds_scores, compute_lis_scores
)

print("Computing LTS...")
lts_scores, layer_outputs = compute_lts_scores(model_fp16, inputs)
for i, s in enumerate(lts_scores):
    print(f"  Layer {i:2d}: LTS = {s:.6f}")
print("Higher LTS = layer transforms less = safer to swap")

In [None]:
print("Computing LRS...")
lrs_scores = compute_lrs_scores(model_fp16, model_int4, inputs, layer_outputs)
for i, s in enumerate(lrs_scores):
    print(f"  Layer {i:2d}: LRS = {s:.6f}")
print("Higher LRS = quantization changes layer less = safer to swap")

In [None]:
print("Computing MDS (this takes a minute)...")
mds_scores = compute_mds_scores(model_fp16, model_int4, inputs)
for i, s in enumerate(mds_scores):
    print(f"  Layer {i:2d}: MDS = {s:.6f}")
print("Higher MDS = swapping barely affects output = safer to swap")

In [None]:
lis_scores = compute_lis_scores(lts_scores, lrs_scores, mds_scores)

lis_ranking = sorted(range(num_layers), key=lambda i: lis_scores[i], reverse=True)
print("LIS swapping order (safest first):")
print(lis_ranking)

print(f"\n{'Rank':<6}{'Layer':<8}{'LIS':<10}{'LTS':<10}{'LRS':<10}{'MDS':<10}")
print("-" * 54)
for rank, idx in enumerate(lis_ranking):
    print(f"{rank:<6}{idx:<8}{lis_scores[idx]:<10.6f}"
          f"{lts_scores[idx]:<10.6f}{lrs_scores[idx]:<10.6f}{mds_scores[idx]:<10.6f}")

In [None]:
from morphserve.visualization import plot_sensitivity

plot_sensitivity(lts_scores, lrs_scores, mds_scores, lis_scores,
                 save_path='../figures/layer_sensitivity.png',
                 title_suffix=' — TinyLlama 1.1B')

from IPython.display import Image
Image('../figures/layer_sensitivity.png')

## 3. Swapping Strategy Comparison

In [None]:
from morphserve.strategies import (
    compute_perplexity, test_ordering,
    greedy_lis_order, find_best_contiguous_block
)

ppl_fp16 = compute_perplexity(model_fp16, inputs['input_ids'])
print(f"FP16 baseline perplexity: {ppl_fp16:.4f}")

front_to_back = list(range(num_layers))
back_to_front = list(range(num_layers - 1, -1, -1))
swap_counts = [1, 2, 4, 8, 11, 16, 22]

print(f"\n{'N swapped':<12}{'LIS':<12}{'Front-Back':<12}{'Back-Front':<12}")
print("-" * 48)
for n in swap_counts:
    ppl_lis = test_ordering(lis_ranking, n, model_fp16, model_int4, inputs['input_ids'])
    ppl_ftb = test_ordering(front_to_back, n, model_fp16, model_int4, inputs['input_ids'])
    ppl_btf = test_ordering(back_to_front, n, model_fp16, model_int4, inputs['input_ids'])
    print(f"{n:<12}{ppl_lis:<12.4f}{ppl_ftb:<12.4f}{ppl_btf:<12.4f}")

In [None]:
print("Computing greedy LIS ordering (this takes a while)...\n")
greedy_order, greedy_ppls = greedy_lis_order(
    model_fp16, model_int4, inputs, lts_scores, lrs_scores
)
print(f"\nGreedy order: {greedy_order}")

In [None]:
# Full perplexity curves
swap_counts_full = list(range(1, num_layers + 1))
ftb_ppls = [test_ordering(front_to_back, n, model_fp16, model_int4, inputs['input_ids'])
            for n in swap_counts_full]
static_ppls = [test_ordering(lis_ranking, n, model_fp16, model_int4, inputs['input_ids'])
               for n in swap_counts_full]

from morphserve.visualization import plot_strategy_comparison
plot_strategy_comparison(greedy_ppls, static_ppls, ftb_ppls, ppl_fp16,
                         greedy_order=greedy_order,
                         save_path='../figures/strategy_comparison.png')

Image('../figures/strategy_comparison.png')

## 4. Contiguous Block Analysis

In [None]:
block_swap_counts = [1, 2, 4, 8, 11, 16]

print(f"{'N swapped':<12}{'Best block':<20}{'Block PPL':<14}{'LIS PPL':<14}{'FtB PPL':<14}")
print("-" * 72)

for n in block_swap_counts:
    best_start, best_ppl = find_best_contiguous_block(
        model_fp16, model_int4, n, inputs['input_ids']
    )
    ppl_lis = test_ordering(greedy_order, n, model_fp16, model_int4, inputs['input_ids'])
    ppl_ftb = test_ordering(front_to_back, n, model_fp16, model_int4, inputs['input_ids'])
    print(f"{n:<12}[{best_start}-{best_start+n-1}]{'':<14}{best_ppl:<14.4f}"
          f"{ppl_lis:<14.4f}{ppl_ftb:<14.4f}")

In [None]:
import torch.nn.functional as F
import numpy as np
from morphserve.strategies import swap_layers

# Middle vs Edge comparison
inputs2 = inputs_list[1] if len(inputs_list) > 1 else inputs

print("Same number of layers swapped: MIDDLE vs EDGES\n")
print(f"{'N swapped':<12}{'Middle block':<18}{'Edge layers':<18}{'Difference':<12}")
print("-" * 60)

for n in [2, 4, 6, 8]:
    half = n // 2
    mid_start = num_layers // 2 - (n // 2)
    mid_block = list(range(mid_start, mid_start + n))
    edge_layers = list(range(half)) + list(range(num_layers - half, num_layers))

    test_inputs = [inputs, inputs2]
    mid_ppls, edge_ppls = [], []
    for inp in test_inputs:
        with swap_layers(model_fp16, model_int4, mid_block):
            mid_ppls.append(compute_perplexity(model_fp16, inp['input_ids']))
        with swap_layers(model_fp16, model_int4, edge_layers):
            edge_ppls.append(compute_perplexity(model_fp16, inp['input_ids']))

    avg_mid = np.mean(mid_ppls)
    avg_edge = np.mean(edge_ppls)
    print(f"{n:<12}{avg_mid:<18.4f}{avg_edge:<18.4f}{avg_edge - avg_mid:+.4f}")
    print(f"{'':12}  middle: {mid_block}")
    print(f"{'':12}  edges:  {edge_layers}\n")

## 5. CUDA Benchmark (Overlap Proof)

In [None]:
from morphserve.benchmark import stats, benchmark_overlap, benchmark_scattered_vs_block

ref_weight = model_fp16.model.layers[10].self_attn.q_proj.weight.data

def simulate_inference(n_steps=50):
    x = torch.randn(1, 2048, dtype=torch.float16, device='cuda')
    for _ in range(n_steps):
        x = x @ ref_weight.T
        x = x / x.norm()
    return x

# Warmup
for _ in range(10):
    simulate_inference()
torch.cuda.synchronize()

# Prepare swap buffer
layer = model_fp16.model.layers[10]
flat = torch.cat([p.data.flatten() for p in layer.parameters()])
host_pinned = torch.empty_like(flat, device='cpu', pin_memory=True)
host_pinned.copy_(flat.cpu())
gpu_buf = flat.clone().cuda()

no_swap = benchmark_overlap(simulate_inference, None, n_iter=20)
with_swap = benchmark_overlap(
    simulate_inference,
    lambda: gpu_buf.copy_(host_pinned, non_blocking=True),
    n_iter=20
)

print(f"Inference (no swap):          {stats(no_swap['compute_times'])}")
print(f"Inference (overlapped swap):  {stats(with_swap['compute_times'])}")
print(f"Swap time (separate stream):  {stats(with_swap['swap_times'])}")

In [None]:
result = benchmark_scattered_vs_block(
    model_fp16,
    scattered_indices=[5, 10, 11, 17],
    block_indices=[8, 9, 10, 11],
    n_iter=50
)

print(f"Scattered: {stats(result['scattered_times'])}")
print(f"Block:     {stats(result['block_times'])}")
print(f"Speedup:   {result['speedup']:.2f}x")
print(f"Jitter reduction: {result['jitter_reduction']:.1f}x")

## 6. Burst Serving Simulation

In [None]:
from morphserve.simulation import MinimalServingSimulator
from morphserve.visualization import plot_burst_results

# Warmup
for _ in range(5):
    with torch.no_grad():
        model_fp16(tokenizer("Hello", return_tensors="pt").input_ids.to("cuda"))
torch.cuda.synchronize()

results = {}
for policy in ['none', 'scattered', 'block']:
    sim = MinimalServingSimulator(model_fp16, model_int4, tokenizer, max_kv_blocks=50)
    ttfts, tpots, phases, schedule = sim.run_burst_experiment(
        swap_policy=policy, decode_steps=10
    )
    results[policy] = {'ttfts': ttfts, 'tpots': tpots, 'phases': phases, 'schedule': schedule}
    print(f"Policy: {policy}")
    print(f"  TTFT p50: {np.percentile(ttfts, 50):.2f}ms  p95: {np.percentile(ttfts, 95):.2f}ms")
    print(f"  TPOT p50: {np.percentile(tpots, 50):.2f}ms  p95: {np.percentile(tpots, 95):.2f}ms")

plot_burst_results(results, save_path='../figures/burst_serving_results.png')
Image('../figures/burst_serving_results.png')