# A test bench testing the functional part of a MLP

In [None]:
import torch
import torch.nn as nn
import math

# Simple MLP test bench
# This notebook now defines a small feed-forward MLP (two-layer MLP with activation)
# and runs basic functional tests (shape checks, dtype checks, finite outputs).

class SimpleMLP(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=3072, output_dim=768, activation=None):
        super().__init__()
        # allow passing None to use a default GELU activation
        if activation is None:
            activation = nn.GELU()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.act = activation
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # x: (B, S, D) -> treat last dim as feature dimension
        return self.fc2(self.act(self.fc1(x)))

# Small helper for printing model summary
def model_summary(model, input_shape=None, dtype=torch.float32):
    # If input_shape is not provided, infer a reasonable one from model.fc1
    if input_shape is None:
        try:
            input_dim = model.fc1.in_features
        except Exception:
            input_dim = 768
        input_shape = (1, 50, input_dim)
    x = torch.randn(*input_shape, dtype=dtype)
    with torch.no_grad():
        out = model(x)
    print(f"Input shape: {input_shape}, Output shape: {tuple(out.shape)}, dtype: {out.dtype}")


# Load the specific weights for the particular MLP we care about

This cell will load a safetensors checkpoint (if present) and pick out any parameters related to the action_time MLP, then print the keys and shapes for inspection.

In [None]:
from safetensors.torch import load_file
import os
import torch

ALL_WEIGHTS_PATH = "../weights/downloads/model.safetensors"

if not os.path.exists(ALL_WEIGHTS_PATH):
    print(f"Checkpoint not found at {ALL_WEIGHTS_PATH}. If you have a checkpoint, set ALL_WEIGHTS_PATH to its path.")
    state = None
else:
    state = load_file(ALL_WEIGHTS_PATH)
    print(f"Loaded checkpoint from {ALL_WEIGHTS_PATH} with {len(state)} keys")

# find keys relevant to action_time MLP
if state is not None:
    action_keys = [k for k in state.keys() if 'action_time_mlp' in k]
    print('Found action_time-related keys:', action_keys)
    # create a small mapping for easier access
    lw = {}
    for k in action_keys:
        # normalize the end-name so we can match e.g. '...action_time_mlp_in.weight' or '...action_time_mlp_out.weight'
        if 'action_time_mlp_in' in k:
            short = k.split('action_time_mlp_in')[-1].lstrip('.')
            lw['in' + ('.' + short if short else '')] = state[k]
        elif 'action_time_mlp_out' in k:
            short = k.split('action_time_mlp_out')[-1].lstrip('.')
            lw['out' + ('.' + short if short else '')] = state[k]
        else:
            # fallback: keep full key name
            lw[k] = state[k]

    print('\nNormalized keys available:')
    for k in lw:
        print(k, getattr(lw[k], 'shape', None), getattr(lw[k], 'dtype', None))

    # instantiate the SimpleMLP with action_time dims
    cfg = {'input_dim': 1440, 'hidden_dim': 720, 'output_dim': 720}
    model = SimpleMLP(**cfg)

    # attempt to copy into model parameters
    with torch.no_grad():
        # fc1 weight
        possible_in_w = [k for k in state.keys() if k.endswith('action_time_mlp_in.weight') or 'action_time_mlp_in.weight' in k]
        possible_in_b = [k for k in state.keys() if k.endswith('action_time_mlp_in.bias') or 'action_time_mlp_in.bias' in k]
        possible_out_w = [k for k in state.keys() if k.endswith('action_time_mlp_out.weight') or 'action_time_mlp_out.weight' in k]
        possible_out_b = [k for k in state.keys() if k.endswith('action_time_mlp_out.bias') or 'action_time_mlp_out.bias' in k]

        def _try_copy(src_key, target_tensor):
            t = torch.as_tensor(state[src_key], dtype=torch.float32)
            if tuple(t.shape) == tuple(target_tensor.shape):
                target_tensor.copy_(t.to(target_tensor.dtype))
                print(f'Copied {src_key} -> target shape {t.shape}')
                return True
            else:
                print(f'Shape mismatch {src_key} {t.shape} vs target {tuple(target_tensor.shape)}')
                return False

        copied = False
        if possible_in_w:
            copied |= _try_copy(possible_in_w[0], model.fc1.weight)
        if possible_in_b:
            copied |= _try_copy(possible_in_b[0], model.fc1.bias)
        if possible_out_w:
            copied |= _try_copy(possible_out_w[0], model.fc2.weight)
        if possible_out_b:
            copied |= _try_copy(possible_out_b[0], model.fc2.bias)

    if not copied:
        print('No compatible action_time MLP parameters were copied. You may need to inspect key names or provide a different checkpoint.')

    # run a quick forward check and print intermediate shapes/dtypes
    x = torch.randn(1, 10, cfg['input_dim'])
    with torch.no_grad():
        fc1_out = model.fc1(x)
        act_out = model.act(fc1_out)
        fc2_out = model.fc2(act_out)
        final_out = fc2_out

    print('\nForward tensors:')
    print('  input:', x.shape, x.dtype)
    print('  fc1_out:', fc1_out.shape, fc1_out.dtype)
    print('  act_out:', act_out.shape, act_out.dtype)
    print('  fc2_out:', fc2_out.shape, fc2_out.dtype)
    print('  final_out:', final_out.shape, final_out.dtype)
else:
    # no checkpoint; still instantiate and run and print intermediates
    cfg = {'input_dim': 1440, 'hidden_dim': 720, 'output_dim': 720}
    model = SimpleMLP(**cfg)
    x = torch.randn(1, 10, cfg['input_dim'])
    with torch.no_grad():
        fc1_out = model.fc1(x)
        act_out = model.act(fc1_out)
        fc2_out = model.fc2(act_out)
        final_out = fc2_out

    print('No checkpoint loaded; sample forward tensors:')
    print('  input:', x.shape, x.dtype)
    print('  fc1_out:', fc1_out.shape, fc1_out.dtype)
    print('  act_out:', act_out.shape, act_out.dtype)
    print('  fc2_out:', fc2_out.shape, fc2_out.dtype)
    print('  final_out:', final_out.shape, final_out.dtype)
