#### This file will demonstrate, and visualise the key functionalities of the Jax order book implementation. Further, it will measure the walltime speeds for these basic operations.

In [1]:
%load_ext autoreload
%autoreload 2
from functools import partial, partialmethod
from typing import OrderedDict
from jax import numpy as jnp
import jax
import numpy as np

#jax.config.update('jax_platform_name', 'cpu')

import gymnax_exchange.jaxob.JaxOrderBookArrays as job



import random
import time
import timeit

import sys
sys.path.append('/Users/sasrey/AlphaTrade')
import gymnax_exchange




In [2]:
def create_init_book(booksize=10,tradessize=10,pricerange=[2190000,2200000,2210000],quantrange=[0,500],timeinit=[34200,0]):
    qtofill=booksize//3 #fill one third of the available space
    asks=[]
    bids=[]
    orderid=1000
    traderid=1000
    times=timeinit[0]
    timens=timeinit[1]
    for i in range(qtofill):
        asks.append([random.randint(pricerange[1],pricerange[2]),random.randint(quantrange[0],quantrange[1]),orderid,traderid,times,timens])
        times+=random.randint(0,1)
        timens+=random.randint(0,10000)
        bids.append([random.randint(pricerange[1],pricerange[2]),random.randint(quantrange[0],quantrange[1]),orderid+1,traderid+1,times,timens])
        times+=random.randint(0,1)
        timens+=random.randint(0,10000)
        orderid+=2
        traderid+=2
    bids=jnp.concatenate((jnp.array(bids),jnp.ones((booksize-qtofill,6),dtype=jnp.int32)*-1),axis=0)
    asks=jnp.concatenate((jnp.array(asks),jnp.ones((booksize-qtofill,6),dtype=jnp.int32)*-1),axis=0)
    trades=jnp.ones((tradessize,6),dtype=jnp.int32)*-1
    return asks,bids,trades

def create_message(type='limit',side='bid',price=2200000,quant=10,times=36000,timens=0):
    if type=='limit':
        type_num=1
    elif type =='cancel' or type == 'delete':
        type_num=2
    elif type =='market':
        type_num=4
    else:
        raise ValueError('Type is none of: limit, cancel, delete or market')

    if side=='bid':
        side_num=1
    elif side =='ask':
        side_num=-1
    else:
        raise ValueError('Side is none of: bid or ask')
    
    dict_msg={
    'side':side_num,
    'type':type_num,
    'price':price,
    'quantity':quant,
    'orderid':8888,
    'traderid':8888,
    'time':times,
    'time_ns':timens}
    array_msg=jnp.array([type_num,side_num,quant,price,8888,8888,times,timens])
    return dict_msg,array_msg

def create_message_forvmap(type='limit',side='bid',price=2200000,quant=10,times=36000,timens=0,nvmap=10):
    if type=='limit':
        type_num=1
    elif type =='cancel' or type == 'delete':
        type_num=2
    elif type =='market':
        type_num=4
    else:
        raise ValueError('Type is none of: limit, cancel, delete or market')

    if side=='bid':
        side_num=1
    elif side =='ask':
        side_num=-1
    else:
        raise ValueError('Side is none of: bid or ask')
    
    dict_msg={
    'side':jnp.array([side_num]*nvmap),
    'type':jnp.array([type_num]*nvmap),
    'price':jnp.array([price]*nvmap),
    'quantity':jnp.array([quant]*nvmap),
    'orderid':jnp.array([8888]*nvmap),
    'traderid':jnp.array([8888]*nvmap),
    'time':jnp.array([times]*nvmap),
    'time_ns':jnp.array([timens]*nvmap)}
    array_msg=jnp.array([type_num,side_num,quant,price,8888,8888,times,timens]*nvmap)
    return dict_msg,array_msg

In [3]:
n_runs=1000
n_repeats=10

Measuring the time for the most basic operations: Adding and order and removing an order from a given side of the book.

In [4]:
## Add an order

