In [29]:
%pip install torch highway-env


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [30]:
import torch
import gymnasium as gym
import highway_env
import numpy as np
import matplotlib.pyplot as plt
import numpy as np

In [31]:
env = gym.make(
    "highway-fast-v0",
    render_mode="rgb_array",
    config={
        "action": {
            "type": "DiscreteMetaAction",
        },
        "observation": {
            "type": "LidarObservation",
            "cells": 128,
        },
        "vehicles_count": 20,
    },
)

# Define Variables
epochs = 100
episodes = 100
epsilon = 0.2
episilon_decay = 0.99
hidden_size = 512
learning_rate = 0.05
momentum = 0.9
depth = 3

obs, info = env.reset()

device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.mps.is_available() else "cpu"
)
device

  interval_1 = [(a - r) @ u / rqu, (b - r) @ u / rqu]
  interval_2 = [(a - r) @ v / rqv, (d - r) @ v / rqv]
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")


device(type='mps')

In [32]:
flattened_observation_size = np.prod(obs.shape)

class OBS_Model(torch.nn.Module):
    def __init__(self):
        super(OBS_Model, self).__init__()
        self.fc1 = torch.nn.Linear(flattened_observation_size, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc3 = torch.nn.Linear(hidden_size, flattened_observation_size)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class ACT_Model(torch.nn.Module):
    def __init__(self):
        super(ACT_Model, self).__init__()
        self.conv1 = torch.nn.Conv1d(depth, hidden_size, 2)
        self.pool1 = torch.nn.MaxPool1d(8)
        self.fc1 = torch.nn.Linear(15872, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc3 = torch.nn.Linear(hidden_size, env.action_space.n)

    def forward(self, x):
        x = torch.nn.functional.relu(self.conv1(x))
        x = self.pool1(x)
        x = x.flatten()
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

obs_net = OBS_Model().to(device)
act_net = ACT_Model().to(device)

In [33]:
class STMemory():
    def __init__(self, size):
        self.size = size
        self.memory = []
    
    def push(self, obj):
        self.memory.append(obj)
        if len(self.memory) > self.size:
            self.memory.pop(0)
    
    def pop(self):
        return self.memory.pop()
    
    def peek(self, index = -1):
        return self.memory[index]
    
    def as_list(self):
        return self.memory
    
    def __len__(self):
        return len(self.memory)

In [34]:
obs_optimizer = torch.optim.Adadelta(obs_net.parameters(), lr=learning_rate)
act_optimizer = torch.optim.Adadelta(act_net.parameters(), lr=learning_rate)
loss_fn = torch.nn.SmoothL1Loss()

In [None]:
loss_hist = []
reward_hist = []
recent_act_loss = 1
recent_obs_loss = 1
recent_reward = 1

obs_net.train()
act_net.train()

for epoch in range(epochs):
    # for epoch in range(epochs):
    for episode in range(episodes):
        obs, info = env.reset(seed=episode)
        done, truncated = False, False
        reward_sum = 0

        obs_memory = STMemory(depth)
        act_memory = STMemory(depth)

        # Fill memory with initial observation
        while len(obs_memory) < depth:
            obs_memory.push(obs.flatten())

        while not done and not truncated:
            # Action Selection
            if np.random.rand() < epsilon:
                action = env.action_space.sample()
            else:
                obs_net.eval()

                obs_x = torch.tensor(obs, dtype=torch.float32).flatten().to(device)

                # Predict N depth observations
                pred_obss = []
                for i in range(depth):
                    pred_obss.append(obs_net(obs_x))
                    obs_x = torch.cat([obs_x[flattened_observation_size:], obs_net(obs_x)])

                act_x = torch.tensor(np.array([pred_obs.cpu().detach().numpy() for pred_obs in pred_obss]), dtype=torch.float32).to(device)
                action_rew = act_net(act_x)
                # print(action_rew.shape)
                action = torch.argmax(action_rew).item()

                obs_net.train()

            # Step
            obs, reward, done, truncated, info = env.step(action)

            # Loss Calculation
            orig_act_values = act_net(torch.tensor(obs_memory.as_list(), dtype=torch.float32).to(device))
            updated_act_values = orig_act_values.clone()
            updated_act_values[action] = reward

            act_loss = loss_fn(orig_act_values, updated_act_values)

            pred_obs = obs_net(torch.tensor(obs_memory.peek(), dtype=torch.float32).to(device))
            obs_loss = loss_fn(pred_obs, torch.tensor(obs.flatten(), dtype=torch.float32).to(device))

            # Backpropagation
            act_optimizer.zero_grad()
            act_loss.backward()
            act_optimizer.step()

            obs_optimizer.zero_grad()
            obs_loss.backward()
            obs_optimizer.step()

            # Closing Sequence
            act_memory.push(action)
            obs_memory.push(obs.flatten())
            reward_sum += reward
            # env.render()

        reward_hist.append(recent_reward)
        recent_act_loss = act_loss.item() * 0.1 + recent_act_loss * 0.9
        recent_obs_loss = obs_loss.item() * 0.1 + recent_obs_loss * 0.9
        recent_reward = reward_sum * 0.1 + recent_reward * 0.9
        print(f"Epoch: {epoch}\t Episode: {episode}\t Reward: {reward_sum}\t Recent Reward: {recent_reward}\t Act Loss: {recent_act_loss * 100}\t Obs Loss: {recent_obs_loss * 100}")

        # Save
        torch.save(obs_net.state_dict(), "obs_net.pth")
        torch.save(act_net.state_dict(), "act_net.pth")

  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")


Epoch: 0	 Episode: 0	 Reward: 22.153554502160564	 Recent Reward: 3.1153554502160565	 Act Loss: 90.00000000018146	 Obs Loss: 93.38616520166397
Epoch: 0	 Episode: 1	 Reward: 20.918677275900926	 Recent Reward: 4.8956876327845436	 Act Loss: 81.0284581563206	 Obs Loss: 84.84360522031784
Epoch: 0	 Episode: 2	 Reward: 21.786887835702995	 Recent Reward: 6.584807653076389	 Act Loss: 72.92561234120156	 Obs Loss: 76.39482004251332
Epoch: 0	 Episode: 3	 Reward: 17.31311228363944	 Recent Reward: 7.657638116132695	 Act Loss: 66.04635858014487	 Obs Loss: 69.66568365221843
Epoch: 0	 Episode: 4	 Reward: 8.579174136873906	 Recent Reward: 7.749791718206816	 Act Loss: 60.232806594448654	 Obs Loss: 64.47402984969878
Epoch: 0	 Episode: 5	 Reward: 6.464219101614234	 Recent Reward: 7.621234456547558	 Act Loss: 54.83760667557481	 Obs Loss: 58.672751793918586
Epoch: 0	 Episode: 6	 Reward: 2.79937727612443	 Recent Reward: 7.139048738505245	 Act Loss: 49.802183913126264	 Obs Loss: 53.177092128685935
Epoch: 0	 Epi

  interval_1 = [(a - r) @ u / rqu, (b - r) @ u / rqu]


Epoch: 0	 Episode: 18	 Reward: 4.046451578715396	 Recent Reward: 8.204966802492418	 Act Loss: 17.439653606753307	 Obs Loss: 20.522547671317305
Epoch: 0	 Episode: 19	 Reward: 10.912489881749778	 Recent Reward: 8.475719110418154	 Act Loss: 16.56828623849817	 Obs Loss: 19.440585666107456
Epoch: 0	 Episode: 20	 Reward: 16.53333333333334	 Recent Reward: 9.281480532709672	 Act Loss: 15.522673444497597	 Obs Loss: 17.9276486127442
Epoch: 0	 Episode: 21	 Reward: 4.279830289487111	 Recent Reward: 8.781315508387417	 Act Loss: 14.302844859553588	 Obs Loss: 17.857824554043937
Epoch: 0	 Episode: 22	 Reward: 7.313112166125755	 Recent Reward: 8.63449517416125	 Act Loss: 13.278392828160419	 Obs Loss: 16.70292313992792
Epoch: 0	 Episode: 23	 Reward: 7.446445499459089	 Recent Reward: 8.515690206691033	 Act Loss: 12.823081651318574	 Obs Loss: 16.41036834806037
Epoch: 0	 Episode: 24	 Reward: 8.946834601513842	 Recent Reward: 8.558804646173314	 Act Loss: 11.930280526568149	 Obs Loss: 15.118153279928851
Epoc

KeyboardInterrupt: 

In [None]:
loss_hist = []
reward_hist = []
recent_act_loss = 1
recent_obs_loss = 1
recent_reward = 1

# for epoch in range(epochs):
for episode in range(episodes):
    obs, info = env.reset(seed=episode)
    done, truncated = False, False
    reward_sum = 0
    obs_memory = STMemory(depth)
    act_memory = STMemory(depth)
    # Fill memory with initial observation
    while len(obs_memory) < depth:
        obs_memory.push(obs.flatten())
    while not done and not truncated:
        # Action Selection
        if np.random.rand() < epsilon:
            action = env.action_space.sample()
        else:
            obs_net.eval()
            obs_x = torch.tensor(obs, dtype=torch.float32).flatten().to(device)
            # Predict N depth observations
            pred_obss = []
            for i in range(depth):
                pred_obss.append(obs_net(obs_x))
                obs_x = torch.cat(
                    [obs_x[flattened_observation_size:], obs_net(obs_x)]
                )
            act_x = torch.tensor(
                np.array(
                    [pred_obs.cpu().detach().numpy() for pred_obs in pred_obss]
                ),
                dtype=torch.float32,
            ).to(device)
            action_rew = act_net(act_x)
            # print(action_rew.shape)
            action = torch.argmax(action_rew).item()
        # Step
        obs, reward, done, truncated, info = env.step(action)
        # Loss Calculation
        orig_act_values = act_net(
            torch.tensor(obs_memory.as_list(), dtype=torch.float32).to(device)
        )
        updated_act_values = orig_act_values.clone()
        updated_act_values[action] = reward
        act_loss = loss_fn(orig_act_values, updated_act_values)
        pred_obs = obs_net(
            torch.tensor(obs_memory.peek(), dtype=torch.float32).to(device)
        )
        obs_loss = loss_fn(
            pred_obs, torch.tensor(obs.flatten(), dtype=torch.float32).to(device)
        )
        # Closing Sequence
        act_memory.push(action)
        obs_memory.push(obs.flatten())
        reward_sum += reward
        env.render()
    reward_hist.append(recent_reward)
    recent_act_loss = act_loss.item() * 0.1 + recent_act_loss * 0.9
    recent_obs_loss = obs_loss.item() * 0.1 + recent_obs_loss * 0.9
    recent_reward = reward_sum * 0.1 + recent_reward * 0.9
    print(
        f"Epoch: {epoch}\t Episode: {episode}\t Reward: {reward_sum}\t Recent Reward: {recent_reward}\t Act Loss: {recent_act_loss * 100}\t Obs Loss: {recent_obs_loss * 100}"
    )

Epoch: 5	 Episode: 0	 Reward: 13.697509064989958	 Recent Reward: 2.269750906498996	 Act Loss: 90.10457827709615	 Obs Loss: 90.78876577317715
Epoch: 5	 Episode: 1	 Reward: 9.765819162156687	 Recent Reward: 3.019357732064765	 Act Loss: 81.22703134734184	 Obs Loss: 82.05658929795027
Epoch: 5	 Episode: 2	 Reward: 4.832044433857514	 Recent Reward: 3.20062640224404	 Act Loss: 73.21876394310966	 Obs Loss: 74.41696684584022
Epoch: 5	 Episode: 3	 Reward: 2.432044434686411	 Recent Reward: 3.123768205488277	 Act Loss: 66.06302031555214	 Obs Loss: 67.33115779779851
Epoch: 5	 Episode: 4	 Reward: 21.751063609128884	 Recent Reward: 4.986497745852338	 Act Loss: 59.45842796421541	 Obs Loss: 60.78995440831035
Epoch: 5	 Episode: 5	 Reward: 3.94445233201402	 Recent Reward: 4.882293204468507	 Act Loss: 53.59280003702486	 Obs Loss: 55.37037103858814
Epoch: 5	 Episode: 6	 Reward: 7.199377276124429	 Recent Reward: 5.114001611634099	 Act Loss: 48.35333016563217	 Obs Loss: 50.21615162033343
Epoch: 5	 Episode: 7