In [1]:
import random
import math
from collections import defaultdict
import random
import time
import csv

In [2]:
class Node:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0

    def add_child(self, child):
        self.children.append(child)

    def update(self, value):
        self.visits += 1
        self.value += value

    def ucb1(self, total_visits):
        return self.value / self.visits + math.sqrt(2 * math.log(total_visits) / self.visits)

In [3]:
class MCTS:
    def __init__(self, game, num_simulations):
        self.game = game
        self.num_simulations = num_simulations

    def search(self, root_state):
        root = Node(root_state)

        for _ in range(self.num_simulations):
            node = self.select(root)
            result = self.simulate(node.state)
            self.backpropagate(node, result)

        return self.best_child(root)

    def select(self, node):
        while not self.game.is_terminal(node.state):
            if not self.game.is_fully_expanded(node):
                return self.expand(node)
            else:
                node =  self.best_child(node)
        return node

    def expand(self, node):
        state = self.game.random_unexpanded_child(node)
        child = Node(state, parent=node)
        node.add_child(child)
        return child

    def simulate(self, state):
        while not self.game.is_terminal(state):
            state = self.game.random_child(state)
        return self.game.reward(state)

    def backpropagate(self, node, result):
        while node is not None:
            node.update(result)
            result = -result
            node = node.parent

    def best_child(self, node):
        total_visits = sum(child.visits for child in node.children)
        return max(node.children, key=lambda child: child.ucb1(total_visits))

In [4]:
class SOR:
    def __init__(self, items, capacity, side):
        self.items = items
        self.capacity = capacity
        self.side = side

    def is_terminal(self, state):
        remaining_capacity, remaining_items, dropped_items = state
        return remaining_capacity <= 0 or len(remaining_items) == 0

    def is_fully_expanded(self, node):
        return len(node.state[1]) == len(node.children)

    def random_child(self, state):
        remaining_capacity, remaining_items, dropped_items = state
        item = random.choice(remaining_items)
        new_remaining_items = remaining_items.copy()
        new_remaining_items.remove(item)
        new_dropped_items = dropped_items.copy()
        if item[2] <= remaining_capacity:
            new_remaining_capacity = remaining_capacity - item[2]
            return new_remaining_capacity, new_remaining_items, new_dropped_items
        else:
            new_dropped_items.append(item)
            return remaining_capacity, new_remaining_items, new_dropped_items

    def random_unexpanded_child(self, node):
        remaining_capacity, remaining_items, dropped_items = node.state
        expanded_items = []
        for child in node.children:
            expanded_item = [item for item in remaining_items if item not in set(child.state[1])]
            expanded_items = list(set(expanded_items)&set(expanded_item))
        unexpanded_items = list(set(remaining_items)-set(expanded_items))
        item = random.choice(unexpanded_items)
        new_remaining_items = remaining_items.copy()
        new_remaining_items.remove(item)
        new_dropped_items = dropped_items.copy()
        if item[2] <= remaining_capacity:
            new_remaining_capacity = remaining_capacity - item[2]
            return new_remaining_capacity, new_remaining_items, new_dropped_items
        else:
            new_dropped_items.append(item)
            return remaining_capacity, new_remaining_items, new_dropped_items

    def reward(self, state):
        remaining_capacity, remaining_items, dropped_items = state
        if self.side == 'buy':
            return sum(item[1] for item in self.items) - sum(item[1] for item in remaining_items)
        else:
            return sum(item[1] for item in remaining_items) - sum(item[1] for item in self.items)
    

