# Reinforcement Learning Agent and Environment

## Install and define dependencies

In [1]:
import gymnasium as gym
from gymnasium import Env
from gymnasium.spaces import Box
import numpy as np
import random
import os
from PIL import Image
from PIL import ImageEnhance

import torch
import sys
sys.modules["gym"] = gym
from stable_baselines3 import SAC

In [3]:
# Path to training images
train_image_path = "../../GTSDB/images/distorted"

## Build Environment

Action: The Agent is able to use continous values to change each image parameter (Sharpness, Brightness, Contrast, Color)

In [29]:
class DistortionEnv(Env):
    
    # Loads the image, resizes it to 128x128 and converts to 3 color channels
    @staticmethod
    def _load_and_convert_image(image_path):
        image = Image.open(image_path)
        image = image.resize((128, 128))
        if image.mode == "RGBA":
            image = image.convert("RGB")
        return np.asarray(image)

    @staticmethod
    def _calculate_mse(image1, image2):
   
        # Resize the images if necessary (to ensure they have the same dimensions)
        image1 = image1.resize((128, 128))
        image2 = image2.resize((128, 128))
    
        # Convert the images to grayscale
        image1 = image1.convert("L")
        image2 = image2.convert("L")
    
        # Convert the images to numpy arrays
        arr1 = np.array(image1)
        arr2 = np.array(image2)
    
        # Calculate the MSE
        mse = np.mean((arr1 - arr2) ** 2)
        return mse

    
    def __init__(self):
        # Load YOLOv5 model
        model = torch.hub.load('ultralytics/yolov5', 'custom', path='./models/YOLOv5_best_1000ep.pt')
        
        # Define the action space bounds for sharpness, contrast, brightness and color
        sharpness_bounds = (0.0, 2.0)
        contrast_bounds = (0.0, 2.0)
        brightness_bounds = (0.0, 2.0)
        color_bounds = (0.0, 2.0)
        num_actions = 4

        # Use when multiple actions are possible
        #self.action_space = Box(low = np.array([sharpness_bounds[0], contrast_bounds[0], brightness_bounds[0], color_bounds[0]]),
        #                       high = np.array([sharpness_bounds[1], contrast_bounds[1], brightness_bounds[1], color_bounds[1]]),
        #                       shape = (num_actions,), 
        #                       dtype = float)

        # TODO: Simplified to only learn adjusting the brightness. Has to be changed later
        self.action_space = Box(low = brightness_bounds[0],
                               high = brightness_bounds[1],
                               shape = (1,), 
                               dtype = 'float32')

        # Define the observation space for an image
        image_shape = (128, 128, 3)  # (height, width, channels)
        image_dtype = np.uint8 

        # Load training images
        self.train_images = os.listdir(train_image_path)

        self.observation_space = Box(low = 0, high = 255, shape = image_shape, dtype = np.uint8)

    def reset(self):
        # TODO: Set duration? e.g. 10 consecutive actions possible, maybe should start with only 1
        self.remaining_actions = 1

        # Choose random image for episode
        self.image_name = random.choice(self.train_images)
        image_path = os.path.join(train_image_path, self.image_name)
        self.state = self._load_and_convert_image(image_path)
        
        return self.state, {}
        
    def step(self, action):
        self.remaining_actions -= 1
        
        # change image parameters according to action 
        distortion_factor = action
        enhancer = ImageEnhance.Brightness(Image.fromarray(self.state))
        enhanced_image = enhancer.enhance(distortion_factor)
        self.state = np.asarray(enhanced_image)
        
        # TODO: Reward calculation other than MSE! only for test purpose
        # Later according to YOLOv5 network results
        original_image = Image.open(os.path.join("../../GTSDB/images", self.image_name))
        distorted_image = Image.open(os.path.join("../../GTSDB/images/distorted", self.image_name))
       
        mse_distorted = self._calculate_mse(original_image, distorted_image)
        mse_enhanced  = self._calculate_mse(original_image, enhanced_image)
        
        reward = 1 if (mse_enhanced < mse_distorted) else -1

        
        observation = self.state
        done = True if (self.remaining_actions <= 0) else False
        info = {} # Placeholder
        
        return observation, reward, done, False, info 
        
    def render(self):
        pass

In [30]:
env = DistortionEnv()

In [31]:
episodes = 10
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0

    while not done:
        action = env.action_space.sample()
        n_state, reward, done, _, info = env.step(action)
        score += reward
    print("Episode:{} Score:{}".format(episode, score))

Episode:1 Score:-1
Episode:2 Score:-1
Episode:3 Score:1
Episode:4 Score:1
Episode:5 Score:1
Episode:6 Score:-1
Episode:7 Score:1
Episode:8 Score:-1
Episode:9 Score:-1
Episode:10 Score:-1


In [32]:
model = SAC("MlpPolicy", env, verbose=1, buffer_size=10000)
model.learn(total_timesteps=10000, log_interval=4)
model.save("sac_pendulum")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env in a VecTransposeImage.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -0.5     |
| time/              |          |
|    episodes        | 4        |
|    fps             | 6        |
|    time_elapsed    | 0        |
|    total_timesteps | 4        |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | 0        |
| time/              |          |
|    episodes        | 8        |
|    fps             | 6        |
|    time_elapsed    | 1        |
|    total_timesteps | 8        |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | 0.167    |
| time/              |          |
|    episodes  

KeyboardInterrupt: 