In [1]:
import time
import numpy as np
import pickle
import sys
import os
import pybullet as p
from init import *
from utils.tools import *
from modules.RRTstar import *
import pybullet_data
from modules.simple_control import *
from modules.dwa import *
import copy
from modules.roboEnv import *
import os

In [2]:
# connect pybullet and initialize
physicsClient = p.connect(p.GUI)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
root_dir = os.path.join(os.path.dirname(os.path.realpath('file')),"../")
obstacle_ids = init_scene(p)
robotId = load_robot(p)
dy_obs = p.loadURDF("cube.urdf", [-1, 0.5, 0], globalScaling = 0.2)
p.setTimeStep(1.0 / 240.0)
# p.setRealTimeSimulation(1)

In [3]:
right_front_wheel_joint = 2
right_back_wheel_joint = 3
left_front_wheel_joint = 6
left_back_wheel_joint = 7
base_wheels = [2,3,6,7]
left_wheels = [left_back_wheel_joint, left_front_wheel_joint]
right_wheels = [right_front_wheel_joint, right_back_wheel_joint]
xmin, xmax = -1.5, 3.5  # x轴范围
ymin, ymax = -4.5, 0.5  # y轴范围
obstacle_aabbs = [p.getAABB(obstacle_id) for obstacle_id in obstacle_ids]
obstacle_aabbs.append(((0.2, -3.8, 0.0),(1.1, -2.7,1.)))
target_body_id = None
path_node_id = []
dy_obstacle_ids = copy.deepcopy(obstacle_ids)
# dy_obstacle_ids.append(dy_obs)
dy_obs_path = [(-1., 0.5), (-1., -1.5), (0.5, -1.5), (0.5, -2.5), (-1.5, -2.5), (-0.5, -4.),(0., -4.5), (-0.5, -4.),(-1.5, -2.5),(0.5, -2.5),(0.5, -1.5), (-1., -1.5)]

In [4]:
available_position = []
for x in range(-8, 17, 1):
    for y in range(-24, 5, 1):
        target_point = [x*0.25, y*0.25]        
        is_valid = True
        for aabb_min, aabb_max in obstacle_aabbs:
            if is_point_in_aabb(target_point, aabb_min, aabb_max, clearance=0.2):
                is_valid = False
                break
        if is_valid:
            available_position.append((x,y))

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)
        )
    def forward(self, state):
        return self.fc(state)

class ValueNetwork(nn.Module):
    def __init__(self, state_dim):
        super(ValueNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
    
    def forward(self, state):
        return self.fc(state)


In [6]:
def test_model(env, policy_net):
    states, actions, rewards, log_probs, dones = [], [], [], [], []
    state = env.reset(reset_target=False)
    done = False
    while not done:
        state_tensor = torch.tensor(state, dtype=torch.float32)
        action_probs = policy_net(state_tensor)
        action_dist = torch.distributions.Categorical(action_probs)
        action = action_dist.sample()
        
        next_state, reward, done = env.step(action.item())
        
        states.append(state)
        actions.append(action.item())
        rewards.append(reward)
        log_probs.append(action_dist.log_prob(action).item())
        dones.append(done)
        
        state = next_state
        if done:
            break

In [7]:
# init env and network
env = RobotEnv(robotId, obstacle_ids, dy_obs, target_position=(0.5, -2, 0), available_position=available_position)
state_dim = 8
action_dim = 8
policy_net = PolicyNetwork(state_dim, action_dim)
value_net = ValueNetwork(state_dim)
policy_net.load_state_dict(torch.load(os.path.join("models", "policy_net_600.pth")))
value_net.load_state_dict(torch.load(os.path.join("models", "value_net_600.pth")))

  policy_net.load_state_dict(torch.load(os.path.join("models", "policy_net_600.pth")))
  value_net.load_state_dict(torch.load(os.path.join("models", "value_net_600.pth")))


<All keys matched successfully>

In [8]:
# test
env_test = RobotEnvTest(robotId, obstacle_ids, dy_obs, target_position=(0.5, -2, 0), available_position=available_position)
test_model(env_test, policy_net)