<a href="https://colab.research.google.com/github/kmalik22/openpi/blob/kmalik2_sharding/physical_intelligence_plots.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Megatron Tensor Parallel vs FSDP

In [2]:
# Requires: pandas, plotly
# pip install pandas plotly

import pandas as pd
import plotly.express as px
import ast
import math
from io import StringIO

CSV_TEXT = """seq_len,sharding_strategy,batch_size,num_shards,model_dim,hidden_dim,mean_time_ms,std_time_ms,wall_time_sec,raw_times
256,default,2,2,2048,8192,0.6266593933105469,0.011472468163698225,11.454527378082275,"[0.6499290466308594, 0.637054443359375, 0.6208419799804688, 0.6296634674072266, 0.6225109100341797, 0.6394386291503906, 0.6191730499267578, 0.6213188171386719, 0.6120204925537109, 0.6146430969238281]"
256,megatron,2,2,2048,8192,0.33152103424072266,0.007583997189761386,11.49135446548462,"[0.3345012664794922, 0.3409385681152344, 0.347137451171875, 0.32901763916015625, 0.3330707550048828, 0.331878662109375, 0.3216266632080078, 0.32639503479003906, 0.3287792205810547, 0.3218650817871094]"
512,default,2,2,2048,8192,0.7415294647216797,0.013186853803778573,11.434635400772095,"[0.7417201995849609, 0.7498264312744141, 0.7302761077880859, 0.7708072662353516, 0.7297992706298828, 0.7507801055908203, 0.7276535034179688, 0.7302761077880859, 0.7336139678955078, 0.7505416870117188]"
512,megatron,2,2,2048,8192,0.40678977966308594,0.005807633856851605,11.711997747421265,"[0.4100799560546875, 0.41556358337402344, 0.408172607421875, 0.41174888610839844, 0.4100799560546875, 0.4093647003173828, 0.3979206085205078, 0.40459632873535156, 0.40435791015625, 0.3960132598876953]"
1024,default,2,2,2048,8192,1.0421037673950195,0.06998756577108811,11.753795146942139,"[0.9958744049072266, 1.0879039764404297, 1.0995864868164062, 1.0516643524169922, 1.0235309600830078, 1.049041748046875, 1.1980533599853516, 1.0056495666503906, 0.9481906890869141, 0.9615421295166016]"
1024,megatron,2,2,2048,8192,0.5845069885253906,0.007690394444990392,11.515445470809937,"[0.5977153778076172, 0.5946159362792969, 0.5884170532226562, 0.5853176116943359, 0.5786418914794922, 0.5900859832763672, 0.5776882171630859, 0.5743503570556641, 0.5826950073242188, 0.5755424499511719]"
2048,default,2,2,2048,8192,1.4653444290161133,0.016355045328555824,11.935031652450562,"[1.4750957489013672, 1.491546630859375, 1.4629364013671875, 1.4913082122802734, 1.451730728149414, 1.470804214477539, 1.4638900756835938, 1.4591217041015625, 1.4421939849853516, 1.4448165893554688]"
2048,megatron,2,2,2048,8192,1.0007619857788086,0.005350625001740073,11.754512310028076,"[1.0063648223876953, 1.0097026824951172, 1.0030269622802734, 1.0044574737548828, 1.0008811950683594, 0.9996891021728516, 0.9989738464355469, 0.9987354278564453, 0.9894371032714844, 0.9963512420654297]"
4096,default,2,2,2048,8192,1.9257068634033203,0.00666037567100945,11.281436204910278,"[1.9392967224121094, 1.9283294677734375, 1.9354820251464844, 1.9228458404541016, 1.9159317016601562, 1.9240379333496094, 1.9252300262451172, 1.9197463989257812, 1.922607421875, 1.9235610961914062]"
4096,megatron,2,2,2048,8192,1.8126487731933594,0.004047221216950025,12.024887561798096,"[1.8203258514404297, 1.8177032470703125, 1.817464828491211, 1.8095970153808594, 1.8124580383300781, 1.8115043640136719, 1.8105506896972656, 1.8088817596435547, 1.8086433410644531, 1.8093585968017578]"
8192,default,2,2,2048,8192,3.1985044479370117,0.01313045816320869,11.958216190338135,"[3.1790733337402344, 3.184795379638672, 3.1936168670654297, 3.1898021697998047, 3.2045841217041016, 3.2203197479248047, 3.2110214233398438, 3.214120864868164, 3.201007843017578, 3.1867027282714844]"
8192,megatron,2,2,2048,8192,3.4035682678222656,0.01086948184124054,12.531583309173584,"[3.431081771850586, 3.4101009368896484, 3.398418426513672, 3.3991336822509766, 3.4034252166748047, 3.4008026123046875, 3.3898353576660156, 3.398418426513672, 3.3948421478271484, 3.4096240997314453]"
"""

