In [31]:
import torch
import torch.nn.functional as F
from torch.func import jacrev

class lnn(torch.nn.Module):
    def __init__(self, env_name, n, obs_size, action_size, dt, a_zeros):
        super(lnn, self).__init__()
        self.env_name = env_name
        self.dt = dt
        self.n = n

        input_size = obs_size - self.n
        out_L = int(self.n*(self.n+1)/2)
        self.fc1_L = torch.nn.Linear(input_size, 64)
        self.fc2_L = torch.nn.Linear(64, 64)
        self.fc3_L = torch.nn.Linear(64, out_L)
        if not self.env_name == "reacher":
            self.fc1_V = torch.nn.Linear(input_size, 64)
            self.fc2_V = torch.nn.Linear(64, 64)
            self.fc3_V = torch.nn.Linear(64, 1)
        print(a_zeros.shape, n, action_size, obs_size)
        self.a_zeros = a_zeros

    def trig_transform_q(self, q):
        if self.env_name == "pendulum":
            return torch.column_stack((torch.cos(q[:,0]),torch.sin(q[:,0])))
        
    
        elif self.env_name == "reacher" or self.env_name == "acrobot":
            return torch.column_stack((torch.cos(q[:,0]),torch.sin(q[:,0]),\
                                       torch.cos(q[:,1]),torch.sin(q[:,1])))
        
        elif self.env_name == "cartpole":
            return torch.column_stack((q[:,0],\
                                       torch.cos(q[:,1]),torch.sin(q[:,1])))
        
        elif self.env_name == "cart2pole":
            return torch.column_stack((q[:,0],\
                                       torch.cos(q[:,1]),torch.sin(q[:,1]),\
                                       torch.cos(q[:,2]),torch.sin(q[:,2])))

        elif self.env_name == "cart3pole":
            return torch.column_stack((q[:,0],\
                                       torch.cos(q[:,1]),torch.sin(q[:,1]),\
                                       torch.cos(q[:,2]),torch.sin(q[:,2]),\
                                       torch.cos(q[:,3]),torch.sin(q[:,3])))
        
        elif self.env_name == "acro3bot":
            return torch.column_stack((torch.cos(q[:,0]),torch.sin(q[:,0]),\
                                       torch.cos(q[:,1]),torch.sin(q[:,1]),\
                                       torch.cos(q[:,2]),torch.sin(q[:,2])))

    def inverse_trig_transform_model(self, x):
        if self.env_name == "pendulum":
            return torch.cat((torch.atan2(x[:,1],x[:,0]).unsqueeze(1),x[:,2:]),1)
        
        elif self.env_name == "reacher" or self.env_name == "acrobot":
            return torch.cat((torch.atan2(x[:,1],x[:,0]).unsqueeze(1),torch.atan2(x[:,3],x[:,2]).unsqueeze(1),x[:,4:]),1)
        
        elif self.env_name == "cartpole":
            return torch.cat((x[:,0].unsqueeze(1),torch.atan2(x[:,2],x[:,1]).unsqueeze(1),x[:,3:]),1)
        
        elif self.env_name == "cart2pole":
            return torch.cat((x[:,0].unsqueeze(1),torch.atan2(x[:,2],x[:,1]).unsqueeze(1),torch.atan2(x[:,4],x[:,3]).unsqueeze(1),x[:,5:]),1)

        elif self.env_name == "cart3pole":
            return torch.cat((x[:,0].unsqueeze(1),torch.atan2(x[:,2],x[:,1]).unsqueeze(1),torch.atan2(x[:,4],x[:,3]).unsqueeze(1),
                              torch.atan2(x[:,6],x[:,5]).unsqueeze(1),x[:,7:]),1)

        elif self.env_name == "acro3bot":
            return torch.cat((torch.atan2(x[:,1],x[:,0]).unsqueeze(1),torch.atan2(x[:,3],x[:,2]).unsqueeze(1),torch.atan2(x[:,5],x[:,4]).unsqueeze(1),
                              x[:,6:]),1)

    def compute_L(self, q):
        y1_L = F.softplus(self.fc1_L(q))
        y2_L = F.softplus(self.fc2_L(y1_L))
        y_L = self.fc3_L(y2_L)
        device = y_L.device
        if self.n == 1:
            L = y_L.unsqueeze(1)

        elif self.n == 2:
            L11 = y_L[:,0].unsqueeze(1)
            L1_zeros = torch.zeros(L11.size(0),1, dtype=torch.float32, device=device)

            L21 = y_L[:,1].unsqueeze(1)
            L22 = y_L[:,2].unsqueeze(1)

            L1 = torch.cat((L11,L1_zeros),1) 
            L2 = torch.cat((L21,L22),1)
            L = torch.cat((L1.unsqueeze(1),L2.unsqueeze(1)),1)
        
        elif self.n == 3:
            L11 = y_L[:,0].unsqueeze(1)
            L1_zeros = torch.zeros(L11.size(0),2, dtype=torch.float32, device=device)

            L21 = y_L[:,1].unsqueeze(1)
            L22 = y_L[:,2].unsqueeze(1)
            L2_zeros = torch.zeros(L21.size(0),1, dtype=torch.float32, device=device)

            L31 = y_L[:,3].unsqueeze(1)
            L32 = y_L[:,4].unsqueeze(1)
            L33 = y_L[:,5].unsqueeze(1)

            L1 = torch.cat((L11,L1_zeros),1) 
            L2 = torch.cat((L21,L22,L2_zeros),1)
            L3 = torch.cat((L31,L32,L33),1)
            L = torch.cat((L1.unsqueeze(1),L2.unsqueeze(1),L3.unsqueeze(1)),1)
        
        elif self.n == 4:
            L11 = y_L[:,0].unsqueeze(1)
            L1_zeros = torch.zeros(L11.size(0),3, dtype=torch.float32, device=device)

            L21 = y_L[:,1].unsqueeze(1)
            L22 = y_L[:,2].unsqueeze(1)
            L2_zeros = torch.zeros(L21.size(0),2, dtype=torch.float32, device=device)

            L31 = y_L[:,3].unsqueeze(1)
            L32 = y_L[:,4].unsqueeze(1)
            L33 = y_L[:,5].unsqueeze(1)
            L3_zeros = torch.zeros(L31.size(0),1, dtype=torch.float32, device=device)

            L41 = y_L[:,6].unsqueeze(1)
            L42 = y_L[:,7].unsqueeze(1)
            L43 = y_L[:,8].unsqueeze(1)
            L44 = y_L[:,9].unsqueeze(1)

            L1 = torch.cat((L11,L1_zeros),1) 
            L2 = torch.cat((L21,L22,L2_zeros),1)
            L3 = torch.cat((L31,L32,L33,L3_zeros),1)
            L4 = torch.cat((L41,L42,L43,L44),1)
            L = torch.cat((L1.unsqueeze(1),L2.unsqueeze(1),L3.unsqueeze(1),L4.unsqueeze(1)),1)

        return L

    def get_A(self, a):
        if self.env_name == "pendulum" or self.env_name == "reacher":
            A = a
        
        elif self.env_name == "acrobot":
            A = torch.cat((self.a_zeros,a),1)
        
        elif self.env_name == "cartpole" or self.env_name == "cart2pole":
            A = torch.cat((a,self.a_zeros),1)
        
        elif self.env_name == "cart3pole" or self.env_name == "acro3bot":
            A = torch.cat((a[:,:1],self.a_zeros,a[:,1:]),1)

        return A

    def get_L(self, q):
        trig_q = self.trig_transform_q(q)
        L = self.compute_L(trig_q)         
        return L.sum(0), L

    def get_V(self, q):
        trig_q = self.trig_transform_q(q)
        y1_V = F.softplus(self.fc1_V(trig_q))
        y2_V = F.softplus(self.fc2_V(y1_V))
        V = self.fc3_V(y2_V).squeeze()
        return V.sum()

    def get_acc(self, q, qdot, a):
        dL_dq, L = jacrev(self.get_L, has_aux=True)(q)
        term_1 = torch.einsum('blk,bijk->bijl', L, dL_dq.permute(2,3,0,1))
        dM_dq = term_1 + term_1.transpose(2,3)
        c = torch.einsum('bjik,bk,bj->bi', dM_dq, qdot, qdot) - 0.5 * torch.einsum('bikj,bk,bj->bi', dM_dq, qdot, qdot)        
        Minv = torch.cholesky_inverse(L)
        dV_dq = 0 if self.env_name == "reacher" else jacrev(self.get_V)(q)
        qddot = torch.matmul(Minv,(self.get_A(a)-c-dV_dq).unsqueeze(2)).squeeze(2)
        return qddot                                                                                                                                                                                                                                                                                                                 
                                                                                                                                                                                                           
    def derivs(self, s, a):
        q, qdot = s[:,:self.n], s[:,self.n:]
        qddot = self.get_acc(q, qdot, a)
        return torch.cat((qdot,qddot),dim=1)                                                                                                                                                               

    def rk2(self, s, a):                                                                                                                                                                                   
        alpha = 2.0/3.0 # Ralston's method                                                                                                                                                                 
        k1 = self.derivs(s, a)                                                                                                                                                                             
        k2 = self.derivs(s + alpha * self.dt * k1, a)                                                                                                                                                      
        s_1 = s + self.dt * ((1.0 - 1.0/(2.0*alpha))*k1 + (1.0/(2.0*alpha))*k2)                                                                                                                            
        return s_1

    def forward(self, o, a):
        s_1 = self.rk2(self.inverse_trig_transform_model(o), a)
        o_1 = torch.cat((self.trig_transform_q(s_1[:,:self.n]),s_1[:,self.n:]),1)
        return o_1


