In [1]:
import os
from icecream import ic

#! Set the environment variables to override gpu (specically for my device)
os.environ['HSA_OVERRIDE_GFX_VERSION'] = '10.3.0'
os.environ['ROCBLAS_TENSILE_LIBRARY'] = '/home/autrio/.local/lib/python3.10/site-packages/torch/lib/rocblas/library/TensileLibrary_lazy_gfx1030.dat'



In [20]:
from stable_baselines3.common.logger import configure
from stable_baselines3.common.monitor import Monitor
import gymnasium as gym
from gymnasium import spaces
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from stable_baselines3 import PPO
from collections import deque
import random

import torch
import torch.nn.functional as F
from torch.func import jacrev


def get_obs(state):
    return np.array([state[0],
                    np.cos(state[1]), np.sin(state[1]),
                    state[2],
                    state[3]
                     ])


class LNNModel(torch.nn.Module):
    def __init__(self, env_name, n, obs_size, action_size, dt, a_zeros):
        super(LNNModel, self).__init__()
        self.env_name = env_name
        self.dt = dt
        self.n = n

        input_size = obs_size - self.n
        out_L = int(self.n*(self.n+1)/2)
        self.fc1_L = torch.nn.Linear(input_size, 64)
        self.fc2_L = torch.nn.Linear(64, 64)
        self.fc3_L = torch.nn.Linear(64, out_L)
        if not self.env_name == "reacher":
            self.fc1_V = torch.nn.Linear(input_size, 64)
            self.fc2_V = torch.nn.Linear(64, 64)
            self.fc3_V = torch.nn.Linear(64, 1)
        print(a_zeros.shape, n, action_size, obs_size)
        self.a_zeros = a_zeros

    def trig_transform_q(self, q):
        if self.env_name == "pendulum":
            return torch.column_stack((torch.cos(q[:, 0]), torch.sin(q[:, 0])))

        elif self.env_name == "reacher" or self.env_name == "acrobot":
            return torch.column_stack((torch.cos(q[:, 0]), torch.sin(q[:, 0]),
                                       torch.cos(q[:, 1]), torch.sin(q[:, 1])))

        elif self.env_name == "cartpole":
            return torch.column_stack((q[:, 0],
                                       torch.cos(q[:, 1]), torch.sin(q[:, 1])))

        elif self.env_name == "cart2pole":
            return torch.column_stack((q[:, 0],
                                       torch.cos(q[:, 1]), torch.sin(q[:, 1]),
                                       torch.cos(q[:, 2]), torch.sin(q[:, 2])))

        elif self.env_name == "cart3pole":
            return torch.column_stack((q[:, 0],
                                       torch.cos(q[:, 1]), torch.sin(q[:, 1]),
                                       torch.cos(q[:, 2]), torch.sin(q[:, 2]),
                                       torch.cos(q[:, 3]), torch.sin(q[:, 3])))

        elif self.env_name == "acro3bot":
            return torch.column_stack((torch.cos(q[:, 0]), torch.sin(q[:, 0]),
                                       torch.cos(q[:, 1]), torch.sin(q[:, 1]),
                                       torch.cos(q[:, 2]), torch.sin(q[:, 2])))

    def inverse_trig_transform_model(self, x):
        if self.env_name == "pendulum":
            return torch.cat((torch.atan2(x[:, 1], x[:, 0]).unsqueeze(1), x[:, 2:]), 1)

        elif self.env_name == "reacher" or self.env_name == "acrobot":
            return torch.cat((torch.atan2(x[:, 1], x[:, 0]).unsqueeze(1), torch.atan2(x[:, 3], x[:, 2]).unsqueeze(1), x[:, 4:]), 1)

        elif self.env_name == "cartpole":
            return torch.cat((x[:, 0].unsqueeze(1), torch.atan2(x[:, 2], x[:, 1]).unsqueeze(1), x[:, 3:]), 1)

        elif self.env_name == "cart2pole":
            return torch.cat((x[:, 0].unsqueeze(1), torch.atan2(x[:, 2], x[:, 1]).unsqueeze(1), torch.atan2(x[:, 4], x[:, 3]).unsqueeze(1), x[:, 5:]), 1)

        elif self.env_name == "cart3pole":
            return torch.cat((x[:, 0].unsqueeze(1), torch.atan2(x[:, 2], x[:, 1]).unsqueeze(1), torch.atan2(x[:, 4], x[:, 3]).unsqueeze(1),
                              torch.atan2(x[:, 6], x[:, 5]).unsqueeze(1), x[:, 7:]), 1)

        elif self.env_name == "acro3bot":
            return torch.cat((torch.atan2(x[:, 1], x[:, 0]).unsqueeze(1), torch.atan2(x[:, 3], x[:, 2]).unsqueeze(1), torch.atan2(x[:, 5], x[:, 4]).unsqueeze(1),
                              x[:, 6:]), 1)

    def compute_L(self, q):
        y1_L = F.softplus(self.fc1_L(q))
        y2_L = F.softplus(self.fc2_L(y1_L))
        y_L = self.fc3_L(y2_L)
        device = y_L.device
        if self.n == 1:
            L = y_L.unsqueeze(1)

        elif self.n == 2:
            L11 = y_L[:, 0].unsqueeze(1)
            L1_zeros = torch.zeros(
                L11.size(0), 1, dtype=torch.float32, device=device)

            L21 = y_L[:, 1].unsqueeze(1)
            L22 = y_L[:, 2].unsqueeze(1)

            L1 = torch.cat((L11, L1_zeros), 1)
            L2 = torch.cat((L21, L22), 1)
            L = torch.cat((L1.unsqueeze(1), L2.unsqueeze(1)), 1)

        elif self.n == 3:
            L11 = y_L[:, 0].unsqueeze(1)
            L1_zeros = torch.zeros(
                L11.size(0), 2, dtype=torch.float32, device=device)

            L21 = y_L[:, 1].unsqueeze(1)
            L22 = y_L[:, 2].unsqueeze(1)
            L2_zeros = torch.zeros(
                L21.size(0), 1, dtype=torch.float32, device=device)

            L31 = y_L[:, 3].unsqueeze(1)
            L32 = y_L[:, 4].unsqueeze(1)
            L33 = y_L[:, 5].unsqueeze(1)

            L1 = torch.cat((L11, L1_zeros), 1)
            L2 = torch.cat((L21, L22, L2_zeros), 1)
            L3 = torch.cat((L31, L32, L33), 1)
            L = torch.cat(
                (L1.unsqueeze(1), L2.unsqueeze(1), L3.unsqueeze(1)), 1)

        elif self.n == 4:
            L11 = y_L[:, 0].unsqueeze(1)
            L1_zeros = torch.zeros(
                L11.size(0), 3, dtype=torch.float32, device=device)

            L21 = y_L[:, 1].unsqueeze(1)
            L22 = y_L[:, 2].unsqueeze(1)
            L2_zeros = torch.zeros(
                L21.size(0), 2, dtype=torch.float32, device=device)

            L31 = y_L[:, 3].unsqueeze(1)
            L32 = y_L[:, 4].unsqueeze(1)
            L33 = y_L[:, 5].unsqueeze(1)
            L3_zeros = torch.zeros(
                L31.size(0), 1, dtype=torch.float32, device=device)

            L41 = y_L[:, 6].unsqueeze(1)
            L42 = y_L[:, 7].unsqueeze(1)
            L43 = y_L[:, 8].unsqueeze(1)
            L44 = y_L[:, 9].unsqueeze(1)

            L1 = torch.cat((L11, L1_zeros), 1)
            L2 = torch.cat((L21, L22, L2_zeros), 1)
            L3 = torch.cat((L31, L32, L33, L3_zeros), 1)
            L4 = torch.cat((L41, L42, L43, L44), 1)
            L = torch.cat((L1.unsqueeze(1), L2.unsqueeze(
                1), L3.unsqueeze(1), L4.unsqueeze(1)), 1)

        return L

    def get_A(self, a):
        if self.env_name == "pendulum" or self.env_name == "reacher":
            A = a

        elif self.env_name == "acrobot":
            A = torch.cat((self.a_zeros, a), 1)

        elif self.env_name == "cartpole" or self.env_name == "cart2pole":
            A = torch.cat((a, self.a_zeros), 1)

        elif self.env_name == "cart3pole" or self.env_name == "acro3bot":
            A = torch.cat((a[:, :1], self.a_zeros, a[:, 1:]), 1)

        return A

    def get_L(self, q):
        trig_q = self.trig_transform_q(q)
        L = self.compute_L(trig_q)
        return L.sum(0), L

    def get_V(self, q):
        trig_q = self.trig_transform_q(q)
        y1_V = F.softplus(self.fc1_V(trig_q))
        y2_V = F.softplus(self.fc2_V(y1_V))
        V = self.fc3_V(y2_V).squeeze()
        return V.sum()

    def get_acc(self, q, qdot, a):
        dL_dq, L = jacrev(self.get_L, has_aux=True)(q)
        term_1 = torch.einsum('blk,bijk->bijl', L, dL_dq.permute(2, 3, 0, 1))
        dM_dq = term_1 + term_1.transpose(2, 3)
        c = torch.einsum('bjik,bk,bj->bi', dM_dq, qdot, qdot) - \
            0.5 * torch.einsum('bikj,bk,bj->bi', dM_dq, qdot, qdot)
        Minv = torch.cholesky_inverse(L)
        dV_dq = 0 if self.env_name == "reacher" else jacrev(self.get_V)(q)
        qddot = torch.matmul(
            Minv, (self.get_A(a)-c-dV_dq).unsqueeze(2)).squeeze(2)
        return qddot

    def derivs(self, s, a):
        q, qdot = s[:, :self.n], s[:, self.n:]
        qddot = self.get_acc(q, qdot, a)
        return torch.cat((qdot, qddot), dim=1)

    def rk2(self, s, a):
        alpha = 2.0/3.0  # Ralston's method
        k1 = self.derivs(s, a)
        k2 = self.derivs(s + alpha * self.dt * k1, a)
        s_1 = s + self.dt * ((1.0 - 1.0/(2.0*alpha))*k1 + (1.0/(2.0*alpha))*k2)
        return s_1

    def forward(self, o, a):
        # ic(o)
        # o = get_obs(o)  # ! edit
        s_1 = self.rk2(self.inverse_trig_transform_model(o), a)
        o_1 = torch.cat((self.trig_transform_q(
            s_1[:, :self.n]), s_1[:, self.n:]), 1)
        return o_1


