# Własne środowisko

### Autorzy: Jakub Kot, Dawid Małecki

## Cel ćwiczenia
Stworzenie własnego środowiska do nauczania ze wzmocnieniem.

## Opis problemu
Zadanie polega na stworzeniu własnego środowiska, a następnie dodania do niego agenta, który będzie chciał zmaksymalizować funkcję nagrody w środowisku. Agent będzie dokonywał decyzji o wyborze akcji na podstawie obserwacji ze środowiska. Po zakończeniu pojedynczej gry agent będzie uczył się na podstawie swoich doświadczeń wykorzystując uczenie przez wzmacnianie.

## Rozwiązanie problemu

Zaimplementowane przez nas środowisko jest grą w czołgi. Na planszy co chwile pojawiają się czołgi przeciwnika. Agent musi na bierząco niszczyć czołgi, aby zdobywać punkty. W przypadku trafienia przez czołg przeciwnika, agent traci jedno życie. Gra kończy się w momencie, gdy agent straci wszystkie trzy życia.

Środowisko, aczkolwiek dyskretne, jest stosunkowo duże, gdyż użyta plansza ma wymiary 15x15, a dodatkowo czołgi mogą poruszać się co jedną szóstą pola w każdym kierunku oraz wystrzeliwać pociski. W wypadku kolizji dwóch pocisków, oba znikają. Kolizja dwóch czołgów nie pozwala na przejście przez siebie nawzajem.

