Imports:

In [2]:
import random
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque


### **General Unit class**

In [3]:
class Unit:
  """
    General unit class.
  """
  def __init__(self, unit_name: str, unit_type: str, cost: int, star: int = 1):
    self.unit_name = unit_name
    self.unit_type = unit_type # We have three types of units: 1) "Healer", 2) "Tank", 3) "Ranger".
    self.cost = cost
    self.star = star  # Level of a unit: 1, 2, or 3.
    self.hp = self.get_hp()
    self.damage = self.get_damage()
    self.attack_range = self.get_attack_range()
    # self.ability_ready = False
    # self.attack_counter = 0 # After unit attacks 5 times - it can use ability.

  def get_hp(self):
    """
      Returns hp of a unit.
    """
    # Dictionary of hp values for 1-star units
    hp_values_1star = {
        1: 500,
        2: 650,
        3: 800,
        4: 950,
        5: 1100
    }

    # Dictionary of multipliers for increasing hp based on level
    star_multipliers = {
        1: 1.0,
        2: 1.5, # 50% increase
        3: 2 # 100% increase
    }

    # Dictionary of unit types for increasing hp based on type
    unit_type = {
        "Tank" : 1.4, # 40% increase
        "Ranger": 0.8, # 20% decrease
        "Healer": 0.9 # 10% decrease
    }

    # Combine the guidelines from above to calculate the hp value
    hp = int(hp_values_1star[self.cost] * star_multipliers[self.star] * unit_type[self.unit_type])
    #hp = int(star_multipliers[self.star] * unit_type[self.unit_type])

    return hp

  def get_damage(self):
    """
      Returns damage of a unit.
    """

    # Dictionary of damage values for 1-star units
    damage_values_1star = {
        1: 70,
        2: 80,
        3: 90,
        4: 100,
        5: 120
    }

    # Dictionary of multipliers for increasing damage based on level
    star_multipliers = {
        1: 1.0,
        2: 2, # 100% increase
        3: 3 # 200% increase
    }

    # Dictionary of unit types for increasing damage based on type
    unit_type = {
        "Tank" : 0.7, # 30% decrease
        "Ranger": 1.5, # 50% increase
        "Healer": 0.9 # 10% decrease
    }

    # Combine the guidelines from above to calculate the damage value
    damage = int(damage_values_1star[self.cost] * star_multipliers[self.star] * unit_type[self.unit_type])
    #damage = int(star_multipliers[self.star] * unit_type[self.unit_type])

    return damage

  def attack(self):
    """
      Stacks attacks counter to gain ability.
    """
    self.attack_counter += 1
    if self.attack_counter == 5:
      self.ability_ready = True

  def reset_ability_charge(self):
    """
      Resets the ability.
    """
    self.attack_counter = 0
    self.ability_ready = False
  
  def get_attack_range(self):
    """
      Returns attack range based on unit type.
    """
    ranges = {
      "Tank": 1,
      "Healer": 3,
      "Ranger": 4
    }
    return ranges[self.unit_type]

  def __repr__(self):
    """
      Returns a string representation of the unit (debugging purposes).
    """
    return f"<{self.star}★ {self.unit_name}, type:{self.unit_type} ({self.cost}-cost) - {self.hp} HP, {self.damage} damage>"

In [4]:
all_units_names_role_and_cost = {
    # 1 cost
    "Silent": ("Ranger", 1), "Flamy": ("Ranger", 1), "Cheddy": ("Ranger", 1), "Hertrude": ("Ranger", 1),
    "Brim": ("Tank", 1), "Bravos": ("Tank", 1), "Lorak": ("Tank", 1), "Kiros": ("Tank", 1),
    "Mary": ("Healer", 1), "Looney": ("Healer", 1), "Kitana": ("Healer", 1), "Miss Luis": ("Healer", 1),
    # 2 cost
    "Marko": ("Ranger", 2), "Colt": ("Ranger", 2), "Kana": ("Ranger", 2),
    "Morfus": ("Tank", 2), "Sol": ("Tank", 2), "Kemer": ("Tank", 2), "Pronto": ("Tank", 2),
    "Summer": ("Healer", 2), "Clover": ("Healer", 2), "Pishta": ("Healer", 2),
    # 3 cost
    "Bruno": ("Ranger", 3), "Tofa": ("Ranger", 3), "Monroe": ("Ranger", 3),
    "Krusty": ("Tank", 3), "Kenny": ("Tank", 3), "Kanye": ("Tank", 3),
    "Ashley": ("Healer", 3), "Bonny": ("Healer", 3),
    # 4 cost
    "Kaneki Ken": ("Ranger", 4), "Satoru Gojo": ("Ranger", 4), "Gabimaru": ("Ranger", 4),
    "Toochka": ("Tank", 4), "MnSano": ("Tank", 4),
    "Avotushenka": ("Healer", 4),
    # 5 cost
    "Keysella": ("Ranger", 5),
    "Maikeru": ("Tank", 5),
    "Militmi": ("Healer", 5)
}

all_units_list = []

for name, (role, cost) in all_units_names_role_and_cost.items():
  all_units_list.append(Unit(name, role, cost))

for unit in all_units_list:
  print(unit)