class RewardNetwork(nn.Module):
    def __init__(self, state_dim, hidden_dim=256):
        super(RewardNetwork, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # Predicts reward
        )

    def forward(self, state):
        return self.model(state)


class ReplayBuffer:
    def __init__(self, max_size=100000):
        self.buffer = deque(maxlen=max_size)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.tensor(states, dtype=torch.float32),
            torch.tensor(actions, dtype=torch.float32),
            torch.tensor(rewards, dtype=torch.float32),
            torch.tensor(next_states, dtype=torch.float32),
            torch.tensor(dones, dtype=torch.float32),
        )

    def __len__(self):
        return len(self.buffer)


class CustomGymEnvironment(gym.Env):
    """
    Custom Gym environment that uses LNN for state transitions, Reward Network for reward calculations,
    and incorporates a replay buffer.
    """

    def __init__(self, lnn_model, reward_model, state_dim, action_dim, action_space, observation_space, device, replay_buffer):
        super(CustomGymEnvironment, self).__init__()
        self.lnn_model = lnn_model
        self.reward_model = reward_model
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = device
        self.replay_buffer = replay_buffer

        # Define observation and action space
        self.observation_space = observation_space
        self.action_space = action_space

        # Initialize state
        self.state = np.zeros(state_dim)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        # print("<<<", self.state.shape)
        if len(self.replay_buffer) > 0:
            # Sample a random state from the replay buffer
            sample = self.replay_buffer.sample(1)
            # print(sample.shape)
            if sample:
                self.state = sample[0][0]
        else:
            # Otherwise, initialize randomly
            self.state = self.observation_space.sample()
            # print(self.state.shape)
        # print("<<<2", self.state.shape)
        return self.state, {}

    def step(self, action):
        # Convert state and action to tensors
        # print(action.shape)
        state_tensor = torch.tensor(
            self.state, dtype=torch.float32, device=self.device).unsqueeze(0)
        action_tensor = torch.tensor(
            action, dtype=torch.float32, device=self.device).unsqueeze(0)
        # print("Hii",state_tensor.shape, action_tensor.shape)
        # Use LNN to predict the next state
        next_state = self.lnn_model(state_tensor, action_tensor).squeeze(
            0).cpu().detach().numpy()

        # Use Reward Network to calculate reward
        reward = (
            self.reward_model(
                torch.tensor(next_state, dtype=torch.float32,
                             device=self.device).unsqueeze(0)
            )
            .item()
        )

        # Check if the episode is done
        done = np.abs(self.state[0]) <= 0.1 and np.abs(
            self.state[1]) <= 0.01   # Example condition

        # Store the transition in the replay buffer
        self.replay_buffer.add(self.state, action, reward, next_state, done)

        # Update state
        self.state = next_state

        return next_state, reward, done, False, {}


