#### This file will demonstrate, and visualise the key functionalities of the Jax order book implementation. Further, it will measure the walltime speeds for these basic operations.

In [7]:
%load_ext autoreload
%autoreload 2
from functools import partial, partialmethod
from typing import OrderedDict
from jax import numpy as jnp
import jax

import sys
sys.path.append('/Users/millionaire/Desktop/UCL/Thesis/AlphaTrade-jaxV3')

#jax.config.update('jax_platform_name', 'cpu')

import gymnax_exchange.jaxob.JaxOrderBookArrays as job

import random
import time
import timeit

import gymnax_exchange




The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
def create_init_book(booksize=10,tradessize=10,pricerange=[2190000,2200000,2210000],quantrange=[0,500],timeinit=[34200,0]):
    ''' 
    Add:
    I´m a little confused that we do not care about bids that are higher than asks and vice versa ...
    Adjusted Code by Mike 6.06 at 14:35
    Original Code:
    asks.append([random.randint(pricerange[1],pricerange[2]),random.randint(quantrange[0],quantrange[1]),orderid,traderid,times,timens])
    bids.append([random.randint(pricerange[1],pricerange[2]),random.randint(quantrange[0],quantrange[1]),orderid+1,traderid+1,times,timens])
    '''
    qtofill=booksize//3 #fill one third of the available space
    asks=[] 
    bids=[]
    orderid=1000
    traderid=1000 
    times=timeinit[0]
    timens=timeinit[1]
    for i in range(qtofill):
        asks.append([random.randint(pricerange[1],pricerange[2]),random.randint(quantrange[0],quantrange[1]),orderid,traderid,times,timens])
        times+=random.randint(0,1)
        timens+=random.randint(0,10000)
        bids.append([random.randint(pricerange[0],pricerange[1]),random.randint(quantrange[0],quantrange[1]),orderid+1,traderid+1,times,timens])
        times+=random.randint(0,1)
        timens+=random.randint(0,10000)
        orderid+=2
        traderid+=2
    bids=jnp.concatenate((jnp.array(bids),jnp.ones((booksize-qtofill,6),dtype=jnp.int32)*-1),axis=0)
    asks=jnp.concatenate((jnp.array(asks),jnp.ones((booksize-qtofill,6),dtype=jnp.int32)*-1),axis=0)
    trades=jnp.ones((tradessize,6),dtype=jnp.int32)*-1
    return asks,bids,trades

def create_message(type='limit',side='bid',price=2200000,quant=10,times=36000,timens=0):
    if type=='limit':
        type_num=1
    elif type =='cancel' or type == 'delete':
        type_num=2
    elif type =='market':
        type_num=4
    else:
        raise ValueError('Type is none of: limit, cancel, delete or market')

    if side=='bid':
        side_num=1
    elif side =='ask':
        side_num=-1
    else:
        raise ValueError('Side is none of: bid or ask')
    
    dict_msg={
    'side':side_num,
    'type':type_num,
    'price':price,
    'quantity':quant,
    'orderid':8888,
    'traderid':8888,
    'time':times,
    'time_ns':timens}
    array_msg=jnp.array([type_num,side_num,quant,price,8888,8888,times,timens])
    return dict_msg,array_msg

def create_message_forvmap(type='limit',side='bid',price=2200000,quant=10,times=36000,timens=0,nvmap=10):
    if type=='limit':
        type_num=1
    elif type =='cancel' or type == 'delete':
        type_num=2
    elif type =='market':
        type_num=4
    else:
        raise ValueError('Type is none of: limit, cancel, delete or market')

    if side=='bid':
        side_num=1
    elif side =='ask':
        side_num=-1
    else:
        raise ValueError('Side is none of: bid or ask')
    
    dict_msg={
    'side':jnp.array([side_num]*nvmap),
    'type':jnp.array([type_num]*nvmap),
    'price':jnp.array([price]*nvmap),
    'quantity':jnp.array([quant]*nvmap),
    'orderid':jnp.array([8888]*nvmap),
    'traderid':jnp.array([8888]*nvmap),
    'time':jnp.array([times]*nvmap),
    'time_ns':jnp.array([timens]*nvmap)}
    array_msg=jnp.array([type_num,side_num,quant,price,8888,8888,times,timens]*nvmap)
    return dict_msg,array_msg

