In [None]:
import random
import torch
from torch import nn
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
import keyboard
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import gymnasium as gym
from torch.distributions.normal import Normal
from tqdm import tqdm
from tqdm import trange
from maglev_env import MagneticEnv, DT

In [None]:
class Policy_Network(nn.Module):
    """Parametrized Policy Network."""

    def __init__(self, obs_space_dims: int, action_space_dims: int):
        """Initializes a neural network that estimates the mean and standard deviation
         of a normal distribution from which an action is sampled from.

        Args:
            obs_space_dims: Dimension of the observation space
            action_space_dims: Dimension of the action space
        """
        super().__init__()

        # NOTE think more about these values
        hidden_space1 = 32
        hidden_space2 = 32
        hidden_space3 = 32

        # Shared Network
        self.shared_net = nn.Sequential(
            nn.Linear(obs_space_dims, hidden_space1),
            nn.ReLU(),
            nn.Linear(hidden_space1, hidden_space2),
            nn.ReLU(),
            nn.Linear(hidden_space2, hidden_space3),
            nn.ReLU(),
        )

        # Policy Mean specific Linear Layer
        self.policy_mean_net = nn.Sequential(
            nn.Linear(hidden_space3, action_space_dims)
        )

        # Policy Std Dev specific Linear Layer
        # NOTE do we want relu on this?
        self.policy_stddev_net = nn.Sequential(
            nn.Linear(hidden_space3, action_space_dims),
            nn.ReLU()
        )


    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Conditioned on the observation, returns the mean and standard deviation
         for each normal distribution from which an action is sampled from.

        Args:
            x: Observation from the environment

        Returns:
            action_means: predicted means of the action space's normal distribution
            action_stddevs: predicted standard deviation of the action space's normal distribution
        """
        shared_features = self.shared_net(x.float())

        action_means = self.policy_mean_net(shared_features)
        action_stddevs = torch.log(
            1 + torch.exp(self.policy_stddev_net(shared_features))
        )

        return action_means, action_stddevs

class Policy:
    """REINFORCE algorithm."""

    def __init__(self, obs_space_dims: int, action_space_dims: int, electromagnets, device="cpu"):
        """Initializes an agent that learns a policy via REINFORCE algorithm.
        Args:
            obs_space_dims: Dimension of the observation space
            action_space_dims: Dimension of the action space
        """
        self.action_space_dims = action_space_dims

        # Hyperparameters
        self.learning_rate = 1e-3  # Learning rate for policy optimization
        self.gamma = 0.99  # Discount factor
        self.eps = 1e-6  # small number for mathematical stability

        self.probs = []  # Stores probability values of the sampled action
        self.rewards = []  # Stores the corresponding rewards
        self.device = device
        self.electromagnets = electromagnets
        self.net = Policy_Network(obs_space_dims, action_space_dims).to(device)
        self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=self.learning_rate)

    def augment_obs(self, obs):
        # obs is 9x1 of XYZ (ball.position,ball.velocity,desired_position)
        ext_data = torch.zeros((1,1+3*len(self.electromagnets))).to(self.device)
        aug_obs = torch.cat((obs, ext_data), axis = 1)
        # Error
        aug_obs[0][9] = torch.linalg.norm(obs[0][6:9] - obs[0][:3])
        for i in range(len(self.electromagnets)):
            # Distance to electromagnet
            aug_obs[(10+3*i):(10+3*i+3)] = torch.linalg.norm(obs[0][6:9] - obs[0][:3])
        return aug_obs

    def sample_action(self, state: np.ndarray) -> float:
        """Returns action(s), conditioned on the policy and observation.

        Args:
            state: Observation from the environment nx1
        Returns:
            action: Action(s) to be performed
        """
        state = torch.tensor(np.array([state])).float().to(self.device)
        # state = torch.tensor(np.array([state]))
        state = self.augment_obs(state)
        action_means, action_stddevs = self.net(state)
        
        action_means = action_means.flatten()
        action_stddevs = action_stddevs.flatten()
        # create a normal distribution from the predicted
        #   mean and standard deviation and sample all actions action
        actions = np.zeros(self.action_space_dims)
        for action_dim in range(self.action_space_dims):
            distrib = Normal(action_means[action_dim] + self.eps, action_stddevs[action_dim] + self.eps)
            action = distrib.sample()
            prob = distrib.log_prob(action)
            actions[action_dim] = action

            self.probs.append(prob)

        return actions

    def update(self):
        """Updates the policy network's weights."""
        running_g = 0
        gs = []

        # Discounted return (backwards) - [::-1] will return an array in reverse
        for R in self.rewards[::-1]:
            running_g = R + self.gamma * running_g
            gs.insert(0, running_g)

        deltas = torch.tensor(gs).to(self.device)

        loss = 0
        # minimize -1 * prob * reward obtained
        for log_prob, delta in zip(self.probs, deltas):
            loss += log_prob.mean() * delta * (-1)

        # Update the policy network
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Empty / zero out all episode-centric/related variables
        self.probs = []
        self.rewards = []