def train_with_lnn_and_ppo(
    model, env, custom_env, device, num_episodes, batch_size
):
    reward_optimizer = optim.Adam(reward_network.parameters(), lr=2e-6)
    lnn_optimizer = optim.Adam(custom_env.lnn_model.parameters(), lr=2e-6)
    reward_loss_fn = nn.MSELoss()
    lnn_loss_fn = nn.MSELoss()

    for episode in range(num_episodes):
        state, _ = env.reset()
        done = False

        while not done:
            # Collect transitions using PPO policy
            action, _ = model.predict(state, deterministic=False)
            state1, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            # print(state.shape, action.shape)
            custom_env.replay_buffer.add(state, action, reward, state1, done)

        # Train LNN and Reward Network
        if len(custom_env.replay_buffer) >= batch_size:
            states, actions, rewards, next_states, dones = custom_env.replay_buffer.sample(
                batch_size)

            # Train LNN
            lnn_preds = custom_env.lnn_model(
                states.to(device), actions.to(device))
            lnn_loss = lnn_loss_fn(lnn_preds, next_states.to(device))
            lnn_optimizer.zero_grad()
            lnn_loss.backward()
            lnn_optimizer.step()

            # Train Reward Network
            reward_preds = custom_env.reward_model(next_states.to(device))
            reward_loss = reward_loss_fn(
                reward_preds, rewards.to(device).unsqueeze(-1))
            reward_optimizer.zero_grad()
            reward_loss.backward()
            reward_optimizer.step()

            print(
                f"Episode: {episode}, LNN loss: {lnn_loss}, Rewards Loss: {reward_loss}")
        # Train PPO policy
        model.learn(total_timesteps=2048, reset_num_timesteps=False)


