# World of Supply

A simulation environment for multi-echelon supply chain optimization problems. 

In [226]:
import numpy as np
from pprint import pprint
import random as rnd
import networkx as nx
from tqdm import tqdm_notebook as tqdm

from abc import ABC
from collections import Counter
from dataclasses import dataclass
from functools import lru_cache

# Core Simulation Logic

In [265]:
class WorldCell(ABC):
    def __init__(self, x, y):
        self.x = x
        self.y = y
    
    def __repr__(self):
        return f"{self.__class__.__name__} ({self.x}, {self.y})"
    
class Agent:
    def get_SAR(self):
        pass
    
    def set_action(self, action):
        pass

# ======= Infrastructure 
class TerrainWorldCell(WorldCell):
    def __init__(self, x, y):
        super(TerrainWorldCell, self).__init__(x, y)

class RailroadWorldCell(WorldCell):
    def __init__(self, x, y):
        super(RailroadWorldCell, self).__init__(x, y)

        
# ======= Transportation
class Transport(Agent):
    
    @dataclass
    class Control:
        capacity: int  # load/unload that many units
    
    def __init__(self, source, product_id):
        self.product_id = product_id
        self.source = source
        self.destination = None
        self.path = None
        self.location_pointer = 0
        self.step = 0
        self.payload = 0 # units

    def set_destination(self, world, destination):
        self.destination = destination
        self.path = world.find_railroad_path(self.source.x, self.source.y, self.destination.x, self.destination.y)
        if self.path == None:
            raise Exception(f"Destination {destination} is unreachable")
        self.step = 1    # 1 - to destination, -1 - to source, 0 - finished
        self.location_pointer = 0

    def path_len(self):
        if self.path == None:
            return 0
        else:
            return len(self.path)
    
    def is_enroute(self):
        return self.step != 0

    def current_location(self):
        if self.path == None:
            return (self.source.x, self.source.y)
        else:
            return self.path[self.location_pointer]
        
    def try_loading(self, capacity):
        if self.source.storage.try_take_units({ self.product_id: capacity }):
            self.payload = capacity
        
    def try_unloading(self, capacity):        
        if self.destination.storage.try_add_units({ self.product_id: self.payload }):
            self.payload = 0

    def act(self, epoch, control):
        if self.step > 0: 
            if self.location_pointer == 0 and self.payload == 0:
                self.try_loading(control.capacity)
           
            if self.payload > 0:   # will stay at the source until loaded
                if self.location_pointer < len(self.path) - 1:
                    self.location_pointer += self.step
                else:
                    self.step = -1 # arrived to the destination

        if self.step < 0: 
            if self.location_pointer == len(self.path) - 1 and self.payload > 0:
                self.try_unloading(control.capacity)
                
            if self.payload == 0: # will stay at the destination until unloaded
                if self.location_pointer > 0: 
                    self.location_pointer += self.step
                else:
                    self.step = 0 # arrived to back to the source

                    
# ======= Storage components
class Storage():
    def __init__(self, max_capacity, storage_cost_unit):
        self.max_capacity = max_capacity
        self.storage_cost_unit = storage_cost_unit
        self.stock_levels = Counter()
    
    def used_capacity(self):
        return sum(self.stock_levels.values())
    
    def available_capacity(self):
        return self.max_capacity - self.used_capacity()
    
    def try_add_units(self, product_quantities):
        # validation
        if self.available_capacity() < sum(product_quantities.values()):
            return False
        # depositing
        for p_id, q in product_quantities.items():
            self.stock_levels[p_id] += q
        return True
    
    def try_take_units(self, product_quantities):
        # validation
        for p_id, q in product_quantities.items():
            if self.stock_levels[p_id] < q:
                return False
        # withdrawal
        for p_id, q in product_quantities.items():
            self.stock_levels[p_id] -= q  
        return True

    
# ======= Manufacturing facilities
@dataclass    
class BillOfMaterials:
    # One manufacturing cycle consumes inputs 
    # and produces output_lot_size units of output_product_id
    
    inputs: Counter  # (product_id -> quantity per lot)
    output_product_id: int
    output_lot_size: int = 1
        
    def input_units_per_lot(self):
        return sum(self.inputs.values())

@dataclass
class ProducerSpecification:
    fleet_size: int
    destinations: list
    bill_of_materials: BillOfMaterials
    storage_max_capacity: int
    storage_cost_unit: int
                
