# JAXLOB Introduction Notebook

#### This file will demonstrate, and visualise the key functionalities of the JAXLOB order book implementation. Further, it will measure the walltime speeds for these basic operations.

In [27]:
%load_ext autoreload
%autoreload 2
from functools import partial, partialmethod
from typing import OrderedDict
from jax import numpy as jnp
import jax
import numpy as np
import random
import time
import timeit

import sys
# ******** INSERT PATH HERE ********
sys.path.append('/data1/sascha/AlphaTrade/')
import gymnax_exchange
import gymnax_exchange.jaxob.JaxOrderBookArrays as job





The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Helper functions to create (fictional) initial orderbook states and to easily create messages in the LOBSTER format for the book to process. 

In [28]:
def create_init_book(cfg:job.Configuration,order_capacity=10,trade_capacity=10,pricerange=[2190000,2200000,2210000],quantrange=[0,500],timeinit=[34200,0]):
    """
    Generates a random orderbook state for a given maximum capacity for orders and trades. 
    Random prices/quantities generated by uniform sampling in pricerange/quantrange. 
    """
    qtofill=order_capacity//3 #fill one third of the available space
    asks=[]
    bids=[]
    orderid=cfg.init_id
    traderid=cfg.init_id
    times=timeinit[0]
    timens=timeinit[1]
    for i in range(qtofill):
        asks.append([random.randint(pricerange[1],pricerange[2]),random.randint(quantrange[0],quantrange[1]),orderid,traderid,times,timens])
        times+=random.randint(0,1)
        timens+=random.randint(0,10000)
        bids.append([random.randint(pricerange[1],pricerange[2]),random.randint(quantrange[0],quantrange[1]),orderid-1,traderid-1,times,timens])
        times+=random.randint(0,1)
        timens+=random.randint(0,10000)
        orderid-=2
        traderid-=2
    bids=jnp.concatenate((jnp.array(bids),jnp.ones((order_capacity-qtofill,6),dtype=jnp.int32)*-1),axis=0)
    asks=jnp.concatenate((jnp.array(asks),jnp.ones((order_capacity-qtofill,6),dtype=jnp.int32)*-1),axis=0)
    trades=jnp.ones((trade_capacity,6),dtype=jnp.int32)*-1
    return asks,bids,trades

def create_message(type='limit',side='bid',price=2200000,quant=10,times=36000,timens=0,id=8888):
    """
    Generates a specific message (based on human-editable inputs)
    Outputs both the 'dictionary' format 
    """
    if type=='limit':
        type_num=1
    elif type =='cancel' or type == 'delete':
        type_num=2
    elif type =='market':
        type_num=4
    else:
        raise ValueError('Type is none of: limit, cancel, delete or market')

    if side=='bid':
        side_num=1
    elif side =='ask':
        side_num=-1
    else:
        raise ValueError('Side is none of: bid or ask')
    
    dict_msg={
        'side':side_num,
        'type':type_num,
        'price':price,
        'quantity':quant,
        'orderid':id,
        'traderid':id,
        'time':times,
        'time_ns':timens}
    array_msg=jnp.array([type_num,side_num,quant,price,8888,8888,times,timens])
    return dict_msg,array_msg

def create_message_forvmap(type='limit',side='bid',price=2200000,quant=10,times=36000,timens=0,nvmap=10):
    if type=='limit':
        type_num=1
    elif type =='cancel' or type == 'delete':
        type_num=2
    elif type =='market':
        type_num=4
    else:
        raise ValueError('Type is none of: limit, cancel, delete or market')

    if side=='bid':
        side_num=1
    elif side =='ask':
        side_num=-1
    else:
        raise ValueError('Side is none of: bid or ask')
    
    dict_msg={
    'side':jnp.array([side_num]*nvmap),
    'type':jnp.array([type_num]*nvmap),
    'price':jnp.array([price]*nvmap),
    'quantity':jnp.array([quant]*nvmap),
    'orderid':jnp.array([8888]*nvmap),
    'traderid':jnp.array([8888]*nvmap),
    'time':jnp.array([times]*nvmap),
    'time_ns':jnp.array([timens]*nvmap)}
    array_msg=jnp.array([type_num,side_num,quant,price,8888,8888,times,timens]*nvmap)
    return dict_msg,array_msg