tmp_path = "./"
# set up logger
new_logger = configure(tmp_path, ["stdout", "csv", "tensorboard"])


# Setup
env = gym.make("Pendulum-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
print("Action Dim", action_dim)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# lnn_model = LNNModel(state_dim, action_dim).to(device)
batch_size = 64
n=1
lnn_model = LNNModel(
    env_name="pendulum",
    n=n,
    obs_size=state_dim,
    action_size=action_dim,
    dt=0.02,  # Time step
    a_zeros=torch.zeros(
        batch_size, max(0, n - action_dim), dtype=torch.float32, device=device
    ) if action_dim <= n else None
).to(device)
reward_network = RewardNetwork(state_dim).to(device)
replay_buffer = ReplayBuffer(max_size=10000)

action_space = env.action_space

custom_env = CustomGymEnvironment(
    lnn_model=lnn_model,
    reward_model=reward_network,
    state_dim=state_dim,
    action_dim=action_dim,
    action_space=action_space,
    observation_space=env.observation_space,
    device=device,
    replay_buffer=replay_buffer,
)

model = PPO("MlpPolicy", custom_env, verbose=1)
model.set_logger(new_logger)

# Train
trained_model, trained_lnn, trained_reward = train_with_lnn_and_ppo(
    model,  env, custom_env, device, num_episodes=10, batch_size=4096
)

# Save models
# torch.save(custom_env.lnn_model.state_dict(), "lnn_cartpole.pth")
# torch.save(custom_env.reward_model.state_dict(), "reward_cartpole.pth")
# trained_model.save("ppo_cartpole_with_lnn")
model.save("ppo_cartpole_with_lnn")

Logging to ./


Action Dim 1
torch.Size([64, 0]) 1 1 3
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


  state_tensor = torch.tensor(


-----------------------------
| time/              |      |
|    fps             | 115  |
|    iterations      | 1    |
|    time_elapsed    | 17   |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 120         |
|    iterations           | 1           |
|    time_elapsed         | 16          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.004061522 |
|    clip_fraction        | 0.0366      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.42       |
|    explained_variance   | -0.14       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.00409     |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.00277    |
|    std                  | 1.01        |
|    value_loss           | 0.416       |
----------------------------------

KeyboardInterrupt: 

In [None]:
model.save("ppo_cartpole_with_lnn")


In [None]:
from stable_baselines3.common.evaluation import evaluate_policy

env = gym.make("Pendulum-v1", render_mode = 'human')
# 5. Load the trained model (optional)
model = PPO.load("ppo_cartpole_with_lnn", env=env)

# 6. Evaluate the trained policy
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=3)
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")

# 7. Run the trained agent
obs, _ = env.reset()
for i in range(1000):  # Run for a fixed number of timesteps
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, _ = env.step(action)
    env.render()
    if terminated or truncated:
        obs, _ = env.reset()
        print("Done ", i)
    # if not i%20: print(i)

env.close()

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




Mean reward: -1080.87 +/- 308.68
Done  199
Done  399
Done  599
Done  799
Done  999
