In [1]:
import ray
import time
import numpy as np

ray.init() # Only call this once.

2023-08-05 17:13:10,552	INFO worker.py:1627 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


0,1
Python version:,3.8.10
Ray version:,2.5.1
Dashboard:,http://127.0.0.1:8265


In [2]:
@ray.remote
class ReplayBuffer(object):
    do_train_step = False
    l = []
    
    def __init__(self):
        pass

    def get_do_train_step(self):
        return do_train_step
    
    def add_samples(self, samples):
        print('add_samples')
        self.l = self.l + samples

    def get_samples(self):
        self.do_train_step = True
        
        print('get_samples')
        indices = np.random.randint(0, len(self.l), 32)
        samples = [self.l[idx] for idx in indices]
        print(samples)
        
        return indices, samples

    def update_samples(self, indices, samples):
        print('update_samples')
        for idx, sample in zip(indices, samples):
            self.l[idx] = sample

        self.do_train_step = False

    def __len__(self):
        return len(self.l)

@ray.remote
class Self_Play(object):
    def __init__(self, replay_buffer):
        self.replay_buffer = replay_buffer

    def do_self_play(self):
        for i in range(20):
            time.sleep(np.random.choice([0.25, 0.5, 0.75, 1]))
            samples = np.random.randint(0, 1e8, 4).tolist()
            while ray.get(self.replay_buffer.get_do_train_step.remote()):
                time.sleep(0.01)
            
            self.replay_buffer.add_samples.remote(samples)

@ray.remote
class Trainer(object):
    def __init__(self, replay_buffer):
        self.replay_buffer = replay_buffer

    def train(self):
        for i in range(20):
            time.sleep(np.random.choice([0.25, 0.5, 0.75, 1]))
            indices, samples = ray.get(self.replay_buffer.get_samples.remote())
            time.sleep(0.1)
            for i in range(len(samples)):
                samples[i]=-1
            self.replay_buffer.update_samples.remote(indices, samples)

replay_buffer = ReplayBuffer.remote()
trainer = Trainer.remote(replay_buffer)

self_plays = [Self_Play.remote(replay_buffer) for _ in range(2)]
for x in self_plays:
    x.do_self_play.remote()
    
time.sleep(2)
trainer.train.remote()

for _ in range(10):
    time.sleep(1)
    print(ray.get(replay_buffer.__len__.remote()))

[2m[36m(ReplayBuffer pid=944133)[0m add_samples
[2m[36m(ReplayBuffer pid=944133)[0m add_samples
[2m[36m(ReplayBuffer pid=944133)[0m add_samples
[2m[36m(ReplayBuffer pid=944133)[0m add_samples
[2m[36m(ReplayBuffer pid=944133)[0m get_samples
[2m[36m(ReplayBuffer pid=944133)[0m [9696367, 11486194, 1064619, 11486194, 11486194, 17617759, 73285423, 49197508, 39293829, 9696367, 1064619, 11486194, 6789660, 85345263, 6050671, 84890281, 73285423, 49197508, 84890281, 73285423, 49197508, 84890281, 17617759, 31871502, 39293829, 85345263, 9696367, 31871502, 85345263, 15445187, 85345263, 39293829]
[2m[36m(ReplayBuffer pid=944133)[0m update_samples
[2m[36m(ReplayBuffer pid=944133)[0m add_samples
[2m[36m(ReplayBuffer pid=944133)[0m add_samples
[2m[36m(ReplayBuffer pid=944133)[0m get_samples
[2m[36m(ReplayBuffer pid=944133)[0m [-1, -1, -1, 11701001, 33039464, -1, 59863985, 33039464, 59863985, -1, -1, -1, -1, -1, 51597109, -1, -1, -1, 25376950, -1, 54031411, -1, 25376950,

In [3]:
## ray.shutdown()

[2m[36m(ReplayBuffer pid=944133)[0m add_samples
[2m[36m(ReplayBuffer pid=944133)[0m get_samples
[2m[36m(ReplayBuffer pid=944133)[0m [80046589, -1, 3269778, 37767792, -1, -1, 43482455, -1, -1, 44866545, 97039299, 44735992, -1, -1, -1, -1, 43482455, -1, -1, -1, -1, -1, -1, -1, -1, 43482455, -1, 37767792, -1, -1, -1, -1]
[2m[36m(ReplayBuffer pid=944133)[0m add_samples
[2m[36m(ReplayBuffer pid=944133)[0m update_samples
[2m[36m(ReplayBuffer pid=944133)[0m add_samples
[2m[36m(ReplayBuffer pid=944133)[0m add_samples
[2m[36m(ReplayBuffer pid=944133)[0m add_samples
[2m[36m(ReplayBuffer pid=944133)[0m get_samples
[2m[36m(ReplayBuffer pid=944133)[0m [-1, -1, -1, -1, 12085194, 59920966, -1, -1, -1, -1, 65627367, 59920966, -1, -1, -1, -1, 40960733, 80163132, -1, -1, -1, 71640239, -1, -1, -1, 76663627, -1, -1, 86937213, 91669807, 59920966, -1]
[2m[36m(ReplayBuffer pid=944133)[0m update_samples
[2m[36m(ReplayBuffer pid=944133)[0m add_samples
[2m[36m(ReplayBuffer 