Global variables used for timing operations. Under 'timeit' function, the target code is run and timed for n_runs, then this operation is repeated for n_repeats. 

The vector of capacities is used to illustrate the effect of the maximum number of orders and trades which can be stored in their respective arrays.

nvmap illustrates the number of book that are processed in parallel when using the vmap operator. 

In [29]:
import gymnax_exchange.jaxob
import gymnax_exchange.jaxob.jaxob_config


n_runs=1000
n_repeats=10
order_and_trade_capacities=[10]
vmap_tests=False
nvmap=1000

job_config=gymnax_exchange.jaxob.jaxob_config.Configuration()
job_config.cancel_mode.value

2

## ADD operation

The add operation (add_order) is called when a limit order is submitted and has some remaining unmatched quantitiy after matching. 

In [30]:
## Add an order

random.seed(0)
addout=[]
bid_inits=[]
for s in order_and_trade_capacities:
    asks,bids,trades=create_init_book(job_config,order_capacity=s)
    bid_inits.append(bids)
    mdict,marray=create_message(type='limit',side='bid',price=2191200,quant=77)
    out=job.add_order(bids,mdict)
    addout.append(out)
    #res=np.array(timeit.repeat('val=job.add_order(bids,mdict); jax.block_until_ready(val)',repeat=n_repeats,number=n_runs,globals=globals()))
    #mu=np.mean(res/n_runs)
    #sigma=np.std(res/n_runs)
    #mi=np.min(res/n_runs)
    #print("Add time for orderbook of size",s,":",mu,"Stdev: ", sigma," Min: ",mi)

print("\n Bid side after initialisation")
print(bid_inits[0])
print("\n Bid side state after adding an order.")
print(addout[0])





if vmap_tests:
    random.seed(0)
    for i,s in enumerate(order_and_trade_capacities):
        asks,bids,trades=create_init_book(order_capacity=s)
        vmdict,marray=create_message_forvmap(type='limit',side='bid',price=2191200,quant=77,nvmap=nvmap)
        vbids=jnp.array([bids]*nvmap)
        out=job.add_order(bids,mdict)
        outv=jax.vmap(job.add_order,(0,{'orderid': 0, 'price': 0, 'quantity': 0, 'side': 0, 'time': 0, 'time_ns': 0, 'traderid': 0, 'type': 0}))(vbids,vmdict)
        res=np.array(timeit.repeat("val=jax.vmap(job.add_order,(0,{'orderid': 0, 'price': 0, 'quantity': 0, 'side': 0, 'time': 0, 'time_ns': 0, 'traderid': 0, 'type': 0}))(vbids,vmdict); jax.block_until_ready(val)",number=n_runs,repeat=n_repeats,globals=globals()))
        mu=np.mean(res/n_runs)
        sigma=np.std(res/n_runs)
        mi=np.min(res/n_runs)
        print("VMAP add time for orderbook of size",s," \n various incoming order sizes:",mu,"Stdev: ", sigma," Min: ",mi)







 Bid side after initialisation
[[2204242     494 -900001 -900001   34201     663]
 [2209558     456 -900003 -900003   34203   13163]
 [2204104     465 -900005 -900005   34203   22984]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]]

 Bid side state after adding an order.
