<a href="https://colab.research.google.com/github/MatiasNazareth1993-coder/Virtual-cell/blob/main/Untitled89.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================================
# üî∫ 3D Attention Visualization (animated surface over time)
# Paste this after training and after the model/env exist.
# Requires: plotly
# ============================================================

!pip install -q plotly

import numpy as np
import torch
import plotly.graph_objects as go
from torch.nn import MultiheadAttention
import time
import os

# ---------- Settings ----------
NUM_FRAMES = 120    # how many time steps/frames to capture (or set to rollout length)
OUTPUT_HTML = "attention_3d_animation.html"
CAPTURE_STEPS = 120  # number of env steps to capture attention for

# ---------- Utility: try to find a MultiheadAttention inside extractor ----------
def find_mha_modules(module):
    return [m for m in module.modules() if isinstance(m, MultiheadAttention)]

extractor = model.policy.features_extractor  # use your trained model's extractor

mha_modules = find_mha_modules(extractor)
if len(mha_modules) == 0:
    print("‚ö†Ô∏è No MultiheadAttention modules found in extractor; cannot extract attention weights.")
else:
    print(f"Found {len(mha_modules)} MultiheadAttention module(s). Using the first one for visualization.")

# ---------- Capture attention weights over a rollout ----------
# We'll attempt to capture attn weights exposed as module.attn_output_weights after forward.
attn_time_series = []  # list of (N_cells x N_cells) matrices (averaged over heads)

# If your extractor's transformer has multiple layers/heads you can optionally average across them
target_mha = mha_modules[0] if len(mha_modules)>0 else None

# forward-hook fallback: store the most recent attn weights (if module exposes them)
captured = {"weights": None}

def hook_fn(module, input, output):
    # Some versions of MultiheadAttention expose `attn_output_weights` as an attribute after forward.
    # Try to read it; fallback to None.
    w = None
    if hasattr(module, "attn_output_weights"):
        w = module.attn_output_weights.detach().cpu().numpy()  # shape: (num_heads, target_len, source_len)
    # Some builds may store it under different attr; attempt common alternatives
    if w is None and hasattr(module, "attn_weights"):
        w = module.attn_weights.detach().cpu().numpy()
    captured["weights"] = w

hook_handle = None
if target_mha is not None:
    try:
        hook_handle = target_mha.register_forward_hook(hook_fn)
    except Exception as e:
        print("Could not register hook on MultiheadAttention:", e)
        hook_handle = None

# Run a rollout and run extractor on each observation to get per-step attention.
obs, _ = env.reset()
steps_to_capture = min(CAPTURE_STEPS, 3000)  # safety cap
for step in range(steps_to_capture):
    # call extractor to trigger attention computation
    with torch.no_grad():
        obs_t = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)  # (1, obs_dim)
        _ = extractor(obs_t)  # this should run the transformer -> hook captures weights

    # try to get captured weights; if None, fallback to computing attention manually (not trivial)
    w = captured.get("weights", None)
    if w is None:
        # fallback: try to run the MHA directly (construct queries/keys from extractor internals)
        # But that's fragile; we'll record NaNs so user knows capture failed.
        print(f"Step {step}: warning ‚Äî no attention weights captured, recording NaN matrix.")
        Ncells = NUM_CELLS
        attn_time_series.append(np.full((Ncells, Ncells), np.nan))
    else:
        # Depending on shape: (num_heads, tgt_len, src_len) or (tgt_len, src_len) if averaged.
        if w.ndim == 3:
            # average over heads
            avg = w.mean(axis=0)  # (tgt_len, src_len)
        elif w.ndim == 2:
            avg = w
        else:
            raise ValueError("Unexpected attention weights shape: " + str(w.shape))
        # If transformer treated cells as sequence tokens, tgt_len == src_len == NUM_CELLS
        attn_time_series.append(avg)
    # step environment with deterministic policy to evolve state
    action, _ = model.predict(obs, deterministic=True)
    obs, r, done, _, info = env.step(action)
    if done:
        break

# Remove hook
if hook_handle is not None:
    hook_handle.remove()

if len(attn_time_series) == 0:
    raise RuntimeError("No attention frames were collected; aborting visualization.")

# ---------- Prepare frames for Plotly ----------
# Convert list to numpy array: (T, N, N)
attn_arr = np.stack(attn_time_series, axis=0)  # shape (T, N, N)
T, N, M = attn_arr.shape
print(f"Captured attention array: frames={T}, shape per frame={N}x{M}")