In [None]:
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
import torch
import numpy as np
# from models.mbrl import lnn  # Import the LNN model from the repository

from collections import deque
import random


# class LNNDynamicsWrapper:
#     def __init__(self, env, lnn_model, device, scale_factor=1.0):
#         self.env = env
#         self.lnn_model = lnn_model
#         self.device = device
#         self.scale_factor = scale_factor

#         # Handle the a_zeros parameter
#         if env.action_space.shape[0] < env.observation_space.shape[0]:
#             self.a_zeros = torch.zeros(
#                 64,
#                 env.observation_space.shape[0] - env.action_space.shape[0],
#                 dtype=torch.float64,
#                 device=device,
#             )
#         else:
#             self.a_zeros = None

#     def predict(self, state, action):
#         """
#         Predict the next state and reward using the LNN model.
#         """
#         state_tensor = torch.tensor(state, dtype=torch.float64, device=self.device).unsqueeze(0)
#         action_tensor = torch.tensor(action, dtype=torch.float64, device=self.device).unsqueeze(0) * self.scale_factor

#         # Handle a_zeros padding for LNN
#         if self.a_zeros is not None:
#             action_tensor = torch.cat([action_tensor, self.a_zeros], dim=1)

#         # Predict next state using LNN
#         with torch.no_grad():
#             next_state = self.lnn_model(state_tensor, action_tensor)
        
