In [1]:
# 0) INSTALL DEPENDENCIES
!pip install -q dash pandas plotly

# 1) IMPORTS
import pandas as pd, plotly.express as px, plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State
from google.colab import files, output as colab_out
import base64, io, numpy as np

# 2) CSV/Excel LOADER & CLEANER
# 2) CSV/Excel LOADER & CLEANER
def int_count_to_str(x):
    try:
        s = str(x).strip()
        # already in "b-s" form?
        if '-' in s:
            b_str, s_str = s.split('-',1)
            b, s = int(b_str), int(s_str)
        else:
            # e.g. 22 → 2-2 ;  5 → 0-5 (invalid later)
            n = int(float(s))
            b, s = divmod(n, 10) if n>=10 else (0,n)
        if 0 <= b <= 3 and 0 <= s <= 2:
            return f"{b}-{s}"
    except:
        pass
    return None


def load_pitchlog(data_bytes):
    buf = io.BytesIO(data_bytes)
    df  = pd.read_csv(buf) if b',' in data_bytes[:100] else pd.read_excel(buf)

    df = df.rename(columns={
        'BREWERS_PITCHER':'Pitcher',
        'PITCH_TYPE'     :'PitchType',
        'PITCH_CALL'     :'PitchResult',
        'VELOCITY'       :'Velo',
        'COUNT'          :'RawCount'
    })

    df['DateTime'] = pd.to_datetime(df['DATE_TIME'], errors='coerce')
    df = df.sort_values('DateTime')
    df['GameDate'] = df['DateTime'].dt.date
    df['Count']    = df['RawCount'].apply(int_count_to_str)

    # first look for your new BATTER_SIDE field…
    if 'BATTER_SIDE' in df.columns:
        # fill NaN with 'All' so they show up under the "Both" selection
        df['BatterHand'] = df['BATTER_SIDE'].fillna('All').astype(str)
    # fallback to old BATTER_HAND if present
    elif 'BATTER_HAND' in df.columns:
        df['BatterHand'] = df['BATTER_HAND'].fillna('All').astype(str)
    # or try extracting from a full name column
    elif 'Batter' in df.columns:
        df['BatterHand'] = (
            df['Batter']
              .astype(str)
              .str.extract(r'\((R|L)\)', expand=False)
              .fillna('All')
        )
    else:
        df['BatterHand'] = 'All'

    df = df.dropna(subset=['Count','PitchType'])
    return df[['Pitcher','GameDate','BatterHand',
               'PitchType','PitchResult','Velo','Count','DateTime']]

# 3) PLYNKO BUILDER
swing_results = {'Foul','Swinging Strike','In Play, No Out','In Play, Out(s)',
                 'In Play, Run(s)','Foul Tip','Swinging Strike (Blocked)'}
def swing_pct(df):
    return df.groupby('Count')['PitchResult'].apply(lambda s: s.isin(swing_results).mean()).to_dict()
def avg_velo(df):
    return df.groupby('Count')['Velo'].mean().to_dict()

