In [2]:
# This cell imports stuff, and sets up a bot instance etc

import sys
import lzma
from s2clientprotocol.sc2api_pb2 import Response, ResponseObservation
from MapAnalyzer import MapData
import pickle
import numpy as np
import matplotlib.pyplot as plt
sys.path.append("../src")
sys.path.append("../src/ares")
from sc2.position import Point2
from sc2.client import Client
from sc2.game_data import GameData
from sc2.game_info import GameInfo
from sc2.game_state import GameState
from unittest.mock import patch
from ares import AresBot
from sc2.bot_ai import BotAI

async def build_bot_object_from_pickle_data(raw_game_data, raw_game_info, raw_observation) -> AresBot:
    # Build fresh bot object, and load the pickled data into the bot object
    bot = BotAI()
    game_data = GameData(raw_game_data.data)
    game_info = GameInfo(raw_game_info.game_info)
    game_state = GameState(raw_observation)
    bot._initialize_variables()
    client = Client(True)
    
    bot._prepare_start(client=client, player_id=1, game_info=game_info, game_data=game_data)
    with patch.object(Client, "query_available_abilities_with_tag", return_value={}):
        await bot._prepare_step(state=game_state, proto_game_info=raw_game_info)
        bot._prepare_first_step()
        # await bot.register_managers()
    return bot

BERLINGRAD = "../tests/pickle_data/BerlingradAIE.xz"
with lzma.open(BERLINGRAD, "rb") as f:
    raw_game_data, raw_game_info, raw_observation = pickle.load(f)

# initiate a BotAI and MapAnalyzer instance
bot = await build_bot_object_from_pickle_data(raw_game_data, raw_game_info, raw_observation)
data = MapData(bot)

# common variables
grid = data.get_pyastar_grid()
position = bot.enemy_start_locations[0]
units = bot.all_units

%load_ext line_profiler
%load_ext Cython

  from .autonotebook import tqdm as notebook_tqdm