<1★ Silent, type:Ranger (1-cost) - 400 HP, 105 damage>
<1★ Flamy, type:Ranger (1-cost) - 400 HP, 105 damage>
<1★ Cheddy, type:Ranger (1-cost) - 400 HP, 105 damage>
<1★ Hertrude, type:Ranger (1-cost) - 400 HP, 105 damage>
<1★ Brim, type:Tank (1-cost) - 700 HP, 49 damage>
<1★ Bravos, type:Tank (1-cost) - 700 HP, 49 damage>
<1★ Lorak, type:Tank (1-cost) - 700 HP, 49 damage>
<1★ Kiros, type:Tank (1-cost) - 700 HP, 49 damage>
<1★ Mary, type:Healer (1-cost) - 450 HP, 63 damage>
<1★ Looney, type:Healer (1-cost) - 450 HP, 63 damage>
<1★ Kitana, type:Healer (1-cost) - 450 HP, 63 damage>
<1★ Miss Luis, type:Healer (1-cost) - 450 HP, 63 damage>
<1★ Marko, type:Ranger (2-cost) - 520 HP, 120 damage>
<1★ Colt, type:Ranger (2-cost) - 520 HP, 120 damage>
<1★ Kana, type:Ranger (2-cost) - 520 HP, 120 damage>
<1★ Morfus, type:Tank (2-cost) - 909 HP, 56 damage>
<1★ Sol, type:Tank (2-cost) - 909 HP, 56 damage>
<1★ Kemer, type:Tank (2-cost) - 909 HP, 56 damage>
<1★ Pronto, type:Tank (2-cost) - 909 HP, 56 da

### **Player Class 😎**

In [None]:
class Player:
  """
    General player class.
  """
  def __init__(self, name: str):
    self.name = name
    # Starting gold, level, hp, and no units for every player.
    self.gold = 13
    self.level = 3
    self.hp = 100
    self.board = [[None for _ in range(8)] for _ in range(4)]
    self.units_on_board = 0
    self.bench = [None for _ in range(8)]
    self.all_units = []
    self.shop = Shop(all_units_list, self.level)
    self.won_last_fight = False

  def level_up(self):
    """
      Increase the level of a Player.
    """
    level_up_costs = {
      3: 6,
      4: 10,
      5: 20,
      6: 36,
      7: 54,
      8: 80
    }
    
    cost = level_up_costs.get(self.level, None)
    
    if cost is not None and self.gold >= cost:
      self.gold -= cost
      self.level += 1
      return True
    
    return False

  def gain_gold(self):
    """
      Gain gold - method that triggers every start of the round. Player gets:
        1) + win bonus if they won last fight;
        2) + interest rate (no more than 5);
        3) + 9 gold.
    """
    # Win bonus.
    win_bonus = 0
    if self.won_last_fight:
      win_bonus = 1

    # Interest rate.
    interest_rate = self.gold // 10
    if interest_rate > 5:
      interest_rate = 5

    # Gaining gold.
    self.gold += 9 + win_bonus + interest_rate
  
  def refresh_shop(self):
    if self.gold >= 2:
      self.gold -= 2
      self.shop.update(self.level)
      return True
    else:
      return False

  def buy_unit_from_shop(self, unit):
    """
      Buy a unit from shop.
    """
    if (unit is None) or (self.gold < unit.cost):
      return False

    for i in range(8):
      if self.bench[i] is None:
        self.bench[i] = unit
        self.all_units.append(unit)
        self.gold -= unit.cost
        return True
    return False
  
  def buy_unit(self, shop_index):
    """
      Buy_unit_from_shop() function wrapper.
    """
    unit = self.shop.units_in_shop[shop_index]
    success = self.buy_unit_from_shop(unit)
    if success:
        self.shop.remove(unit)
        self.check_and_upgrade_units()
    return success

  def check_and_upgrade_units(self):
    """
    Check and upgrade units to their 2★ or 3★ versions.
    """
    for star_level in [1, 2]:  # For every star level before 3★.
        unit_counter = {}

        # Count units
        for unit in self.all_units:
            if unit.star == star_level:
                key = (unit.unit_name, unit.unit_type, unit.cost)
                unit_counter[key] = unit_counter.get(key, 0) + 1

        for (unit_name, unit_type, cost), count in unit_counter.items():
            while count >= 3:
                # print("We found a unit to upgrade!")
                # print(unit_name, unit_type, cost, count)
                # print("-----------------------------------------------------------------")
                # Find three same units.
                to_remove = []
                locations = []

                # Find them on the board.
                if len(to_remove) < 3:
                    for row in range(4):
                        for col in range(8):
                            u = self.board[row][col]
                            if u and u.unit_type == unit_type and u.cost == cost and u.star == star_level:
                                to_remove.append(u)
                                locations.append(("board", (row, col)))
                                if len(to_remove) == 3:
                                    break
                        if len(to_remove) == 3:
                            break

                # If not enough on bench - find the rest on the bench.
                for idx, u in enumerate(self.bench):
                    if u and u.unit_type == unit_type and u.cost == cost and u.star == star_level:
                        to_remove.append(u)
                        locations.append(("bench", idx))
                        if len(to_remove) == 3:
                            break

                # Delete units we found from bench and board.
                for place, idx in locations:
                    if place == "bench":
                        self.bench[idx] = None
                    else:
                        row, col = idx
                        self.board[row][col] = None

                # Delete units we found from all_units list.
                for u in to_remove:
                    self.all_units.remove(u)

                # Add upgraded version of the unit to the list.
                new_unit = Unit(unit_name, unit_type, cost, star=star_level + 1)
                self.all_units.append(new_unit)

                # Place a new unit to the first board location (we first checked board for these units, so first should be board).
                # If there is no board location, just place it on the bench.
                place, idx = locations[0]
                if place == "bench":
                    self.bench[idx] = new_unit
                else:
                    row, col = idx
                    self.board[row][col] = new_unit
                
                count -= 3
                unit_counter[(unit_name, unit_type, cost)] -= 3
                # print("Unit has been upgraded!")

  def sell_unit(self, unit):
    """
      Sells a unit.
    """
    self.all_units.remove(unit)
    self.gold += unit.cost * unit.star
    if unit in self.bench:
      for i in range(8):
        if self.bench[i] == unit:
          self.bench[i] = None
          break
    elif unit in self.board:
      for i in range(4):
        for j in range(8):
          if self.board[i][j] == unit:
            self.board[i][j] = None
            self.units_on_board -= 1
            break

  def sell_unit_from_cell(self, from_cell):
    """ Sells a unit, given specified cell. """
    if from_cell < 8:
        unit = self.bench[from_cell]
    else:
        row = (from_cell - 8) // 8
        col = (from_cell - 8) % 8
        unit = self.board[row][col]
    
    if unit is not None:
        self.sell_unit(unit)
        return True
    
    return False

  def move_unit(self, from_cell, to_cell):
    """
    If there is a unit (not a None) in the starting cell - moves this unit from one cell to another.
    If there is also a unit (not a None) in the destination cell - switches their positions.
    """
    # from_cell and to_cell are integers [0,39]. 0-7: bench, 8-39: board (row by row from top to bottom).
    
    # Get source unit
    if from_cell < 8:
        source_unit = self.bench[from_cell]
        source_location = ("bench", from_cell)
    else:
        row = (from_cell - 8) // 8
        col = (from_cell - 8) % 8
        source_unit = self.board[row][col]
        source_location = ("board", row, col)
    
    # Get destination unit
    if to_cell < 8:
        target_unit = self.bench[to_cell]
        target_location = ("bench", to_cell)
    else:
        row = (to_cell - 8) // 8
        col = (to_cell - 8) % 8
        target_unit = self.board[row][col]
        target_location = ("board", row, col)
    
    # If source is empty, nothing to move
    if source_unit is None:
        return False
    
    # Update source location
    if source_location[0] == "bench":
        self.bench[source_location[1]] = target_unit
    else:
        self.board[source_location[1]][source_location[2]] = target_unit
        if target_unit is None:
          self.units_on_board += 1
    
    # Update target location
    if target_location[0] == "bench":
        self.bench[target_location[1]] = source_unit
        self.units_on_board -= 1
    else:
        self.board[target_location[1]][target_location[2]] = source_unit
    
    return True



