In [5]:
import cv2
import zmq
import time
import torch
import pickle
import numpy as np
import gymnasium as gym
from stable_baselines3 import PPO

In [6]:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect("tcp://localhost:5555")

<SocketContext(connect='tcp://localhost:5555')>

In [7]:
def send_action(steer: float, throttle: float):
    action = {
        "action": [steer, throttle]
    }
    socket.send(pickle.dumps(action))  # Serialize & send
    response = socket.recv()           # Receive serialized reply
    data = pickle.loads(response)      # Deserialize

    return data

# ==== Decode image from bytes ====
def decode_image(image_bytes):
    np_arr = np.frombuffer(image_bytes, dtype=np.uint8)
    img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
    return img

# # ==== Main loop ====
# for step in range(200):
#     steer = 0.0               # Replace with your RL output
#     throttle = 0.5            # Forward motion
#     obs = send_action(steer, throttle)

#     # Extract observation
#     try:
#         image = decode_image(obs["image"])
#         gnss = obs["gnss"]
#         timestamp = obs["timestamp"]

#         print(f"[Step {step}] GNSS: {gnss} | Time: {timestamp}")

#         # Optional: show image
#         cv2.imshow("Camera", image)
#         cv2.waitKey(1)

#     except Exception as e:
#         print(f"[ERROR] Failed to process observation: {e}")
#         break

#     time.sleep(0.05)  # slow it down for now
# cv2.destroyAllWindows()

# Reset Carla at end (optional)
# reset_command = {"command": "reset"}
# socket.send(pickle.dumps(reset_command))
# _ = socket.recv()

In [15]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import io

# Simulated RL output
steer = 0.0
throttle = 0.0

# Simulated environment image response
obs = send_action(steer, throttle)  # Should return a dict with "image" in bytes

# Load checkpoint
checkpoint = torch.load('culane_res34.pth', map_location='cpu')
state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint

# Sanitize keys
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith("module.model."):
        new_k = k.replace("module.model.", "")
    elif k.startswith("model."):
        new_k = k.replace("model.", "")
    else:
        new_k = k
    new_state_dict[new_k] = v

# Inspect keys to understand the architecture
print("Keys in state dict:")
for k in list(new_state_dict.keys())[:10]:
    print(k)

# WARNING: You CANNOT load these weights into torchvision.models.resnet34 directly
# unless you confirm the model was trained with that structure.

# Image decode and preprocessing
def decode_image(image_bytes):
    return Image.open(io.BytesIO(image_bytes)).convert('RGB')

transform = transforms.Compose([
    transforms.Resize((288, 800)),  # CULane input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.3598, 0.3653, 0.3662], std=[0.2573, 0.2663, 0.2756])  # CULane-specific
])

image = decode_image(obs["image"])
input_tensor = transform(image).unsqueeze(0)

# You need the **EXACT MODEL DEFINITION** for lane detection used during training
# Example (not working yet):
# model = CustomLaneNet()

# model.load_state_dict(new_state_dict)
# model.eval()

# with torch.no_grad():
#     output = model(input_tensor)


Keys in state dict:
conv1.weight
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.bn1.running_mean


In [17]:
import torchvision
model = torchvision.models.resnet34()
model.load_state_dict(new_state_dict, strict=False)


_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=['module.cls.0.weight', 'module.cls.0.bias', 'module.cls.1.weight', 'module.cls.1.bias', 'module.cls.3.weight', 'module.cls.3.bias', 'module.pool.weight', 'module.pool.bias'])

In [18]:
with torch.no_grad():
    output = model(input_tensor)
