In [9]:
import sys
from pathlib import Path
SRC_ROOT = "../src"
if str(SRC_ROOT) not in sys.path:
    sys.path.insert(0, str(SRC_ROOT))

In [10]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
import random
import gymnasium as gym
import itertools
import json
import os
from torch.utils.tensorboard import SummaryWriter
from collections import OrderedDict
from datetime import datetime
from collections import deque
from twc.twc_builder import build_twc
from twc.twc_io import mcc_obs_encoder, twc_out_2_mcc_action
from td3.td3_train import TD3Config, td3_train
from td3.td3_engine import TD3Engine
from utils import SequenceBuffer
from mlp.MLP_models import Critic
from utils.ou_noise import OUNoise

In [11]:
ENV = "MountainCarContinuous-v0"
SEED = 42
TWC_INTERNAL_STEPS = 1
CRITIC_HID_LAYERS = [400, 300]
GAMMA              = 0.99
TAU                = 5e-3
ACTOR_LR           = 0.00028007729801810964
CRITIC_LR          = 0.004320799314236164
TARGET_POLICY_NOISE = 0.2
TARGET_POLICY_CLIP = 0.5
DEVICE             = torch.device("cuda" if torch.cuda.is_available() else "cpu")

td3_config = TD3Config()
td3_config.max_episode = 2
td3_config.use_bptt = True
td3_config.sequence_length = 100
td3_config.burn_in_length = 10
td3_config.warmup_steps = 1000

In [12]:
print(td3_config.to_json())

{
    "max_episode": 2,
    "max_time_steps": 999,
    "warmup_steps": 1000,
    "batch_size": 128,
    "num_update_loops": 2,
    "policy_delay": 1,
    "device": "cpu",
    "eval_interval_episodes": 10,
    "eval_episodes": 10,
    "sigma_start": 0.2,
    "sigma_end": 0.05,
    "sigma_decay_episodes": 100,
    "use_bptt": true,
    "sequence_length": 100,
    "burn_in_length": 10,
    "best_model_prefix": "td3_actor_best"
}


In [13]:
# Set seeds
np.random.seed(SEED)
torch.manual_seed(SEED)
env = gym.make(ENV)
env.reset(seed=SEED)
env.action_space.seed(SEED)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
actor = build_twc(obs_encoder=mcc_obs_encoder,
                action_decoder=twc_out_2_mcc_action,
                internal_steps=TWC_INTERNAL_STEPS,
                log_stats=True)
critic_1 = Critic(state_dim, action_dim, size=CRITIC_HID_LAYERS)
critic_2 = Critic(state_dim, action_dim, size=CRITIC_HID_LAYERS)
# --- CHANGED: Using SequenceBuffer ---
buffer = SequenceBuffer(
    capacity=100_000
)
ou_noise = OUNoise(action_dimension=env.action_space.shape[0],
                mu=0,
                theta=0.15,
                sigma=td3_config.sigma_start,
                sigma_end=td3_config.sigma_end,
                sigma_decay_epis=td3_config.sigma_decay_episodes)
actor_opt = torch.optim.Adam(actor.parameters(),  lr=ACTOR_LR)
critic_opt = torch.optim.Adam(
    itertools.chain(critic_1.parameters(), critic_2.parameters()),
    lr=CRITIC_LR # Using the tuned LR from your params
)
td3 = TD3Engine(gamma=GAMMA,
                tau=TAU,
                observation_space=env.observation_space,
                action_space=env.action_space,
                actor=actor,
                critic_1=critic_1,
                critic_2=critic_2,
                actor_optimizer=actor_opt,
                critic_optimizer=critic_opt,
                policy_delay=td3_config.policy_delay,
                target_policy_noise=TARGET_POLICY_NOISE,
                target_noise_clip=TARGET_POLICY_CLIP,
                device=DEVICE)


In [14]:
print(actor.state_dict())

