-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rework starting, easier now as well as add rest endpoints due to an T…
…CP issue
- Loading branch information
1 parent
f984831
commit 9f7e103
Showing
50 changed files
with
1,045 additions
and
196 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import os | ||
import sys | ||
import requests | ||
|
||
from datetime import datetime | ||
|
||
import logging | ||
logging.basicConfig(level=logging.INFO, format='%(message)s') | ||
logger = logging.getLogger('MyLogger') | ||
|
||
class Client: | ||
def __init__(self, host, port, serverId): | ||
self.host = host | ||
self.port = port | ||
self.serverId = serverId | ||
self.url = f"http://{host}:{port}/v1.0/invoke/{serverId}" | ||
|
||
def Init(self, simId): | ||
self.simId = simId | ||
self.Create() | ||
|
||
def Create(self): | ||
msg = { "envId": self.simId } | ||
res = requests.post(f"{self.url}/method/create", json=msg) | ||
res = res.json() | ||
self.instanceId = res["instanceId"] | ||
|
||
def Reset(self): | ||
res = requests.post(f"{self.url}/method/{self.instanceId}/reset") | ||
res = res.json() | ||
return res | ||
|
||
def ActionSpaceSample(self): | ||
res = requests.post(f"{self.url}/method/{self.instanceId}/action-space-sample") | ||
res = res.json() | ||
return res | ||
|
||
def ActionSpaceInfo(self): | ||
res = requests.post(f"{self.url}/method/{self.instanceId}/action-space-info") | ||
res = res.json() | ||
return res | ||
|
||
def ObservationSpaceInfo(self): | ||
res = requests.post(f"{self.url}/method/{self.instanceId}/observation-space-info") | ||
res = res.json() | ||
return res | ||
|
||
def Step(self, action): | ||
msg = { "action": action } | ||
res = requests.post(f"{self.url}/method/{self.instanceId}/step", json=msg) | ||
res = res.json() | ||
return res | ||
|
||
def MonitorStart(self): | ||
res = requests.post(f"{self.url}/method/{self.instanceId}/monitor-start") | ||
return res | ||
|
||
def MonitorStop(self): | ||
res = requests.post(f"{self.url}/method/{self.instanceId}/monitor-stop") | ||
return res | ||
|
||
# def DebugSlow(self): | ||
# req = roadwork_messages.BaseRequest(instanceId=self.instanceId) | ||
# res = self.DaprInvoke("debug", req, roadwork_messages.BaseResponse) | ||
# return res | ||
|
||
# def DebugFast(self): | ||
# data = Any(value='ACTION 1'.encode('utf-8')) | ||
# print(f"[Client][DaprInvoke][debug] Creating Envelope {datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}") | ||
# envelope = dapr_messages.InvokeServiceEnvelope(id=self.simId, method="debug", data=data) | ||
# print(f"[Client][DaprInvoke][debug] Creating Envelope 2 {datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}") | ||
# res = self.client.InvokeService(envelope) | ||
# print(f"[Client][DaprInvoke][debug] Creating Envelope 3 {datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}") | ||
# return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
FROM python:latest | ||
|
||
WORKDIR /app | ||
|
||
# Download dependencies | ||
ADD requirements.txt . | ||
RUN pip install -r requirements.txt | ||
|
||
# Copy Source Code | ||
COPY . . | ||
|
||
# Main Entry | ||
CMD [ "python", "main.py" ] |
30 changes: 30 additions & 0 deletions
30
src-rest/Clients/python/experiments/cartpole/kubernetes.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# A job will run and automatically create & delete a pod | ||
apiVersion: batch/v1 | ||
kind: Job | ||
metadata: | ||
name: rw-client-python-cartpole | ||
spec: | ||
ttlSecondsAfterFinished: 120 # How long do we keep the pod after finishing? | ||
backoffLimit: 1 # Restart once on failure | ||
template: | ||
metadata: | ||
labels: # Labels for the POD | ||
app: rw-client-python-cartpole # This way we can filter with kubectl get pods -l app=... | ||
name: rw-client-python-cartpole # pod-roadwork-... | ||
annotations: | ||
dapr.io/enabled: "true" # Do we inject a sidecar to this deployment? | ||
dapr.io/id: "id-rw-client-python-cartpole" # Unique ID or Name for Dapr App (so we can communicate with it) | ||
dapr.io/protocol: "grpc" | ||
spec: | ||
restartPolicy: OnFailure # OnFailure since sometimes we fail to identify the GRPC Channel | ||
containers: | ||
- name: client # Name of our container, e.g. `kubectl logs -c c-rw-...` | ||
image: roadwork.io/rw-client-python-cartpole:latest | ||
imagePullPolicy: IfNotPresent # Production: Always or NotIfPresent | ||
env: | ||
- name: PYTHONUNBUFFERED | ||
value: "1" | ||
# - name: GRPC_VERBOSITY | ||
# value: "DEBUG" | ||
# - name: GRPC_TRACE | ||
# value: "api,channel,call_error,connectivity_state,http,server_channel" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# import gym | ||
import numpy as np | ||
import math | ||
from collections import deque | ||
|
||
import sys | ||
import os | ||
from Client import Client | ||
import time | ||
|
||
DAPR_HTTP_PORT = os.getenv("DAPR_HTTP_PORT", 3500) | ||
DAPR_GRPC_PORT = os.getenv("DAPR_GRPC_PORT", 50001) | ||
|
||
print(f"==================================================") | ||
print(f"DAPR_PORT_GRPC: {DAPR_GRPC_PORT}; DAPR_PORT_HTTP: {DAPR_HTTP_PORT}") | ||
print(f"==================================================") | ||
|
||
class QCartPoleSolver(): | ||
def __init__(self, buckets=(1, 1, 6, 12,), n_episodes=1000, n_win_ticks=195, min_alpha=0.1, min_epsilon=0.1, gamma=1.0, ada_divisor=25, max_env_steps=None, quiet=False, monitor=False): | ||
self.buckets = buckets # down-scaling feature space to discrete range | ||
self.n_episodes = n_episodes # training episodes | ||
self.n_win_ticks = n_win_ticks # average ticks over 100 episodes required for win | ||
self.min_alpha = min_alpha # learning rate | ||
self.min_epsilon = min_epsilon # exploration rate | ||
self.gamma = gamma # discount factor | ||
self.ada_divisor = ada_divisor # only for development purposes | ||
self.quiet = quiet | ||
|
||
self.env = Client("localhost", DAPR_HTTP_PORT, "id-rw-server-openai") | ||
self.env.Init("CartPole-v0") | ||
|
||
if max_env_steps is not None: self.env._max_episode_steps = max_env_steps | ||
actionSpaceInfo = self.env.ActionSpaceInfo() | ||
actionSpaceInfoType = actionSpaceInfo["name"] | ||
actionSPaceInfoN = actionSpaceInfo["n"] | ||
|
||
print(f"Action Space (Name= {actionSpaceInfoType}; N={actionSPaceInfoN})") | ||
self.Q = np.zeros(self.buckets + (actionSPaceInfoN,)) | ||
|
||
def discretize(self, obs): | ||
observationSpace = self.env.ObservationSpaceInfo() | ||
|
||
upper_bounds = [observationSpace["high"][0], 0.5, observationSpace["high"][2], math.radians(50)] | ||
lower_bounds = [observationSpace["low"][0], -0.5, observationSpace["low"][2], -math.radians(50)] | ||
|
||
ratios = [(obs[i] + abs(lower_bounds[i])) / (upper_bounds[i] - lower_bounds[i]) for i in range(observationSpace["shape"][0])] | ||
new_obs = [int(round((self.buckets[i] - 1) * ratios[i])) for i in range(observationSpace["shape"][0])] | ||
new_obs = [min(self.buckets[i] - 1, max(0, new_obs[i])) for i in range(observationSpace["shape"][0])] | ||
return tuple(new_obs) | ||
return | ||
|
||
def choose_action(self, state, epsilon): | ||
return self.env.ActionSpaceSample() if (np.random.random() <= epsilon) else np.argmax(self.Q[state]) | ||
|
||
def update_q(self, state_old, action, reward, state_new, alpha): | ||
self.Q[state_old][action] += alpha * (reward + self.gamma * np.max(self.Q[state_new]) - self.Q[state_old][action]) | ||
|
||
def get_epsilon(self, t): | ||
return max(self.min_epsilon, min(1, 1.0 - math.log10((t + 1) / self.ada_divisor))) | ||
|
||
def get_alpha(self, t): | ||
return max(self.min_alpha, min(1.0, 1.0 - math.log10((t + 1) / self.ada_divisor))) | ||
|
||
def run(self): | ||
scores = deque(maxlen=100) | ||
self.env.MonitorStart() | ||
|
||
for e in range(self.n_episodes): | ||
current_state = self.discretize(self.env.Reset()) | ||
|
||
alpha = self.get_alpha(e) | ||
epsilon = self.get_epsilon(e) | ||
done = False | ||
i = 0 | ||
|
||
while not done: | ||
# self.env.Render() | ||
action = self.choose_action(current_state, epsilon) | ||
|
||
stepResponse = self.env.Step(action) | ||
|
||
obs = stepResponse["obs"] | ||
reward = stepResponse["reward"] | ||
done = stepResponse["isDone"] | ||
|
||
# obs, reward, done, _ = self.env.Step(action) | ||
|
||
new_state = self.discretize(obs) | ||
self.update_q(current_state, action, reward, new_state, alpha) | ||
current_state = new_state | ||
i += 1 | ||
|
||
scores.append(i) | ||
mean_score = np.mean(scores) | ||
|
||
if mean_score >= self.n_win_ticks and e >= 100: | ||
if not self.quiet: print('Ran {} episodes. Solved after {} trials ✔'.format(e, e - 100)) | ||
return e - 100 | ||
|
||
if e % 10 == 0 and not self.quiet: | ||
print('[Episode {}] - Mean survival time over last 100 episodes was {} ticks.'.format(e, mean_score)) | ||
|
||
if not self.quiet: print('Did not solve after {} episodes 😞'.format(e)) | ||
|
||
self.env.MonitorStop() | ||
|
||
return e | ||
|
||
if __name__ == "__main__": | ||
solver = QCartPoleSolver() | ||
solver.run() | ||
# gym.upload('tmp/cartpole-1', api_key='') |
5 changes: 5 additions & 0 deletions
5
src-rest/Clients/python/experiments/cartpole/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
grpcio==1.25.0 | ||
grpcio-tools==1.25.0 | ||
protobuf==3.10.0 | ||
requests==2.23.0 | ||
numpy==1.18.4 |
Oops, something went wrong.