[[2204242     494 -900001 -900001   34201     663]
 [2209558     456 -900003 -900003   34203   13163]
 [2204104     465 -900005 -900005   34203   22984]
 [2191200      77    8888    8888   36000       0]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1

# CANCEL operation

Removes the desired quantity at the given price level. If the quantity is totally removed, removes the order and frees up the space. 

In [31]:
random.seed(0)
cancelout=[]
for i,s in enumerate(order_and_trade_capacities):
    bids=addout[i]
    mdict,marray=create_message(type='cancel',side='bid',price=2191200,quant=66,id=8688)
    out=job.cancel_order(job_config,bids,mdict)
    cancelout.append(out)
    #res=np.array(timeit.repeat('val=job.cancel_order(job_config,bids,mdict); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    #mu=np.mean(res/n_runs)
    #sigma=np.std(res/n_runs)
    #print("Cancel time for orderbook of size",s,":",mu,"Stdev: ", sigma)

print("\n Bid side state after cancellation")
print(cancelout[0])


 Bid side state after cancellation
[[2204242     494 -900001 -900001   34201     663]
 [2209558     456 -900003 -900003   34203   13163]
 [2204104     465 -900005 -900005   34203   22984]
 [2191200      11    8888    8888   36000       0]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]
 [     -1      -1      -1      -1      -1      -1]]


In [None]:
import gymnax_exchange.jaxob.jaxob_constants


job_config.cancel_mode.value==gymnax_exchange.jaxob.jaxob_constants.CancelMode.CANCEL_UNIFORM.value

# MATCH operation
Matching a single very large order against an identified order (and price) from the other side of the book. There will still be a remaining quantity to match after this matching operation.