#         next_state = next_state.squeeze(0).cpu().numpy()
#         reward = self.env.get_reward(next_state)  # Custom reward function based on your environment
#         return next_state, reward

def get_obs(state):
    return np.array([state[0],
                    np.cos(state[1]),np.sin(state[1]),
                    state[2],
                    state[3]
                    ])

class ReplayBuffer:
    def __init__(self, max_size):
        self.buffer = deque(maxlen=max_size)

    def add(self, state, action, reward, next_state):
        self.buffer.append((state, action, reward, next_state))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states = zip(*batch)
        return (
            np.array(states),
            np.array(actions),
            np.array(rewards),
            np.array(next_states),
        )

    def size(self):
        return len(self.buffer)

def train_lnn(lnn_model, replay_buffer, optimizer, device, batch_size=64):
    if replay_buffer.size() < batch_size:
        return  # Skip training if not enough data

    # Sample a batch of transitions
    states, actions, _, next_states = replay_buffer.sample(batch_size)

    # Convert to tensors
    states = torch.tensor(states, dtype=torch.float32, device=device)
    actions = torch.tensor(actions, dtype=torch.float32, device=device).reshape((-1,1))
    next_states = torch.tensor(next_states, dtype=torch.float32, device=device)

    # Forward pass through the LNN
    print(states.shape, actions.shape)
    predicted_next_states = lnn_model(states, actions)

    # Compute loss
    loss = torch.nn.functional.mse_loss(predicted_next_states, next_states)

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(lnn_model.parameters(), max_norm=1.0)  # Gradient clipping
    optimizer.step()

    return loss.item()