### **Shop Class**

In [6]:
class Shop:
  """
    General shop class.
  """
  def __init__(self, all_units_list, player_level):
    self.units_in_shop = [None] * 5 # List of the 5 units to choose from
    self.all_units_list = all_units_list # List of all available units
    self.fill_shop(player_level) # Fill the shop with units initially
  
  def update(self, player_level):
    """
    Updates (rerolls) the shop. Should be called if player pays 2 gold.
    """
    self.fill_shop(player_level)

  def fill_shop(self, player_level):
    """
      Fills the shop with units.
    """
    self.units_in_shop = [None] * 5 # Reset the shop
    probabilities = self.get_probabilities(player_level) # Get probabilities for each unit

    for i in range(5):
      roll = random.random() * 100
      cumulative_prob = 0
      selected_cost = 1

      for cost, prob in probabilities.items():
        cumulative_prob += prob
        if roll <= cumulative_prob:
          selected_cost = cost
          break
      cost_units = [unit for unit in self.all_units_list if unit.cost == selected_cost]

      # Select a random unit if available
      if cost_units:
        self.units_in_shop[i] = random.choice(cost_units)

  def get_probabilities(self, player_level):
    """
      Returns a list of probabilities for each unit in the shop.
    """
    # Probability distributions for each unit in the shop
    distributions = {
        3: {1: 75, 2: 25, 3: 0, 4: 0, 5: 0}, # Start with 3 level
        4: {1: 60, 2: 30, 3: 10, 4: 0, 5: 0},
        5: {1: 40, 2: 35, 3: 20, 4: 5, 5: 0},
        6: {1: 25, 2: 40, 3: 25, 4: 10, 5: 0},
        7: {1: 15, 2: 30, 3: 35, 4: 15, 5: 5},
        8: {1: 10, 2: 20, 3: 25, 4: 35, 5: 10},
        9: {1: 5, 2: 15, 3: 20, 4: 40, 5: 20}, # 9 levels max
    }

    return distributions[player_level]

  def remove(self, unit):
    """
      Removes a unit from the shop.
    """

    for i in range(len(self.units_in_shop)):
      if self.units_in_shop[i] == unit:
        self.units_in_shop[i] = None
        return True
    return False

  def __repr__(self):
    """
      Returns a string representation of the unit (debugging purposes).
    """
    unit_names = [str(unit) if unit else "Empty" for unit in self.units_in_shop]
    return f"Shop: {unit_names}"


In [7]:
# Testing shop working with a player.
player = Player("Alex")
player.gold = 999
print(player.gold)
print(player.shop)

999
Shop: ['<1★ Looney, type:Healer (1-cost) - 450 HP, 63 damage>', '<1★ Kitana, type:Healer (1-cost) - 450 HP, 63 damage>', '<1★ Silent, type:Ranger (1-cost) - 400 HP, 105 damage>', '<1★ Brim, type:Tank (1-cost) - 700 HP, 49 damage>', '<1★ Mary, type:Healer (1-cost) - 450 HP, 63 damage>']