In [None]:
matchout=[]
for i,s in enumerate(order_and_trade_capacities):
    #Doesn't actually get used: just initialisation of empty list of trades. 
    _,_,trades=create_init_book(order_capacity=s)

    bids=cancelout[i]
    idx=job._get_top_bid_order_idx(bids)
    print("Time to get top bid order for order book of size ",s,":",timeit.timeit('val=job._get_top_bid_order_idx(bids); jax.block_until_ready(val)',number=n_runs,globals=globals())/n_runs)

    matchtuple=(idx,bids,1000,0,trades,9999,36000,1)
    bids,qtm,price,trades,agrid,times,timens=job.match_order(matchtuple)
    
    matchout.append((bids,qtm,trades))
    res=np.array(timeit.repeat('val=job.match_order(matchtuple); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    print("Match time for orderbook of size",s,":",mu,"Stdev: ", sigma)


print("\n Bid side state after matching the order.")
print(matchout[0][0])
print("\n Trades logged after the matching operation")
print(matchout[0][2])
print("\n Quantity still to match (or add) from the incoming order after matching this order")
print(matchout[0][1])

Match against an entire side, by using a price of 0 (inf if matching against ask orders). 

In this case the incoming quantity will have an effect on the time to match. 

In [None]:
matchout=[]

for i,s in enumerate(order_and_trade_capacities):
    for j,q in enumerate([0,10,500,1000,10000]):
        _,_,trades=create_init_book(order_capacity=s,trade_capacity=s)

        bids=cancelout[i]

        matchtuple=(bids,q,0,trades,9999,36000,1)
        bids,qtm,price,trades=job._match_against_bid_orders(*matchtuple)
        
        matchout.append((bids,qtm,trades))
        res=np.array(timeit.repeat('val=job._match_against_bid_orders(*matchtuple); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
        mu=np.mean(res/n_runs)
        sigma=np.std(res/n_runs)
        print("Match time for orderbook of size",s," \n with an incoming order of size",q,":",mu,"Stdev: ", sigma)

if vmap_tests:
    for i,s in enumerate(order_and_trade_capacities):
        _,_,trades=create_init_book(order_capacity=s,trade_capacity=s)
        bids=cancelout[i]

        vbids=jnp.array([bids]*nvmap)
        vtrades=jnp.array([trades]*nvmap)
        vq=jnp.array([100,100,100,100,100]*(nvmap//5))

        matchtuple=(vbids,vq,0,vtrades,9999,36000,1)
        jax.vmap(job._match_against_bid_orders,((0,0,None,0,None,None,None)))(*matchtuple)
        
        matchout.append((bids,qtm,trades))
        res=np.array(timeit.repeat('val=jax.vmap(job._match_against_bid_orders,((0,0,None,0,None,None,None)))(*matchtuple); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
        mu=np.mean(res/n_runs)
        sigma=np.std(res/n_runs)
        
        print("VMAP Match time for orderbook of size",s," \n various incoming order sizes:",mu,"Stdev: ", sigma)




# High - Level Message Types

## Passive Limit Orders
## Cancel Orders
## Aggressive Limit Orders



In [None]:
random.seed(0)
outs=[]
for i in order_and_trade_capacities:
    asks,bids,trades=create_init_book(order_capacity=i,trade_capacity=i)
    _,limitmsg=create_message(type='limit',side='bid',price=2191200,quant=77)
    _,cancelmsg=create_message(type='cancel',side='bid',price=2191200,quant=77)
    _,matchmsg=create_message(type='limit',side='ask',price=2191200,quant=100)

    bid_limit,_=job.cond_type_side((asks,bids,trades),limitmsg)
    res=np.array(timeit.repeat('val=job.cond_type_side((asks,bids,trades),limitmsg); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    print("Limit order for book of size ",i,":",mu,"Stdev: ", sigma)
    bid_cancel,_=job.cond_type_side(bid_limit,cancelmsg)
    res=np.array(timeit.repeat('val=job.cond_type_side(out,cancelmsg); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    print("Cancel order for book of size ",i,":",mu,"Stdev: ", sigma)
    bid_cross_limit,_=job.cond_type_side(bid_cancel,matchmsg)
    res=np.array(timeit.repeat('val=job.cond_type_side(out,matchmsg); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
    mu=np.mean(res/n_runs)
    sigma=np.std(res/n_runs)
    print("Matched limit order for book of size ",i,":",mu,"Stdev: ", sigma)
    outs.append((bids,bid_limit[1],bid_cancel[1],bid_cross_limit[1]))

    if vmap_tests:
        vasks=jnp.array([asks]*nvmap)
        vbids=jnp.array([bids]*nvmap)
        vtrades=jnp.array([trades]*nvmap)
        vlimitms=jnp.array([limitmsg]*nvmap)
        vcancelms=jnp.array([cancelmsg]*nvmap)
        vmatchms=jnp.array([matchmsg]*nvmap)

        out,_=jax.vmap(job.cond_type_side,((0,0,0),0))((vasks,vbids,vtrades),vlimitms)
        res=np.array(timeit.repeat('val=jax.vmap(job.cond_type_side,((0,0,0),0))((vasks,vbids,vtrades),vlimitms); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
        mu=np.mean(res/n_runs)
        sigma=np.std(res/n_runs)
        print("VMAP limit order for book of size ",i,":",mu,"Stdev: ", sigma)
        out,_=jax.vmap(job.cond_type_side,((0,0,0),0))(out,vcancelms)
        res=np.array(timeit.repeat('val=jax.vmap(job.cond_type_side,((0,0,0),0))(out,vcancelms); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
        mu=np.mean(res/n_runs)
        sigma=np.std(res/n_runs)
        print("VMAP cancel order for book of size ",i,":",mu,"Stdev: ", sigma)
        out,_=jax.vmap(job.cond_type_side,((0,0,0),0))(out,vmatchms)
        res=np.array(timeit.repeat('val=jax.vmap(job.cond_type_side,((0,0,0),0))(out,vmatchms); jax.block_until_ready(val)',number=n_runs,repeat=n_repeats,globals=globals()))
        mu=np.mean(res/n_runs)
        sigma=np.std(res/n_runs)
        print("VMAP matched limit order for book of size ",i,":",mu,"Stdev: ", sigma)

print("\n Bid side initial state.")
print(outs[0][0])
print("\n Bid side after adding an order.")
print(outs[0][1])
print("\n Bid side after cancellation")
print(outs[0][2])
print("\n Bid side after incoming sell limit order ")
print(outs[0][3])
    
