In [2]:
import cv2
import zmq
import pickle
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from process_frame import process_frame
from stable_baselines3.common.callbacks import CheckpointCallback



In [3]:
class Location:
    def __init__(self, x, y, z=0.0):
        self.x = x
        self.y = y
        self.z = z

    def distance(self, other):
        return np.sqrt((self.x - other.x)**2 + (self.y - other.y)**2 + (self.z - other.z)**2)

def quaternion_to_yaw(q):
    w, x, y, z = q["w"], q["x"], q["y"], q["z"]
    siny_cosp = 2 * (w * z + x * y)
    cosy_cosp = 1 - 2 * (y**2 + z**2)
    yaw = np.arctan2(siny_cosp, cosy_cosp)
    return yaw 

def normalize_angle(angle):
    while angle > np.pi:
        angle -= 2 * np.pi
    while angle < -np.pi:
        angle += 2 * np.pi
    return angle

def decode_image(image_bytes):
    np_arr = np.frombuffer(image_bytes, dtype=np.uint8)
    img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
    return img

In [6]:
class CarlaEnv(gym.Env):
    def __init__(self, port=6501):
        super(CarlaEnv, self).__init__()
        self.port = port
        self.context = zmq.Context()
        self.socket = self.context.socket(zmq.REQ)
        self.socket.connect(f"tcp://localhost:{self.port}")

        self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
                                                 # [dist, angle, speed, lane_dev, obs_prox]
        self.observation_space = spaces.Box(low=np.array([-np.inf, -np.pi, 0, -1, 0]),
                                            high=np.array([np.inf, np.pi, np.inf, 1, 1]),
                                            dtype=np.float32)
        
        self.route = None
        self.wp_index = 0

        self.previous_location = None
        self.previous_timestamp = None
        self.previous_distance = None

        self.max_steps = 2000
        self.step_count = 0


    def reset(self, seed = None, options = None):
        self.socket.send(pickle.dumps({"command": "reset"}))
        
        response = pickle.loads(self.socket.recv())
        # print(response)
        if response["status"] != "reset done":
            raise RuntimeError(f"Server reset failed check port cur port {self.port}")
        
        obs = response["observation"] 
        self.route = [Location(x, y) for [x, y] in response["route"]]
        self.wp_index = 0

        image = decode_image(obs["image"])
        _, lane_deviation, obstacle_proximity = process_frame(image)
        state = self.get_state(obs, lane_deviation, obstacle_proximity)

        self.previous_location = self.get_current_location(obs)
        self.previous_timestamp = obs["timestamp"]
        self.previous_distance = self.get_distance_to_waypoint(obs)
        self.previous_steer = 0
        self.step_count=0

        return state, {}
    
    def step(self, action):
        self.socket.send(pickle.dumps({"action": action.tolist()}))
        obs = pickle.loads(self.socket.recv())

        image = decode_image(obs["image"])
        processed_image, lane_deviation, obstacle_proximity = process_frame(image)
        state = self.get_state(obs, lane_deviation, obstacle_proximity)

        reward = self.compute_reward(obs, lane_deviation, obstacle_proximity, action[0])

        if obs["collision"] is not None:
            print("Collision detected, respawning...")
            state, _ = self.reset()
            reward -= 100

        else:
            state = self.get_state(obs, lane_deviation, obstacle_proximity)

        terminated = self.is_terminated(obs)
        truncated = self.is_truncated()

        self.step_count += 1

        self.previous_location = self.get_current_location(obs)
        self.previous_timestamp = obs["timestamp"]
        self.previous_distance = self.get_distance_to_waypoint(obs)
        self.previous_steer = action[0]

        return state, reward, terminated, truncated, {}
    
    def get_current_location(self, obs):
        lat, lon = obs["gnss"]["latitude"], obs["gnss"]["longitude"]
        x = -14418.6285 * lat + 111279.5690 * lon - 3.19252014
        y = -109660.6210 * lat + 4.33686914 * lon + 0.367254638
        return Location(x, y)
    
    def get_state(self, obs, lane_deviation, obstacle_proximity):
        current_loc = self.get_current_location(obs)

        if self.wp_index < len(self.route):
            next_wp_loc = self.route[self.wp_index]
            distance = current_loc.distance(next_wp_loc) / 100.0
            vehicle_yaw = quaternion_to_yaw(obs["imu"]["orientation"])
            delta_x = next_wp_loc.x - current_loc.x
            delta_y = next_wp_loc.y - current_loc.y
            desired_yaw = np.arctan2(delta_y, delta_x)
            angle_to_way_point = normalize_angle(desired_yaw - vehicle_yaw)
        else:
            distance = 0
            angle_to_way_point = 0
        
        if self.previous_location is not None and self.previous_timestamp < obs["timestamp"]:
            time_elapsed = obs["timestamp"] - self.previous_timestamp
            distance_travelled = current_loc.distance(self.previous_location)
            speed = (distance_travelled / time_elapsed if time_elapsed > 0 else 0) /20.0
        else:
            speed = 0

        return np.array([distance, angle_to_way_point, speed, lane_deviation, obstacle_proximity], dtype=np.float32)
    
    def get_distance_to_waypoint(self, obs):
        current_loc = self.get_current_location(obs)
        if self.wp_index < len(self.route):
            next_wp_loc = self.route[self.wp_index]
            return current_loc.distance(next_wp_loc)
        return 0
    
    def compute_reward(self, obs, lane_deviation, obstacle_proximity, steer):
        current_distance = self.get_distance_to_waypoint(obs)
        progress = self.previous_distance - current_distance
        reward = progress

        reward -= 0.1 * abs(lane_deviation)

        if obstacle_proximity > 0.5:
            reward -= (obstacle_proximity - 0.5) * 20
        
        if obs["collision"] is not None:
            reward -= 100

        if obs["lane_invaded"]["violated"]:
            reward -= 10

        speed = self.get_state(obs, lane_deviation, obstacle_proximity)[2] * 20.0
        if speed < 5 or speed > 10:
            reward -= 0.1
        else:
            reward += 0.1

        steer_change = abs(steer - self.previous_steer)
        if steer_change > 0.2:
            reward -= steer_change * 5

        return reward
    
    
    def is_terminated(self, obs):
        if self.wp_index >= len(self.route):
            return True
        
        current_distance = self.get_distance_to_waypoint(obs)
        if self.wp_index == len(self.route) - 1 and current_distance < 5.0:
            print("Near final waypoint, success")
            return True
        elif current_distance < 2.0:
            self.wp_index += 1
            if self.wp_index >= len(self.route):
                return True
        
        return False
    
    def is_truncated(self):
        return self.step_count >= self.max_steps

