# Ablation experiment

**Goal of the experiment:** silence or randomize one set of coordinates (Cartesian/polar) to see the effects

Potential metrics:
- performance histogram
- % correct
- shift in behavior
- Steps number

In [28]:
from pathlib import Path
import numpy as np
import torch
from utils import make_deterministic, random_choice
from agent import EpsilonGreedy, neural_network
import utils
from environment import CONTEXTS_LABELS, Actions, Cues, DuplicatedCoordsEnv

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
save_path = Path("..") / "save"
save_path.exists()

True

In [4]:
data_dir = save_path / "2025-03-08_01-44-12_EastWest_save-all-agents"
# data_dir = save_path / "2025-03-08_01-47-50_LeftRight_save-all-agents"
data_dir.exists()

True

In [5]:
data_path = data_dir / "data.tar"
data_path.exists()

True

In [6]:
model_path = data_dir / "trained-agent-state-0.pt"
model_path.exists()

True

In [34]:
data_dict = torch.load(data_path, weights_only=False, map_location=DEVICE)

# Access individual arrays by their names
p = data_dict["p"]
env = data_dict["env"]
net = data_dict["net"]


## Inference loop

In [38]:
net.load_state_dict(torch.load(model_path, weights_only=True, map_location=DEVICE))
net

DQN(
  (mlp): Sequential(
    (0): Linear(in_features=19, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=512, bias=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Linear(in_features=512, out_features=512, bias=True)
    (6): ReLU()
    (7): Linear(in_features=512, out_features=3, bias=True)
  )
)

In [55]:
def ablate_state(state, keep, silence=True):
    new_state = state
    if keep == "cartesian":
        idx = np.arange(9, 19)
    elif keep == "polar":
        idx = np.arange(1, 9)
    else:
        raise ValueError("The state to keep can only be either 'polar' or 'cartesian'")

    if silence:
        new_state[idx] = 0
    else:
        new_state[idx] = torch.rand(len(idx), device=DEVICE)
    return new_state

In [85]:
state = env.reset()  # Reset the environment
state = state.clone().float().detach().to(DEVICE)
state

tensor([ 0.0000,  2.0000,  2.0000, -0.0000, -1.0000,  2.0000,  2.0000,  0.0000,
         1.0000,  2.8284,  0.7071,  0.7071, -0.7071, -0.7071,  2.8284,  0.7071,
         0.7071,  0.7071,  0.7071])

In [86]:
keep = "cartesian"
# keep = "polar"

In [87]:
state_ablated = ablate_state(state=state, keep=keep, silence=True)
state_ablated

tensor([ 0.,  2.,  2., -0., -1.,  2.,  2.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.])

In [None]:
Transition = namedtuple(
    "Transition", ("state", "action", "reward", "next_state", "done")
)

for episode in tqdm(
    episodes, desc=f"Run {run + 1}/{p.n_runs} - Episodes", leave=False
    ):
    state = env.reset()  # Reset the environment
    state = state.clone().float().detach().to(DEVICE)
    step_count = 0
    done = False
    total_rewards = 0
    loss = torch.ones(1, device=DEVICE) * torch.nan
    
    while not done:
        state_action_values = net(state).to(DEVICE)  # Q(s_t)
        action = explorer.choose_action(
            action_space=env.action_space,
            state=state,
            state_action_values=state_action_values,
        ).item()
    
        # Record states and actions
        all_states[run][episode].append(state.cpu())
        all_actions[run][episode].append(Actions(action).name)
    
        next_state, reward, done = env.step(action=action, current_state=state)
    
        # Store transition in replay buffer
        # [current_state (2 or 28 x1), action (1x1), next_state (2 or 28 x1),
        # reward (1x1), done (1x1 bool)]
        done = torch.tensor(done, device=DEVICE).unsqueeze(-1)
        replay_buffer.append(
            Transition(
                state,
                action,
                reward,
                next_state,
                done,
            )
        )
    
        # Start training when `replay_buffer` is full
        if len(replay_buffer) == p.replay_buffer_max_size:
            transitions = utils.random_choice(
                replay_buffer,
                length=len(replay_buffer),
                num_samples=p.batch_size,
                generator=generator,
            )
            batch = Transition(*zip(*transitions, strict=True))
            state_batch = torch.stack(batch.state)
            action_batch = torch.tensor(batch.action, device=DEVICE)
            reward_batch = torch.cat(batch.reward)
            # next_state_batch = torch.stack(batch.next_state)
            # done_batch = torch.cat(batch.done)
    
            # See DQN paper for equations: https://doi.org/10.1038/nature14236
            state_action_values_sampled = net(state_batch).to(DEVICE)  # Q(s_t)
            state_action_values = torch.gather(
                input=state_action_values_sampled,
                dim=1,
                index=action_batch.unsqueeze(-1),
            ).squeeze()  # Q(s_t, a)
    
            # Compute a mask of non-final states and concatenate
            # the batch elements
            # (a final state would've been the one after which simulation ended)
            non_final_mask = torch.tensor(
                tuple(map(lambda s: not s, batch.done)),
                device=DEVICE,
                dtype=torch.bool,
            )
            non_final_next_states = torch.stack(
                [s[1] for s in zip(batch.done, batch.next_state) if not s[0]]
            )
    
            # Compute V(s_{t+1}) for all next states.
            # Expected values of actions for non_final_next_states are computed
            # based on the "older" target_net;
            # selecting their best reward with max(1).values
            # This is merged based on the mask,
            # such that we'll have either the expected
            # state value or 0 in case the state was final.
            next_state_values = torch.zeros(p.batch_size, device=DEVICE)
            if non_final_next_states.numel() > 0 and non_final_mask.numel() > 0:
                with torch.no_grad():
                    next_state_values[non_final_mask] = (
                        target_net(non_final_next_states).max(1).values
                    )
            # Compute the expected Q values
            expected_state_action_values = reward_batch + (
                next_state_values * p.gamma
            )
    
            # Compute loss
            # criterion = nn.MSELoss()
            criterion = nn.SmoothL1Loss()
            loss = criterion(
                input=state_action_values,  # prediction
                target=expected_state_action_values,  # target/"truth" value
            )  # TD update
    
            # Optimize the model
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_value_(
                net.parameters(), 100
            )  # In-place gradient clipping
            optimizer.step()
    
            # # Reset the target network
            # if step_count % p.target_net_update == 0:
            #     target_net.load_state_dict(net.state_dict())
    
            # Soft update of the target network's weights
            # θ′ ← τ θ + (1 −τ )θ′
            target_net_state_dict = target_net.state_dict()
            net_state_dict = net.state_dict()
            for key in net_state_dict:
                target_net_state_dict[key] = net_state_dict[
                    key
                ] * p.tau + target_net_state_dict[key] * (1 - p.tau)
            target_net.load_state_dict(target_net_state_dict)
    
            losses[run].append(loss.item())
    
            weights, biases = utils.collect_weights_biases(net=net)
            weights_val_stats = utils.params_df_stats(
                weights, key="val", current_df=weights_grad_stats
            )
            biases_val_stats = utils.params_df_stats(
                biases, key="val", current_df=biases_val_stats
            )
            biases_grad_stats = utils.params_df_stats(
                biases, key="grad", current_df=biases_grad_stats
            )
            weights_grad_stats = utils.params_df_stats(
                weights, key="grad", current_df=weights_val_stats
            )
    
        total_rewards += reward
        step_count += 1
    
        # Move to the next state
        state = next_state
    
        explorer.epsilon = explorer.update_epsilon(episode)
        epsilons.append(explorer.epsilon)
    
    all_states[run][episode].append(state.cpu())
    rewards[episode, run] = total_rewards
    steps[episode, run] = step_count
    logger.info(
        f"Run: {run + 1}/{p.n_runs} - Episode: {episode + 1}/{p.total_episodes}"
        f" - Steps: {step_count} - Loss: {loss.item()}"
        f" - epsilon: {explorer.epsilon}"
    )