2023-06-03 13:20:42.524 | INFO     | MapAnalyzer.MapData:__init__:122 - dev Compiling Berlingrad AIE [32m
[32m Version dev Map Compilation Progress [37m: 0.4it [00:00,  1.76it/s]


# Cythonizing commonly used functions in a python-sc2 bot

* [Converting is_position_safe](#Python-version-of-is_position_safe) **6.57x speedup**
* [Alternative to python-sc2's units.closest_to](#Alternative-to-python-sc2's-units.closest_to) **6.7x speedup**
* [Speeding up `Units.center`](#Speeding-up-Units.center) **2.01x speedup**
* [Distance to / `unit.distance_to`](#Distance-to-/-unit.distance_to) **2.47x speedup**

# Converting `is_position_safe` to cython

## Python version of `is_position_safe`

In [3]:
def is_position_safe(
    grid: np.ndarray,
    position: Point2,
    weight_safety_limit: float = 1.0,
) -> bool:
    weight: float = grid[position.rounded]
    # np.inf check if drone is pathing near a spore crawler
    return weight == np.inf or weight <= weight_safety_limit

In [4]:
%timeit is_position_safe(grid, position)

4.04 µs ± 23.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [5]:
%lprun -f is_position_safe is_position_safe(grid, position)

Timer unit: 1e-07 s

Total time: 4.63e-05 s
File: C:\Users\Tom\AppData\Local\Temp\ipykernel_14304\2777119042.py
Function: is_position_safe at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def is_position_safe(
     2                                               grid: np.ndarray,
     3                                               position: Point2,
     4                                               weight_safety_limit: float = 1.0,
     5                                           ) -> bool:
     6         1        150.0    150.0     32.4      weight: float = grid[position.rounded]
     7                                               # np.inf check if drone is pathing near a spore crawler
     8         1        313.0    313.0     67.6      return weight == np.inf or weight <= weight_safety_limit

## Cython version of `is_position_safe`

In [6]:
%%cython
import numpy as np
cimport numpy as cnp
from cython cimport boundscheck, wraparound
@boundscheck(False)
@wraparound(False)
cpdef bint is_position_safe(
    cnp.ndarray[cnp.npy_float32, ndim=2] grid,
    (unsigned int, unsigned int) position,
    double weight_safety_limit = 1.0,
):
    cdef double weight = 0.0
    weight = grid[position[0], position[1]]
    # np.inf check if drone is pathing near a spore crawler
    return weight == np.inf or weight <= weight_safety_limit

In [7]:
%timeit is_position_safe(grid, position.rounded)

609 ns ± 2.83 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


# Alternative to `python-sc2`'s `units.closest_to`

In [8]:
units = bot.all_units
position = bot.enemy_start_locations[0]
unit = units[0]

In [9]:
# slower using closest_to(Point2)
%timeit units.closest_to(position)

197 µs ± 812 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [10]:
# this is faster since distance between all units is cached
%timeit units.closest_to(unit)

97.2 µs ± 448 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [11]:
%%cython
from cython cimport boundscheck, wraparound

cdef double euclidean_distance_squared(
        (float, float) p1,
        (float, float) p2
):
    return (p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2

@boundscheck(False)
@wraparound(False)
cpdef object closest_to((float, float) position, object units):
    cdef:
        object closest = units[0]
        double closest_dist = 999.9
        double dist = 0.0
        unsigned int len_units = len(units)
        (float, float) pos
        
    for i in range(len_units):
        unit = units[i]
        pos = unit.position
        dist = euclidean_distance_squared((pos[0], pos[1]), (position[0], position[1]))
        if dist < closest_dist:
            closest_dist = dist
            closest = unit
            
    return closest


In [12]:
%timeit closest_to(position, units)

14.5 µs ± 103 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


# Speeding up `Units.center`

Similar to `units.center` in `python-sc2`, but tweaked to work in this notebook

In [32]:
from sc2.position import Point2
def center(units) -> Point2:
    """Returns the central position of all units."""
    assert units, f"Units object is empty"
    amount = units.amount
    return Point2(
        (
            sum(unit._proto.pos.x for unit in units) / amount,
            sum(unit._proto.pos.y for unit in units) / amount,
        )
    )

In [33]:
%timeit center(units)

107 µs ± 775 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## Convert `Units.center` to cython

In [34]:
%%cython

cimport cython
from sc2.position import Point2

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef (double, double) cy_center(object units):
    """Returns the central position of all units."""
    cdef:
        unsigned int i = 0
        unsigned int num_units = len(units)
        double sum_x, sum_y = 0.0
        (double, double) position
        object unit

    for i in range(num_units):
        pos = units[i]._proto.pos
        position = (pos.x, pos.y)
        sum_x += position[0]
        sum_y += position[1]

    return (sum_x / num_units, sum_y / num_units)

In [36]:
%timeit Point2(cy_center(units))

53.2 µs ± 93.1 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


# Distance to / `unit.distance_to`

Check and profile distances between 2 units / 1 unit and a Point2 / two Point2's using `python-sc2` implementation of `distance_to`. 

Then see if cython does it faster

In [18]:
unit1 = bot.workers[0]
unit2 = bot.workers[4]
position1 = bot.game_info.map_center
position2 = bot.main_base_ramp.top_center

In [19]:
unit1.distance_to(unit2)

5.0

In [20]:
%timeit unit1.distance_to(unit2)

581 ns ± 11.1 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [21]:
unit1.distance_to(position1)

73.81395532011545

In [22]:
%timeit unit1.distance_to(position1)

988 ns ± 9.49 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [23]:
position1.distance_to(position2)

53.766904318548974

In [24]:
%timeit position1.distance_to(position2)

381 ns ± 3.13 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


## Convert `distance_to` to cython

In [25]:
%%cython

cpdef double cy_distance_to(
        (float, float) p1,
        (float, float) p2
):
    return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5

In [26]:
cy_distance_to(unit1.position, unit2.position)

5.0

In [27]:
%timeit cy_distance_to(unit1.position, unit2.position)

199 ns ± 2.34 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [28]:
cy_distance_to(unit1.position, position1)

73.81395532011545

In [29]:
%timeit cy_distance_to(unit1.position, position1)

175 ns ± 3.51 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)


In [30]:
cy_distance_to(position1, position2)

53.766907769498424

In [31]:
%timeit cy_distance_to(position1, position2)

154 ns ± 0.476 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