In [8]:
player.buy_unit(0)

True

In [9]:
print(player.gold)
print(player.shop)
print(player.all_units)
print(player.board)
print(player.bench)

998
Shop: ['Empty', '<1★ Kitana, type:Healer (1-cost) - 450 HP, 63 damage>', '<1★ Silent, type:Ranger (1-cost) - 400 HP, 105 damage>', '<1★ Brim, type:Tank (1-cost) - 700 HP, 49 damage>', '<1★ Mary, type:Healer (1-cost) - 450 HP, 63 damage>']
[<1★ Looney, type:Healer (1-cost) - 450 HP, 63 damage>]
[[None, None, None, None, None, None, None, None], [None, None, None, None, None, None, None, None], [None, None, None, None, None, None, None, None], [None, None, None, None, None, None, None, None]]
[<1★ Looney, type:Healer (1-cost) - 450 HP, 63 damage>, None, None, None, None, None, None, None]


In [10]:
player.move_unit(0,34)

True

In [11]:
player.refresh_shop()
print(player.gold)
print(player.shop)

996
Shop: ['<1★ Morfus, type:Tank (2-cost) - 909 HP, 56 damage>', '<1★ Hertrude, type:Ranger (1-cost) - 400 HP, 105 damage>', '<1★ Flamy, type:Ranger (1-cost) - 400 HP, 105 damage>', '<1★ Bravos, type:Tank (1-cost) - 700 HP, 49 damage>', '<1★ Looney, type:Healer (1-cost) - 450 HP, 63 damage>']


In [12]:
player.buy_unit(4)

True

In [13]:
print(player.gold)
print(player.shop)
print(player.all_units)
print(player.board)
print(player.bench)

995
Shop: ['<1★ Morfus, type:Tank (2-cost) - 909 HP, 56 damage>', '<1★ Hertrude, type:Ranger (1-cost) - 400 HP, 105 damage>', '<1★ Flamy, type:Ranger (1-cost) - 400 HP, 105 damage>', '<1★ Bravos, type:Tank (1-cost) - 700 HP, 49 damage>', 'Empty']
[<1★ Looney, type:Healer (1-cost) - 450 HP, 63 damage>, <1★ Looney, type:Healer (1-cost) - 450 HP, 63 damage>]
[[None, None, None, None, None, None, None, None], [None, None, None, None, None, None, None, None], [None, None, None, None, None, None, None, None], [None, None, <1★ Looney, type:Healer (1-cost) - 450 HP, 63 damage>, None, None, None, None, None]]
[<1★ Looney, type:Healer (1-cost) - 450 HP, 63 damage>, None, None, None, None, None, None, None]


### **Environment**

