In [None]:
import nbimporter
from map import Map
import numpy as np
import pandas as pd
from organisms import Organism, Plant
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
from IPython.display import clear_output

In [None]:
class World(object):
    
    def __init__(self, area, organisms, tile_limit=None):
        self.archive = defaultdict(lambda: [])
        self.map = area
        self.tile_limit = tile_limit
        self.cemetery = []
        for o in organisms:
            self.archive[o.location].append(o)
    
    def simulate(self, iterations):
        for iteration in range(iterations):
            buffer = defaultdict(lambda: [])
            for k, v in self.archive.items():
                for organism in v:
                    seed = organism.generate()
                    if seed is not None and (
                        self.tile_limit is None or 
                        len(self.archive[k]) < self.tile_limit):
                        buffer[seed.location].append(seed)
            for k, v in buffer.items():
                self.archive[k] += v
            yield iteration
    
    def simulate_death(self, iterations):
        for iteration in range(iterations):
            buffer = defaultdict(lambda: [])
            died = []
            for k, v in self.archive.items():
                for organism in v:
                    if organism.natural_death(iteration):
                        self.cemetery.append(organism)
                        died.append(organism)
                    else:
                        seed = organism.generate()
                        if seed is not None and (
                            self.tile_limit is None or 
                            len(self.archive[k]) < self.tile_limit):
                            buffer[seed.location].append(seed)
            for k, v in buffer.items():
                self.archive[k] += v
            for death in died:
                self.archive[death.location].remove(death)
            yield iteration

    def show(self, ax):
        self.map.show(ax)
        for k, v in self.archive.items():
            if len(v) > 0:
                for color, count in Counter([k.color for k in v]).most_common():
                    ax.scatter(self.map.map[k][0], self.map.map[k][1], 
                               s=min([5000, count*(20 / (1 + np.exp(-count)))]), 
                               alpha=0.3, c=color)
    
    def stats(self):
        data = defaultdict(lambda: [])
        for k, v in self.archive.items():
            for o in v:
                for x, y in o.as_dict().items():
                    data[x].append(y)
        for o in self.cemetery:
            for x, y in o.as_dict().items():
                data[x].append(y)
        return pd.DataFrame(data)