OrderedDict([('in_layer.threshold', tensor([0., 0., 0., 0.])), ('in_layer.decay', tensor([0.1000, 0.1000, 0.1000, 0.1000])), ('hid_layer.threshold', tensor([0., 0., 0., 0., 0.])), ('hid_layer.decay', tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000])), ('out_layer.threshold', tensor([0., 0.])), ('out_layer.decay', tensor([0.1000, 0.1000])), ('in2hid_IN.w', tensor([[ 0.8375,  0.9092, -0.2566,  1.0063, -0.2400],
        [ 0.2211, -0.5333,  0.6433,  0.9657, -0.8036],
        [ 0.9522,  0.2050,  0.8093,  0.1484,  0.5282],
        [-0.1547,  0.8445,  0.1619, -0.5114,  0.2792]])), ('in2hid_IN.w_mask', tensor([[1., 0., 1., 1., 0.],
        [1., 1., 0., 1., 0.],
        [0., 0., 1., 0., 0.],
        [0., 1., 1., 0., 0.]])), ('in2hid_GJ.gj_w', tensor([ 0.4679, -0.2049])), ('in2hid_GJ.gj_idx', tensor([[1, 2],
        [2, 1]])), ('hid_IN.w', tensor([[-0.5047, -0.1285, -0.4449,  0.7267, -0.8647],
        [-0.5050, -0.3093, -0.6587,  0.1034, -1.0819],
        [ 0.9893, -0.9305,  0.8457,  0.1823, -0.3

In [15]:
obs, _ = env.reset()
actor.reset() # Reset stateful actor
for _ in range(10_000):
    action = env.action_space.sample()
    next_obs, reward, terminated, truncated, _ = env.step(action)
    buffer.store(obs, action, reward, next_obs, terminated, truncated)
    obs = next_obs
    if terminated or truncated:
        obs, _ = env.reset()
        actor.reset()

In [None]:
# --- Parameter & gradient inspection helpers ---
def snapshot_params(module):
    """Return a dict name->tensor clone of parameters (cpu, detached)."""
    return OrderedDict((name, p.detach().cpu().clone()) for name, p in module.named_parameters())

def param_norm(t):
    return t.norm().item()

def print_param_changes(before, after, grads, top_k=20):
    rows = []
    for name in before.keys():
        b = before[name]
        a = after[name]
        delta = (a - b)
        delta_norm = delta.norm().item()
        param_norm_val = b.norm().item()
        # avoid division by zero
        ratio = delta_norm / (param_norm_val + 1e-12)
        grad_norm = grads.get(name).norm().item() if (name in grads and grads[name] is not None) else None
        rows.append((name, param_norm_val, grad_norm, delta_norm, ratio))
    # sort by absolute update magnitude
    rows.sort(key=lambda r: r[3], reverse=True)
    print(f"{'param':60s} {'param_norm':>12s} {'grad_norm':>12s} {'delta_norm':>12s} {'update_ratio':>12s}")
    for name, pn, gn, dn, r in rows[:top_k]:
        gn_s = f"{gn:.3e}" if gn is not None else "None"
        print(f"{name:60s} {pn:12.3e} {gn_s:12s} {dn:12.3e} {r:12.3e}")
    # summary stats
    total_param = sum(r[1] for r in rows)
    total_delta = sum(r[3] for r in rows)
    print(f"\nTotal param norm: {total_param:.3e}, total update norm: {total_delta:.3e}, avg update ratio: {(total_delta/(total_param+1e-12)):.3e}")

# --- Register hooks and snapshot before update ---
actor_device = next(actor.parameters()).device if any(True for _ in actor.parameters()) else torch.device("cpu")
# Clear any existing grads
for p in actor.parameters():
    if p.grad is not None:
        p.grad.detach_()
        p.grad.zero_()

before = snapshot_params(actor)

# Install hooks to capture gradients arriving at parameters during backward
grad_captures = {}
hook_handles = []
for name, p in actor.named_parameters():
    # define a closure capturing the name
    def make_hook(n):
        def hook(g):
            # store a detached cpu clone of the grad tensor
            grad_captures[n] = g.detach().cpu().clone() if g is not None else None
        return hook
    h = p.register_hook(make_hook(name))
    hook_handles.append(h)

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = f"twc_td3_bptt_test_{timestamp}" # Added _bptt
log_dir = os.path.join("out", "runs", run_name)
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
# --- Run the single BPTT update step (as you already do) ---
td3_train(env=env,
          engine=td3,
          replay_buf=buffer,
          ou_noise=ou_noise,
          writer=writer,
          config=td3_config,
          timestamp=timestamp)

# --- After the update: snapshot & report ---
after = snapshot_params(actor)

# Remove hooks
for h in hook_handles:
    try:
        h.remove()
    except Exception:
        pass

# Print top parameter changes and gradient norms
print_param_changes(before, after, grad_captures, top_k=50)

  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
print(actor)

TWC(
  (in_layer): FIURIModule()
  (hid_layer): FIURIModule()
  (out_layer): FIURIModule()
  (in2hid_IN): FiuriDenseConn()
  (in2hid_GJ): FiuriSparseGJConn()
  (hid_IN): FiuriDenseConn()
  (hid_EX): FiuriDenseConn()
  (hid2out): FiuriDenseConn()
)


In [None]:
# Let's trace the activation through the network
test_obs = torch.randn(1, *env.observation_space.shape).to(DEVICE)
actor.reset()

# Enable activation tracking
activations = {}
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output if not isinstance(output, tuple) else output[0]
    return hook

# Register forward hooks
hooks = []
for name, module in actor.named_modules():
    if 'in2hid_IN' in name:  # Track the specific module we care about
        h = module.register_forward_hook(get_activation(name))
        hooks.append(h)

# Run a forward pass
with torch.set_grad_enabled(True):
    out = actor(test_obs)
    if isinstance(out, tuple):
        out = out[0]

# Clean up hooks
for h in hooks:
    h.remove()

# Check activation statistics
print("\n5. Activation Analysis:")
for name, act in activations.items():
    if isinstance(act, torch.Tensor):
        print(f"\n{name} activation stats:")
        print(f"shape: {act.shape}")
        print(f"mean: {act.mean().item():.2e}")
        print(f"std: {act.std().item():.2e}")
        print(f"min: {act.min().item():.2e}")
        print(f"max: {act.max().item():.2e}")
        print(f"fraction of zeros: {(act == 0).float().mean().item():.2%}")

# Let's also look at the module itself
print("\n6. Module Structure:")
for name, module in actor.named_modules():
    if 'in2hid_IN' in name:
        print(f"\n{name}:")
        print(f"Module type: {type(module).__name__}")
        for pname, param in module.named_parameters():
            print(f"  Parameter '{pname}':")
            print(f"    shape: {param.shape}")
            print(f"    stats: mean={param.mean().item():.2e}, std={param.std().item():.2e}")
            print(f"    requires_grad: {param.requires_grad}")

# Examine the computational path
print("\n7. Network Path Analysis:")
def has_path_to_output(tensor, target, visited=None):
    if visited is None:
        visited = set()
    if tensor is target:
        return True
    if not hasattr(tensor, 'grad_fn') or tensor in visited:
        return False
    visited.add(tensor)
    if tensor.grad_fn is not None:
        for next_tensor in tensor.grad_fn.next_functions:
            if next_tensor[0] is not None:
                if has_path_to_output(next_tensor[0], target, visited):
                    return True
    return False

# Find the in2hid_IN parameter in the computation
found = False
for name, p in actor.named_parameters():
    if 'in2hid_IN' in name:
        found = True
        print(f"\nChecking computational path for {name}:")
        print(f"Has path to output: {has_path_to_output(out, p)}")
        print(f"Grad function type: {type(out.grad_fn).__name__ if out.grad_fn else None}")
        
if not found:
    print("Could not find in2hid_IN parameter!")


5. Activation Analysis:

in2hid_IN activation stats:
shape: torch.Size([1, 5])
mean: -3.96e+00
std: 5.62e+00
min: -1.20e+01
max: -0.00e+00
fraction of zeros: 60.00%

6. Module Structure:

in2hid_IN:
Module type: FiuriDenseConn
  Parameter 'w':
    shape: torch.Size([4, 5])
    stats: mean=3.01e-01, std=5.71e-01
    requires_grad: True

7. Network Path Analysis:

Checking computational path for in2hid_IN.w:
Has path to output: False
Grad function type: UnsqueezeBackward0