Measuring the time for the most basic operations: Adding and order and removing an order from a given side of the book.

In [14]:
## Add an order
n_runs=1000
random.seed(0)
addout=[]
for i in [10,100,1000]:
    asks,bids,trades=create_init_book(booksize=i)
    mdict,marray=create_message(type='limit',side='bid',price=2191200,quant=77)
    out=job.add_order(bids,mdict)
    addout.append(out)
    print("Add time for orderbook of size",i,":",timeit.timeit('val=job.add_order(bids,mdict); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)
    
print(addout[0])

random.seed(0)
#Now do it when vmapped
nvmap=1000
for i,s in enumerate([10,100,1000]):
    asks,bids,trades=create_init_book(booksize=s)
    vmdict,marray=create_message_forvmap(type='limit',side='bid',price=2191200,quant=77,nvmap=nvmap)

    vbids=jnp.array([bids]*nvmap)

    out=job.add_order(bids,mdict)
    outv=jax.vmap(job.add_order,(0,{'orderid': 0, 'price': 0, 'quantity': 0, 'side': 0, 'time': 0, 'time_ns': 0, 'traderid': 0, 'type': 0}))(vbids,vmdict)
    
    print("VMAP add time for orderbook of size",s," \n various incoming order sizes:",timeit.timeit("val=jax.vmap(job.add_order,(0,{'orderid': 0, 'price': 0, 'quantity': 0, 'side': 0, 'time': 0, 'time_ns': 0, 'traderid': 0, 'type': 0}))(vbids,vmdict); jax.block_until_ready(val)",number=n_runs,globals=globals())/n_runs)






Add time for orderbook of size 10 : 2.2923666998394766e-05
Add time for orderbook of size 100 : 2.1344291002606043e-05
Add time for orderbook of size 1000 : 4.853370900309528e-05
[[2204242     494    1001    1001   34201     663]
 [2209558     456    1003    1003   34203   13163]
 [2204104     465    1005    1005   34203   22984]
 [2191200      77    8888    8888   36000       0]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]]
VMAP add time for orderbook of size 10  
 various incoming order sizes: 0.0006840860000011162
VMAP add time for orderbook of size 100  
 various incoming order sizes: 0.004222076082998683
VMAP add time for orderbook of size 1000  
 various incoming order sizes: 0.03733441954199952


In [15]:
## Cancel an order
n_runs=1000
random.seed(0)
cancelout=[]
for i,s in enumerate([10,100,1000]):
    bids=addout[i]
    mdict,marray=create_message(type='cancel',side='bid',price=2191200,quant=77)
    out=job.cancel_order(bids,mdict)
    cancelout.append(out)
    print("Cancel time for orderbook of size",s,":",timeit.timeit('val=job.cancel_order(bids,mdict); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)
cancelout[0]

Cancel time for orderbook of size 10 : 1.402724999934435e-05
Cancel time for orderbook of size 100 : 1.421916700201109e-05
Cancel time for orderbook of size 1000 : 1.915174999885494e-05


Array([[2204242,     494,    1001,    1001,   34201,     663],
       [2209558,     456,    1003,    1003,   34203,   13163],
       [2204104,     465,    1005,    1005,   34203,   22984],
       [     -1,      -1,      -1,      -1,      -1,      -1],
       [     -1,      -1,      -1,      -1,      -1,      -1],
       [     -1,      -1,      -1,      -1,      -1,      -1],
       [     -1,      -1,      -1,      -1,      -1,      -1],
       [     -1,      -1,      -1,      -1,      -1,      -1],
       [     -1,      -1,      -1,      -1,      -1,      -1],
       [     -1,      -1,      -1,      -1,      -1,      -1]],      dtype=int32)

Matching a single order against an identified order from the other side of the book:

In [21]:
matchout=[]

for i,s in enumerate([10,100,1000]):
    '''
    Adjusted by Mike 06.06 13:35
    Original code was:
    matchtuple=(bids,1000,0,idx,trades,9999,36000,1)
    bids,qtm,price,idx,trades,agrid,times,timens=job.match_order(matchtuple)
    '''
    _,_,trades=create_init_book(booksize=s)

    bids=cancelout[i]
    idx=job.__get_top_bid_order_idx(bids)
    print("Time to get top bid order for order book of size ",s,":",timeit.timeit('val=job.__get_top_bid_order_idx(bids); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)

    matchtuple=(idx, bids, 1000, 0, trades, 9999, 36000, 1)
    bids,qtm,price,trades,agrid,times,timens=job.match_order(matchtuple)
    
    matchout.append((bids,qtm,trades))
    print("Match time for orderbook of size",s,":",timeit.timeit('val=job.match_order(matchtuple); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)
matchout[0]

Time to get top bid order for order book of size  10 : 8.134166004310828e-06
Match time for orderbook of size 10 : 3.160449999995763e-05
Time to get top bid order for order book of size  100 : 8.639083003799897e-06
Match time for orderbook of size 100 : 2.5700958001834807e-05
Time to get top bid order for order book of size  1000 : 1.6346124997653533e-05
Match time for orderbook of size 1000 : 2.763362500263611e-05


(Array([[2204242,     494,    1001,    1001,   34201,     663],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [2204104,     465,    1005,    1005,   34203,   22984],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1]],      dtype=int32),
 Array(544, dtype=int32),
 Array([[2209558,     456,    1003,    9999,   36000,       1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,     

Match against an entire side, continuing until either the full incoming order is matched, the book is empty, or the best price is no longer acceptable to the lim-order.

In [22]:
matchout=[]

for i,s in enumerate([10,100,1000]):
    for j,q in enumerate([0,10,500,1000,10000]):
        _,_,trades=create_init_book(booksize=s,tradessize=s)

        bids=cancelout[i]

        matchtuple=(bids,q,0,trades,9999,36000,1)
        bids,qtm,price,trades=job._match_against_bid_orders(*matchtuple)
        
        matchout.append((bids,qtm,trades))
        print("Match time for orderbook of size",s," \n with an incoming order of size",q,":",timeit.timeit('val=job._match_against_bid_orders(*matchtuple); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)


#Now do it when vmapped (i.e. this skips the cond)
nvmap=1000
for i,s in enumerate([10,100,1000]):
    _,_,trades=create_init_book(booksize=s,tradessize=s)
    bids=cancelout[i]

    vbids=jnp.array([bids]*nvmap)
    vtrades=jnp.array([trades]*nvmap)
    vq=jnp.array([100,100,100,100,100]*(nvmap//5))

    matchtuple=(vbids,vq,0,vtrades,9999,36000,1)
    jax.vmap(job._match_against_bid_orders,((0,0,None,0,None,None,None)))(*matchtuple)
    
    matchout.append((bids,qtm,trades))
    print("VMAP Match time for orderbook of size",s," \n various incoming order sizes:",timeit.timeit('val=jax.vmap(job._match_against_bid_orders,((0,0,None,0,None,None,None)))(*matchtuple); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)




Match time for orderbook of size 10  
 with an incoming order of size 0 : 2.029716700053541e-05
Match time for orderbook of size 10  
 with an incoming order of size 10 : 2.1648417001415508e-05
Match time for orderbook of size 10  
 with an incoming order of size 500 : 2.284325000073295e-05
Match time for orderbook of size 10  
 with an incoming order of size 1000 : 2.1224832998996135e-05
Match time for orderbook of size 10  
 with an incoming order of size 10000 : 2.172916700510541e-05
Match time for orderbook of size 100  
 with an incoming order of size 0 : 2.286583300156053e-05
Match time for orderbook of size 100  
 with an incoming order of size 10 : 2.9515375004848464e-05
Match time for orderbook of size 100  
 with an incoming order of size 500 : 3.4804542003257665e-05
Match time for orderbook of size 100  
 with an incoming order of size 1000 : 2.7688583999406546e-05
Match time for orderbook of size 100  
 with an incoming order of size 10000 : 2.8163791001134087e-05
Match tim

Matching is what takes the longest, and increases when the while loop needs to turn for longer. But even for a single iteration, it takes roughly 1.5 times the time than a simple add order.
Next we consider the higher-level message types and include the branching logic required to direct orders across different types and sides of orders.

In [23]:
random.seed(0)
nvmap=1000000
outs=[]
for i in [10,100]:
    asks,bids,trades=create_init_book(booksize=i,tradessize=i)
    _,limitmsg=create_message(type='limit',side='bid',price=2191200,quant=77)
    _,cancelmsg=create_message(type='cancel',side='bid',price=2191200,quant=77)
    _,matchmsg=create_message(type='limit',side='ask',price=2191200,quant=100)

    out,_=job.cond_type_side((asks,bids,trades),limitmsg)
    print("Limit order for book of size ",i,":",timeit.timeit('val=job.cond_type_side((asks,bids,trades),limitmsg); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)
    out,_=job.cond_type_side(out,cancelmsg)
    print("Cancel order for book of size ",i,":",timeit.timeit('val=job.cond_type_side(out,cancelmsg); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)
    out,_=job.cond_type_side(out,matchmsg)
    print("Matched limit order for book of size ",i,":",timeit.timeit('val=job.cond_type_side(out,matchmsg); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)
    outs.append(out)

    vasks=jnp.array([asks]*nvmap)
    vbids=jnp.array([bids]*nvmap)
    vtrades=jnp.array([trades]*nvmap)
    vlimitms=jnp.array([limitmsg]*nvmap)
    vcancelms=jnp.array([cancelmsg]*nvmap)
    vmatchms=jnp.array([matchmsg]*nvmap)

    out,_=jax.vmap(job.cond_type_side,((0,0,0),0))((vasks,vbids,vtrades),vlimitms)
    print("VMAP limit order for book of size ",i,":",timeit.timeit('val=jax.vmap(job.cond_type_side,((0,0,0),0))((vasks,vbids,vtrades),vlimitms); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)
    out,_=jax.vmap(job.cond_type_side,((0,0,0),0))(out,vcancelms)
    print("VMAP cancel order for book of size ",i,":",timeit.timeit('val=jax.vmap(job.cond_type_side,((0,0,0),0))(out,vcancelms); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)
    out,_=jax.vmap(job.cond_type_side,((0,0,0),0))(out,vmatchms)
    print("VMAP matched limit order for book of size ",i,":",timeit.timeit('val=jax.vmap(job.cond_type_side,((0,0,0),0))(out,vmatchms); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)


    


Limit order for book of size  10 : 2.709624999988591e-05
Cancel order for book of size  10 : 2.660962500522146e-05
Matched limit order for book of size  10 : 2.5626583999837748e-05
VMAP limit order for book of size  10 : 1.6759693045000021


KeyboardInterrupt: 

Limit order for book of size  10 : 0.00010684770345687866
Cancel order for book of size  10 : 7.434402499347926e-05
Matched limit order for book of size  10 : 0.00016043131798505782
VMAP limit order for book of size  10 : 0.006443732594139874
VMAP cancel order for book of size  10 : 0.0064374489830806856
VMAP matched limit order for book of size  10 : 0.006422483234666288
Limit order for book of size  100 : 0.00014240943174809216
Cancel order for book of size  100 : 8.986285887658596e-05
Matched limit order for book of size  100 : 0.0002076397556811571
VMAP limit order for book of size  100 : 0.05448427036125213
VMAP cancel order for book of size  100 : 0.05465182608179748
VMAP matched limit order for book of size  100 : 0.05493127669394016