In [None]:
# device = ("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu" # current implementation has cpu being faster
print(f"Using {device} device") 

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)

mag_coords = [np.array([0,0,4]),]
spawn_range = ((-0.1,0.1),(-0.1,0.1),(0,1))
desired_range = ((0,0),(0,0),(0.5,1.5))
# Create and wrap the environment
env = MagneticEnv(mag_coords, DT)
wrapped_env = gym.wrappers.RecordEpisodeStatistics(env, 50)  # Records episode-reward

# Reinitialize agent every seed
electro_positions = [e.position for e in env.electromagnets]

# 2 extra dimensions for engineered features
obs_space_dims = env.observation_space.shape[0] + 1 + 3*len(electro_positions)
action_space_dims = env.action_space.shape[0]

rewards_over_seeds = []

In [None]:
env.fig = plt.figure()
env.ax = env.fig.add_subplot(111, projection="3d")
env.fig.tight_layout()

In [None]:
#Tqdm progress bar object contains a list of the batch indices to train over
DO_RENDER = True
total_num_episodes = int(5e3)  # Total number of episodes
progress_bar = tqdm(range(total_num_episodes), desc='Training...', leave=False, disable=False)

for seed in [43]:  # Fibonacci seeds
    # set seed
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    # Reinitialize agent every seed
    agent = Policy(obs_space_dims, action_space_dims,electro_positions, device)
    reward_over_episodes = []

    for episode in progress_bar:
        obs, info = wrapped_env.reset(seed=RANDOM_SEED, options=(spawn_range,desired_range))
        done = False
        
        while not done:
            # prev_obs = obs.copy()
            action = agent.sample_action(obs)
            obs, reward, terminated, truncated, info = wrapped_env.step(action)
            agent.rewards.append(reward)
            # if reward > 200:
            #     pass #for breakpoints
            if DO_RENDER: wrapped_env.env.render()

            done = terminated or truncated
            # if terminated: print("position reached successfully")

        reward_val = wrapped_env.return_queue[-1]
        progress_bar.set_postfix({"Episode Reward:": f"Epoch: {episode}, Reward {reward_val}"})
        if episode % 200 == 0:
            avg_reward = int(np.mean(wrapped_env.return_queue))
            print("Episode:", episode, "Average Reward:", avg_reward)

        agent.update()
        reward_over_episodes.append(reward_val)
    rewards_over_seeds.append(reward_over_episodes)

In [None]:
# Extract rewards and clip values
# rewards_to_plot = [[reward[0] for reward in rewards] for rewards in rewards_over_seeds]
rewards_to_plot = np.array(rewards_over_seeds)
rewards_to_plot = rewards_to_plot.flatten()
# rewards_to_plot = rewards_to_plot[8000:15000]
xs = np.arange(len(rewards_to_plot))
# rewards_to_plot = np.stack((xs,rewards_to_plot))
rewards_to_plot = np.expand_dims(rewards_to_plot, axis=0)

# Create a DataFrame for plotting
df1 = pd.DataFrame(rewards_to_plot).melt()
df1.rename(columns={"variable": "Episodes", "value": "Reward"}, inplace=True)

# Plot the line plot with a trendline
ax = sns.lineplot(x="Episodes", y="Reward", data=df1)
sns.regplot(x="Episodes", y="Reward", data=df1, ax=ax, scatter=False, color='red')
sns.set(style="darkgrid", context="talk", palette="rainbow")
ax.set_ylim(-5000,10000)
# Add a trendline using regplot

ax.set(title="REINFORCE for Levitation Learn (Attempt#1)")
plt.show()

In [None]:
matplotlib.rc_file_defaults()
sns.reset_orig()
plt.clf()