In [5]:
def smart_order_router(excelbids, excelasks, side, symbol, qty, order_type='market', price = None):
    needed_qty = qty
    route = {}
    available_qty = 0
    if side == 'buy':
        bid_shared = []
        for shared in excelbids[symbol]:
            if order_type == 'limit' and shared[1] <= price:
                bid_shared.append(shared)
        if len(bid_shared) == 0:
            return 'No available shared'
        sor = SOR(bid_shared, needed_qty, side)
        mcts = MCTS(sor, num_simulations=500)
        root_state = (needed_qty, bid_shared, [])
        best_state = mcts.search(root_state)
        solution = best_state
        while not len(solution.children)==0:
            solution =mcts.best_child(solution)
        packed_items = [item for item in bid_shared if ((item not in solution.state[1]) and (item not in solution.state[2]))]
        for packed_item in packed_items:
            exchanger = packed_item[0]
            if exchanger not in route:
                route[exchanger] = {}
            route[exchanger][packed_item[1]] = packed_item[2]
            available_qty = available_qty + packed_item[2]
        order = {'route':route, 'leave_qty': needed_qty - available_qty}
    elif side == 'sell':
        ask_shared = []
        for shared in excelasks[symbol]:
            if order_type == 'limit' and shared[1] >= price:
                ask_shared.append(shared)
        if len(ask_shared) == 0:
            return 'No available shared'
        sor = SOR(ask_shared, needed_qty, side)
        mcts = MCTS(sor, num_simulations=500)
        root_state = (needed_qty, ask_shared, [])
        best_state = mcts.search(root_state)
        solution = best_state
        while not len(solution.children)==0:
            solution =mcts.best_child(solution)
        packed_items = [item for item in ask_shared if ((item not in solution.state[1]) and (item not in solution.state[2]))]
        for packed_item in packed_items:
            exchanger = packed_item[0]
            if exchanger not in route:
                route[exchanger] = {}
            route[exchanger][packed_item[1]] = packed_item[2]
            available_qty = available_qty + packed_item[2]
        order = {'route':route, 'leave_qty': needed_qty - available_qty}
    return order

In [6]:
def test_order_number():
	excelbids = defaultdict(list)
	excelasks = defaultdict(list)
	for i in range(0,50):
		exchanger_name = "Exchanger"+str(i)
		for _ in range(0, 5):
			excelbids['Microsofts'].append((exchanger_name,round(random.uniform(183.5, 184.5), 2),random.randint(100, 500)))
			excelasks['Microsofts'].append((exchanger_name,round(random.uniform(183.5, 184.5), 2),random.randint(100, 500)))
	results = []
	for i in range(1, 21, 1):
		order_counts = i
		print(order_counts)
		start_time = time.time()
		for order_count in range(0, order_counts):
			qty = random.randint(1000, 5000)
			limit_price = round(random.uniform(183.5, 184.5), 2)
			side = random.choice(['buy', 'sell'])
			route = smart_order_router(excelasks, excelbids, 'sell', 'Microsofts', qty, 'limit', limit_price)
		end_time = time.time()
		results.append([order_counts, end_time - start_time])
	with open('./mcts_order.csv', 'w') as f:
		writer = csv.writer(f)
		for result in results:
			writer.writerow(result)

In [7]:
def test_exchanger_number():
    results = []
    for i in range(0, 110, 10):
        exchanger_num = i
        if exchanger_num == 0:
            exchanger_num = 2
        print(exchanger_num)
        excelbids = defaultdict(list)
        excelasks = defaultdict(list)
        for i in range(0, exchanger_num):
            exchanger_name = "Exchanger"+str(i)
            for _ in range(0, 5):
                excelbids['Microsofts'].append((exchanger_name,round(random.uniform(183.5, 184.5), 2),random.randint(100, 500)))
                excelasks['Microsofts'].append((exchanger_name,round(random.uniform(183.5, 184.5), 2),random.randint(100, 500)))
        start_time = time.time()
        for order_count in range(0, 20):
            qty = random.randint(1000, 5000)
            limit_price = round(random.uniform(183.5, 184.5), 2)
            side = random.choice(['buy', 'sell'])
            route = smart_order_router(excelasks, excelbids, 'sell', 'Microsofts', qty, 'limit', limit_price)
        end_time = time.time()
        results.append([exchanger_num, end_time - start_time])
    print('finished')
    with open('./mcts_exchanger.csv', 'w') as f:
        writer = csv.writer(f)
        for result in results:
            writer.writerow(result)

In [16]:
test_order_number()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20


In [8]:
test_exchanger_number()

2
10
20
30
40
50
60
70
80
90
100
finished
