In [1]:
# Importing all the libraries
import pygame as pg
import sys
import threading
import numpy as np
import rect
import copy
import random



pygame 2.5.2 (SDL 2.28.3, Python 3.12.1)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
np.random.seed(42)  # For reproducible results


# Variables for simulation window
# Variables for the width and height of the simulation window
WIDTH = 1100
HEIGHT = 700

# Variable for the width and the height of the simulation inside the main window
# Must be <= HEIGHT and WIDTH
GRID_WIDTH = WIDTH - 400
GRID_HEIGHT = HEIGHT
FPS = 100

# Initialising pygame
pg.init()
pg.display.init()
pg.display.set_caption("DeepLearningWeek")
screen = pg.display.set_mode((WIDTH, HEIGHT))
clock = pg.time.Clock()
pg.font.init()

font = pg.font.SysFont('Arial', 20)

white = (255,255,255)

In [3]:
# A class for storing the Qvlaues for each grid cell
class GridCell:
    def __init__(self):
        self.qvals = [0, 0, 0, 0, 0]

In [4]:
# A class to create the ocean environment for the simulation

class OceanEnvironment:
    def __init__(self, row, col, no_row, no_col, mode):
        self.current_row = row
        self.current_col = col

        self.max_rows = no_row - 1
        self.max_cols = no_col - 1

        self.color = 0

        if mode == 0:
            self.fish_population = self.gradient_fish_generator()
        else:
            self.fish_population = self.random_fish_generator()
        self.environment_val = 0

        self.population_history = []

    def gradient_fish_generator(self):
        dist_from_shore_y = abs(self.max_rows - self.current_row)
        dist_from_shore_x = abs(self.max_cols - self.current_col)

        diag_dist = pow(pow(dist_from_shore_x, 2) + pow(dist_from_shore_y, 2), 0.5)
        max_dist = pow(pow(self.max_rows, 2) + pow(self.max_cols, 2), 0.5)
        fish_pop = int((255 / max_dist) * diag_dist)
        if fish_pop == 0:
            return 1
        return fish_pop


    def random_fish_generator(self):
        return np.random.randint(5, 100)

In [5]:
# A class to create and handle the actor for Qlearning
class Boat:
    def __init__(self, grid):
        # position is the grid coordinates of the boat
        self.reset_pos = (len(grid) // 2, len(grid) - 1)

        self.pos = list(self.reset_pos)

        self.fuel_used = 0
        self.grid = grid

    def move_up(self):
        if self.pos[1] > 0:
            self.pos[1] -= 1
            self.render()
            return True
        return False

    def move_down(self):
        if self.pos[1] < len(self.grid) - 1:
            self.pos[1] += 1
            self.render()
            return True
        return False

    def move_left(self):
        if self.pos[0] > 0:
            self.pos[0] -= 1
            self.render()
            return True
        return False

    def move_right(self):
        if self.pos[0] < len(self.grid) - 1:
            self.pos[0] += 1
            self.render()
            return True
        return False

    def fish(self):
        decline = 10
        self.grid[self.pos[0]][self.pos[1]].fish_population -= decline

    def render(self):
        cell_width = GRID_WIDTH / len(self.grid[0])
        cell_height = GRID_HEIGHT / len(self.grid)
        rect = (cell_width * self.pos[0], cell_height * self.pos[1], cell_width, cell_height)
        pg.draw.rect(screen, white, rect)
        pg.display.update(pg.Rect(rect))

# The Q-learning Algorithm

Action Space

0 -> move up<br>
1 -> move down<br>
2 -> move left<br>
3 -> move right<br>
4 -> fish

In [6]:


# Epsilon greedy policy for choosing the action for Q-learning
def epsilon_greedy_policy(Qtable, state, epsilon):
    random_int = random.uniform(0, 1)
    if random_int > epsilon:
        action = Qtable[state[0]][state[1]].qvals.index(max(Qtable[state[0]][state[1]].qvals))
    else:
        action = random.randint(0, 4)
    return action

In [7]:
# A function to return the rewards for each action taken

def take_step(boat, action, environment_grid, avg_population):
    population_d = 70
    if action == 0:
        if boat.move_up():
            return -2
        return -100
    elif action == 1:
        if boat.move_down():
            return -1.5
        return -100
    elif action == 2:
        if boat.move_left():
            return -1
        return -100
    elif action == 3:
        if boat.move_right():
            return -1
        return -100
    elif action == 4:
        fish_population = environment_grid[boat.pos[1]][boat.pos[0]].fish_population
        if fish_population < avg_population:
            return -1*fish_population/(population_d*255)
        return (fish_population/population_d)/255

In [8]:
# The actual Q learninng algorithm

learning_rate = 0.5
gamma = 0.95


def train(training_episodes, decay_rate, max_steps, Qtable, environment_grid, boat, avg_population):

    max_epsilon = 1.0
    min_epsilon = 0.05

    for episode in range(training_episodes):

        epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate * episode)
        # Reset the environment
        state = list(boat.reset_pos)
        copy_env_grid = copy.deepcopy(list(environment_grid))
        boat.pos = state

        # repeat
        for step in range(max_steps):

            action = epsilon_greedy_policy(Qtable, state, epsilon)
            reward = take_step(boat, action, copy_env_grid, avg_population)
            new_state = boat.pos

            state_val = Qtable[state[1]][state[0]].qvals
            Qtable[state[1]][state[0]].qvals[action] = state_val[action] + learning_rate * (
                        reward + gamma * max(Qtable[new_state[1]][new_state[0]].qvals) - state_val[action])

            if action == 4:
                copy_env_grid[state[1]][state[0]].fish_population -= 100

            # Our state is the new state
            state = new_state