def build_plinko(df, mode='frequency'):
    df = df.sort_values('DateTime').reset_index(drop=True)
    valid = [f"{b}-{s}" for b in range(4) for s in range(3)]
    strike_cnt = {f"{b}-2" for b in range(4)}
    walk_cnt   = {"3-0", "3-1", "3-2"}
    end_play   = {'In Play, No Out','In Play, Out(s)','In Play, Run(s)'}
    df['at_bat_id'] = (df['Count'] == '0-0').cumsum()
    trans = {}
    for _, ab in df.groupby('at_bat_id'):
        ab = ab.reset_index(drop=True)
        for i in range(len(ab)):
            u, pr, pt = ab.loc[i, ['Count','PitchResult','PitchType']]
            if u in strike_cnt and pr in ('Swinging Strike','Called Strike','Foul Tip','Swinging Strike (Blocked)'):
                v = 'K'
            elif u in walk_cnt and pr == 'Ball':
                v = 'BB'
            elif pr in end_play:
                break
            else:
                if i < len(ab)-1:
                    v = ab.loc[i+1, 'Count']
                else:
                    continue
            if u in valid and v in valid+['K','BB']:
                rec = trans.setdefault((u,v), {'total':0,'types':{}})
                rec['total'] += 1
                rec['types'][pt] = rec['types'].get(pt,0)+1
            if v in ('K','BB'):
                break
    metrics = {
        'frequency': {**{c:0 for c in valid+['K','BB']},
                      **df.groupby('Count').size().to_dict()},
        'swing_percentage': swing_pct(df),
        'velocity': avg_velo(df)
    }[mode]
    cmin=cmax=None
    if mode=='velocity':
        s=df['Velo'].dropna(); cmin,cmax=(s.min(),s.max()) if not s.empty else (80,100)
    pos={'0-0':(0,3),'1-0':(1,2),'2-0':(2,1),'3-0':(3,0),
         '0-1':(-1,2),'1-1':(0,1),'2-1':(1,0),'3-1':(2,-1),
         '0-2':(-2,1),'1-2':(-1,0),'2-2':(0,-1),'3-2':(1,-2),
         'K':(-1,-3),'BB':(2,-3)}
    edge_tr, edge_hover = [], []
    for (u,v),info in trans.items():
        if u not in pos or v not in pos: continue
        x0,y0=pos[u]; x1,y1=pos[v]; wt=info['total']
        body="<br>".join(f"{pt}: {cnt}" for pt,cnt in sorted(info['types'].items(),
                                                             key=lambda t:t[1],reverse=True))
        edge_tr.append(go.Scatter(x=[x0,x1],y=[y0,y1],mode='lines',
                                  line=dict(width=wt,color='rgba(150,150,150,0.8)'),
                                  hoverinfo='none'))
        edge_hover.append(go.Scatter(x=[(x0+x1)/2],y=[(y0+y1)/2],
                                     mode='markers',marker=dict(size=25,color='rgba(0,0,0,0)'),
                                     hoverinfo='text',
                                     hovertext=f"<b>{u}→{v}</b><br>Total:{wt}<br><br>{body}"))
    node_x,node_y,node_color,node_text,node_hover=[],[],[],[],[]
    nodes=sorted(pos,key=lambda n:pos[n][1],reverse=True)
    for n in nodes:
        x,y=pos[n]; node_x+=[x]; node_y+=[y]
        val=metrics.get(n,0) or 0; node_color+=[val]; node_text+=[n]
        txt=f"Pitches:{int(val)}" if mode=='frequency' else \
            f"Swing %:{val:.1%}" if mode=='swing_percentage' else f"Avg Velo:{val:.1f}"
        node_hover+=[f"<b>{n}</b><br>{txt}"]
    node_tr=go.Scatter(x=node_x,y=node_y,mode='markers+text',text=node_text,
                       textposition='top center',hoverinfo='text',hovertext=node_hover,
                       marker=dict(size=50,colorscale='YlGnBu',showscale=True,color=node_color,
                                   cmin=cmin,cmax=cmax,
                                   colorbar=dict(title=mode.replace('_',' ').title(),thickness=15)))
    fig=go.Figure(data=edge_tr+edge_hover+[node_tr])
    fig.update_layout(title="Plinko Pitch-Sequence",xaxis=dict(visible=False),
                      yaxis=dict(visible=False),plot_bgcolor='white',
                      hovermode='closest',height=700,showlegend=False)
    return fig

# 4) APP LAYOUT
app = Dash(__name__, external_stylesheets=['https://codepen.io/chriddyp/pen/bWLwgP.css'])
app.layout = html.Div(style={'fontFamily':'Arial','maxWidth':'1000px','margin':'20px auto'}, children=[
    html.H1("Plinko Sequencing Dashboard", style={'textAlign':'center'}),
    dcc.Upload(id='upload', children=html.Button("📂 Upload CSV / Excel")),
    html.Div(id='file-info', style={'textAlign':'center','margin':'8px 0'}),
    dcc.Store(id='raw-data'),
    html.Div(style={'display':'flex','gap':'10px','marginBottom':'10px'}, children=[
        dcc.Dropdown(id='pitcher-dd', placeholder='Select pitcher', style={'flex':2}),
        dcc.Dropdown(id='date-dd', placeholder='Select dates', multi=True, style={'flex':3}),
        dcc.Dropdown(id='bat-hand-dd',
                     options=[{'label':'Both','value':'All'},
                              {'label':'Left','value':'L'},
                              {'label':'Right','value':'R'}],
                     value='All', style={'flex':1})
    ]),
    html.Div(style={'display':'flex','gap':'10px','marginBottom':'10px'}, children=[
        dcc.Dropdown(id='ptype-dd', placeholder='Pitch type (opt)', style={'flex':3}),
        dcc.RadioItems(id='mode',
                       options=[{'label':'Count freq','value':'frequency'},
                                {'label':'Swing %','value':'swing_percentage'},
                                {'label':'Avg velo','value':'velocity'}],
                       value='frequency', labelStyle={'marginRight':'15px'})
    ]),
    dcc.Graph(id='plinko-graph'),

    html.Div(style={'display':'grid','gridTemplateColumns':'repeat(2,1fr)','gap':'25px'}, children=[
        dcc.Graph(id='heat-freq'),
        dcc.Graph(id='heat-csw'),
        dcc.Graph(id='usage-count'),
        dcc.Graph(id='velo-count')
    ])
])

