In [None]:
import numpy as np
import plotly.graph_objects as go

class WoundViz:
    # ------------------------------------------------------------------ #
    # INITIALISE
    # ------------------------------------------------------------------ #
    def __init__(self, dice:int = 10, S_rng=(3,11), T_rng=(3,11)):
        self.dice = dice
        self.S = np.arange(*S_rng)         # Strength axis
        self.T = np.arange(*T_rng)         # Toughness axis
        self.Sg, self.Tg = np.meshgrid(self.S, self.T)

        # discrete values you’ll toggle through
        self.save_vals = [2,3,4,5,6]               # armour save
        self.fnp_vals  = [None,6,5,4]              # None / 6+ / 5+ / 4+
        self.hit_vals  = [2,3,4,5,6]               # 2+ … 6+

        # pre-compute every combination once (fast)
        self.cache = {}
        for hit in self.hit_vals:
            h_prob = self.val_to_prob(hit)
            for sv in self.save_vals:
                sv_fail = self.fail_prob(sv)
                for fnp in self.fnp_vals:
                    fnp_fail = self.fail_prob(fnp)
                    Z = self._grid(h_prob, sv_fail, fnp_fail)
                    self.cache[(hit,sv,fnp)] = Z

    # ------------------------------------------------------------------ #
    #   DICE MATH HELPERS
    # ------------------------------------------------------------------ #
    @staticmethod
    def val_to_prob(v:int) -> float:           # e.g. 3  → 0.667
        return (7-v)/6

    def fail_prob(self, v) -> float:           # None → 1.0 (always fail)
        return 1.0 if v is None else 1-self.val_to_prob(v)

    @staticmethod
    def wound_prob(S, T):
        r = S/T
        return (5/6 if r>=2 else 4/6 if r>1 else 3/6 if r==1
                else 2/6 if r>.5 else 1/6)

    def _grid(self, hit_p, sv_fail, fnp_fail):
        wound_matrix = np.vectorize(self.wound_prob)(self.Sg, self.Tg)
        return self.dice * hit_p * wound_matrix * sv_fail * fnp_fail

    # ------------------------------------------------------------------ #
    #   PLOT CONSTRUCTION
    # ------------------------------------------------------------------ #
    def figure(self):
        # default view
        hit0, sv0, fnp0 = 3, 3, None
        surf0 = self.cache[(hit0, sv0, fnp0)]

        fig = go.Figure(data=[go.Surface(
            z=surf0,  x=self.S,  y=self.T,
            colorscale='Viridis', showscale=False,
            hovertemplate="Strength: %{x}<br>Toughness: %{y}<br>Wounds: %{z:.2f}<extra></extra>"
        )])

        # scene sizing (80 % for plot, 20 % left for controls)
        fig.update_layout(
            scene=dict(
                xaxis_title='Strength',  yaxis_title='Toughness',
                zaxis_title='Expected Wounds',
                domain=dict(x=[0.0,0.8]),      # leave 20 % width for buttons
                aspectratio=dict(x=.75, y=.75, z=1)
            ),
            margin=dict(l=10,r=10,b=10,t=30),
            width=900, height=700,
            title="Expected wounds for {} attacks".format(self.dice)
        )

        # ----------------------------------------------------------------  #
        #   SLIDERS
        # ----------------------------------------------------------------  #
        slider_steps_save, slider_steps_fnp = [], []
        for i,sv in enumerate(self.save_vals):
            slider_steps_save.append(dict(
                method="restyle",
                args=[{"z":[self.cache[(hit0, sv, fnp0)]]}],
                label=f"{sv}+"
            ))
        for i,fnp in enumerate(self.fnp_vals):
            label = "None" if fnp is None else f"{fnp}+"
            slider_steps_fnp.append(dict(
                method="restyle",
                args=[{"z":[self.cache[(hit0, sv0, fnp)]]}],
                label=label
            ))

        # place sliders vertically on the right
        fig.update_layout(sliders=[
            dict(active=1, steps=slider_steps_save,
                 y=0.55, len=0.35, x=0.82, pad=dict(t=20), currentvalue=dict(prefix="Save ")),
            dict(active=0, steps=slider_steps_fnp,
                 y=0.10, len=0.35, x=0.82, pad=dict(t=20), currentvalue=dict(prefix="FNP "))
        ])

        # ----------------------------------------------------------------  #
        #   RADIO-BUTTON GROUP  (updatemenus)
        # ----------------------------------------------------------------  #
        hit_buttons = []
        for h in self.hit_vals:
            hit_buttons.append(dict(
                label=f"Hits on {h}+",
                method="restyle",
                args=[{"z":[self.cache[(h, sv0, fnp0)]]}]
            ))
        fig.update_layout(
            updatemenus=[dict(
                type="buttons", buttons=hit_buttons,
                direction="down", showactive=True,
                x=0.82, y=0.95, xanchor="left"
            )])

        return fig

    # shortcut
    def run(self):
        return self.figure()

In [None]:
fig = WoundViz(dice=10).run()
fig.show()