In [9]:
# Function for generating the Q-table
def create_qtable(rows, columns):
    return [[GridCell() for j in range(columns)] for i in range(rows)]

In [10]:
# Function for generatig the environment grid
def create_env(rows, columns, mode):
    return [[OceanEnvironment(i, j, rows, columns, mode) for j in range(columns)] for i in range(rows)]

In [11]:
# Function to get the average fish population - used later for calculating rewards
def avg_fish_population(envGrid):
    n = len(envGrid) * len(envGrid[0])
    pop = 0
    for row in envGrid:
        for cell in row:
            pop += cell.fish_population
    return int(pop / n)

In [12]:
# FUnction to move the boat based on the action chosen by Q-learning algorithm
def move_boat(boat, action):
    if action == 0:
        boat.move_up()
    elif action == 1:
        boat.move_down()
    elif action == 2:
        boat.move_left()
    elif action == 3:
        boat.move_right()
    elif action == 4:
        boat.fish()

In [13]:
# Functions to increase the fish population after each month

def logistic_growth(initial_population):
    max_population = 255  # Carrying capacity -> the max limit of fish per cell
    inst_natural_mortality_rate = 0.99
    fertility = 1.43
    survival_rate = 1
    offsprings_produced = ((np.exp(-inst_natural_mortality_rate)+fertility)/(10-survival_rate))*initial_population
    if offsprings_produced+initial_population >= max_population:
        return max_population*(1-survival_rate/10)
    return offsprings_produced+(initial_population*(1-survival_rate/10))

def update_population(env_grid):
    for row in env_grid:
        for cell in row:
            if cell.fish_population>0:
                cell.fish_population = logistic_growth(cell.fish_population)
            cell.color = 0
    return env_grid

In [14]:
# Setting up the parameters for Q-learning
no_rows = 25
no_cols = 25

# 0 for gradient lke fish distribution
# 1 for randomg fish distribution
mode = 0

Qtable = create_qtable(no_rows, no_cols)
environment_grid = create_env(no_rows, no_cols, mode)

# Training parameters
training_episodes = 100

# Environment parameters
max_steps = 110

# Exploration parameters
decay_rate = 0.0005

In [15]:
# function to traing the Q-learning model
def train_model(boat_actor, env_grid):
    avg_population = avg_fish_population(environment_grid)
    thread = threading.Thread(target=train,
                              args=(training_episodes, decay_rate, max_steps, Qtable, tuple(env_grid), boat_actor,
                                    avg_population))
    thread.start()
    return thread

In [16]:
# Function to choose the best fishing location based on the final Q-table
def update_environment(Qtable, env_grid, color):
    sm_qvals = 0
    n = 0
    for row in Qtable:
        for cell in row:
            qv = cell.qvals
            ind = qv.index(max(qv))
            if ind == 4:
                val = qv[4]
                sm_qvals += qv[4]
                n += 1

    avg = sm_qvals / n

    cnt = 9
    for row_no, row in enumerate(env_grid):
        for cell_no, cell in enumerate(row):
            qt_cell = Qtable[row_no][cell_no].qvals
            ind = qt_cell.index(max(qt_cell))
            if ind == 4 and qt_cell[ind] > avg and env_grid[row_no][cell_no].color == 0:
                cell.fish_population = 10
                cell.color = color
                cnt += 1
            if cnt > 9:
                return

# Code for UI -> using pygame

In [17]:
# Rendering the grid on screen
def render_grid(grid):
    x_pos = 0
    y_pos = 0
    darker_blue = (50, 76, 168)
    no_rows = len(grid)
    no_cols = len(grid[0])
    cell_width = GRID_WIDTH / no_cols
    cell_height = GRID_HEIGHT / no_rows

    for i in range(no_rows):
        for j in range(no_cols):
            rect = cell_width * j, cell_height * i, cell_width, cell_height
            if grid[i][j].color != 0:
                pg.draw.rect(screen, grid[i][j].color, rect)
            elif (abs(grid[i][j].fish_population - 255)) > 255:
                pg.draw.rect(screen, (0, 0, 255), rect)
            else:
                pg.draw.rect(screen, (0, 0, abs(grid[i][j].fish_population - 255)), rect)

    for row in grid:
        pg.draw.rect(screen, darker_blue, rect=(0, y_pos, GRID_WIDTH, 1))
        y_pos += cell_height

    for column in grid[0]:
        pg.draw.rect(screen, darker_blue, (x_pos, 0, 1, GRID_HEIGHT))
        x_pos += cell_width

    pg.draw.rect(screen, darker_blue, rect=(0, y_pos, GRID_WIDTH, 1))
    pg.draw.rect(screen, darker_blue, (x_pos, 0, 1, GRID_HEIGHT))

    pg.display.update(pg.Rect(0, 0, GRID_WIDTH + 1, GRID_HEIGHT + 1))