class ProducerWorldCell(WorldCell, Agent):
    
    @dataclass
    class Control:
        production_rate: int         # lots per time step
        priority_destination: list   # permutation over destinations
        transport_control: Transport.Control
    
    def __init__(self, x, y, world, spec):  
        super(ProducerWorldCell, self).__init__(x, y)
        self.world = world
        self.fleet = [Transport(self, spec.bill_of_materials.output_product_id) for i in range(spec.fleet_size)]
        #print(f"{self} -> {spec.destinations}")
        self.destinations = spec.destinations
        self.storage = Storage(spec.storage_max_capacity, spec.storage_cost_unit)
        self.bom = spec.bill_of_materials
    
    def act(self, epoch, control):
        # manufacturing cycle
        # check we have enough storage space for the output lot
        for i in range(control.production_rate):
            if self.storage.available_capacity() >= self.bom.output_lot_size - self.bom.input_units_per_lot(): 
                # check we have enough input materials  
                if self.storage.try_take_units(self.bom.inputs):                
                    self.storage.stock_levels[self.bom.output_product_id] += self.bom.output_lot_size
        
        # distribution cycle
        for vechicle in self.fleet:
            if not vechicle.is_enroute():
                #print(f"{self} -> {self.destinations[control.priority_destination]}")
                vechicle.set_destination( world, self.destinations[control.priority_destination] )
                vechicle.act(epoch, control.transport_control)
                break  # can schedule up to one vechicle per time step
            else:
                vechicle.act(epoch, control.transport_control)
            
            
class FactoryWorldCell(ProducerWorldCell):
    def __init__(self, x, y, world, spec):
        super(FactoryWorldCell, self).__init__(x, y, world, spec)

class SteelFactoryWorldCell(FactoryWorldCell):
    def __init__(self, x, y, world, spec):
        super(SteelFactoryWorldCell, self).__init__(x, y, world, spec)

class LamberFactoryWorldCell(FactoryWorldCell):
    def __init__(self, x, y, world, spec):
        super(LamberFactoryWorldCell, self).__init__(x, y, world, spec)

class ToyFactoryWorldCell(FactoryWorldCell):
    def __init__(self, x, y, world, spec):
        super(ToyFactoryWorldCell, self).__init__(x, y, world, spec)

class WarehouseWorldCell(ProducerWorldCell):
    def __init__(self, x, y, world, spec):
        super(WarehouseWorldCell, self).__init__(x, y, world, spec)

        
# ======= End consumer facilities        
@dataclass
class EndConsumerSpecification:
    storage_max_capacity: int
    storage_cost_unit: int
        
class RetailerWorldCell(WorldCell, Agent):
    
    @dataclass
    class Control:
        pass
    
    def __init__(self, x, y, spec):
        super(RetailerWorldCell, self).__init__(x, y)
        self.storage = Storage(spec.storage_max_capacity, spec.storage_cost_unit)
        
    def act(self, epoch, control):
        pass

    
# ======= The World
class World:
    
    @dataclass
    class Control:
        retailer_controls: list
        warehouse_controls: list
        factory_controls: list    
        raw_material_controls: list    
    
    def __init__(self, x, y):
        self.size_x = x
        self.size_y = y
        self.grid = None

        self.retailers = []
        self.warehouses = []
        self.factories = []
        self.raw_materials = []
        
    def create_cell(self, x, y, clazz):
        self.grid[x][y] = clazz(x, y)

    def place_cell(self, *cells):
        for c in cells:
            self.grid[c.x][c.y] = c
        
    def act(self, epoch, control):
        for actor, ctrl in zip(self.raw_materials, control.raw_material_controls):
            actor.act(epoch, ctrl)
        for actor, ctrl in zip(self.factories, control.factory_controls):
            actor.act(epoch, ctrl)
        for actor, ctrl in zip(self.warehouses, control.warehouse_controls):
            actor.act(epoch, ctrl)
        for actor, ctrl in zip(self.retailers, control.retailer_controls):
            actor.act(epoch, ctrl)
            
    def is_railroad(self, x, y):
        return isinstance(self.grid[x][y], RailroadWorldCell)
    
    def is_traversable(self, x, y):
        return not isinstance(self.grid[x][y], TerrainWorldCell)
    
    def c_tostring(x,y):
        return np.array([x,y]).tostring()
                
    def map_to_graph(self):
        g = nx.Graph()
        for x in range(1, self.size_x-1):
            for y in range(1, self.size_y-1):
                for c in [(x-1, y), (x+1, y), (x, y-1), (x, y+1)]:
                    if self.is_traversable(x, y) and self.is_traversable(c[0], c[1]):
                        g.add_edge(World.c_tostring(x, y), World.c_tostring(c[0], c[1]))
        return g
    
    @lru_cache(maxsize=32) # speedup the simulation
    def find_railroad_path(self, x1, y1, x2, y2):
        g = self.map_to_graph()
        path = nx.astar_path(g, source=World.c_tostring(x1, y1), target=World.c_tostring(x2, y2))
        path_np = [np.fromstring(p, dtype=int) for p in path]
        return [(p[0], p[1]) for p in path_np]
    
    