# === Read CSV ===
df = pd.read_csv(StringIO(CSV_TEXT))

# === Parse raw_times and recompute mean/std/stderr ===
df["raw_times_list"] = df["raw_times"].apply(ast.literal_eval)
df["n"] = df["raw_times_list"].apply(len)
df["mean_time_ms"] = df["raw_times_list"].apply(lambda xs: pd.Series(xs).mean())
df["std_time_ms"] = df["raw_times_list"].apply(lambda xs: pd.Series(xs).std(ddof=1))
df["stderr_time_ms"] = df.apply(
    lambda r: r["std_time_ms"] / math.sqrt(r["n"]) if r["n"] > 0 else float("nan"),
    axis=1,
)

# === Title info ===
model_dim = df["model_dim"].iloc[0]
hidden_dim = df["hidden_dim"].iloc[0]
num_shards = df["num_shards"].iloc[0]
batch_size = df["batch_size"].iloc[0]

title1 = (
    f"Mean Latency ± StdErr vs Seq Len "
    f"(model_dim={model_dim}, hidden_dim={hidden_dim}, "
    f"num_shards={num_shards}, batch_size={batch_size})"
)

# === First plot: mean latency with stderr ===
fig1 = px.line(
    df,
    x="seq_len",
    y="mean_time_ms",
    color="sharding_strategy",
    error_y="stderr_time_ms",
    markers=True,
    title=title1,
    labels={"seq_len": "Sequence Length", "mean_time_ms": "Mean Latency (ms)"},
)

fig1.update_layout(
    template="plotly_white",
    hovermode="x unified",
    legend_title_text="Strategy",
    width=1000,   # <-- set width here
    height=500   # <-- set height here
)
fig1.show()

# === Second plot: Speedup (default / megatron) with propagated stderr ===
pivot = df.pivot(index="seq_len", columns="sharding_strategy", values=["mean_time_ms", "stderr_time_ms"])
pivot.columns = ["_".join(col).strip() for col in pivot.columns.values]
pivot = pivot.reset_index()

pivot["speedup"] = pivot["mean_time_ms_default"] / pivot["mean_time_ms_megatron"]
pivot["speedup_stderr"] = pivot["speedup"] * (
    (pivot["stderr_time_ms_default"] / pivot["mean_time_ms_default"])**2
    + (pivot["stderr_time_ms_megatron"] / pivot["mean_time_ms_megatron"])**2
) ** 0.5

fig2 = px.line(
    pivot,
    x="seq_len",
    y="speedup",
    error_y="speedup_stderr",
    markers=True,
    title=f"Speedup of Megatron vs Default (same params as above)",
    labels={"seq_len": "Sequence Length", "speedup": "Speedup (Default / Megatron)"},
)

fig2.update_layout(
    template="plotly_white",
    hovermode="x unified",
    width=1000,   # <-- set width here
    height=500   # <-- set height here
)
fig2.show()

In [8]:
# @title Megatron vs Default: clean plots (Plotly)
import numpy as np
import plotly.graph_objects as go

# ── Replace these two lists with your timing data ───────────────────────────────
MEGATRON_TIMES_MS = [3.868579864501953, 3.9048194885253906, 3.86810302734375, 3.8712024688720703, 3.8726329803466797, 3.8559436798095703, 3.864288330078125, 3.8690567016601562, 3.8619041442871094, 3.8716793060302734, 3.8759708404541016, 3.8743019104003906, 3.8666725158691406, 3.86810302734375, 3.8673877716064453, 3.869295120239258, 3.8568973541259766, 3.870248794555664, 3.859996795654297, 3.8177967071533203, 3.8323402404785156, 3.827333450317383, 3.8127899169921875, 3.827333450317383, 3.8404464721679688, 3.818511962890625, 3.8301944732666016, 3.8115978240966797, 3.8194656372070312, 3.8254261016845703, 3.8263797760009766, 3.8292407989501953, 3.8306713104248047, 3.8373470306396484, 3.823518753051758, 3.827810287475586, 3.8270950317382812, 3.8292407989501953, 3.815889358520508, 3.8301944732666016, 3.815889358520508, 3.8368701934814453, 3.8247108459472656, 3.8318634033203125, 3.8270950317382812, 3.8352012634277344, 3.8335323333740234, 3.8220882415771484, 3.8635730743408203, 3.8607120513916016]
DEFAULT_TIMES_MS  = [4.194736480712891, 3.851175308227539, 3.8580894470214844, 4.205226898193359, 4.210233688354492, 4.192829132080078, 4.207611083984375, 3.9169788360595703, 4.189729690551758, 3.9098262786865234, 3.898143768310547, 3.830432891845703, 3.8099288940429688, 3.821134567260742, 4.156589508056641, 4.156827926635742, 4.149675369262695, 3.809213638305664, 4.134416580200195, 4.173040390014648, 3.866434097290039, 3.888845443725586, 3.820657730102539, 4.159450531005859, 4.167079925537109, 4.170894622802734, 4.143238067626953, 3.874063491821289, 3.821849822998047, 4.159212112426758, 3.8886070251464844, 3.815889358520508, 4.149198532104492, 4.153966903686523, 3.884553909301758, 4.169225692749023, 3.817319869995117, 3.8268566131591797, 3.8216114044189453, 4.162073135375977, 4.166603088378906, 3.8192272186279297, 4.152059555053711, 3.8421154022216797, 3.8864612579345703, 4.149198532104492, 4.156351089477539, 4.158735275268555, 4.199504852294922, 3.9157867431640625]
# ───────────────────────────────────────────────────────────────────────────────

