In [1]:
import gym
import logging
import numpy as np
import dill as pickle
import os

from threading import Thread, Lock
from time import sleep
from client import CarDirection, Client
from env import JunctionEnvironment
from matplotlib import pyplot as plt

%load_ext autoreload
%autoreload 2

In [2]:
logger = logging.getLogger(None)
logger.setLevel(logging.WARNING)

In [3]:
team_name = "ipa"
team_key = "admin"
N_GAMES = 1
CACHE_DIR = "./cache"
DATASET_DIR = os.path.join(CACHE_DIR, "dataset")
IMIT_LEARNING_DATASET_DIR = os.path.join(DATASET_DIR, "imitation_learning")

if not os.path.isdir(IMIT_LEARNING_DATASET_DIR):
    os.makedirs(IMIT_LEARNING_DATASET_DIR)

In [4]:
def megaalg(obs):
    return np.random.randint(0, 5)

In [10]:
class Runner(Thread):
    def __init__(self, car_id, game_id, env, lock):
        super().__init__()
        self.car_id = car_id
        self.game_id = game_id
        self.env = env
        self.lock = lock
        
        self.prev_obs = None
        
        self.obss = []
        self.scores = []
        self.actions = []
        
    def run(self):
        # Need to do some initial action to fetch observations
        obs, score, done, _ = self.env.step(1, self.car_id)
        self.prev_obs = obs
        
        while True:
            try:
                new_action = megaalg(self.prev_obs)
                self.lock.acquire()
                obs, score, done, _ = self.env.step(new_action, self.car_id)
                print(score, done)
            except Exception as ex:
                print(f"{self.car_id}: {ex}")
            finally:
                self.lock.release()
            if done:
                break

            action = new_action

            self.obss.append(self.prev_obs)
            self.scores.append(score)
            self.actions.append(action)

            self.prev_obs = obs                
            sleep(0.5)
           
        
        self.obss = np.array(self.obss)
        self.scores = np.array(self.scores)
        self.actions = np.array(self.actions)
        
        with open(os.path.join(IMIT_LEARNING_DATASET_DIR, f"game_{self.game_id}_car_{self.car_id}_obs.pkl"), "wb"):
            pickle.dumps(self.obss)
        with open(os.path.join(IMIT_LEARNING_DATASET_DIR, f"game_{self.game_id}_car_{self.car_id}_scores.pkl"), "wb"):
            pickle.dumps(self.scores)
        with open(os.path.join(IMIT_LEARNING_DATASET_DIR, f"game_{self.game_id}_car_{self.car_id}_actions.pkl"), "wb"):
            pickle.dumps(self.actions)
        print(f"{self.car_id} finished")

In [9]:
if __name__ == "__main__":
    print("In main thread")
    client = Client(team_name=team_name, team_key=team_key)
    env = JunctionEnvironment(client)

    lock = Lock()

    for i in range(N_GAMES):
        print("Running game", i)
        game_id = np.random.randint(0, 100000)
        _ = env.reset()

        processes = []
        for car_id in env.car_ids:
            process = Runner(car_id, game_id, env, lock)
            processes.append(process)

        for process in processes:
            process.start()

        for process in processes:
            process.join()
        print(f"Game {i} finished")

In main thread
Running game 0
0 is runnning
1 is runnning
2 is runnning
-5 False
-5 False
-5 False
-3 False
-3 False
-3 False
-6 False
-6 False
-6 False
-8 False
-10 False
-10 False
-10 False
-10 False
-10 False
-13 False
-13 False
-13 False
-16 False
-16 False
-16 False
-19 False
-19 False
-19 False
-21 False
-21 False
-21 False
-24 False
-24 False
-24 False
-27 False
-27 False
-27 False
-29 False
-29 False
-29 False
-33 False
-33 False
-33 False
-35 False
-35 False
-35 False
-38 False
-38 False
-38 False
-41 False
-41 False
-41 False
-43 False
-43 False
-43 False
-45 False
-45 False
-45 False
-47 False
-47 False
-47 False
-48 False
-48 False
-48 False
-50 False
-50 False
-50 False
-53 False
-53 False
-53 False
-54 False
-54 False
-54 False
-56 False
-56 False
-56 False
-59 False
-59 False
-59 False
-62 False
-62 False
-62 False
-65 False
-65 False
-65 False
-67 False
-67 False
-67 False
-69 False
-69 False
-69 False
-71 False
-71 False
-71 False
-74 False
-76 False
-76 False
-77 Fals