In [18]:
!apt-get -y install libglu1-mesa-dev
!pip install stable-baselines tensorflow-cpu

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following additional packages will be installed:
  libbsd0 libdrm-amdgpu1 libdrm-common libdrm-intel1 libdrm-nouveau2
  libdrm-radeon1 libdrm2 libedit2 libegl-dev libegl-mesa0 libegl1 libelf1
  libgbm1 libgl-dev libgl1 libgl1-mesa-dev libgl1-mesa-dri libglapi-mesa
  libgles-dev libgles1 libgles2 libglu1-mesa libglvnd-dev libglvnd0 libglx-dev
  libglx-mesa0 libglx0 libllvm12 libopengl-dev libopengl0 libpciaccess0
  libpthread-stubs0-dev libsensors-config libsensors5 libvulkan1
  libwayland-client0 libwayland-server0 libx11-6 libx11-data libx11-dev
  libx11-xcb1 libxau-dev libxau6 libxcb-dri2-0 libxcb-dri3-0 libxcb-glx0
  libxcb-present0 libxcb-randr0 libxcb-shm0 libxcb-sync1 libxcb-xfixes0
  libxcb1 libxcb1-dev libxdmcp-dev libxdmcp6 libxext6 libxfixes3 libxshmfence1
  libxxf86vm1 mesa-vulkan-drivers x11proto-core-dev x11proto-dev
  xorg-sgml-doctools xtrans-dev
Suggested packages:
  pciu

In [21]:
import gym

import numpy as np

from gym import spaces

class GoniometerEnvironment(gym.Env):
    '''
    An environment for goniometer sample rotation
    '''
    
    metadata = {'render.modes': ['human']}
    
    def __init__(self):
        super(CustomEnv, self).__init__()

        # Six actions: up/down along x/y/z axis
        self.action_space = spaces.Discrete(6)
        
        # 50 x 50 grayscale image is produced
        self.observation_space = spaces.Box(low=0, high=255, shape=(50, 50), dtype=np.uint8)

        self.current_angle = [0, 0, 0]
    
    def step(self, action):

        # Move the angle in the chosen direction
        if action == 0:
            self.current_angle[0] -= 5
            if self.current_angle[0] < 0:
                self.current_angle[0] = 360 + self.current_angle[0]
        elif action == 1:
            self.current_angle[0] += 5
            if self.current_angle[0] > 360:
                self.current_angle[0] = self.current_angle[0] - 360
        if action == 2:
            self.current_angle[1] -= 5
            if self.current_angle[1] < 0:
                self.current_angle[1] = 360 + self.current_angle[1]
        elif action == 3:
            self.current_angle[1] += 5
            if self.current_angle[1] > 360:
                self.current_angle[1] = self.current_angle[1] - 360
        if action == 4:
            self.current_angle[2] -= 5
            if self.current_angle[2] < 0:
                self.current_angle[2] = 360 + self.current_angle[2]
        elif action == 5:
            self.current_angle[2] += 5
            if self.current_angle[2] > 360:
                self.current_angle[2] = self.current_angle[2] - 360

        # produce the current image
        image = self.make_image()

        # If withine step distance of the true value, halt
        if abs(self._peak_location[0] - self.current_angle[0]) < 5 and abs(self._peak_location[1] - self.current_angle[1]) < 5 and abs(self._peak_location[2] - self.current_angle[2]) < 5:
            found = True
        else:
            found = False

        # return the detector image, the reward, whether to halt, and additional parameters
        return image, np.mean(image), found, {}
        
    def reset(self):

        # The peak is at a random location
        self._peak_location = [numpy.random.randint(0,365), numpy.random.randint(0,365), numpy.random.randint(0,365)]
        return self.make_image()  
        
    def render(self, mode='human'):
        pass
        
    def close (self):
        pass

    def make_image(self):

        # 50 x 50 image
        image = np.zeros(50,50)

        # Determine whether peak is right/left of current
        if self._peak_location[0] < self.current_angle[0]:
            if self._peak_location[0] < self.current_angle[0] - 182.5:
                right_higher = True
            else:
                right_higher = False
        else:
            if self.current_angle[0] < self._peak_location[0] - 182.5:
                right_higher = False
            else:
                right_higher = True

        # Determine whether peak is above/below current
        if self._peak_location[1] < self.current_angle[1]:
            if self._peak_location[1] < self.current_angle[1] - 182.5:
                top_higher = True
            else:
                top_higher = False
        else:
            if self.current_angle[1] < self._peak_location[1] - 182.5:
                top_higher = False
            else:
                top_higher = True

        # Measure how far off the peak is along Z
        z_distance = min((self._peak_location[2] - self.current_angle[2]) % 365, (self.current_angle[2] - self._peak_location[2]) % 365)

        # Calculate each pixel value
        for i in range(50):
            for j in range(50):

                # The image is random noise, but brighter the close Z is
                image[i][j] = np.random.uniform(0,50) + (z_distance / 182.5)

                # Make the image brighter on the left/right and top/bottom side depending, pointing towards the peak
                if i < 25:
                    if top_higher:
                        image[i][j] += 50
                else:
                    if not top_higher:
                        image[i][j] += 50

                if j < 25:
                    if not right_higher:
                        image[i][j] += 50
                else:
                    if right_higher:
                        image[i][j] += 50

        return image
            

In [20]:
from stable_baselines import ACER

# Create a Goniometer
instrument = GoniometerEnvironment()

# Train an Actor-Critic Experience Replay model
model = ACER('CnnPolicy', instrument).learn(total_timesteps=100)

ImportError: libGL.so.1: cannot open shared object file: No such file or directory

In [None]:
# Set the goniometer position
model.current_angle = [180, 180, 180]

# Create an image
image = np.ones(50, 50)

# Get the model's decision on which direction to turn
model.predict(image)