# ---- Tuning knobs ----
FIG_W, FIG_H = 1200, 600
FONT_SIZE    = 16
MARKER_SIZE  = 6
LINE_WIDTH   = 2
# ----------------------

n = min(len(MEGATRON_TIMES_MS), len(DEFAULT_TIMES_MS))
mega = np.asarray(MEGATRON_TIMES_MS[:n], float)
deft = np.asarray(DEFAULT_TIMES_MS[:n], float)
runs = np.arange(1, n + 1)

mean_mega, std_mega = mega.mean(), mega.std(ddof=1) if n > 1 else (mega.mean(), 0.0)
mean_def,  std_def  = deft.mean(), deft.std(ddof=1) if n > 1 else (deft.mean(), 0.0)

speedup = deft / mega
arith_mean_speed = float(speedup.mean())
geo_mean_speed   = float(np.exp(np.log(speedup).mean()))

MEGA_COLOR = "#1f77b4"  # blue
DEF_COLOR  = "#ff7f0e"  # orange
SPD_COLOR  = "#2ca02c"  # green

layout_common = dict(
    template="simple_white",
    hovermode="x",
    margin=dict(l=60, r=20, t=60, b=50),
    width=FIG_W, height=FIG_H,
    font=dict(size=FONT_SIZE),
    legend=dict(orientation="h", yanchor="bottom", y=1.04, xanchor="center", x=0.5, font=dict(size=FONT_SIZE-1)),
)

# ===== Plot 1: Raw times =====
fig1 = go.Figure()
fig1.add_trace(go.Scatter(
    x=runs, y=mega, mode="lines+markers",
    name=f"Megatron (μ={mean_mega:.3f} ms, σ={std_mega:.3f})",
    line=dict(color=MEGA_COLOR, width=LINE_WIDTH),
    marker=dict(size=MARKER_SIZE),
    hovertemplate="Run %{x}<br>%{y:.4f} ms<extra>Megatron</extra>",
))
fig1.add_trace(go.Scatter(
    x=runs, y=deft, mode="lines+markers",
    name=f"Default (μ={mean_def:.3f} ms, σ={std_def:.3f})",
    line=dict(color=DEF_COLOR, width=LINE_WIDTH),
    marker=dict(size=MARKER_SIZE),
    hovertemplate="Run %{x}<br>%{y:.4f} ms<extra>Default</extra>",
))
# Mean lines (same color, dashed)
fig1.add_hline(y=mean_mega, line_dash="dash", line_width=2, line_color=MEGA_COLOR)
fig1.add_hline(y=mean_def,  line_dash="dash", line_width=2, line_color=DEF_COLOR)

fig1.update_layout(layout_common | {"title":"Raw Forward-Pass Times (ms)"})
fig1.update_xaxes(title_text="Run #", showgrid=True, gridcolor="rgba(0,0,0,0.08)")
fig1.update_yaxes(title_text="Time (ms)", showgrid=True, gridcolor="rgba(0,0,0,0.08)")
fig1.show()

# ===== Plot 2: Speedup =====
fig2 = go.Figure()
fig2.add_trace(go.Scatter(
    x=runs, y=speedup, mode="lines+markers",
    name="Speedup (Default ÷ Megatron)",
    line=dict(color=SPD_COLOR, width=LINE_WIDTH),
    marker=dict(size=MARKER_SIZE),
    hovertemplate="Run %{x}<br>%{y:.3f}×<extra>Speedup</extra>",
))
# Mean speedup (geometric; add arith as subtitle text)
fig2.add_hline(y=geo_mean_speed, line_dash="dash", line_width=2, line_color=SPD_COLOR)

subtitle = f"Geom mean {geo_mean_speed:.3f}× • Arith mean {arith_mean_speed:.3f}×"
fig2.update_layout(layout_common | {"title": f"Speedup per Run — {subtitle}"})
fig2.update_xaxes(title_text="Run #", showgrid=True, gridcolor="rgba(0,0,0,0.08)")
fig2.update_yaxes(title_text="Speedup (×)", showgrid=True, gridcolor="rgba(0,0,0,0.08)")
fig2.show()

# TP vs FSDP