In [4]:
%pip install torch highway-env

You should consider upgrading via the '/Users/iato/Code/autobots/.venv/bin/python -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [5]:
import torch
import gymnasium as gym
import highway_env
import numpy as np
import matplotlib.pyplot as plt
import numpy as np

In [6]:
env = gym.make(
    "highway-fast-v0",
    render_mode="rgb_array",
    config={
        "action": {
            "type": "DiscreteMetaAction",
        },
        "observation": {
            "type": "OccupancyGrid",
            "vehicles_count": 15,
            "features": ["presence", "x", "y", "vx", "vy", "cos_h", "sin_h"],
            "features_range": {
                "x": [-100, 100],
                "y": [-100, 100],
                "vx": [-20, 20],
                "vy": [-20, 20]
            },
            "grid_size": [[-27.5, 27.5], [-27.5, 27.5]],
            "grid_step": [5, 5],
            "absolute": False
        },
        "vehicles_count": 20,
    },
)

epochs = 100
episodes = 100
epsilon = 0.2
episilon_decay = 0.99
hidden_size = 512
learning_rate = 0.05
momentum = 0.9

obs, info = env.reset()
obs

array([[[ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  1.        ,  0.        ,  0.        ,
          0.        ,  

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
device

device(type='mps')

In [8]:
flattened_observation_size = np.prod(obs.shape)
net = torch.nn.Sequential(
    torch.nn.Linear(flattened_observation_size, hidden_size),
    torch.nn.LeakyReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(hidden_size, hidden_size),
    torch.nn.LeakyReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(hidden_size, hidden_size),
    torch.nn.LeakyReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(hidden_size, hidden_size),
    torch.nn.LeakyReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(hidden_size, hidden_size),
    torch.nn.LeakyReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(hidden_size, env.action_space.n),
).to(device)

In [9]:
optimizer = torch.optim.Adadelta(net.parameters(), lr=learning_rate, rho=momentum)
loss_fn = torch.nn.MSELoss()

In [None]:
loss_hist = []
reward_hist = []
recent_loss = 1

net.train()

for epoch in range(epochs):
    # for epoch in range(epochs):
    for episode in range(episodes):
        obs, info = env.reset(seed=episode)
        done, truncated = False, False
        reward = 0
        while not done and not truncated:
 
            x = torch.tensor(obs, dtype=torch.float32).flatten().to(device)
            y = net(x)
            y = torch.where(torch.isnan(y), torch.tensor(0), y)
            
            action = max(enumerate(y), key=lambda x: x[1])[0]
            if (np.random.rand() < epsilon) or (torch.isnan(x).any()):
                action = env.action_space.sample()

            nobs, reward, done, truncated, info = env.step(action)
            # x, y for training

            # If not yet discovered, set output to 0.5 for all actions
            ny = y.clone()
            ny[action] = reward

            # Calculate loss
            loss = loss_fn(ny, y)

            # Backpropagate
            optimizer.zero_grad()
            loss.backward()

            # Update weights
            optimizer.step()

            # Update observation
            obs = nobs
            
            env.render()
        # Save history
        recent_loss = loss.item() * 0.05 + recent_loss * 0.95
        print(f"Epoch {epoch}\t Episode {episode}\tReward {round(reward, 2)}\tRecent Loss {recent_loss}")
        reward_hist.append(reward)
        loss_hist.append(loss.item())
        plt.plot(reward_hist)
        plt.plot(loss_hist)
    torch.save(net.state_dict(), "model.pth")
    # epsilon *= episilon_decay
    epoch += 1


2024-11-22 16:55:12.756 Python[56630:362358] +[IMKClient subclass]: chose IMKClient_Modern


Epoch 0	 Episode 0	Reward 0.09	Recent Loss 0.9502059605671093
Epoch 0	 Episode 1	Reward 0.03	Recent Loss 0.902695960236083
Epoch 0	 Episode 2	Reward 0.07	Recent Loss 0.8575612663472568
Epoch 0	 Episode 3	Reward 0.07	Recent Loss 0.8146844118992378
Epoch 0	 Episode 4	Reward 0.33	Recent Loss 0.7745619529822861
Epoch 0	 Episode 5	Reward 0.17	Recent Loss 0.7358871223220061
Epoch 0	 Episode 6	Reward 0.07	Recent Loss 0.6991217461268043
Epoch 0	 Episode 7	Reward 0.07	Recent Loss 0.6642068570154224
Epoch 0	 Episode 8	Reward 0.07	Recent Loss 0.6310504256636318
Epoch 0	 Episode 9	Reward 0.17	Recent Loss 0.5999453852732127
Epoch 0	 Episode 10	Reward 0.07	Recent Loss 0.570066311566122
Epoch 0	 Episode 11	Reward 0.11	Recent Loss 0.5416025220163745
Epoch 0	 Episode 12	Reward 0.0	Recent Loss 0.5145224873282429
Epoch 0	 Episode 13	Reward 0.33	Recent Loss 0.488904171722198
Epoch 0	 Episode 14	Reward 0.07	Recent Loss 0.46455608038314805
Epoch 0	 Episode 15	Reward 0.33	Recent Loss 0.44142181788533685
Epoc

KeyboardInterrupt: 

In [10]:
net.load_state_dict(torch.load("model.pth"))
net.eval()

num_tests = 100

reward_hist = []

for i in range(num_tests):
    obs, info = env.reset(seed=i)
    done, truncated = False, False
    reward = 0
    while not done and not truncated:
        x = torch.tensor(obs, dtype=torch.float32).flatten().to(device)
        action = max(enumerate(net(x)), key=lambda x: x[1])[0]
        obs, reward, done, truncated, info = env.step(action)
        env.render()
    reward_hist.append(reward)
    print(f"Test {i} done")

plt.plot(reward_hist)
plt.show()

  net.load_state_dict(torch.load("model.pth"))
2024-11-22 18:18:48.469 Python[68553:444413] +[IMKClient subclass]: chose IMKClient_Modern


Test 0 done
Test 1 done
Test 2 done
Test 3 done
Test 4 done
Test 5 done
Test 6 done
Test 7 done
Test 8 done
Test 9 done
Test 10 done
Test 11 done
Test 12 done
Test 13 done
Test 14 done
Test 15 done
Test 16 done
Test 17 done
Test 18 done
Test 19 done
Test 20 done
Test 21 done
Test 22 done
Test 23 done
Test 24 done
Test 25 done
Test 26 done
Test 27 done
Test 28 done
Test 29 done
Test 30 done
Test 31 done
Test 32 done
Test 33 done
Test 34 done
Test 35 done
Test 36 done
Test 37 done
Test 38 done
Test 39 done
Test 40 done
Test 41 done
Test 42 done
Test 43 done
Test 44 done
Test 45 done
Test 46 done
Test 47 done
Test 48 done
Test 49 done
Test 50 done


KeyboardInterrupt: 