![Game](https://i.imgur.com/YE0HrWK.png)

Czerwony czołg należy do agenta, zielone czołgi to przeciwnicy, brązowe kulki to pociski. Przeciwnicy mogą się pojawiać w kilkunastu wyznaczonych miejscach w okolicach brzegu planszy.

Obserwacja, którą dokonuje agent, dostarcza informacji na temat położenia najbliższego przeciwnika, obecnej rotacji czołgu, najbliższych przeszkód oraz nadlatujących pocisków. Na podstawie tych informacji podejmuje decyzje o ruchu i strzale.

Zastosowana przez nas metoda uczenia to Deep Q-Learning. Wykorzystana sieć neuronowa ma 2 warstwy ukryte po 512 neuronów. Funkcja straty to błąd średniokwadratowy.

Strategia przyjęta przez agenta okazała się stosunkowo prosta. Agent cierpliwie czeka na miejscu startowym (gdzie jest chroniony z dwóch stron przez ściany) i cierpliwie czeka na pojawienie się przeciwnika. Strategia ta jest stosunkowo skuteczna, jednak jeśli żaden z przeciwników nie pojawi się w zasięgu strzału przez odpowiednio długi czas, gra zostaje zakończona. Miało to na celu zmuszenie agenta do podejmowania akcji, jednak ostatecznie nie udało się nam zmusić do tego agenta.

Nadal jednak wyniki agenta są stosunkowo dobre, co pokazuje poniższy wykres (pomimo ustawienia maksymalnej prędkości toczenia się gry, każda gra jest stosunkowo długa, więc wyniki zostały zaprezentowane na niewielkiej liczbie gier).

![Results](https://i.imgur.com/24o6jOd.png)

Większość eksperymentów przynosiła najlepsze wyniki w okolicach 25 punktów.

## Wnioski:
- Stworzenie własnego środowiska do uczenia ze wzmocnieniem jest stosunkowo proste.
- Dobranie odpowiednich obserwacji oraz funkcji nagrody dla agenta jest nie tylko bardzo ważne, ale też trudne.
- Uczenie przez wzmacnianie jest stosunkowo skomplikowanym zagadnieniem, które często polega na długim procesie dobierania i dostosowywania hiperparametrów oraz cierpliwości.
- Dalsze usprawnienia mogłyby polegać na bardziej skomplikowanej strategii agenta, na przykład nauczeniu go poruszania się po planszy w poszukiwaniu przeciwników, a nie czekaniu na nich w miejscu.

# Kod źródłowy

## Kod gry

Kod do najczystszych nie należy, gdyż oryginalnym pomysłem było stworzenie dwóch agentów, którzy mieliby walczyć ze soba nawzajem. Pomysł okazał się jednak stosunkowo trudny, dlatego gra została przekształcona w grę jednoosobową. Kod źródłowy nadal jednak posiada fragmenty kodu, które odpowiadały za drugiego gracza.

In [None]:
import pygame
from enum import Enum
import random

PIXEL_SIZE = 6
BLOCK_SIZE = PIXEL_SIZE*6
BULLET_SIZE = PIXEL_SIZE*2
from PIL import Image

pygame.init()

MAP_SIZE_X = 0
MAP_SIZE_Y = 0

class Direction(Enum):
    UP = (0, -PIXEL_SIZE)
    DOWN = (0, PIXEL_SIZE)
    LEFT = (-PIXEL_SIZE, 0)
    RIGHT = (PIXEL_SIZE, 0)

class Utils:
    def init_map(map_name):
        # open txt file with map
        # read character by character, if the character is "#" then create an obstacle
        # if it is "-" then do nothing
        # if it is "1" then create a tank 1
        # if it is "2" then create a tank 2
        # if it is "3" then create a tank bot spawner
        # if it is "\n" then go to the next line
        # return the list of obstacles

        with open(map_name, 'r') as file:
            global MAP_SIZE_X
            global MAP_SIZE_Y
            MAP_SIZE_X = (len(file.readline())-1) * BLOCK_SIZE
            MAP_SIZE_Y = (sum(1 for line in file)+2) * BLOCK_SIZE
            file.seek(0)
            map = []
            tank2 = None
            spawners = []
            for y, line in enumerate(file):
                for x, char in enumerate(line):
                    if char == '#':
                        map.append(Obstacle(x * BLOCK_SIZE, y * BLOCK_SIZE))
                    elif char == '@':
                        map.append(Obstacle(x * BLOCK_SIZE, y * BLOCK_SIZE, shootable=True, color=(0, 0, 255)))
                    elif char == '*':
                        map.append(Obstacle(x * BLOCK_SIZE, y * BLOCK_SIZE, destructible=True, color=(128, 0, 128)))
                    elif char == '1':
                        tank1 = Tank(x * BLOCK_SIZE//PIXEL_SIZE, y * BLOCK_SIZE//PIXEL_SIZE, (255, 0, 0))
                    elif char == '2':
                        tank2 = Tank(x * BLOCK_SIZE//PIXEL_SIZE, y * BLOCK_SIZE//PIXEL_SIZE, (0, 255, 0))
                    elif char == '3':
                        spawners += [TankBotSpawner(x * BLOCK_SIZE, y * BLOCK_SIZE, (0, 255, 0))]
        if tank2 is None:
            return map, tank1, spawners
        return map, tank1, tank2            

    def draw_hp(screen, tank1, tank2, game_state):
        font = pygame.font.Font(None, 36)

        text_surface_tank1 = font.render(f'Tank 1 HP: {tank1.health}', True, (0, 255, 0))
        text_rect_tank1 = text_surface_tank1.get_rect()
        text_rect_tank1.bottomleft = (PIXEL_SIZE*2, MAP_SIZE_Y-PIXEL_SIZE)
        screen.blit(text_surface_tank1, text_rect_tank1)

        if tank2 is not None:
            text_surface_tank2 = font.render(f'Tank 2 HP: {tank2.health}', True, (0, 255, 0))
            text_rect_tank2 = text_surface_tank2.get_rect()
            text_rect_tank2.bottomright = (MAP_SIZE_X-PIXEL_SIZE*2, MAP_SIZE_Y-PIXEL_SIZE)
            screen.blit(text_surface_tank2, text_rect_tank2) 
        else:
            text_surface_points = font.render(f'Points: {game_state.points}', True, (0, 255, 0))
            text_rect_points = text_surface_points.get_rect()
            text_rect_points.bottomright = (MAP_SIZE_X-PIXEL_SIZE*2, MAP_SIZE_Y-PIXEL_SIZE)
            screen.blit(text_surface_points, text_rect_points)
            

    def check_collision(object1, object2):
        return object1.rect.colliderect(object2.rect)
    
    def check_nearby_objects(tank, objects):
        rect_up = pygame.Rect(tank.x, tank.y - PIXEL_SIZE, BLOCK_SIZE, BLOCK_SIZE)
        rect_down = pygame.Rect(tank.x, tank.y + PIXEL_SIZE, BLOCK_SIZE, BLOCK_SIZE)
        rect_left = pygame.Rect(tank.x - PIXEL_SIZE, tank.y, BLOCK_SIZE, BLOCK_SIZE)
        rect_right = pygame.Rect(tank.x + PIXEL_SIZE, tank.y, BLOCK_SIZE, BLOCK_SIZE)

        nearby_objects = [0, 0, 0, 0]
        for object in objects:
            if rect_up.colliderect(object.rect):
                nearby_objects[0] = 1
            if rect_down.colliderect(object.rect):
                nearby_objects[1] = 1
            if rect_left.colliderect(object.rect):
                nearby_objects[2] = 1
            if rect_right.colliderect(object.rect):
                nearby_objects[3] = 1
        return nearby_objects
    
    def find_closest_bot(game_state):
        closest_bot = None
        closest_distance = 10000
        for bot in game_state.tank_bots:
            distance = Utils.get_distance(game_state.tank1, bot)
            if distance < closest_distance:
                closest_distance = distance
                closest_bot = bot
        return closest_bot

    def check_nearby_bullets(tank, bullets):
        rect_up = pygame.Rect(tank.x, tank.y - BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE)
        rect_2_up = pygame.Rect(tank.x, tank.y - 2*BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE)
        rect_down = pygame.Rect(tank.x, tank.y + BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE)
        rect_2_down = pygame.Rect(tank.x, tank.y + 2*BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE)
        rect_left = pygame.Rect(tank.x - BLOCK_SIZE, tank.y, BLOCK_SIZE, BLOCK_SIZE)
        rect_2_left = pygame.Rect(tank.x - 2*BLOCK_SIZE, tank.y, BLOCK_SIZE, BLOCK_SIZE)
        rect_right = pygame.Rect(tank.x + BLOCK_SIZE, tank.y, BLOCK_SIZE, BLOCK_SIZE)
        rect_2_right = pygame.Rect(tank.x + 2*BLOCK_SIZE, tank.y, BLOCK_SIZE, BLOCK_SIZE)

        nearby_bullets = [10, 10, 10, 10]
        for bullet in bullets:
            if rect_up.colliderect(bullet.rect) or rect_2_up.colliderect(bullet.rect):
                distance = abs(tank.y-bullet.y)
                nearby_bullets[0] = distance if nearby_bullets[0] > distance > 0 else nearby_bullets[0]
            if rect_down.colliderect(bullet.rect) or rect_2_down.colliderect(bullet.rect):
                distance = abs(tank.y-bullet.y)
                nearby_bullets[1] = distance if nearby_bullets[1] > distance > 0 else nearby_bullets[1]
            if rect_left.colliderect(bullet.rect) or rect_2_left.colliderect(bullet.rect):
                distance = abs(tank.x-bullet.x)
                nearby_bullets[2] = distance if nearby_bullets[2] > distance > 0 else nearby_bullets[2]
            if rect_right.colliderect(bullet.rect) or rect_2_right.colliderect(bullet.rect):
                distance = abs(tank.x-bullet.x)
                nearby_bullets[3] = distance if nearby_bullets[3] > distance > 0 else nearby_bullets[3]
        for i in range(4):
            nearby_bullets[i] = -100 if nearby_bullets[i] == 10000 else nearby_bullets[i]
        return nearby_bullets

    def get_distance(object1, object2):
        return ((object1.x - object2.x)**2 + (object1.y - object2.y)**2)**0.5

class State:
    def __init__(self, tank1, init_map, tank2=None, spawners=[]):
        self.tank1 = tank1
        self.tank2 = tank2
        self.bullets = []
        self.map = init_map
        self.spawners = spawners
        self.spawn_cooldown = 60
        self.tank_bots = [self.spawn_bot_on_start()]
        self.points = 0
        self.max_enemies = 5

    def spawn_bot_on_start(self):
        spawner = random.choice(self.spawners)
        return spawner.spawn()

    def draw(self, screen, game_state):         
        screen.fill((80, 80, 80))
        self.tank1.draw(screen)
        if self.tank2 is not None:
            self.tank2.draw(screen)
        else:
            # for spawner in self.spawners:
                # spawner.draw(screen)
            for tank_bot in self.tank_bots:
                tank_bot.draw(screen)
        for obstacle in self.map:
            obstacle.draw(screen)
        
        for bullet in self.bullets:
            bullet.draw(screen)
        Utils.draw_hp(screen, self.tank1, self.tank2, game_state)

    def game_tick(self, game_state):
        points = 0
        took_damage = False
        if self.tank2 is None:
            for tank_bot in self.tank_bots:
                if tank_bot.health == 0:
                    self.tank_bots.remove(tank_bot)
                    self.points += 1
                    points += 1
                    continue
                tank_bot.move(game_state)
                tank_bot.shoot(game_state)
                tank_bot.decrease_cooldown()
        for bullet in self.bullets:
            bullet.move()
            for obstacle in self.map:
                if obstacle.shootable:
                    continue
                if Utils.check_collision(bullet, obstacle):
                    if bullet in self.bullets:
                        self.bullets.remove(bullet)
                    if obstacle.destructible:
                        obstacle.hp -= 1
                        if obstacle.hp == 0:
                            self.map.remove(obstacle)
                    break
            if Utils.check_collision(bullet, self.tank1):
                self.tank1.health -= 1
                if bullet in self.bullets:
                    self.bullets.remove(bullet)
                    took_damage = True
            if self.tank2 is not None:
                if Utils.check_collision(bullet, self.tank2):
                    self.tank2.health -= 1
                    if bullet in self.bullets:
                        self.bullets.remove(bullet)
            else:
                if not bullet.is_from_bot:
                    for tank_bot in self.tank_bots:
                        if Utils.check_collision(bullet, tank_bot):
                            tank_bot.health -= 1
                            if bullet in self.bullets:
                                self.bullets.remove(bullet)
                            break
        for bullet1 in self.bullets:
            for bullet2 in self.bullets:
                if bullet1 == bullet2:
                    continue
                if Utils.check_collision(bullet1, bullet2):
                    self.bullets.remove(bullet1)
                    self.bullets.remove(bullet2)
        self.bullets = [bullet for bullet in self.bullets if 0 <= bullet.x <= MAP_SIZE_X and 0 <= bullet.y <= MAP_SIZE_Y]
        if self.tank2 is not None:
            if self.tank1.health == 0:
                # print("Player 2 wins!")
                return True, took_damage, points
            if self.tank2.health == 0:
                # print("Player 1 wins!")
                return True, took_damage, points
            self.tank1.decrease_cooldown()
            self.tank2.decrease_cooldown()
        else:
            if self.tank1.health == 0:
                # print("You lose!")
                # print(f"Final score: {self.points}")
                return True, took_damage, points
            self.tank1.decrease_cooldown()
        
            if len(self.tank_bots) < self.max_enemies and self.spawn_cooldown == 0 or len(self.tank_bots) == 0:
                spawner = random.choice(self.spawners)
                # check if there's any tank on the spawner
                while any(Utils.check_collision(spawner, tank) for tank in self.tank_bots + [self.tank1]):
                    spawner = random.choice(self.spawners)
                self.tank_bots.append(spawner.spawn())
                self.spawn_cooldown = 120
        if self.spawn_cooldown > 0:
            self.spawn_cooldown -= 1
        return False, took_damage, points

class Tank:
    def __init__(self, x, y, color):
        # positon
        self.x = x * PIXEL_SIZE
        self.y = y * PIXEL_SIZE

        # movement
        self.move_cooldown = 0
        self.direction = (0, PIXEL_SIZE)
        
        # shooting
        self.reload_time = 0
        
        # hp
        self.health = 3

        # skin
        self.color = color

        self.rect = pygame.Rect(self.x, self.y, BLOCK_SIZE, BLOCK_SIZE)

    def move(self, direction: Direction, game_state):
        dx, dy = direction.value
        self.direction = direction.value
        new_x = min(max(0, self.x + dx), MAP_SIZE_X)
        new_y = min(max(0, self.y + dy), MAP_SIZE_Y)
        old_rect = self.rect
        self.rect = pygame.Rect(new_x, new_y, BLOCK_SIZE, BLOCK_SIZE)
        for obstacle in game_state.map + [game_state.tank1] + [game_state.tank2] + game_state.tank_bots:
            if obstacle is None or obstacle == self:
                continue
            if Utils.check_collision(self, obstacle):
                self.rect = old_rect
                return
        if self.move_cooldown > 0:
            return
        self.move_cooldown = 3
        self.x = new_x
        self.y = new_y

    def shoot(self, game_state):
        if self.reload_time > 0:
            return
        self.reload_time = 80
        # it just works, don't ask me how
        bullet_x = self.x + (BLOCK_SIZE ) // 2 - BULLET_SIZE // 2 + self.direction[0]//PIXEL_SIZE * BLOCK_SIZE
        bullet_y = self.y + (BLOCK_SIZE) // 2 - BULLET_SIZE // 2 + self.direction[1]//PIXEL_SIZE * BLOCK_SIZE
        game_state.bullets.append(Bullet(bullet_x, bullet_y, self.direction))

    def decrease_cooldown(self):
        if self.move_cooldown > 0:
            self.move_cooldown -= 1
        if self.reload_time > 0:
            self.reload_time -= 1

    def draw(self, screen):
        if self.color == (255, 0, 0):
            i = 1
        elif self.color == (0, 255, 0):
            i = 2
        else:
            i = 3
        match self.direction:
            case Direction.LEFT.value:
                img = pygame.image.load(f"lab4/assets/tank{i}_90.png")
            case Direction.RIGHT.value:
                img = pygame.image.load(f"lab4/assets/tank{i}_270.png")
            case Direction.UP.value:
                img = pygame.image.load(f"lab4/assets/tank{i}_0.png")
            case Direction.DOWN.value:
                img = pygame.image.load(f"lab4/assets/tank{i}_180.png")
        screen.blit(img, (self.x, self.y))

class TankBot(Tank):
    def __init__(self, x, y, color):
        super().__init__(x//PIXEL_SIZE, y//PIXEL_SIZE, color)
        self.turn_cooldown = 25
        self.health = 1

    def move(self, game_state):
        if self.move_cooldown > 0:
            return
        if self.turn_cooldown == 0:
            direction = random.choice(list(Direction)).value
            self.turn_cooldown = 25
        else:
            direction = self.direction
        self.move_cooldown = 5
        self.direction = direction
        dx, dy = self.direction
        new_x = min(max(0, self.x + dx), MAP_SIZE_X)
        new_y = min(max(0, self.y + dy), MAP_SIZE_Y)
        old_rect = self.rect
        self.rect = pygame.Rect(new_x, new_y, BLOCK_SIZE, BLOCK_SIZE)
        for obstacle in game_state.map + [game_state.tank1] + [game_state.tank2] + game_state.tank_bots:
            if obstacle is None or obstacle == self:
                continue
            if Utils.check_collision(self, obstacle):
                self.rect = old_rect
                return
        self.x = new_x
        self.y = new_y

    def shoot(self, game_state):
        if self.reload_time > 0:
            return
        self.reload_time = 180 
        # it just works, don't ask me how
        bullet_x = self.x + (BLOCK_SIZE ) // 2 - BULLET_SIZE // 2 + self.direction[0]//PIXEL_SIZE * BLOCK_SIZE
        bullet_y = self.y + (BLOCK_SIZE) // 2 - BULLET_SIZE // 2 + self.direction[1]//PIXEL_SIZE * BLOCK_SIZE
        game_state.bullets.append(Bullet(bullet_x, bullet_y, self.direction, is_from_bot=True, velocity=0.2))

    def decrease_cooldown(self):
        if self.move_cooldown > 0:
            self.move_cooldown -= 1
        if self.reload_time > 0:
            self.reload_time -= 1
        if self.turn_cooldown > 0:
            self.turn_cooldown -= 1

class TankBotSpawner:
    def __init__(self, x, y, color):
        self.x = x
        self.y = y
        self.color = color
        self.rect = pygame.Rect(self.x, self.y, BLOCK_SIZE, BLOCK_SIZE)

    def spawn(self):
        return TankBot(self.x, self.y, self.color)        
    
    def draw(self, screen):
        pygame.draw.rect(screen, self.color, (self.x, self.y, BLOCK_SIZE, BLOCK_SIZE))

class Bullet:
    def __init__(self, start_x, start_y, direction, is_from_bot=False, velocity=1):
        self.x = start_x
        self.y = start_y
        self.direction = direction
        self.velocity = velocity
        self.is_from_bot = is_from_bot
        self.rect = pygame.Rect(self.x, self.y, BULLET_SIZE, BULLET_SIZE)
        
    def draw(self, screen):
        img = pygame.image.load("lab4/assets/bullet.png")
        screen.blit(img, (self.x, self.y))
        
    def move(self):
        dx, dy = self.direction
        self.x += dx * self.velocity
        self.y += dy * self.velocity
        self.rect = pygame.Rect(self.x, self.y, BULLET_SIZE, BULLET_SIZE)

class Obstacle:
    def __init__(self, x, y, shootable=False, destructible=False, color=(0, 0, 0)):
        self.x = x
        self.y = y
        self.shootable = shootable
        self.destructible = destructible
        self.hp = 3
        self.color = color
        self.rect = pygame.Rect(self.x, self.y, BLOCK_SIZE, BLOCK_SIZE)

    def draw(self, screen):
        if self.destructible:
            self.color = (255, 64*self.hp, 0)
        pygame.draw.rect(screen, self.color, (self.x, self.y, BLOCK_SIZE, BLOCK_SIZE))

def handle_key_presses(game_state: State):
    keys = pygame.key.get_pressed()
    if keys[pygame.K_w]:
        game_state.tank1.move(Direction.UP, game_state)
    elif keys[pygame.K_s]:
        game_state.tank1.move(Direction.DOWN, game_state)
    elif keys[pygame.K_a]:
        game_state.tank1.move(Direction.LEFT, game_state)
    elif keys[pygame.K_d]:
        game_state.tank1.move(Direction.RIGHT, game_state)
    if keys[pygame.K_SPACE]:
        game_state.tank1.shoot(game_state)
        closest_bot = Utils.find_closest_bot(game_state)
        if game_state.tank1.direction == Direction.UP.value and closest_bot.y > game_state.tank1.y and abs(closest_bot.x - game_state.tank1.x) < 25 \
            or game_state.tank1.direction == Direction.DOWN.value and closest_bot.y < game_state.tank1.y and abs(closest_bot.x - game_state.tank1.x) < 25 \
            or game_state.tank1.direction == Direction.LEFT.value and closest_bot.x < game_state.tank1.x and abs(closest_bot.y - game_state.tank1.y) < 25 \
            or game_state.tank1.direction == Direction.RIGHT.value and closest_bot.x > game_state.tank1.x and abs(closest_bot.y - game_state.tank1.y) < 25:
            return 3
    return 0

def handle_key_presses_AI(game_state: State, action):
    if action[0] == 1:
        game_state.tank1.move(Direction.UP, game_state)
    if action[1] == 1:
        game_state.tank1.move(Direction.DOWN, game_state)
    if action[2] == 1:
        game_state.tank1.move(Direction.LEFT, game_state)
    if action[3] == 1:
        game_state.tank1.move(Direction.RIGHT, game_state)
    if action[4] == 1:
        game_state.tank1.shoot(game_state)
        closest_bot = Utils.find_closest_bot(game_state)
        if game_state.tank1.direction == Direction.UP.value and closest_bot.y > game_state.tank1.y and abs(closest_bot.x - game_state.tank1.x) < 25 \
            or game_state.tank1.direction == Direction.DOWN.value and closest_bot.y < game_state.tank1.y and abs(closest_bot.x - game_state.tank1.x) < 25 \
            or game_state.tank1.direction == Direction.LEFT.value and closest_bot.x < game_state.tank1.x and abs(closest_bot.y - game_state.tank1.y) < 25 \
            or game_state.tank1.direction == Direction.RIGHT.value and closest_bot.x > game_state.tank1.x and abs(closest_bot.y - game_state.tank1.y) < 25:
            return 3
    return 0

class TankGame:
    def __init__(self, map_name, FPS=60):
        self.map_name = map_name
        init_map, tank1, spawners = Utils.init_map(map_name)
        self.game_state = State(tank1, init_map, spawners=spawners)
        self.screen = pygame.display.set_mode([MAP_SIZE_X, MAP_SIZE_Y])
        self.clock = pygame.time.Clock()
        self.FPS = FPS
        self.init_FPS = FPS
        self.frame_iteration = 0
        self.cooldown = 0
        self.human_mode = True
        self.good_shoots = 0

    def reset(self):
        self.good_shoots = 0
        init_map, tank1, spawners = Utils.init_map(self.map_name)
        self.game_state = State(tank1, init_map, spawners=spawners)
        self.frame_iteration = 0
        self.game_state.tank_bots = [self.game_state.spawn_bot_on_start()]

    def play_step(self, action):
        self.clock.tick(self.FPS)
        self.frame_iteration += 1
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                return False
        keys = pygame.key.get_pressed()
        if keys[pygame.K_h] and self.cooldown == 0:
            self.human_mode = not self.human_mode
            self.cooldown = 100
        if keys[pygame.K_f] and self.cooldown == 0:
            self.FPS = 5 if self.FPS == self.init_FPS else self.init_FPS
            self.cooldown = 100
        if self.cooldown > 0:
            self.cooldown -= 1
        reward = 0
        shoot_points = handle_key_presses_AI(self.game_state, action)
        is_over, took_damage, points = self.game_state.game_tick(self.game_state)
        # nagroda za trafianie przeciwników i strzelanie, gdy czas przeładowania jest równy 0
        # kara za otrzymanie obrażeń i strzelanie, gdy czas przeładowania jest różny od 0 (niepotrzebne strzały)
        # porażka, gdy gracz straci całe zdrowie lub zbyt długo nie będzie niszczył przeciwników
        if took_damage:
            reward -= 50
            if is_over:
                return reward, is_over, self.game_state.points
        if points > 0:
            reward += 20
        if self.frame_iteration > 1500*(self.game_state.points+1):
            reward -= 100
            return reward, True, self.game_state.points
        if action[4] == 1 and 0 < self.game_state.tank1.reload_time < 79:
            reward -= 1
        elif action[4] == 1 and self.game_state.tank1.reload_time == 79:
            reward += 3
        reward += shoot_points
        if shoot_points > 0:
            self.good_shoots += 1
        if self.human_mode:
            self.game_state.draw(self.screen, self.game_state)
            pygame.display.flip()
        return reward, is_over, self.game_state.points

## Model agenta

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import random
from collections import namedtuple, deque

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))

class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = deque(maxlen=capacity)

    def push(self, *args):
        """Save a transition."""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.qnet = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):
        x = self.qnet(x)
        return x

    def save(self, file_name='model.pth'):
        model_folder_path = './model'
        if not os.path.exists(model_folder_path):
            os.makedirs(model_folder_path)
        
        file_name = os.path.join(model_folder_path, file_name)
        torch.save(self.state_dict(), file_name)

class QTrainer:
    def __init__(self, model, target_model, lr, gamma, batch_size, memory_capacity):
        self.lr = lr
        self.gamma = gamma
        self.model = model
        self.target_model = target_model
        self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
        self.criterion = nn.SmoothL1Loss()  # Huber loss
        self.memory = ReplayMemory(memory_capacity)
        self.batch_size = batch_size

    def train_step(self):
        if len(self.memory) < self.batch_size:
            return
        
        transitions = self.memory.sample(self.batch_size)
        batch = Transition(*zip(*transitions))

        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool)
        non_final_next_states = torch.stack([s for s in batch.next_state if s is not None])

        state_batch = torch.stack(batch.state)
        action_batch = torch.stack(batch.action)
        reward_batch = torch.stack(batch.reward)
        
        state_action_values = self.model(state_batch).gather(1, action_batch)

        next_state_values = torch.zeros(self.batch_size)
        next_state_values[non_final_mask] = self.target_model(non_final_next_states).max(1)[0].detach()
        expected_state_action_values = reward_batch + (self.gamma * next_state_values)

        loss = self.criterion(state_action_values, expected_state_action_values.unsqueeze(1))

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def push_to_memory(self, state, action, next_state, reward, done):
        self.memory.push(state, action, next_state, reward, done)

## Agent

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import numpy as np
from model import Linear_QNet, QTrainer
from game_logic_ai import TankGame, Direction, Utils, PIXEL_SIZE, BLOCK_SIZE
from helper import plot

# Constants
MAX_MEMORY = 100_000
BATCH_SIZE = 64
LR = 0.001
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200  # Decay parameter for epsilon

class Agent:
    def __init__(self, state_size, action_size):
        self.n_games = 0
        self.state_size = state_size
        self.action_size = action_size
        self.epsilon = EPS_START
        self.gamma = 0.98
        self.memory = deque(maxlen=MAX_MEMORY)
        self.model = Linear_QNet(state_size, 512, action_size)
        self.target_model = Linear_QNet(state_size, 512, action_size)
        self.target_model.load_state_dict(self.model.state_dict())
        self.target_model.eval()
        self.optimizer = optim.Adam(self.model.parameters(), lr=LR)
        self.criterion = nn.SmoothL1Loss()

    def get_state(self, game: TankGame):
        # tank_location = [game.game_state.tank1.x, game.game_state.tank1.y]
        # find closest bot, give relative cords
        closest_bot = Utils.find_closest_bot(game.game_state)
        bot_location = [(game.game_state.tank1.x-closest_bot.x)/100, (game.game_state.tank1.y-closest_bot.y)/100]
        rotation = [np.sign(game.game_state.tank1.direction[0]), np.sign(game.game_state.tank1.direction[1])]
        # iterate through all game objects, check if they are next to the tank in 4 main directions
        nearby_objects = Utils.check_nearby_objects(game.game_state.tank1, game.game_state.tank_bots + game.game_state.map)
        nearby_bullets = [i/100 for i in Utils.check_nearby_bullets(game.game_state.tank1, game.game_state.bullets)]
        reload_time = 1 if game.game_state.tank1.reload_time > 0 else 0
        state = bot_location + rotation + nearby_objects + nearby_bullets + [reload_time]
        # state = [reload_time] + nearby_objects 
        return np.array(state, dtype=int)

    def get_action(self, state):
        if np.random.rand() <= self.epsilon:
            action = random.randrange(self.action_size)
            return action
        else:
            with torch.no_grad():
                state = torch.tensor(state, dtype=torch.float).unsqueeze(0)
                action = self.model(state).max(1)[1].item()
                return action

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def train(self):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, BATCH_SIZE)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.tensor(states, dtype=torch.float)
        actions = torch.tensor(actions, dtype=torch.long).view(-1, 1)
        rewards = torch.tensor(rewards, dtype=torch.float).view(-1, 1)
        next_states = torch.tensor(next_states, dtype=torch.float)
        dones = torch.tensor(dones, dtype=torch.bool).view(-1, 1)

        current_q_values = self.model(states).gather(1, actions)

        next_q_values = torch.zeros(BATCH_SIZE, dtype=torch.float)
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, next_states)), dtype=torch.bool)
        non_final_next_states = next_states[non_final_mask]
        with torch.no_grad():
            next_q_values[non_final_mask] = self.target_model(non_final_next_states).max(1)[0]

        expected_q_values = rewards + (self.gamma * next_q_values.unsqueeze(1))

        loss = self.criterion(current_q_values, expected_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def decay_epsilon(self):
        self.epsilon = EPS_END + (EPS_START - EPS_END) * np.exp(-1. * self.n_games / EPS_DECAY)


def train():
    plot_scores = []
    plot_mean_scores = []
    total_score = 0
    agent = Agent(state_size=13, action_size=5)
    total_score = 0
    record = 0
    game = TankGame("lab4/survival.txt", FPS=600000)
    for episode in range(1000):  # Adjust the number of episodes as needed
        # Run episode
        done = False
        score = 0
        total_reward = 0
        # Initialize your game environment here
        while not done:
            state = agent.get_state(game) 
            action = agent.get_action(state)
            final_move = [0] * 5
            final_move[action] = 1
            reward, done, score = game.play_step(final_move)
            next_state = agent.get_state(game)


            agent.remember(state, action, reward, next_state, done)
            agent.train()
            agent.update_target_model()

            total_reward += reward

        # Logging
        game.reset()
        agent.n_games += 1
        agent.decay_epsilon()
        if agent.n_games % 1 == 0:
            print(f'Episode: {episode}, Total Episode Reward: {total_reward}')
        if score > record:
            record = score
            agent.model.save()
        
        print('Game', agent.n_games, 'Score', score, 'Record', record)
        
        plot_scores.append(score)
        total_score += score
        mean_score = total_score / agent.n_games
        plot_mean_scores.append(mean_score)
        plot(plot_scores, plot_mean_scores)

if __name__ == '__main__':
    train()

## Pomocnicza funkcja do wykresów

In [None]:
import matplotlib.pyplot as plt
from IPython import display

plt.ion()

def plot(scores, mean_scores):
    display.clear_output(wait=True)
    display.display(plt.gcf())
    plt.clf()
    plt.title('Training...')
    plt.xlabel('Number of games')
    plt.ylabel('Score')
    plt.plot(scores)
    plt.plot(mean_scores)
    plt.ylim(ymin=0)
    plt.text(len(scores)-1, scores[-1], str(scores[-1]))
    plt.text(len(mean_scores)-1, mean_scores[-1], str(mean_scores[-1]))

    plt.show(block=False)
    plt.pause(.1)