In [None]:
import gettelemetry as client
import gamepad as gp
import window as gwd
import wandb
import os

import torch
import torch.nn.functional as F
import mss
import cv2
import time
import numpy as np

import pywinctl as gw
import vgamepad as vg
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import SAC
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CheckpointCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.policies import BaseFeaturesExtractor
from stable_baselines3.common.callbacks import CallbackList

import torchvision.models as models
from torchvision import transforms

In [None]:
# parameters

steps = 100000

In [None]:
wandb.login()
wandb.init(project="trackmania_sac")

In [None]:
class WindowCap():
    def __init__(self, window_name):
        self.window_name = window_name
        self.window = gw.getWindowsWithTitle(window_name)
        if not self.window:
            raise Exception(f"Window with name '{window_name}' not found.")
        self.window = self.window[0]
        self.top = self.window.top
        self.left = self.window.left
        self.width = self.window.width
        self.height = self.window.height
        self.monitor = {"top": self.top, "left": self.left, "width": self.width, "height": self.height}
        self.sct = mss.mss()
        self.resize = 128

    def capture(self):
        img = np.array(self.sct.grab(self.monitor))
        img = cv2.resize(img, (self.resize, self.resize))       # resize
        img = cv2.cvtColor(img, cv2.COLOR_BGRA2GRAY)        # Convert to grayscale
        img = img / 255.0           # normalize
        return img

    def __del__(self):
        self.sct.close()
        

In [None]:
class TrackmaniaEnv(gym.Env) :
    def __init__(self, window_name="Trackmania"):
        super(TrackmaniaEnv, self).__init__()
        self.window = WindowCap(window_name)
        self.client = client.TMClient()
        self.gamepad = gp.GamepadHandler()
        self.action_space = gym.spaces.Box(
            low=np.array([-1.0, 0.0, 0.0]), # Steering angle from -1 (left) to 1 (right) throttle from 0 to 1, braking 0 to 1
            high=np.array([1.0, 1.0, 1.0]), 
            dtype=np.float32)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(1, 128, 128), dtype=np.uint8)
        self.reward_range = (-np.inf, np.inf)
        self.metadata = {'render.modes': ['human']}
        self.spec = None
        self.terminated = False
        self.truncated = False
        self.reward = 0
        self.prev_action = np.array([0, 0, 0])
        self.prev_obs = np.zeros((1, 128, 128), dtype=np.uint8)
        
        self.id = gwd.get_window_id("Trackmania")      
        self.focus = gwd.focus_window(self.id)
        self.speed_buffer = []

        self.steps = 0
        self.episode_reward = 0
        self.time_start = 0

    def reset(self, seed=None, options=None):
        self.gamepad.reset()
        self.speed_buffer = []
        time.sleep(1.5)
        #self.focus = gwd.focus_window(self.id)
        if seed is not None:
            self.seed(seed)
        self.terminated = False
        self.reward = 0
        self.prev_action = np.array([0, 0, 0])
        obs = self.window.capture()
        obs = np.expand_dims(obs, axis=0)  # Add channel dimension
        obs = obs.astype(np.float32)  # Ensure float32 type
        self.time_start = time.time()
        
        # Return observation and empty info dict
        return obs, {}
    
    def seed(self, seed=None):
        np.random.seed(seed)

    def step(self, action):
        self.steps += 1
        # Send the action to the game
        #print(action)
        self.gamepad.send_action(action)
        self.prev_action = action
        obs = self.window.capture()  # Capture the current image
        telemetry = self.client.retrieve_data()  # Retrieve the telemetry data

        # Process telemetry data to compute the reward and determine if the episode is done
        checkpoint = telemetry['checkpoint']
        lap = telemetry['lap']
        speed = telemetry['speed']
        position = telemetry['position']
        steer = telemetry['steer']
        gas = telemetry['gas']
        brake = telemetry['brake']
        finished = telemetry['finished']
        acceleration = telemetry['acceleration']
        jerk = telemetry['jerk']
        aim_yaw = telemetry['aim_yaw']
        aim_pitch = telemetry['aim_pitch']
        fl_steer_angle = telemetry['fl_steer_angle']
        fr_steer_angle = telemetry['fr_steer_angle']
        fl_slip = telemetry['fl_slip']
        fr_slip = telemetry['fr_slip']
        gear = telemetry['gear']

        # reward = average of speed_buffer of last 10 steps
        self.speed_buffer.append(speed)
        if len(self.speed_buffer) > 50:
            self.speed_buffer.pop(0)
        speed_av = sum(self.speed_buffer) / len(self.speed_buffer)

        # Reward function
        reward = speed_av*0.1
        #print(reward)
        if speed_av < 20:
            reward -= 0.1

        if checkpoint!=0:
            reward += 20
        
        if finished:
            reward += 100  # Bonus reward for finishing the race
            self.gamepad.press_a()
            self.terminated = True
            # press improve to restart

        # if 
        if speed_av < 2 and acceleration <0.1 and time.time() - self.time_start > 5:
            reward -= 0.1
            self.terminated = True

        
        if self.steps % 100 == 0:
            print(f"step: {self.steps} / {steps}")

        self.reward = reward
        truncated = False
        terminated = self.terminated
        info = {
            'speed': speed,
            'position': position,
            'checkpoint': checkpoint,
            'lap': lap,
        }
        return obs, reward, terminated, truncated, info
    
    def make_env():
        def _init():
            env = TrackmaniaEnv()
            return env
        return _init

    def close(self):
        del self.window
        self.client.close()  # Close the TMClient connection

In [None]:
class CustomWandbCallback(BaseCallback):
    def __init__(self, verbose=0):
        super().__init__(verbose)
        
    def _on_step(self) -> bool:
        # Log only available metrics
        wandb.log({
            'reward': self.locals['rewards'],
            'timesteps': self.num_timesteps
        })
        
        # Log episode info if available
        info = self.locals.get('infos', [{}])[0]
        if info:
            wandb.log({
                'speed': info.get('speed', 0),
                'checkpoint': info.get('checkpoint', 0),
                'lap': info.get('lap', 0),
                'episode_duration': info.get('episode_duration', 0)
            })
        return True

In [None]:
# create environment
env = TrackmaniaEnv()
env.reset()
env = DummyVecEnv([lambda: env])

# Create the SAC model
model = SAC('CnnPolicy', 
            env, 
            verbose=1,
            buffer_size=400000)



# Train the model with the WandbCallback
model.learn(total_timesteps=steps, callback=CustomWandbCallback())
model.save("models/trackmania_sac")

# Close the environment
env.close()