In [None]:
class TFTEnv(gym.Env):
    def __init__(self):
        super().__init__()

        # Player initialization.
        self.player1 = player1
        self.player2 = player2
        self.current_player = self.player1

        # Maximum number of steps before fight.
        self.max_steps_per_round = 50
        self.steps_this_round = 0

        self.action_space = spaces.Dict({
            "action_type": spaces.Discrete(10),  # 0-4: buy a unit, 5: sell unit, 6: reroll shop, 7: level up, 8: move unit, 9: end turn.
            "from_cell": spaces.Discrete(40),   # Only matters if action_type is 5 or 8.
            "to_cell": spaces.Discrete(40),     # Only matters if action_type is 8.
        })

        self.observation_space = spaces.Dict({
            "gold": spaces.Box(low=0, high=np.inf, shape=(), dtype=np.float32),
            "health": spaces.Discrete(101),
            "shop": spaces.MultiDiscrete([6] * 5),  # 0-5, 0 means no unit, 1-5 is cost.
            "bench": spaces.MultiDiscrete([6] * 8), # 8 slots on the bench.
            "board": spaces.MultiDiscrete([6] * (4 * 8)),   # 4x8 slots on the board.
        })

        self.done = False

    def reset(self, seed=None, options=None):
        self.player1 = Player("Our Newbie")
        self.player2 = Player("PRO GAMER 3000")
        self.current_player = self.player1
        self.steps_this_round = 0
        self.done = False

        observation = self.get_observation()
        return observation, {}

    def step(self, action):
        """
            Make an action.
        """
        if self.done:
            raise Exception("Game is over. Call reset().")

        # Count the reward
        reward = 0

        action_type = action["action_type"]

        # Actions:
        if action_type in range(5):  # Buy 1 out of 5 units from the shop.
            unit = self.current_player.shop.units_in_shop[action_type]
            if unit and self.current_player.gold >= unit.cost:
                success = self.current_player.buy_unit(action_type)  # Changed: pass action (index) instead of unit
                if success:
                    reward += 0.2
                else:
                    reward -= 0.3
        elif action_type == 5:  # Sell unit.
            from_cell = action["from_cell"]
            success = self.current_player.sell_unit_from_cell(from_cell)
            if success:
                reward += 0.1
            else:
                reward -= 0.2
        elif action_type == 6:  # Reroll the shop.
            success = self.current_player.refresh_shop()
            if success:
                reward += 0.3
            else:
                reward -= 0.4
        elif action_type == 7:  # Level up.
            success = self.current_player.level_up()
            if success:
                reward += 0.3
            else:
                reward -= 0.5
        elif action_type == 8: # Move unit.
            from_cell = action["from_cell"]
            to_cell = action["to_cell"]
            success = self.current_player.move_unit(from_cell, to_cell)
            if success:
                reward += 0.4
            else:
                reward -= 0.6
        elif action_type == 9:  # End player's turn.
            self.start_fight()
            if not self.current_player.won_last_fight:
                if self.steps_this_round <= 10:
                    reward -= 1.5
                elif self.steps_this_round <= 20:
                    reward -= 1.4
                elif self.steps_this_round <= 30:
                    reward -= 1.3
                elif self.steps_this_round <= 40:
                    reward -= 1.2
                else:
                    reward -= 1.1
            else:
                reward += 1.0
        
        # Count steps in a round.
        self.steps_this_round += 1
        # Check if the fight can start.
        if self.steps_this_round > self.max_steps_per_round:
            self.start_fight()
            if self.current_player.won_last_fight:
                reward += 1.0
            else:
                reward -= 2.0
        
        self.player1.gain_gold()
        self.player2.gain_gold()

        # Check the final health.
        if self.player1.hp <= 0 or self.player2.hp <= 0:
            self.done = True
            if self.player2.hp <= 0:
                print(f"Player {self.player1.name} won! Health left: {self.player1.hp}")
                reward += 5
            if self.player1.hp <= 0:
                print(f"Player {self.player2.name} won! Health left: {self.player2.hp}")
                reward -= 10

        observation = self.get_observation()
        return observation, reward, self.done, False, {}

    def start_fight(self):
        """
            Starts the fight.
            1. Creates the battlefield;
            2. Uses battle_step() function while both sides are "alive";
            3. Calculates who won.
        """
        battlefield = self.create_battlefield()
        
        # Fight continues while both players have living units.
        while True:
            player1_alive = any(player_id == 1 for (player_id, unit) in battlefield.values())
            player2_alive = any(player_id == 2 for (player_id, unit) in battlefield.values())
            
            # If one of the players lost all his units.
            if not player1_alive or not player2_alive:
                break
            
            # One step of a fight.
            battlefield = self.battle_step(battlefield)
        
        # Find the winner.
        # YOU CAN TURN PRINTS OFF!!!
        if player1_alive and not player2_alive:
            # Player 2 lost.
            self.player2.won_last_fight = False
            damage = sum(1 for (player_id, unit) in battlefield.values() if player_id == 1) * 2
            self.player2.hp -= damage
            
            # Player 1 won.
            self.player1.won_last_fight = True
            
            # print(f"{self.player2.name} takes {damage} damage!\n{self.player1.name}'s health: {self.player1.hp}\n{self.player2.name}'s health: {self.player2.hp}")
        elif player2_alive and not player1_alive:
            # Player 1 lost.
            self.player1.won_last_fight = False
            damage = sum(1 for (player_id, unit) in battlefield.values() if player_id == 2) * 2
            self.player1.hp -= damage
            
            # Player 2 won.
            self.player2.won_last_fight = True
            
            # print(f"{self.player1.name} takes {damage} damage!\n{self.player1.name}'s health: {self.player1.hp}\n{self.player2.name}'s health: {self.player2.hp}")
        else:
            # A draw. Both players lose 5 hp.
            damage = 5
            self.player1.hp -= damage
            self.player2.hp -= damage
            
            # print(f"A draw has occured! Both players lose 5 hp!\n{self.player1.name}'s health: {self.player1.hp}\n{self.player2.name}'s health: {self.player2.hp}")
        # print("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------")
    
    def create_battlefield(self):
        """
            Connects players' boards into one 8x8 battlefield by copying them and rotating player 2 board.
            Returns battlefield: dict {(row, col): (player_id, unit)}
        """
        battlefield = {}

        # Place units of the first player on the battlefield.
        for row in range(4):
            for col in range(8):
                unit = self.player1.board[row][col]
                if unit:
                    # Player 1 takes place of the lower half of the battlefield.
                    # The view on the battlefield will be the view of the player 1.
                    # Rows: 4-7
                    battlefield[(row + 4, col)] = (1, unit)

        # Place units of the second player on the battlefield, but rotate his board by 180.
        for row in range(4):
            for col in range(8):
                unit = self.player2.board[row][col]
                if unit:
                    # Rotation: (row, col) -> (3 - row, 7 - col).
                    new_row = 3 - row
                    new_col = 7 - col
                    # Player 2 takes place of the upper half of the battlefield.
                    # Rows: 0-3
                    battlefield[(new_row, new_col)] = (2, unit)

        # TURN THIS OFF IF YOU WANT
        # print(F"--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n{battlefield}")
        return battlefield
    
    def find_closest_enemy(self, my_pos, my_player_id, battlefield):
        """
            Finds the closest enemy using Mahattan's distance.
            Returns a tuple (enemy_pos, distance, enemy_unit).
            If there are no enemies - returns None.
        """
        my_row, my_col = my_pos
        _, my_unit = battlefield[my_pos]

        min_distance = float('inf')
        closest_enemy = None

        for (row, col), (player_id, unit) in battlefield.items():
            if player_id != my_player_id:  # Find the enemy using manhattans distance.
                distance = abs(my_row - row) + abs(my_col - col)

                if distance < min_distance:
                    min_distance = distance
                    closest_enemy = ((row, col), unit)

        if closest_enemy:
            enemy_pos, enemy_unit = closest_enemy
            return enemy_pos, min_distance, enemy_unit
        else:
            return None  # No enemies? (meme, hehehe)
    
    def move_towards(self, my_pos, target_pos):
        """
            Calculates the position where to move (1 step) towards the enemy.
            Moves only in one cell over the step.
        """
        my_row, my_col = my_pos
        target_row, target_col = target_pos

        if abs(my_row - target_row) > abs(my_col - target_col):
            # Move vertically.
            if my_row < target_row:
                return (my_row + 1, my_col)
            elif my_row > target_row:
                return (my_row - 1, my_col)
        else:
            # Move horizontally.
            if my_col < target_col:
                return (my_row, my_col + 1)
            elif my_col > target_col:
                return (my_row, my_col - 1)

        # Just in case unit somehow decides that he is inside the enemy.
        return my_pos
    
    def battle_step(self, battlefield):
        """
            One step of a fight.
            Every unit makes a decision: either attack or move towards the enemy.
        """
        intents = {}  # Tuple {my_pos: ("attack", enemy_pos) or ("move", new_pos)}.
        occupied_destinations = set()  # To avoid conflicts.

        # Collect the intentions of all units.
        for my_pos, (player_id, unit) in battlefield.items():
            # Find the closest enemy.
            result = self.find_closest_enemy(my_pos, player_id, battlefield)
            if result is None:
                continue  # There are no enemies left. Skip this unit.

            enemy_pos, distance, enemy_unit = result

            # If the enemy is within the attack range - unit wants to attack him.
            if distance <= unit.attack_range:
                intents[my_pos] = ("attack", enemy_pos)
            else:
                # Else - unit wants to move towards this enemy.
                new_pos = self.move_towards(my_pos, enemy_pos)
                intents[my_pos] = ("move", new_pos)

        # Next step is to perform the intents.
        new_battlefield = battlefield.copy()

        # First, we process the attacks (so that units do not attack empty places).
        for my_pos, action in intents.items():
            if action[0] == "attack":
                enemy_pos = action[1]
                if enemy_pos in new_battlefield:
                    _, enemy_unit = new_battlefield[enemy_pos]
                    _, my_unit = new_battlefield[my_pos]

                    enemy_unit.hp -= my_unit.damage

        # Next, we process the movement.
        for my_pos, action in intents.items():
            if action[0] == "move":
                new_pos = action[1]

                if new_pos not in new_battlefield and new_pos not in occupied_destinations:
                    # Unit moves to a new cell towards the enemy
                    new_battlefield[new_pos] = new_battlefield[my_pos]
                    del new_battlefield[my_pos]
                    occupied_destinations.add(new_pos)
                else:
                    # If this position has been just occupied, unfortunately, unit has no place to move.
                    # So we just skip his intention.
                    pass

        # Delete dead units.
        to_delete = []
        for pos, (player_id, unit) in new_battlefield.items():
            if unit.hp <= 0:
                to_delete.append(pos)

        for pos in to_delete:
            del new_battlefield[pos]

        # And return the new battlefield state. (YOU CAN TURN THIS OFF)
        # print(new_battlefield)
        return new_battlefield

    def get_observation(self):
        """
            Return current player's observation.
        """
        shop_obs = [0 if unit is None else unit.cost for unit in self.current_player.shop.units_in_shop]
        bench_obs = [0 if unit is None else unit.cost for unit in self.current_player.bench]
        board_obs = []

        for row in self.current_player.board:
            board_obs.extend([0 if unit is None else unit.cost for unit in row])
        
        return {
            "gold": float(self.current_player.gold),
            "hp": self.current_player.hp,
            "shop": np.array(shop_obs, dtype=np.int64),
            "bench": np.array(bench_obs, dtype=np.int64),
            "board": np.array(board_obs, dtype=np.int64),
        }


