In [1]:
# Load the Drive helper and mount
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd /content/drive/MyDrive/Documents/Fulbright\ Application\ 2020-2021/Courses/'Spring Semester 2024'/'Deep Decision and Reinforcement Learning'/'Self-Driving-Car-Project'
%ls demos

Mounted at /content/drive
/content/drive/MyDrive/Documents/Fulbright Application 2020-2021/Courses/Spring Semester 2024/Deep Decision and Reinforcement Learning/Self-Driving-Car-Project
circle_clock.json         never_seen.json  recover_3.json  recover_6.json
circle_counterclock.json  recover_1.json   recover_4.json  snake_2.json
eight.json                recover_2.json   recover_5.json  snake.json


In [2]:
import os
import json
from enum import Enum

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

from src.ppo_agent import PPOModel, Constants
from src.dataset import Dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

obs_list = torch.tensor([])
action_list = torch.tensor([])

for file in os.listdir("./demos"):
    if file.startswith("*") or file.startswith("."):
        continue
    with open(f"./demos/{file}", "r") as f:
        data = json.load(f)
        for episode in data:
            min_length = min(len(episode[0]), len(episode[1]))
            obs = episode[0][:min_length]
            action = episode[1][:min_length]

            if len(obs) == 0 or len(action) == 0:
                continue

            obs = torch.tensor(obs, dtype=torch.float32)
            action = torch.tensor(action, dtype=torch.float32)
            obs_list = torch.cat([obs_list, obs])
            action_list = torch.cat([action_list, action])

dataset = Dataset(Constants.INPUT_SIZE.value, obs_list, action_list, Constants.NUM_HISTORY.value)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

model = PPOModel(
    Constants.NUM_HISTORY.value,
    Constants.INPUT_SIZE.value,
    Constants.OUTPUT_SIZE.value
).to(device)

# create a loss function
loss_fn = nn.MSELoss().to(device)

# create an optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=Constants.lr.value)
scheduler = ReduceLROnPlateau(optimizer, "min", patience=2)
# train the model
iterator = tqdm(range(1, Constants.EPOCHS.value + 1), total=Constants.EPOCHS.value, desc="Training")

for epoch in iterator:
    model.train()
    iterator.set_description("Training")
    for obs, action in train_dataloader:
        optimizer.zero_grad()
        obs = obs.to(device)
        action = action.to(device)
        # pred = model(obs)
        dist = model(obs)
        log_prob = dist.log_prob(action).sum()
        loss = -log_prob.mean()
        loss.backward()
        optimizer.step()

    # evaluate the model
    iterator.set_description("Evaluating")
    model.eval()
    with torch.no_grad():
        test_loss = 0
        for obs, action in test_dataloader:
            obs = obs.to(device)
            action = action.to(device)
            # pred = model(obs)
            dist = model(obs)
            log_prob = dist.log_prob(action).sum()
            loss = -log_prob.mean()
            test_loss += loss.item()
        test_loss /= len(test_dataloader)
    print(f"Epoch: {epoch}. Loss: {test_loss:0.4f}")
    iterator.set_postfix(epoch=epoch, loss=test_loss)
    scheduler.step(test_loss)

# save the model
torch.save(model.state_dict(), f"pretrained_model_dict_{device}.pt")

cuda
PPOModel2(
  (policy): Sequential(
    (0): Linear(in_features=25, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=128, bias=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=2, bias=True)
    (7): Tanh()
  )
)


Training:   5%|▌         | 1/20 [00:11<03:31, 11.11s/it, epoch=1, loss=691]  

Epoch: 1. Loss: 690.8305


Training:  10%|█         | 2/20 [00:21<03:09, 10.55s/it, epoch=2, loss=620]  

Epoch: 2. Loss: 619.7026


Training:  15%|█▌        | 3/20 [00:31<02:58, 10.49s/it, epoch=3, loss=625]  

Epoch: 3. Loss: 624.6833


Training:  20%|██        | 4/20 [00:42<02:48, 10.51s/it, epoch=4, loss=620]  

Epoch: 4. Loss: 620.3424


Training:  25%|██▌       | 5/20 [00:52<02:36, 10.46s/it, epoch=5, loss=619]  

Epoch: 5. Loss: 619.2061


Training:  30%|███       | 6/20 [01:02<02:23, 10.25s/it, epoch=6, loss=619]  

Epoch: 6. Loss: 619.2915


Training:  35%|███▌      | 7/20 [01:13<02:17, 10.56s/it, epoch=7, loss=622]  

Epoch: 7. Loss: 621.8585


Training:  40%|████      | 8/20 [01:23<02:05, 10.44s/it, epoch=8, loss=618]  

Epoch: 8. Loss: 618.4229


Training:  45%|████▌     | 9/20 [01:33<01:52, 10.22s/it, epoch=9, loss=619]  

Epoch: 9. Loss: 619.0679


Training:  50%|█████     | 10/20 [01:43<01:41, 10.12s/it, epoch=10, loss=621]  

Epoch: 10. Loss: 620.8302


Training:  55%|█████▌    | 11/20 [01:53<01:31, 10.19s/it, epoch=11, loss=626]  

Epoch: 11. Loss: 626.3526


Training:  60%|██████    | 12/20 [02:03<01:21, 10.17s/it, epoch=12, loss=618]  

Epoch: 12. Loss: 618.3317


Training:  65%|██████▌   | 13/20 [02:14<01:11, 10.15s/it, epoch=13, loss=619]  

Epoch: 13. Loss: 618.5053


Training:  70%|███████   | 14/20 [02:23<00:59,  9.86s/it, epoch=14, loss=618]  

Epoch: 14. Loss: 618.3738


Training:  75%|███████▌  | 15/20 [02:33<00:49,  9.94s/it, epoch=15, loss=618]  

Epoch: 15. Loss: 618.3583


Training:  80%|████████  | 16/20 [02:43<00:40, 10.03s/it, epoch=16, loss=618]  

Epoch: 16. Loss: 618.3048


Training:  85%|████████▌ | 17/20 [02:54<00:30, 10.14s/it, epoch=17, loss=618]  

Epoch: 17. Loss: 618.2884


Training:  90%|█████████ | 18/20 [03:03<00:19,  9.96s/it, epoch=18, loss=618]  

Epoch: 18. Loss: 618.2787


Training:  95%|█████████▌| 19/20 [03:13<00:10, 10.01s/it, epoch=19, loss=618]  

Epoch: 19. Loss: 618.2788


Evaluating: 100%|██████████| 20/20 [03:23<00:00, 10.19s/it, epoch=20, loss=618]

Epoch: 20. Loss: 618.2788



