In [1]:
class Group():
    def __init__(self, dim=[100, 100], rows_cols=10):
        self.dim = dim
        if type(rows_cols) is int:
            rows_cols = [int(rows_cols), int(rows_cols)]
        self.rows_cols = [int(rows_cols[0]), int(rows_cols[1])]
        self.data = self.new_data()
        self.dx = dim[0] / rows_cols[0]
        self.dy = dim[1] / rows_cols[1]
        self.delete_at_0 = False
        self.count = 0
        
    def __len__(self):
        return self.count
        
    def new_data(self):
        return np.array([[{} for _ in range(self.rows_cols[0])] for _ in range(self.rows_cols[1])])
        
    def add_point(self, x, y, value=0):
        self.count +=1
        i, j = min(int(x//self.dx), self.rows_cols[0]-1), min(int(y//self.dy), self.rows_cols[1]-1)
        self.data[i, j][(x, y)] = value
        return self
        
    def fill(self, n=0, value=0, random=True):
        self.count += n
        for _ in range(n):
            x, y = rd.random() * self.dim[0], rd.random() * self.dim[1]
            self.add_point(x, y, value=value)
        return self 
        
    def unroll(self):
        for i in range(self.rows_cols[0]):
            for j in range(self.rows_cols[1]):
                for n in self.data[i, j]:
                    yield [n, self.data[i, j][n]]
                    
    def get_around(self, x0, y0, radius):
        i0, j0 = int(x0//self.dx), int(y0//self.dy)
        di, dj = int(radius//self.dx), int(radius//self.dy)
        for i in range(i0-di, i0+di+1):
            if i >=0 and i < self.rows_cols[0]:
                for j in range(j0-dj, j0+dj+1):
                    if j >=0 and j < self.rows_cols[1]:
                        for n in self.data[i, j]:
                            if dist(n, (x0, y0)) <= radius:
                                yield [n, self.data[i, j][n]] 
                    
    def plot(self, display=True, around=None):
        x, y, v = [], [], []
        if around is None:
            for (a, b), val in self.unroll():
                x.append(a)
                y.append(b)
                v.append(float(val))
        else:
            for (a, b), val in self.get_around(around[0], around[1], around[2]):
                x.append(a)
                y.append(b)
                v.append(float(val))
        if display:
            fig = plt.figure(figsize=(8, 8))
            ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect=1)
            ax.set_xlim(0, 100)
            ax.set_ylim(0, 100)
            plt.scatter(x, y, s=v)
            if around is not None:
                plt.scatter(around[0], around[1], s=around[2], c='r')
            plt.show()
        return (x, y, v)
    
    def collision(self, distance=10, other=None):
        if other is None:
            other = self
            collided = []
            self_collision = True
        else:
            self_collision = False
        for n, a in self.unroll():
            for n_o, b in other.get_around(n[0], n[1], radius = distance):
                if a != b:
                    if not self_collision :
                        yield a, b, n, n_o
                    else:
                        if (b, a) not in collided:
                            collided.append((a, b))
                            yield a, b, n, n_o
        return None, None, None, None
        
    def change_value(self, pos, new_value):
        i, j = int(pos[0]//self.dx), int(pos[1]//self.dy)
        self.data[i, j][pos] = new_value
        return self
    
    def update_values(self, fix=None, random=True, do_if_neg=True):
        for i in range(self.rows_cols[0]):
            for j in range(self.rows_cols[1]):
                for n in self.data[i, j]:
                    if do_if_neg or self.data[i, j][n] > 0:
                        f = fix
                        if random:
                            f = f * rd.random()
                        self.data[i, j][n] += f        
        return self

In [2]:
class Population(Group):
    def __init__(self, dim=[100, 100], rows_cols=10):
        super(Population, self).__init__(dim=dim, rows_cols=rows_cols)
        
    def fill(self, n=0, random=True):
        for _ in range(n):
            x, y = rd.random() * self.dim[0], rd.random() * self.dim[1]
            self.add_point(x, y, value=Animal([x, y, rd.random()*360], dim = self.dim, age=rd.randint(0, DNA_MAX['Max_age'])))
                                                # rd.randint(0, DNA_MAX['Max_age']/2)

        return self 
    
    def plot(self, display=True, around=None):
        d = np.array([a[1].cin.arrow() for a in self.unroll()])
        if display:
            fig = plt.figure(figsize=(8, 8))
            ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect=1)
            ax.set_xlim(0, 100)
            ax.set_ylim(0, 100)    
            Q = ax.quiver(d[:, 0], d[:, 1], np.cos(d[:, 2]), np.sin(d[:, 2]), pivot='mid', color='r', scale = 50)
        return d
    
    def update_pos(self):
        new_data = self.new_data()
        for i in range(self.rows_cols[0]):
            for j in range(self.rows_cols[1]):
                l = list(self.data[i, j].keys())
                for n in l:
                    x, y, x_new, y_new = self.data[i, j][n].update_pos()
                    i_new, j_new = min(int(x_new//self.dx), self.rows_cols[0]-1), min(int(y_new//self.dy), self.rows_cols[1]-1)
                    new_data[i_new, j_new][(x_new, y_new)] = self.data[i, j][n]
        self.data = new_data
        return self
    
    def clean_dead(self):
        for i in range(self.rows_cols[0]):
            for j in range(self.rows_cols[1]):
                l = list(self.data[i, j].keys())
                for k in l:
                    if self.data[i, j][k].alive == False:
                        self.data[i, j].pop(k)
                        self.count -=1
        return self