In [19]:
# Test on two players with units on each side.
# YOU CAN TURN OFF THE PRINTS IN CODE CELL ABOVE FOR TRAINING!!!
# Prints are in use for now to see how things work!!!
player1 = Player("Kaya")
player2 = Player("Alex")

player1.board[0][3] = Unit("Kayas's Teddy Bear", "Tank", 4, 1)
player1.board[0][4] = Unit("Kayas's Teddy Wolf", "Tank", 2, 1)
player1.board[3][0] = Unit("Kayas's Deadly Flower", "Ranger", 3, 2)

player2.board[0][4] = Unit("Alex's Plushy Robot", "Tank", 4, 1)
player2.board[2][6] = Unit("Alex's Plushy Doctor", "Healer", 3, 1)
player2.board[3][7] = Unit("Alex's Plushy Dragon", "Ranger", 2, 3)

env = TFTEnv()

env.player1 = player1
env.player2 = player2

env.start_fight()

# obs, _ = env.reset()

# done = False
# while not done:
#     action = env.action_space.sample()  # RANDOM ACTIONS (FOR NOW).
#     obs, reward, done, truncated, info = env.step(action)
#     print(f"Action: {action}")
#     for key, info in obs.items(): print(f"{key}: {info}")
#     print("---------------------------------------------------------------------------")


In [20]:
# TEST move_unit function
# Create a test player
test_player = Player("Test Player")

# Create some test units using your existing Unit class
unit1 = Unit("A", "Tank", 1)  # 1-cost Tank
unit2 = Unit("B", "Tank", 2)  # 2-cost Tank
unit3 = Unit("C", "Ranger", 3)  # 3-cost Ranger

# Place units
test_player.bench[0] = unit1  # on bench position 0
test_player.board[0][0] = unit2  # on board position (0,0) = cell 8
test_player.bench[5] = unit3  # on bench position 5