# Some frames may contain NaNs if capture failed for some steps; replace NaNs with nearest valid frame average
if np.isnan(attn_arr).any():
    # find first valid frame
    valid_idx = np.where(~np.isnan(attn_arr).reshape(T, -1).any(axis=1))[0]
    if len(valid_idx) == 0:
        print("‚ö†Ô∏è All captured frames are NaN ‚Äî attention capture failed.")
    else:
        fill = attn_arr[valid_idx[0]]
        nan_mask = np.isnan(attn_arr)
        attn_arr[nan_mask] = np.take(fill, np.where(nan_mask)[1], axis=0)  # crude fill

# Normalize each frame for better surface scaling (optional)
# We'll scale to [0,1] per frame
attn_min = attn_arr.min(axis=(1,2), keepdims=True)
attn_max = attn_arr.max(axis=(1,2), keepdims=True)
attn_norm = (attn_arr - attn_min) / (np.maximum(attn_max - attn_min, 1e-9))

# Create x,y grid for surface: use source (columns) along x, target (rows) along y
x = np.arange(N)  # source cells
y = np.arange(N)  # querying cells

# ---------- Build plotly frames ----------
frames = []
for t in range(T):
    z = attn_norm[t]
    frame = go.Frame(
        data=[go.Surface(z=z, x=x, y=y, cmin=0, cmax=1, showscale=False)],
        name=str(t),
        traces=[0],
    )
    frames.append(frame)

# initial surface
init_z = attn_norm[0]

fig = go.Figure(
    data=[go.Surface(z=init_z, x=x, y=y, cmin=0, cmax=1, colorscale="Viridis", showscale=True)],
    layout=go.Layout(
        title="3D Attention Surface (per-cell interactions) ‚Äî animated",
        scene=dict(
            xaxis=dict(title="Source Cell (attention to)"),
            yaxis=dict(title="Querying Cell (attention from)"),
            zaxis=dict(title="Normalized Attention", range=[0,1]),
        ),
        updatemenus=[
            dict(
                type="buttons",
                buttons=[
                    dict(label="Play",
                         method="animate",
                         args=[None, {"frame": {"duration": 100, "redraw": True}, "fromcurrent": True}]),
                    dict(label="Pause",
                         method="animate",
                         args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}])
                ],
                pad={"r": 10, "t": 10},
                showactive=True,
                x=0.1,
                xanchor="right",
                y=0,
                yanchor="top"
            )
        ],
        sliders=[{
            "pad": {"b": 10, "t": 60},
            "len": 0.9,
            "x": 0.1,
            "y": 0,
            "steps": [
                {"args": [[str(k)], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}],
                 "label": str(k), "method": "animate"}
                for k in range(T)
            ]
        }]
    ),
    frames=frames
)

# Save interactive HTML
fig.write_html(OUTPUT_HTML)
print(f"Saved interactive 3D attention animation to: {OUTPUT_HTML}")
# If in Colab, display inline
try:
    from IPython.display import HTML, display
    display(HTML(fig.to_html(full_html=False, include_plotlyjs='cdn')))
except Exception:
    pass

# ---------- Optional: export to static mp4/gif (requires extra tools) ----------
# To save an MP4/GIF, you'd need to render frames to images and stitch them.
# Example (Colab): pip install imageio kaleido && use fig.to_image for each frame, then imageio.mimsave
# This can be slow and sometimes limited by server support for kaleido/ffmpeg.

In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from gymnasium import spaces

class CustomExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Box, features_dim: int = 64):
        super().__init__(observation_space, features_dim)
        self.extractor = AttentionReportingTransformerExtractor(observation_space)
        self._features_dim = features_dim

    def forward(self, observations):
        feats = self.extractor(observations)
        return feats

    def get_attention(self):
        """Retrieve latest attention weights (N x N)"""
        return getattr(self.extractor, "last_attention", None)

In [None]:
policy_kwargs = dict(
    features_extractor_class=CustomExtractor,
    features_extractor_kwargs=dict(features_dim=64),
)

model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)

In [None]:
obs, _ = env.reset()
for step in range(100):
    action, _ = model.predict(obs, deterministic=True)
    obs, rewards, done, _, info = env.step(action)

    # Retrieve per-step attention weights (N_cells x N_cells)
    attn = model.policy.features_extractor.get_attention()
    if attn is not None:
        print(f"Step {step} attention shape: {attn.shape}")