In [None]:
from google.colab import drive, files
drive.mount('/content/drive')

import os
import shutil

SAVE_DIR = "/content/drive/MyDrive/highway_results/grayscale_dqn"
os.makedirs(SAVE_DIR, exist_ok=True)

def check_drive_connection():
    test_file = f"{SAVE_DIR}/.connection_test"
    try:
        with open(test_file, 'w') as f:
            f.write('ok')
        os.remove(test_file)
        print(f"Google Drive connected. Saving to: {SAVE_DIR}")
        return True
    except Exception as e:
        print(f"ERROR: Cannot write to Google Drive! {e}")
        return False

assert check_drive_connection(), "Fix Google Drive connection before continuing!"

In [None]:
!pip install highway-env stable-baselines3 gymnasium

In [None]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import torch
import glob

from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from stable_baselines3.common.vec_env import DummyVecEnv

import highway_env

print(f"Using device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

In [None]:
def create_env():
    env = gym.make('highway-fast-v0', render_mode='rgb_array', config={
        "observation": {
            "type": "GrayscaleObservation",
            "observation_shape": (128, 64),
            "stack_size": 4,
            "weights": [0.2989, 0.5870, 0.1140],
            "scaling": 1.75,
        },
    })
    env.reset()
    return env

test_env = create_env()
obs, _ = test_env.reset()
print(f"Observation shape: {obs.shape}")
print(f"Action space: {test_env.action_space}")
test_env.close()

In [None]:
class RewardLoggerCallback(BaseCallback):
    def __init__(self, save_path, verbose=0):
        super().__init__(verbose)
        self.episode_rewards = []
        self.current_rewards = 0
        self.save_path = save_path

    def _on_step(self) -> bool:
        self.current_rewards += self.locals['rewards'][0]
        if self.locals['dones'][0]:
            self.episode_rewards.append(self.current_rewards)
            self.current_rewards = 0
            if len(self.episode_rewards) % 50 == 0:
                avg = np.mean(self.episode_rewards[-50:])
                print(f"Episode {len(self.episode_rewards)}: Avg reward (last 50) = {avg:.2f}")
                np.save(f"{self.save_path}/episode_rewards.npy", self.episode_rewards)
        return True


class DownloadCallback(BaseCallback):
    def __init__(self, save_path, download_freq=10000, verbose=0):
        super().__init__(verbose)
        self.save_path = save_path
        self.download_freq = download_freq
        self.last_download = 0

    def _on_step(self) -> bool:
        if self.num_timesteps - self.last_download >= self.download_freq:
            self.last_download = self.num_timesteps
            try:
                zip_name = f"grayscale_dqn_checkpoint_{self.num_timesteps}"
                shutil.make_archive(f"/content/{zip_name}", 'zip', self.save_path)
                files.download(f"/content/{zip_name}.zip")
                print(f"Checkpoint downloaded: {self.num_timesteps} timesteps")
            except Exception as e:
                print(f"Download skipped: {e}")
        return True

In [None]:
TOTAL_TIMESTEPS = 100000
CHECKPOINT_FREQ = 10000

print("Setting up DQN training...")
env = DummyVecEnv([create_env])

checkpoints = glob.glob(f"{SAVE_DIR}/checkpoint_*.zip")
if checkpoints:
    latest = max(checkpoints, key=lambda x: int(x.split('_')[-2]))
    print(f"Resuming from: {latest}")
    model = DQN.load(latest, env=env, device="cuda")
    if os.path.exists(f"{SAVE_DIR}/episode_rewards.npy"):
        saved_rewards = np.load(f"{SAVE_DIR}/episode_rewards.npy").tolist()
    else:
        saved_rewards = []
else:
    print("Starting new DQN training")
    model = DQN(
        "CnnPolicy",
        env,
        learning_rate=5e-4,
        buffer_size=15000,
        learning_starts=200,
        batch_size=32,
        gamma=0.8,
        train_freq=1,
        gradient_steps=1,
        target_update_interval=50,
        exploration_fraction=0.7,
        verbose=1,
        device="cuda"
    )
    saved_rewards = []

checkpoint_cb = CheckpointCallback(save_freq=CHECKPOINT_FREQ, save_path=SAVE_DIR, name_prefix="checkpoint")
reward_cb = RewardLoggerCallback(save_path=SAVE_DIR)
reward_cb.episode_rewards = saved_rewards
download_cb = DownloadCallback(save_path=SAVE_DIR, download_freq=CHECKPOINT_FREQ)

print(f"Training DQN for {TOTAL_TIMESTEPS} timesteps...")
model.learn(total_timesteps=TOTAL_TIMESTEPS, callback=[checkpoint_cb, reward_cb, download_cb])
model.save(f"{SAVE_DIR}/grayscale_dqn_model_3")
print("Training complete!")

In [None]:
plt.figure(figsize=(10, 6))
rewards = reward_cb.episode_rewards
if len(rewards) >= 50:
    rolling = np.convolve(rewards, np.ones(50)/50, mode='valid')
    plt.plot(np.arange(50, len(rewards)+1), rolling, 'b-', linewidth=2, label='Rolling Mean (50 ep)')
plt.plot(rewards, 'lightblue', alpha=0.3, label='Episode Reward')
plt.xlabel('Episode')
plt.ylabel('Mean Episodic Reward (Return)')
plt.title('Highway-v0 GrayscaleObservation DQN - Learning Curve (ID 3)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(f"{SAVE_DIR}/highway_grayscale_learning_curve_3.png", dpi=150)
plt.show()

In [None]:
print("Running performance test: 500 episodes...")
eval_env = create_env()
eval_rewards = []

for ep in range(500):
    obs, _ = eval_env.reset()
    done = truncated = False
    total = 0
    while not (done or truncated):
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, _ = eval_env.step(action)
        total += reward
    eval_rewards.append(total)
    if (ep + 1) % 100 == 0:
        print(f"Completed {ep + 1}/500 | Mean: {np.mean(eval_rewards):.2f}")

eval_env.close()
print(f"Final mean reward: {np.mean(eval_rewards):.2f} +/- {np.std(eval_rewards):.2f}")

In [None]:
plt.figure(figsize=(8, 6))
parts = plt.violinplot(eval_rewards, positions=[1], showmeans=True, showmedians=True)
for pc in parts['bodies']:
    pc.set_facecolor('steelblue')
    pc.set_alpha(0.7)
plt.text(1.25, np.mean(eval_rewards), f'Mean: {np.mean(eval_rewards):.2f}\nStd: {np.std(eval_rewards):.2f}')
plt.xlabel('Highway GrayscaleObs (DQN)')
plt.ylabel('Episodic Reward (Return)')
plt.title('Performance Test - 500 Episodes (ID 4)')
plt.xticks([1], ['Grayscale DQN'])
plt.grid(True, alpha=0.3, axis='y')
plt.savefig(f"{SAVE_DIR}/highway_grayscale_performance_4.png", dpi=150)
plt.show()

In [None]:
print("Downloading final results...")
shutil.make_archive("/content/grayscale_dqn_final", 'zip', SAVE_DIR)
files.download("/content/grayscale_dqn_final.zip")
print("Done!")