# To render the gradient for choosing the colour of eac boat
def render_gradient(top_x, top_y):
    for row in range(255):
        for col in range(255):
            pg.draw.rect(screen, (col, row, 0), (top_x + col, top_y + row, 1, 1))


def render_boat_color(boat_ls):
    top_x, top_y = GRID_WIDTH + 20, 365
    for ind, boat_color in enumerate(boat_ls):
        pg.draw.rect(screen, boat_color, (top_x, top_y, 25, 25))
        top_y += 35
        if ind == 4:
            top_x += 150
            top_y = 365


color_picker = rect.Rect(screen, GRID_WIDTH + 20, 20, 255, 255)
render_gradient(GRID_WIDTH + 20, 20)

p_color_final = (100, 100, 100)
final_color = False
boat_ls = []

no_people = 10
no_assigned = 0
cnt = 0
running = True
confirm_press = False
boats_confirmed = False

next_button = rect.Rect(screen, GRID_WIDTH + 20, HEIGHT - 70, WIDTH - GRID_WIDTH - 40, 50)
next_text = "Next month"
next_text = font.render(next_text, True, (0, 0, 0))
nextRect = next_text.get_rect()
nextWidth = next_text.get_width()
nextHeight = next_text.get_height()
screen.blit(next_text, dest=((GRID_WIDTH+20)+((WIDTH-GRID_WIDTH-40)//2-nextWidth//2), ((HEIGHT-70)+(50//2)-nextHeight//2)))

add_boat = rect.Rect(screen, GRID_WIDTH + 20, 295, WIDTH - GRID_WIDTH - 40, 50)
add_text = "Add boat for simulation"
add_text = font.render(add_text, True, (0, 0, 0))
addRect = add_text.get_rect()
addWidth = add_text.get_width()
addHeight = add_text.get_height()
screen.blit(add_text, dest=((GRID_WIDTH+20)+((WIDTH-GRID_WIDTH-40)//2-addWidth//2), ((295)+(50//2)-addHeight//2)))

confirm_button = rect.Rect(screen, GRID_WIDTH + 20, HEIGHT - 140, WIDTH - GRID_WIDTH - 40, 50)
confirm_text = "Confirm and run simulation"
confirm_text = font.render(confirm_text, True, (0,0,0))
confirmWidth = confirm_text.get_width()
confirmHeight = confirm_text.get_height()
screen.blit(confirm_text, dest=((GRID_WIDTH+20)+((WIDTH-GRID_WIDTH-40)//2-confirmWidth//2), ((HEIGHT - 140)+(50//2)-confirmHeight//2)))

<rect(781, 574, 239, 23)>

In [18]:
while True:
    render_grid(environment_grid)
    mouse_pos = pg.mouse.get_pos()

    if color_picker.rect_dist(mouse_pos):
        p_color = (mouse_pos[0] - GRID_WIDTH - 20, mouse_pos[1] - 20, 0)
        if pg.mouse.get_pressed()[0]:
            p_color_final = p_color
            final_color = True
    else:
        p_color = (100, 100, 100)

    if final_color:
        p_color = p_color_final
        if add_boat.rect_dist(mouse_pos) and pg.mouse.get_pressed()[0] and len(boat_ls) < 10 and not boats_confirmed:
            boat_ls.append(p_color)
            final_color = False

    if len(boat_ls) > 0:
        render_boat_color(boat_ls)
    current_color = rect.Rect(screen, GRID_WIDTH + 295, 20, 85, 85, p_color)

    if len(boat_ls) > 0:
        if confirm_button.rect_dist(mouse_pos) and pg.mouse.get_pressed()[0]:
            if not confirm_press:
                running = False
                no_people = len(boat_ls)
                if no_people == no_assigned:
                    pass
                else:
                    no_assigned = 0
                cnt += 1
                confirm_press = True
                boats_confirmed = True
        else:
            confirm_press = False

    if not running and no_assigned < no_people:
        thread = train_model(Boat(Qtable), environment_grid)
        running = True
        no_assigned += 1
    if cnt > 0:
        if running and not thread.is_alive():
            update_environment(Qtable, environment_grid, boat_ls[no_assigned - 1])
            Qtable = create_qtable(no_rows, no_cols)
            running = False
    if not running and no_assigned>= no_people and cnt>0:
        if next_button.rect_dist(mouse_pos) and pg.mouse.get_pressed()[0]:
            no_assigned = 0
            update_population(environment_grid)

    pg.display.update()

    for event in pg.event.get():
        if event.type == pg.QUIT:
            pg.quit()
            sys.exit()

    clock.tick(FPS)  # To limit the FPS to 100
    pg.display.flip()

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