# Print initial state
print("Initial state:")
print(f"Bench[0]: {test_player.bench[0]}")
print(f"Board[0][0]: {test_player.board[0][0]}")
print(f"Bench[5]: {test_player.bench[5]}")

# Moving from bench to board
test_player.move_unit(0, 8)  # Move A from bench[0] to board[0][0]

# Print state after first move
print("\nAfter moving bench[0] to board[0][0]:")
print(f"Bench[0]: {test_player.bench[0]}")
print(f"Board[0][0]: {test_player.board[0][0]}")

# Test moving between bench positions
test_player.move_unit(5, 3)  # Move C from bench[5] to bench[3]

# Print state after second move
print("\nAfter moving bench[5] to bench[3]:")
print(f"Bench[3]: {test_player.bench[3]}")
print(f"Bench[5]: {test_player.bench[5]}")

# Moving from board to bench
test_player.move_unit(8, 0)  # Move A from board[0][0] back to bench[0]

# Print final state
print("\nAfter moving board[0][0] to bench[0]:")
print(f"Bench[0]: {test_player.bench[0]}")
print(f"Board[0][0]: {test_player.board[0][0]}")

Initial state:
Bench[0]: <1★ A, type:Tank (1-cost) - 700 HP, 49 damage>
Board[0][0]: <1★ B, type:Tank (2-cost) - 909 HP, 56 damage>
Bench[5]: <1★ C, type:Ranger (3-cost) - 640 HP, 135 damage>

After moving bench[0] to board[0][0]:
Bench[0]: <1★ B, type:Tank (2-cost) - 909 HP, 56 damage>
Board[0][0]: <1★ A, type:Tank (1-cost) - 700 HP, 49 damage>

After moving bench[5] to bench[3]:
Bench[3]: <1★ C, type:Ranger (3-cost) - 640 HP, 135 damage>
Bench[5]: None

After moving board[0][0] to bench[0]:
Bench[0]: <1★ A, type:Tank (1-cost) - 700 HP, 49 damage>
Board[0][0]: <1★ B, type:Tank (2-cost) - 909 HP, 56 damage>


### DQN

In [25]:
class DQN(nn.Module):
    def __init__(self, obs_space, action_space):
        super().__init__()
        
        # Dimensions from the observation space
        self.gold_dim = 1
        self.hp_dim = 1
        self.shop_dim = len(obs_space["shop"]) 
        self.bench_dim = len(obs_space["bench"])
        self.board_dim = len(obs_space["board"])
        
        # Action space dimensions
        self.action_type_dim = action_space["action_type"].n
        
        # Embedding layers
        self.shop_embedding = nn.Embedding(6, 8)
        self.bench_embedding = nn.Embedding(6, 8)
        self.board_embedding = nn.Embedding(6, 8)
        
        # Feature processing layers (3 linear)
        self.shop_linear = nn.Sequential(
            nn.Linear(self.shop_dim * 8, 64),
            nn.ReLU())
        
        self.bench_linear = nn.Sequential(
            nn.Linear(self.bench_dim * 8, 64),
            nn.ReLU())
        
        self.board_linear = nn.Sequential(
            nn.Linear(self.board_dim * 8, 128),
            nn.ReLU())
        
        # Feature combination layer
        self.combine = nn.Sequential(
            nn.Linear(self.gold_dim + self.hp_dim + 64 + 64 + 128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU())
        
        # Action outputs
        self.action_type_head = nn.Linear(128, self.action_type_dim)
        
    def forward(self, observation):
        # Extract features from observation dictionary
        gold = torch.tensor([[observation["gold"]]], dtype=torch.float32)
        hp = torch.tensor([[float(observation["hp"])]], dtype=torch.float32)
        shop = torch.tensor(observation["shop"], dtype=torch.long).unsqueeze(0)
        bench = torch.tensor(observation["bench"], dtype=torch.long).unsqueeze(0)
        board = torch.tensor(observation["board"], dtype=torch.long).unsqueeze(0)
        
        # Flatten embeddings
        shop_emb = self.shop_embedding(shop).flatten(1)
        bench_emb = self.bench_embedding(bench).flatten(1)
        board_emb = self.board_embedding(board).flatten(1)
        
        # Apply linear layers
        shop_features = self.shop_linear(shop_emb)
        bench_features = self.bench_linear(bench_emb)
        board_features = self.board_linear(board_emb)
        
        # Combine all features
        combined = torch.cat([gold, hp, shop_features, bench_features, board_features], dim=1)
        features = self.combine(combined)
        
        # Output action logits
        action_type_logits = self.action_type_head(features)
        
        return action_type_logits


class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
        
    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
        
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)