In [None]:
import os

if __name__ == "__main__":
    env = CarlaEnv(port=6501)
    checkpoint_path = 'checkpoints/carla_rl_model_50000_steps.zip'

    if os.path.exists(checkpoint_path):
        model = PPO.load(checkpoint_path, env=env)
        print(f"Resuming training from checkpoint: {checkpoint_path}")
    else:
        model = PPO(
            "MlpPolicy",
            env,
            verbose=1,
            n_steps=2048,
            batch_size=64,
            n_epochs=10,
            gamma=0.99,
            learning_rate=3e-4,
            gae_lambda=0.95,
            ent_coef=0.01
        )
        print("Starting new training.")

    check_point_callback = CheckpointCallback(save_freq=50000, save_path='./checkpoints/', name_prefix='carla_rl_model')

    total_timesteps = 5000000

    model.learn(total_timesteps=total_timesteps, callback=check_point_callback, progress_bar=False)

    model.save("carla_rl_agent")

# import ray
# from ray.tune.registry import register_env
# from ray.rllib.algorithms.ppo import PPO as RayPPO
# import gymnasium as gym
# import torch
# import os


# def env_creator(env_config):
#     return CarlaEnv(port=env_config["port"])

# if __name__ == "__main__":
#     if torch.cuda.is_available():
#         print(f"Using GPU: {torch.cuda.get_device_name(0)}")
#     else:
#         raise RuntimeError("CUDA not available, GPU required for multi-PC training")