class WorldBuilder:
    def create(x, y):
        world = World(x, y)
        world.grid = [[TerrainWorldCell(xi, yi) for yi in range(y)] for xi in range(x)]
        
        # parameters
        def producer_spec(destinations, bom):
            return ProducerSpecification(fleet_size = 1, 
                                         destinations = destinations, 
                                         bill_of_materials = bom, 
                                         storage_max_capacity = 10, 
                                         storage_cost_unit = 1)
        def consumer_spec():
            return EndConsumerSpecification(storage_max_capacity = 10, storage_cost_unit = 3)
        
        distribution_bom = BillOfMaterials(Counter({'toy_car': 1}), 'toy_car' )
        toy_bom = BillOfMaterials(Counter({'lamber': 1, 'steel': 1}), 'toy_car')
        steel_bom = BillOfMaterials(Counter(), 'steel', 1)
        lamber_bom = BillOfMaterials(Counter(), 'lamber', 1)
        
        map_margin = 2
        size_y_margins = world.size_y - 2*map_margin
        
        # final consumers
        n_retailers = 3
        world.retailers = [ \
            RetailerWorldCell(70, int(size_y_margins/(n_retailers - 1)*i + map_margin), consumer_spec() ) \
            for i in range(n_retailers) ]
        world.place_cell(*world.retailers) 

        # distribution  
        n_warehouses = 2
        world.warehouses = []
        for i in range(n_warehouses):
            w =  WarehouseWorldCell(50, int(size_y_margins/(n_warehouses - 1)*i + map_margin), world, producer_spec(world.retailers, distribution_bom) )
            world.warehouses.append(w)
            world.place_cell(w) 
            WorldBuilder.connect_cells(world, w, *world.retailers)

        # manufacturing
        n_toy_factories = 3
        world.factories = []
        for i in range(n_toy_factories):
            f = ToyFactoryWorldCell(35, int(size_y_margins/(n_toy_factories - 1)*i + map_margin), world, producer_spec(world.warehouses, toy_bom) )
            world.factories.append(f)
            world.place_cell(f) 
            WorldBuilder.connect_cells(world, f, *world.warehouses)

        # raw materials
        steel_01 = SteelFactoryWorldCell(10, 6, world, producer_spec(world.factories, steel_bom) ) 
        lamber_01 = LamberFactoryWorldCell(10, 10, world, producer_spec(world.factories, lamber_bom) )
        
        world.raw_materials = [steel_01, lamber_01]
        world.place_cell(*world.raw_materials) 
        WorldBuilder.connect_cells(world, steel_01, *world.factories)
        WorldBuilder.connect_cells(world, lamber_01, *world.factories)
    
        return world
        
    def connect_cells(world, source, *destinations):
        for dest_cell in destinations:
            WorldBuilder.build_railroad(world, source.x, source.y, dest_cell.x, dest_cell.y)
        
    def build_railroad(world, x1, y1, x2, y2):
        step_x = np.sign(x2 - x1)
        step_y = np.sign(y2 - y1)

        # make several attempts to find a route non-adjacent to existing roads  
        for i in range(5):
            xi = min(x1, x2) + int(abs(x2 - x1) * rnd.uniform(0.1, 0.9))
            if not (world.is_railroad(xi-1, y1) or world.is_railroad(xi+1, y1)):
                break

        for x in range(x1 + step_x, xi, step_x):
            world.create_cell(x, y1, RailroadWorldCell) 
        if step_y != 0:
            for y in range(y1, y2, step_y):
                world.create_cell(xi, y, RailroadWorldCell) 
            for x in range(xi, x2, step_x):
                world.create_cell(x, y2, RailroadWorldCell) 