class DQNAgent:
    def __init__(self, env, buffer_size=10000, batch_size=64, gamma=0.99, lr=1e-4, eps_start=1.0, eps_end=0.1, eps_decay=0.995):
        # Hyperparameters
        self.env = env
        self.batch_size = batch_size
        self.gamma = gamma
        self.eps = eps_start
        self.eps_end = eps_end
        self.eps_decay = eps_decay
        
        # Initialize networks
        self.policy_net = DQN(env.observation_space, env.action_space)
        self.target_net = DQN(env.observation_space, env.action_space)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        
        # Initialize optimizer
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        
        # Initialize replay buffer
        self.replay_buffer = ReplayBuffer(buffer_size)
        
        # For tracking training progress
        self.steps_done = 0
        self.update_frequency = 1000  # Update target network every 1000 steps
        
        # Action types
        self.safe_actions = [0, 1, 2, 3, 4, 6, 7, 9]
        
    def select_action(self, state):
        """ Select an action using epsilon-greedy """
        if random.random() > self.eps:
            with torch.no_grad():
                # Get Q-values for action type
                action_type_q = self.policy_net(state)
                
                # Select highest Q-value action type
                action_type = torch.argmax(action_type_q, dim=1).item()
                
                return {
                    "action_type": action_type,
                    "from_cell": 0,  # Fixed value
                    "to_cell": 0     # Fixed value
                }
        else:
            # Random action from safe actions
            action_type = random.choice(self.safe_actions)
            return {
                "action_type": action_type,
                "from_cell": 0,  # Fixed value
                "to_cell": 0     # Fixed value
            }
    
    def update_epsilon(self):
        """ Decay epsilon over time """
        self.eps = max(self.eps_end, self.eps * self.eps_decay)
    
    def optimize_model(self):
        """ Perform optimization step """
        if len(self.replay_buffer) < self.batch_size:
            return
        
        # mini-batch from replay buffer
        transitions = self.replay_buffer.sample(self.batch_size)
        
        # Separate the batch components
        batch_states = []
        batch_actions = []
        batch_rewards = []
        batch_next_states = []
        batch_done = []
        
        for transition in transitions:
            state, action, reward, next_state, done = transition
            batch_states.append(state)
            batch_actions.append(action["action_type"])
            batch_rewards.append(reward)
            batch_next_states.append(next_state)
            batch_done.append(done)
        
        # Process states
        current_q_values = []
        target_q_values = []
        
        for i in range(self.batch_size):
            # Get current Q values
            current_q = self.policy_net(batch_states[i])
            action_idx = batch_actions[i]
            current_q_value = current_q[0, action_idx]
            current_q_values.append(current_q_value)
            
            # Get next state Q values
            if not batch_done[i]:
                with torch.no_grad():
                    next_q = self.target_net(batch_next_states[i])
                    next_q_value = next_q.max()
                    target = batch_rewards[i] + self.gamma * next_q_value
            else:
                target = batch_rewards[i]
                
            target_q_values.append(target)
        
        # Convert lists to tensors
        current_q_tensor = torch.stack(current_q_values)
        target_q_tensor = torch.tensor(target_q_values, dtype=torch.float32)  # Convert to tensor

        # Calculate loss
        loss = nn.MSELoss()(current_q_tensor, target_q_tensor)
        
        # Optimize, back propogate
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Update target network periodically
        self.steps_done += 1
        if self.steps_done % self.update_frequency == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())
    
    def train(self, num_episodes, max_steps=1000):
        """ Train the agent for a number of episodes """
        rewards_history = []
        
        for episode in range(num_episodes):
            state, _ = self.env.reset()
            episode_reward = 0
            
            for step in range(max_steps):
                # Select and perform action
                action = self.select_action(state)
                
                # avoid actions that might cause errors
                # try:
                next_state, reward, done, truncated, _ = self.env.step(action)
                episode_reward += reward
                    
                # Add a custom reward for actions that don't error out
                reward += 0.1
                    
                # Store transition in replay buffer
                self.replay_buffer.add(state, action, reward, next_state, done)
                    
                # Move to next state
                state = next_state
                # except Exception as e:
                #     print(f"Error at episode {episode}, step {step}: {e}")
                #     reward = -1.0 # Negative reward if error
                    
                #     # Store transition with negative reward
                #     self.replay_buffer.add(state, action, reward, state, True)  # Use same state as next state
                    
                #     # Restart episode
                #     state, _ = self.env.reset()
                #     break
                
                # Perform optimization step
                self.optimize_model()
                
                if done or truncated:
                    break
            
            # Update epsilon after each episode
            self.update_epsilon()
            
            # Track rewards
            rewards_history.append(episode_reward)
            print(f"Episode {episode + 1}, Reward: {episode_reward}, Epsilon: {self.eps}")
        
        return rewards_history

player1 = Player("DQNAgent")
player2 = Player("Random Agent")

env = TFTEnv()

agent = DQNAgent(env)
rewards = agent.train(num_episodes=100)  # Start with fewer episodes for testing

Episode 1, Reward: -35.50000000000001, Epsilon: 0.995
Episode 2, Reward: -37.9, Epsilon: 0.990025
Episode 3, Reward: -38.1, Epsilon: 0.985074875
Episode 4, Reward: -35.50000000000001, Epsilon: 0.9801495006250001
Episode 5, Reward: -37.800000000000004, Epsilon: 0.9752487531218751
Episode 6, Reward: -35.1, Epsilon: 0.9703725093562657
Episode 7, Reward: -32.7, Epsilon: 0.9655206468094844
Episode 8, Reward: -40.099999999999994, Epsilon: 0.960693043575437
Episode 9, Reward: -32.6, Epsilon: 0.9558895783575597
Episode 10, Reward: -41.800000000000004, Epsilon: 0.9511101304657719
Episode 11, Reward: -33.4, Epsilon: 0.946354579813443
Episode 12, Reward: -40.1, Epsilon: 0.9416228069143757
We found a unit to upgrade!
Mary Healer 1 3
-----------------------------------------------------------------
Unit has been upgraded!
Episode 13, Reward: -34.2, Epsilon: 0.9369146928798039
Episode 14, Reward: -36.10000000000001, Epsilon: 0.9322301194154049
Episode 15, Reward: -32.20000000000001, Epsilon: 0.92756