# 4. Custom Training Loop
def train_with_lnn_and_ppo(model, lnn_model, replay_buffer, env, device, num_episodes=1000, batch_size=64):
    lnn_optimizer = torch.optim.Adam(lnn_model.parameters(), lr=0.001)

    for episode in range(num_episodes):
        # Collect data from the environment
        obs, _ = env.reset()
        print(obs.shape)
        done = False
        while not done:
            action, _ = model.predict(obs, deterministic=False)
            next_obs, reward, terminated, truncated, _ = env.step(action)

            # Store in replay buffer
            replay_buffer.add(
                
                get_obs(obs), action, reward, get_obs(next_obs))
            obs = next_obs
            done = terminated or truncated

        # Train the LNN with data from the replay buffer
        if replay_buffer.size() >= batch_size:
            lnn_loss = train_lnn(lnn_model, replay_buffer, lnn_optimizer, device, batch_size)
            print(f"Episode {episode + 1}, LNN Loss: {lnn_loss:.4f}")

        # Train the PPO agent
        model.learn(total_timesteps=env.spec.max_episode_steps)

        print(f"Episode {episode + 1} completed.")




In [None]:
import gymnasium as gym
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy


# 1. Initialize the environment
env = gym.make("CartPole-v1")
obs_size = 5
action_size = 1  # Continuous action for CartPole (custom implementation)
n =2
batch_size= 64
# 2. Initialize the LNN model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lnn_model = lnn(
    env_name="cartpole",
    n=n,  # Number of generalized coordinates (cart position, velocity, angle, angular velocity)
    obs_size=obs_size,
    action_size=action_size,
    dt=0.02,  # Time step
    a_zeros = torch.zeros(
        batch_size, max(0, n - action_size), dtype=torch.float32, device=device
    ) if action_size < n else None
).to(device)

# 3. Initialize PPO agent
model = PPO(
    policy="MlpPolicy",
    env=env,
    learning_rate=0.0003,
    n_steps=2048,
    batch_size=batch_size,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    ent_coef=0.0,
    verbose=1,
)

# 4. Initialize replay buffer for LNN training
replay_buffer = ReplayBuffer(max_size=10000)

# 5. Train PPO and LNN
print("Starting training...")
train_with_lnn_and_ppo(
    model=model,
    lnn_model=lnn_model,
    replay_buffer=replay_buffer,
    env=env,
    device=device,
    num_episodes=1000,
    batch_size=64,
)

# 6. Evaluate the trained PPO policy
print("Evaluating the trained PPO policy...")
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")

# 7. Save the trained PPO model and LNN
print("Saving the models...")
model.save("ppo_cartpole_with_lnn")
torch.save(lnn_model.state_dict(), "lnn_cartpole.pth")


torch.Size([64, 1]) 2 1 5
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Starting training...
(4,)




---------------------------------
| rollout/           |          |
|    ep_len_mean     | 22.5     |
|    ep_rew_mean     | 22.5     |
| time/              |          |
|    fps             | 412      |
|    iterations      | 1        |
|    time_elapsed    | 4        |
|    total_timesteps | 2048     |
---------------------------------
Episode 1 completed.
(4,)
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 27.5     |
|    ep_rew_mean     | 27.5     |
| time/              |          |
|    fps             | 432      |
|    iterations      | 1        |
|    time_elapsed    | 4        |
|    total_timesteps | 2048     |
---------------------------------
Episode 2 completed.
(4,)
torch.Size([64, 5]) torch.Size([64, 1])
Episode 3, LNN Loss: 0.4577
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 64.9     |
|    ep_rew_mean     | 64.9     |
| time/              |          |
|    fps             | 423    

KeyboardInterrupt: 

In [34]:

env = gym.make("CartPole-v1", render_mode = 'human')
# 5. Load the trained model (optional)
model = PPO.load("ppo_cartpole_with_lnn", env=env)

# 8. Visualize the PPO agent
print("Running the trained PPO policy...")
obs, _ = env.reset()
for _ in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, _, terminated, truncated, _ = env.step(action)
    env.render()
    if terminated or truncated:
        obs, _ = env.reset()

env.close()



FileNotFoundError: [Errno 2] No such file or directory: 'ppo_cartpole_with_lnn.zip'