#     ray.init(address="auto" if "RAY_HEAD_IP" not in os.environ else f"{os.environ['RAY_HEAD_IP']}:6379")

#     register_env("carla_env", env_creator)

#     config = {
#         "env": "carla_env",
#         "num_workers": 8,  # 4 per PC
#         "num_gpus": 2,     # 1 per PC
#         "num_cpus": 32,    # 16 per PC
#         "framework": "torch",
#         "train_batch_size": 2048 * 8,  # Scaled for 8 envs
#         "sgd_minibatch_size": 64,
#         "num_sgd_iter": 10,
#         "lr": 3e-4,
#         "gamma": 0.99,
#         "lambda": 0.95,
#         "entropy_coeff": 0.01,
#         "env_config": {"port": ray.tune.grid_search([6501, 6502, 6503, 6504, 6505, 6506, 6507, 6508])},
#         "model": {
#             "custom_model": None,  # Use default MLP
#         },
#     }

#     algo = RayPPO(config=config)

#     total_timesteps = 5000000
#     steps_trained = 0
#     while steps_trained < total_timesteps:
#         result = algo.train()
#         steps_trained = result["timesteps_total"]
#         print(f"Steps: {steps_trained}, Reward: {result['episode_reward_mean']}")
#         if steps_trained % 50000 == 0:
#             checkpoint = algo.save("./ray_checkpoints")
#             print(f"Checkpoint saved: {checkpoint}")

#     algo.save("./ray_final_model")
#     ray.shutdown()

In [None]:
obs, _ = env.reset()
for _ in range(1000):
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    image = decode_image(obs["image"])
    processed_img, _, _ = process_frame(image) 
    cv2.imshow("Processed Image", processed_img)
    print(f"Action: {action}, Reward: {reward}, Terminated: {terminated}, Truncated: {truncated}")
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
    if terminated or truncated:
        obs, _ = env.reset()
cv2.destroyAllWindows()

In [1]:
import torch

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
else:
    print("CUDA not available")


PyTorch version: 2.6.0+cu124
CUDA available: True
GPU: NVIDIA RTX 5000 Ada Generation


In [None]:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect("tcp://localhost:6501")

def send_action(steer: float, throttle: float):
    action = {
        "action": [steer, throttle]
    }
    socket.send(pickle.dumps(action))  
    response = socket.recv()           
    data = pickle.loads(response)      

    return data



def decode_image(image_bytes):
    np_arr = np.frombuffer(image_bytes, dtype=np.uint8)
    img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
    return img

for step in range(200):
    steer = 0
    throttle = -1
    obs = send_action(steer, throttle)

    # Extract observation
    # try:
    image = decode_image(obs["image"])
    gnss = obs["gnss"]
    collision = obs["collision"]
    imu = obs["imu"]
    timestamp = obs["timestamp"]
    lane_invaded = obs["lane_invaded"]
    
    procsses_image = process_frame(image)
    
    if cv2.waitKey(1) == ord('q'):
        quit = True
        break
    
    # print(f"[Step {step}] GNSS: {gnss} | Time: {timestamp} | Collision: {collision} | imu: {imu} | lane: {lane_invaded}")
    # # print(f"lane: {lane_invaded}")
    # if lane_invaded['violated'] == True:
    #     print(lane_invaded["last_event"])
    #     print(lane_invaded)
        # cv2.imshow("Invasion", image)

    # Optional: show image
    cv2.imshow("Camera", procsses_image)
    # cv2.imshow("Camera1", image)
    cv2.waitKey(1)

    # except Exception as e:
    #     print(lane_invaded["last_event"])
    #     print(lane_invaded)
    #     print(e)
    #     # print(f"[ERROR] Failed to process observation: {e}")
    #     # print(f"[DEBUG] Image shape before processing: {image.shape}")
    #     pass
    #     # break

    # time.sleep(0.05) 
cv2.destroyAllWindows()



In [12]:
reset_command = {"command": "reset"}
socket.send(pickle.dumps(reset_command))
_ = socket.recv()