# 5) CALLBACKS
@app.callback(Output('raw-data','data'),
              Output('file-info','children'),
              Input('upload','contents'),
              State('upload','filename'))
def store_data(content, fname):
    if not content: return None, "No file uploaded."
    try:
        df = load_pitchlog(base64.b64decode(content.split(',')[1]))
        return df.to_json(date_format='iso', orient='split'), f"Loaded {len(df):,} pitches from {fname}"
    except Exception as e:
        return None, f"Error loading file: {e}"

@app.callback(Output('pitcher-dd','options'), Input('raw-data','data'))
def set_pitchers(json_df):
    if not json_df: return []
    df = pd.read_json(json_df, orient='split')
    return [{'label':p,'value':p} for p in sorted(df['Pitcher'].unique())]

@app.callback(Output('date-dd','options'),
              Input('pitcher-dd','value'),
              State('raw-data','data'))
def set_dates(pitcher, json_df):
    if not (pitcher and json_df): return []
    df = pd.read_json(json_df, orient='split')
    df['GameDate'] = pd.to_datetime(df['GameDate']).dt.date
    dates = sorted(df[df['Pitcher']==pitcher]['GameDate'].unique())
    return [{'label':d.isoformat(),'value':d.isoformat()} for d in dates]

@app.callback(Output('ptype-dd','options'),
              Input('raw-data','data'),
              Input('pitcher-dd','value'),
              Input('date-dd','value'))
def set_ptypes(json_df, pitcher, dates):
    if not json_df: return []
    df = pd.read_json(json_df, orient='split')
    df['GameDate'] = pd.to_datetime(df['GameDate']).dt.date
    if pitcher: df = df[df['Pitcher']==pitcher]
    if dates:
        sel = [pd.to_datetime(d).date() for d in dates]
        df = df[df['GameDate'].isin(sel)]
    return [{'label':pt,'value':pt} for pt in sorted(df['PitchType'].unique())]

@app.callback(Output('plinko-graph','figure'),
              Input('raw-data','data'),
              Input('pitcher-dd','value'),
              Input('date-dd','value'),
              Input('bat-hand-dd','value'),
              Input('ptype-dd','value'),
              Input('mode','value'))
def update_plinko(json_df, pitcher, dates, hand, ptype, mode):
    if not (json_df and pitcher and dates): return go.Figure()
    df = pd.read_json(json_df, orient='split')
    df['GameDate'] = pd.to_datetime(df['GameDate']).dt.date
    df = df[(df['Pitcher']==pitcher) &
            (df['GameDate'].isin([pd.to_datetime(d).date() for d in dates]))]
    if hand!='All': df = df[df['BatterHand']==hand]
    if ptype:       df = df[df['PitchType']==ptype]
    return build_plinko(df, mode)

# -------------------------------------------------------------------
# ---------- SECTION 6  —  HEATMAPS  (with aesthetic tweaks) ---------
# -------------------------------------------------------------------
def _matrix_fig(percent_df: pd.DataFrame, count_df: pd.DataFrame, title: str, scale: str):
    """
    Helper → returns a heat-map figure whose annotation shows
       <pct>% on first line
       (n)   on second line
    and whose y-tick labels are left-aligned.
    """
    # build annotation matrix
    annot = percent_df.copy()
    for r in percent_df.index:
        for c in percent_df.columns:
            pct = percent_df.loc[r, c]
            n = count_df.loc[r, c]
            annot.loc[r, c] = f"{pct:.1f}%<br>({n})" if n else ""

    fig = px.imshow(
        percent_df.round(1),
        text_auto=False,
        color_continuous_scale=scale,
        aspect="auto",
    )
    fig.update_traces(
        text=annot.values,
        texttemplate="%{text}",
        textfont_size=12,
    )
    fig.update_layout(
        title=title,
        margin=dict(l=0, r=0, t=40, b=0),
    )
    # left-align y tick labels
    fig.update_yaxes(
        autorange="reversed",
        ticklabelposition="outside left",
        tickfont=dict(size=12),
    )
    return fig