#  ======= Baseline control policies               
class SimpleControlPolicy:
    
    def __init__(self, destination_optimization = 'const'):
        # destination optimization:
        #   'const' - always choose the first destination from the list
        #   'min_stock' - choose a destination with the minimal stock of the produced item  
        self.destination_optimization = destination_optimization 
    
    def get_control(self, epoch, world):
        world_control = World.Control([], [], [], [])
        
        world_control.retailer_controls = [ RetailerWorldCell.Control() ] * len(world.retailers)
        
        transport_control = Transport.Control(capacity = 5)
        def facility_control(source):
            return ProducerWorldCell.Control(production_rate = 10, 
                                             priority_destination = self.optimize_destination(source), 
                                             transport_control = transport_control)
                     
        world_control.warehouse_controls = [ facility_control(f) for f in world.warehouses ]
        world_control.factory_controls = [ facility_control(f) for f in world.factories ] 
        world_control.raw_material_controls = [ facility_control(f) for f in world.raw_materials ]
            
        return world_control
    
    def optimize_destination(self, source):
        if self.destination_optimization == 'const':
            return 0
        else:
            return 0 ## TODO
        

# ======= Measure the simulation rate, steps/sec
world = WorldBuilder.create(80, 16)
policy = SimpleControlPolicy()
for i in tqdm(range(100000)):
    world.act(i, policy.get_control(i, world))

HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))






# Simple ASCII Rendered

In [268]:
import PIL
from PIL import Image, ImageFont, ImageDraw
import numpy as np
from IPython.display import Image
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import display, HTML
import yaml
from multipledispatch import dispatch

class Utils:
    def ascii_progress_bar(done, limit, bar_lenght_char = 15):
        if limit == 0:
            done_chars = 0
        else:
            done_chars = round(done/limit*bar_lenght_char)
        bar = ['='] * (done_chars)
        return ''.join(bar + (['-'] * (bar_lenght_char - done_chars)) + [f" {done}/{limit}"])

class WorldRenderer:
    def plot_sequence_images(image_array):
        ''' Display images sequence as an animation in jupyter notebook
    
        Args:
        image_array(numpy.ndarray): image_array.shape equal to (num_images, height, width, num_channels)
        '''
        dpi = 72.0
        xpixels, ypixels = image_array[0].shape[:2]
        fig = plt.figure(figsize=(ypixels/dpi, xpixels/dpi), dpi=dpi)
        im = plt.figimage(image_array[0])

        def animate(i):
            im.set_array(image_array[i])
            return (im,)

        anim = animation.FuncAnimation(fig, animate, frames=len(image_array), interval=200, repeat_delay=1, repeat=True)
        display(HTML(anim.to_html5_video()))

class AsciiWorldStatusPrinter():
    
    @dispatch(World)
    def status(world: World) -> list:
        status = []
        for r in world.retailers:
            status.append(AsciiWorldStatusPrinter.status(r))    
        for w in world.warehouses:
            status.append(AsciiWorldStatusPrinter.status(w))    
        for f in world.factories:
            status.append(AsciiWorldStatusPrinter.status(f)) 
        for r in world.raw_materials:
            status.append(AsciiWorldStatusPrinter.status(r))
            
        return status  
    
    def cell_status(cell):
        return [f"{cell.__class__.__name__} ({cell.x}, {cell.y})"]
    
    @dispatch(RetailerWorldCell)
    def status(retailer: RetailerWorldCell) -> list:
        status = AsciiWorldStatusPrinter.cell_status(retailer)
        storage_status = ["Storage:", AsciiWorldStatusPrinter.status(retailer.storage) ]
        status.append([storage_status])
        return status
    
    @dispatch(ProducerWorldCell)
    def status(producer: ProducerWorldCell) -> list:
        status = AsciiWorldStatusPrinter.cell_status(producer)
        fleet_status = ["Fleet:", 
                        [f"{v.__class__.__name__} {Utils.ascii_progress_bar(v.location_pointer, v.path_len()-1)}, payload: {v.payload}" 
                         for v in producer.fleet]
                       ]
        storage_status = ["Storage:", AsciiWorldStatusPrinter.status(producer.storage) ]
        status.append([fleet_status, storage_status])
        return status
    
    @dispatch(Storage)
    def status(storage: Storage) -> list:
        return [f"Usage: {Utils.ascii_progress_bar(storage.used_capacity(), storage.max_capacity)}",
                f"Storage cost/unit: {storage.storage_cost_unit}",
                f"Inventory: {dict(storage.stock_levels)}"]

    