random.seed(0)
addout=[]
for i in [10,100,1000]:
    asks,bids,trades=create_init_book(booksize=i)
    mdict,marray=create_message(type='limit',side='bid',price=2191200,quant=77)
    out=job.add_order(bids,mdict)
    addout.append(out)
    res=np.array(timeit.repeat('val=job.add_order(bids,mdict); jax.block_until_ready(val)',repeat=n_repeats,number=n_runs,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    mi=np.min(res/n_runs)
    print("Add time for orderbook of size",i,":",mu,"Stdev: ", sigma," Min: ",mi)
    
print(addout[0])

random.seed(0)
#Now do it when vmapped
nvmap=1000
for i,s in enumerate([10,100,1000]):
    asks,bids,trades=create_init_book(booksize=s)
    vmdict,marray=create_message_forvmap(type='limit',side='bid',price=2191200,quant=77,nvmap=nvmap)

    vbids=jnp.array([bids]*nvmap)

    out=job.add_order(bids,mdict)
    outv=jax.vmap(job.add_order,(0,{'orderid': 0, 'price': 0, 'quantity': 0, 'side': 0, 'time': 0, 'time_ns': 0, 'traderid': 0, 'type': 0}))(vbids,vmdict)
    res=np.array(timeit.repeat("val=jax.vmap(job.add_order,(0,{'orderid': 0, 'price': 0, 'quantity': 0, 'side': 0, 'time': 0, 'time_ns': 0, 'traderid': 0, 'type': 0}))(vbids,vmdict); jax.block_until_ready(val)",number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    mi=np.min(res/n_runs)
    print("VMAP add time for orderbook of size",s," \n various incoming order sizes:",mu,"Stdev: ", sigma," Min: ",mi)






I0000 00:00:1697729713.564935 2255205 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


Add time for orderbook of size 10 : 0.00011665769610553982 Stdev:  2.82336258091503e-06  Min:  0.00011339719220995903
Add time for orderbook of size 100 : 0.00015359500423073766 Stdev:  2.388907738788754e-05  Min:  0.00013932540826499463
Add time for orderbook of size 1000 : 0.00017484180349856615 Stdev:  3.544141469062844e-06  Min:  0.00017145046405494214
[[2204242     494    1001    1001   34201     663]
 [2209558     456    1003    1003   34203   13163]
 [2204104     465    1005    1005   34203   22984]
 [2191200      77    8888    8888   36000       0]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]]
VMAP add time for orderbook of size 10  
 various incoming order sizes: 0.0008546348931267858 Stdev:  1.6007305871247412e-05  Min:

In [5]:
## Cancel an order
n_runs=1000
random.seed(0)
cancelout=[]
for i,s in enumerate([10,100,1000]):
    bids=addout[i]
    mdict,marray=create_message(type='cancel',side='bid',price=2191200,quant=77)
    out=job.cancel_order(bids,mdict)
    cancelout.append(out)
    res=np.array(timeit.repeat('val=job.cancel_order(bids,mdict); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    print("Cancel time for orderbook of size",s,":",mu,"Stdev: ", sigma)
cancelout[0]

Cancel time for orderbook of size 10 : 8.66373972967267e-05 Stdev:  5.483749225949132e-06
Cancel time for orderbook of size 100 : 0.00012360328752547504 Stdev:  5.272897726920175e-06
Cancel time for orderbook of size 1000 : 0.00013759225104004145 Stdev:  5.456113120927571e-07


Array([[2204242,     494,    1001,    1001,   34201,     663],
       [2209558,     456,    1003,    1003,   34203,   13163],
       [2204104,     465,    1005,    1005,   34203,   22984],
       [     -1,      -1,      -1,      -1,      -1,      -1],
       [     -1,      -1,      -1,      -1,      -1,      -1],
       [     -1,      -1,      -1,      -1,      -1,      -1],
       [     -1,      -1,      -1,      -1,      -1,      -1],
       [     -1,      -1,      -1,      -1,      -1,      -1],
       [     -1,      -1,      -1,      -1,      -1,      -1],
       [     -1,      -1,      -1,      -1,      -1,      -1]],      dtype=int32)

Matching a single order against an identified order from the other side of the book:

In [15]:
matchout=[]

for i,s in enumerate([10,100,1000]):
    _,_,trades=create_init_book(booksize=s)

    bids=cancelout[i]
    idx=job.__get_top_bid_order_idx(bids)
    print("Time to get top bid order for order book of size ",s,":",timeit.timeit('val=job.__get_top_bid_order_idx(bids); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)

    matchtuple=(idx,bids,1000,0,trades,9999,36000,1)
    
    bids,qtm,price,trades,agrid,times,timens=job.match_order(matchtuple)
    
    matchout.append((bids,qtm,trades))
    res=np.array(timeit.repeat('val=job.match_order(matchtuple); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    print("Match time for orderbook of size",s,":",mu,"Stdev: ", sigma)
matchout[0]

Time to get top bid order for order book of size  10 : 4.2800363153219224e-05
Match time for orderbook of size 10 : 0.00017796710841357712 Stdev:  3.5799693920809505e-06
Time to get top bid order for order book of size  100 : 8.830972760915756e-05
Match time for orderbook of size 100 : 0.0001710767026990652 Stdev:  3.729168730023052e-06
Time to get top bid order for order book of size  1000 : 9.932580031454563e-05
Match time for orderbook of size 1000 : 0.00016701713688671587 Stdev:  6.257017796367845e-07


(Array([[2204242,     494,    1001,    1001,   34201,     663],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [2204104,     465,    1005,    1005,   34203,   22984],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1]],      dtype=int32),
 Array(544, dtype=int32),
 Array([[2209558,     456,    1003,    9999,   36000,       1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,      -1],
        [     -1,      -1,      -1,      -1,      -1,     

Match against an entire side, continuing until either the full incoming order is matched, the book is empty, or the best price is no longer acceptable to the lim-order.

In [16]:
matchout=[]

for i,s in enumerate([10,100,1000]):
    for j,q in enumerate([0,10,500,1000,10000]):
        _,_,trades=create_init_book(booksize=s,tradessize=s)

        bids=cancelout[i]

        matchtuple=(bids,q,0,trades,9999,36000,1)
        bids,qtm,price,trades=job._match_against_bid_orders(*matchtuple)
        
        matchout.append((bids,qtm,trades))
        res=np.array(timeit.repeat('val=job._match_against_bid_orders(*matchtuple); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
        mu=np.mean(res/n_runs)
        sigma=np.std(res/n_runs)
        print("Match time for orderbook of size",s," \n with an incoming order of size",q,":",mu,"Stdev: ", sigma)


#Now do it when vmapped (i.e. this skips the cond)
nvmap=1000
for i,s in enumerate([10,100,1000]):
    _,_,trades=create_init_book(booksize=s,tradessize=s)
    bids=cancelout[i]

    vbids=jnp.array([bids]*nvmap)
    vtrades=jnp.array([trades]*nvmap)
    vq=jnp.array([100,100,100,100,100]*(nvmap//5))

    matchtuple=(vbids,vq,0,vtrades,9999,36000,1)
    jax.vmap(job._match_against_bid_orders,((0,0,None,0,None,None,None)))(*matchtuple)
    
    matchout.append((bids,qtm,trades))
    res=np.array(timeit.repeat('val=jax.vmap(job._match_against_bid_orders,((0,0,None,0,None,None,None)))(*matchtuple); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    
    print("VMAP Match time for orderbook of size",s," \n various incoming order sizes:",mu,"Stdev: ", sigma)




Match time for orderbook of size 10  
 with an incoming order of size 0 : 0.0001599755199626088 Stdev:  1.4136404594710928e-05
Match time for orderbook of size 10  
 with an incoming order of size 10 : 0.00023149909507483242 Stdev:  8.236890279002566e-07
Match time for orderbook of size 10  
 with an incoming order of size 500 : 0.00023494059592485433 Stdev:  3.7953098015473232e-06
Match time for orderbook of size 10  
 with an incoming order of size 1000 : 0.00024334206711500885 Stdev:  1.356343510487945e-05
Match time for orderbook of size 10  
 with an incoming order of size 10000 : 0.00025096572581678625 Stdev:  1.0152672097264163e-05
Match time for orderbook of size 100  
 with an incoming order of size 0 : 0.0002072071246802807 Stdev:  2.8212552781626246e-06
Match time for orderbook of size 100  
 with an incoming order of size 10 : 0.0003378016976639629 Stdev:  1.546980475443614e-06
Match time for orderbook of size 100  
 with an incoming order of size 500 : 0.000333969629369676

Matching is what takes the longest, and increases when the while loop needs to turn for longer. But even for a single iteration, it takes roughly 1.5 times the time than a simple add order.
Next we consider the higher-level message types and include the branching logic required to direct orders across different types and sides of orders.

In [7]:
random.seed(0)
nvmap=1000
outs=[]
for i in [1000]:
    asks,bids,trades=create_init_book(booksize=i,tradessize=i)
    _,limitmsg=create_message(type='limit',side='bid',price=2191200,quant=77)
    _,cancelmsg=create_message(type='cancel',side='bid',price=2191200,quant=77)
    _,matchmsg=create_message(type='limit',side='ask',price=2191200,quant=100)

    out,_=job.cond_type_side((asks,bids,trades),limitmsg)
    res=np.array(timeit.repeat('val=job.cond_type_side((asks,bids,trades),limitmsg); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    print("Limit order for book of size ",i,":",mu,"Stdev: ", sigma)
    out,_=job.cond_type_side(out,cancelmsg)
    res=np.array(timeit.repeat('val=job.cond_type_side(out,cancelmsg); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    print("Cancel order for book of size ",i,":",mu,"Stdev: ", sigma)
    out,_=job.cond_type_side(out,matchmsg)
    res=np.array(timeit.repeat('val=job.cond_type_side(out,matchmsg); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    print("Matched limit order for book of size ",i,":",mu,"Stdev: ", sigma)
    outs.append(out)

    vasks=jnp.array([asks]*nvmap)
    vbids=jnp.array([bids]*nvmap)
    vtrades=jnp.array([trades]*nvmap)
    vlimitms=jnp.array([limitmsg]*nvmap)
    vcancelms=jnp.array([cancelmsg]*nvmap)
    vmatchms=jnp.array([matchmsg]*nvmap)

    out,_=jax.vmap(job.cond_type_side,((0,0,0),0))((vasks,vbids,vtrades),vlimitms)
    res=np.array(timeit.repeat('val=jax.vmap(job.cond_type_side,((0,0,0),0))((vasks,vbids,vtrades),vlimitms); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    print("VMAP limit order for book of size ",i,":",mu,"Stdev: ", sigma)
    out,_=jax.vmap(job.cond_type_side,((0,0,0),0))(out,vcancelms)
    res=np.array(timeit.repeat('val=jax.vmap(job.cond_type_side,((0,0,0),0))(out,vcancelms); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    print("VMAP cancel order for book of size ",i,":",mu,"Stdev: ", sigma)
    out,_=jax.vmap(job.cond_type_side,((0,0,0),0))(out,vmatchms)
    res=np.array(timeit.repeat('val=jax.vmap(job.cond_type_side,((0,0,0),0))(out,vmatchms); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    print("VMAP matched limit order for book of size ",i,":",mu,"Stdev: ", sigma)


    


Limit order for book of size  1000 : 0.0002546505982056259 Stdev:  1.0699935274656523e-05
Cancel order for book of size  1000 : 0.00017423695195466278 Stdev:  5.416616068377795e-07
Matched limit order for book of size  1000 : 0.00037279701139777895 Stdev:  8.050058718516285e-06
VMAP limit order for book of size  1000 : 0.010609133225679399 Stdev:  5.5580927006953906e-05
VMAP cancel order for book of size  1000 : 0.010422131344489754 Stdev:  7.548130142640579e-05
VMAP matched limit order for book of size  1000 : 0.010352472906559705 Stdev:  5.344761249890838e-05


Limit order for book of size  10 : 0.00010684770345687866
Cancel order for book of size  10 : 7.434402499347926e-05
Matched limit order for book of size  10 : 0.00016043131798505782
VMAP limit order for book of size  10 : 0.006443732594139874
VMAP cancel order for book of size  10 : 0.0064374489830806856
VMAP matched limit order for book of size  10 : 0.006422483234666288
Limit order for book of size  100 : 0.00014240943174809216
Cancel order for book of size  100 : 8.986285887658596e-05
Matched limit order for book of size  100 : 0.0002076397556811571
VMAP limit order for book of size  100 : 0.05448427036125213
VMAP cancel order for book of size  100 : 0.05465182608179748
VMAP matched limit order for book of size  100 : 0.05493127669394016