@app.callback(
    Output("heat-freq", "figure"),
    Output("heat-csw", "figure"),
    Output("usage-count", "figure"),
    Output("velo-count", "figure"),
    Input("raw-data", "data"),
    Input("pitcher-dd", "value"),
    Input("date-dd", "value"),
    Input("bat-hand-dd", "value"),
    Input("ptype-dd", "value"),
)
def update_heatmaps(json_df, pitcher, dates, hand, ptype):
    if not (json_df and pitcher and dates):
        return [go.Figure()] * 4

    df = pd.read_json(json_df, orient="split")
    df = df[df["Pitcher"] == pitcher]
    df["GameDate"] = pd.to_datetime(df["GameDate"]).dt.date
    df = df[df["GameDate"].isin([pd.to_datetime(d).date() for d in dates])]
    if hand != "All":
        df = df[df["BatterHand"] == hand]

    df = df.sort_values("DateTime")
    df["prev"] = df["PitchType"].shift()
    trans = df.dropna(subset=["prev"]).copy()

    # ----------------------------------------------------------------
    # A)  FREQ & CSW matrices  ---------------------------------------
    trans["csw"] = trans["PitchResult"].isin(
        {"Swinging Strike", "Called Strike", "Foul Tip", "Swinging Strike (Blocked)"}
    ).astype(int)

    freq_cnt = (
        trans.groupby(['prev', 'PitchType'])
            .size()
            .unstack(fill_value=0)
    )

    # 🔸 NEW – divide by the grand total, not the row total
    tot_trans = freq_cnt.values.sum()               # total number of transitions
    freq_pct  = (freq_cnt / tot_trans) * 100        # every cell = % of ALL transitions


    csw_cnt = freq_cnt.copy()  # denominator is the same
    csw_pct = (
        trans.groupby(["prev", "PitchType"])["csw"].mean().unstack(fill_value=0) * 100
    )

    csw_success = (
        trans.groupby(['prev','PitchType'])['csw']
            .sum()
            .unstack(fill_value=0)
    )

    # 4) build the CSW figure with “(success/total)” in each cell
    #    — left-aligned bucket names and color scale as before
    annot = csw_pct.copy().round(1).astype(str)  # start with pct
    for r in csw_pct.index:
        for c in csw_pct.columns:
            tot = int(freq_cnt.loc[r, c])
            suc = int(csw_success.loc[r, c])
            if tot:
                annot.loc[r,c] = f"{csw_pct.loc[r,c]:.1f}%<br>({suc}/{tot})"
            else:
                annot.loc[r,c] = ""

    csw_fig = px.imshow(
        csw_pct,
        text_auto=False,
        color_continuous_scale='Purples',
        aspect='auto'
    )
    csw_fig.update_traces(
        text=annot.values,
        texttemplate='%{text}',
        textfont_size=12
    )
    csw_fig.update_layout(
        title='Pitch-to-Pitch CSW %',
        margin=dict(l=0, r=0, t=40, b=0)
    )
    csw_fig.update_yaxes(
        autorange='reversed',
        ticklabelposition='outside left',
        tickfont=dict(size=12),
    )

    # ----------------------------------------------------------------
    # B)  USAGE matrix  ----------------------------------------------
    hitters_cts = {"1-0", "2-0", "2-1", "3-0", "3-1"}

    def bucket(c):
        if c == "0-0":
            return "0-0"
        if c == "1-1":
            return "1-1"
        if isinstance(c, str) and c.endswith("-2"):
            return "2K"
        if c in hitters_cts:
            return "Hitter's"
        return None  # <<<  drop "Other"

    df["Bucket"] = df["Count"].map(bucket)
    df = df.dropna(subset=["Bucket"])               # remove rows that mapped to None

    usage_cnt = df.groupby(["Bucket", "PitchType"]).size().unstack(fill_value=0)
    usage_pct = usage_cnt.div(usage_cnt.sum(axis=1), axis=0) * 100

    # ----------------------------------------------------------------
    # C)  VELO matrix  -----------------------------------------------
    if ptype:
        df_v = df[df["PitchType"] == ptype]
    else:
        df_v = df
    velo_cnt = df_v.groupby(["Bucket", "PitchType"]).size().unstack(fill_value=0)
    velo_val = df_v.groupby(["Bucket", "PitchType"])["Velo"].mean().unstack()

    # ----------------------------------------------------------------
    # Return four figs with common helper
    return (
        _matrix_fig(freq_pct,   freq_cnt,    'Pitch-to-Pitch Transition %', 'Blues'),
        csw_fig,
        _matrix_fig(usage_pct,  usage_cnt,   "Pitch Usage % by Bucket",      'YlOrBr'),
        _matrix_fig(velo_val,   velo_cnt,    "Avg Velo by Bucket",           'Viridis'),
    )


# 7) RUN  ────────────────────────────────────────────────────────────
colab_out.clear()
app.run(mode="inline", port=8050)

<IPython.core.display.Javascript object>