class AsciiWorldRenderer(WorldRenderer):
    def render(self, world):
        ascii_canvas = []

        # print infrastructure (background)
        for y in range(world.size_y):
            row = []
            for x in range(world.size_x):
                c = world.grid[x][y]
                if isinstance(c, RailroadWorldCell):
                    row.append(self.railroad_sprite(x, y, world.grid)) 
                else:
                    row.append(' ') 
            ascii_canvas.append(row)

        # print vechicles
        for y in range(world.size_y):
            for x in range(world.size_x):
                c = world.grid[x][y]
                if isinstance(c, ProducerWorldCell):
                    for vechicle in c.fleet:
                        if vechicle.is_enroute():              
                            location = vechicle.current_location()
                            ascii_canvas[location[1]][location[0]] = '*'
                            
        # print facilities (foreground)
        for y in range(world.size_y):
            for x in range(world.size_x):
                c = world.grid[x][y]
                if isinstance(c, SteelFactoryWorldCell):
                    ascii_canvas[y][x] = 'S' 
                if isinstance(c, LamberFactoryWorldCell):
                    ascii_canvas[y][x] = 'L' 
                if isinstance(c, ToyFactoryWorldCell):
                    ascii_canvas[y][x] = 'T' 
                if isinstance(c, WarehouseWorldCell):
                    ascii_canvas[y][x] = 'W' 
                if isinstance(c, RetailerWorldCell):
                    ascii_canvas[y][x] = 'R' 

        # print ascii on canvas
        margin_px = 10 
        text = "\n".join(''.join(row) for row in ascii_canvas)
        font = ImageFont.truetype("resources/FiraMono-Bold.ttf", 24)
        (font_x, font_y) = font.getsize('╬')
        img_w = font_x * world.size_x + 2 * margin_px
        img_h = font_y * world.size_y * 2
        
        img = PIL.Image.new('RGB', (img_w, img_h), color='#1F2605')
        canvas = ImageDraw.Draw(img)
        canvas.text((margin_px, margin_px), text, font=font, fill='#D6Ce15')
        
        # print logo
        logo = PIL.Image.open('resources/world-of-supply-logo.png', 'r').convert("RGBA")
        logo.thumbnail((img_w/5, img_h/10), PIL.Image.ANTIALIAS)
        img.paste(logo, (int(img_w/2 - img_w/10), 0), mask=logo)
        
        # print status
        font = ImageFont.truetype("resources/FiraMono-Regular.ttf", 12)
        status = AsciiWorldStatusPrinter.status(world)
        n_row = 5  # facilities per column
        col_wide = 30
        i_col = 0
        for i in range(0, len(status), n_row):
            canvas.text((50 + font_x * col_wide * i_col, font_y * world.size_y * 0.9), 
                        yaml.dump(status[i : i+n_row], default_style=None), font=font, fill='#BBBBBB')
            i_col += 1
        
        return img

    def railroad_sprite(self, x, y, grid):
        top = False
        bottom = False
        left = False
        right = False

        if isinstance(world.grid[x-1][y], RailroadWorldCell):
            left = True
        if isinstance(world.grid[x+1][y], RailroadWorldCell):
            right = True
        if isinstance(world.grid[x][y-1], RailroadWorldCell):
            top = True
        if isinstance(world.grid[x][y+1], RailroadWorldCell):
            bottom = True

        # Sprites: ╔╗╚╝╠╣╦╩╬═║
        if (top or bottom) and not right and not left:
            return '║'
        if (right or left) and not top and not bottom:
            return '═'  
        if top and not bottom and right and not left:
            return '╚'
        if top and not bottom and not right and left:
            return '╝' 
        if bottom and not top and right and not left:
            return '╔' 
        if bottom and not top and not right and left:
            return '╗'
        if top and bottom and not right and left:
            return '╣'
        if top and bottom and right and not left:
            return '╠'
        if top and not bottom and right and left:
            return '╩'
        if bottom and not top and right and left:
            return '╦'
        if top and bottom and right and left:
            return '╬'  
        
renderer = AsciiWorldRenderer()
frame_seq = []
world = WorldBuilder.create(80, 16)
policy = SimpleControlPolicy()
for epoch in range(100):
    if epoch % 20 == 0:
        print(f"Rendering epoch {epoch}")
    frame = renderer.render(world)
    frame_seq.append(np.asarray(frame))
    world.act(epoch, policy.get_control(epoch, world))

AsciiWorldRenderer.plot_sequence_images(frame_seq)

Rendering epoch 0




Rendering epoch 20
Rendering epoch 40
Rendering epoch 60
Rendering epoch 80


<Figure size 1300x1120 with 0 Axes>