In [28]:
import pandas as pd
import numpy as np
import cv2
import seaborn as sns
import matplotlib.pyplot as plt
import glob 

from skimage.util.shape import view_as_windows
from ipywidgets import interact, interactive, fixed, interact_manual

In [2]:
from sklearn.cluster import KMeans
from sklearn import datasets

In [None]:
%%html
<style>

div.output_scroll {
    height : auto;
}
</style>

In [3]:
%matplotlib notebook

In [39]:
class FlatGroundSpawnStrategy():
    def __init__(self, hm_path):
        self.hm = cv2.imread(hm_path)
        self.hm = cv2.cvtColor(self.hm, cv2.COLOR_BGR2GRAY)
        
    def center_patch(self, patch):
        w,h = patch.shape
        patch = patch.astype(np.float32)
        center = patch[w // 2, h // 2]
        patch -= center

        return patch   
    
    
    def show_spawn_pos(self, positions, size):
        fig = plt.figure()
        ax = plt.subplot(1, 1, 1)
        sns.heatmap(self.hm, ax=ax)

        for pos in positions:
            x,y = pos
            ax.plot(x + size // 2, y  + size //2, marker='o', color='r', ls='', label='finish')
    
    def find_spawn_points(self, size=40, step=2, tol=1e-3):
        positions = []
        patches =  view_as_windows(self.hm, (size,size), step)
        y = 0
        for row in patches:
            x = 0
            for patch in row:
                patch = self.center_patch(patch)
                if np.abs(patch.mean()) < tol:  positions.append((x,y))

                x += step
            y += step

        return positions
    
    def reduce_positions_by_clustering(self, positions, k=100):
        k = len(positions) if len(positions) < k else k
        X = np.array(positions)
        self.estimator = KMeans(n_clusters=k)
        self.estimator.fit(X)

        clusters2points = {i: X[np.where(self.estimator.labels_ == i)] for i in range(self.estimator.n_clusters)}
        new_positions = [clusters2points[key][len(item) // 2] for key, item in clusters2points.items()]
        
        return new_positions
    
    def __call__(self, k=100, size=40, *args, **kwargs):
        positions = self.find_spawn_points(size=size, *args, **kwargs)
        print(len(positions))
#         self.show_spawn_pos(positions, size)
        new_positions = self.reduce_positions_by_clustering(positions, k=k)
        self.show_spawn_pos(new_positions,size)
        
        return new_positions
        

In [40]:
def interact2strategy(hm_path):
    strat = FlatGroundSpawnStrategy(hm_path)
    strat(k=100, tol=1e-1)

interact(interact2strategy, hm_path=glob.glob('/home/francesco/Documents/Master-Thesis/core/maps/train/*.png'))

interactive(children=(Dropdown(description='hm_path', options=('/home/francesco/Documents/Master-Thesis/core/m…

<function __main__.interact2strategy(hm_path)>