In [9]:
import tensorflow as tf
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt

from pathlib import Path
from param_search import (generate_configs, 
                          train_model, 
                          export_model_outputs, 
                          plot_all_model_metrics, 
                          compare_models, 
                          plot_raw_lob_snapshot)
from lob_gan import plot_lob_snapshot, Config

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.config.list_physical_devices('GPU')
tf.keras.backend.clear_session()

In [28]:
GRID_TYPE = 'all' #Choose "baseline", "fm_only", "mbd_only", "ls_only", "all"
FILE_PATH = Path('BTCUSDT-lob.parq')
OUTPUT_DIR = Path('output')
OUTPUT_DIR.mkdir(exist_ok=True)

raw_data = pd.read_parquet(FILE_PATH, engine='pyarrow')
configs = generate_configs(GRID_TYPE, size=20, seed=5204)

In [None]:
for config in configs:
    train_model(output_dir=OUTPUT_DIR, raw_data=raw_data, config=config)

  2%|▏         | 1/64 [00:04<04:50,  4.60s/it]


KeyboardInterrupt: 

## Finding the best model

### Baseline:

In [43]:
GRID_TYPE = 'baseline' #Choose "baseline", "fm_only", "mbd_only", "ls_only", "all"
FILE_PATH = Path('BTCUSDT-lob.parq')
OUTPUT_DIR = Path('output/baseline')
OUTPUT_DIR.mkdir(exist_ok=True)

raw_data = pd.read_parquet(FILE_PATH, engine='pyarrow')
configs = generate_configs(GRID_TYPE, size=20, seed=1324)

for config in configs:
    train_model(output_dir=OUTPUT_DIR, raw_data=raw_data, config=config)

results cached, skipping 782266732070bf9b38ae6b61acb445ab90bf37e1bfb008c05142655118d210e8
results cached, skipping 23642808fe1e973da7d4debc13cd4c6db059b0c34addbd81bab4a512432a75f9
results cached, skipping 8d191b40138c8af9c56e2c9cf23cccc885c2b38cae1fa10e3bbf25ba725747fd
results cached, skipping b0acb1ad6379ac44dc48e7b71257b7f3c8a8d1df254a2a0668f9c6e146d346dc
results cached, skipping 674e28daf3eb5b10aa917aa4db20b0ba9ff97bc53630f4e8ecbf3a61a70378d2
results cached, skipping 873bfc820e3c8d25bdee31d6db0fc5599c58be05476b500d0b29551fa36e8f4c
results cached, skipping 97f974b70002f3d907d36edd97a6428ac2051713db4e572674696cbb020ea340
results cached, skipping a1d37f4983ff7ec89aa5d36e8b89d0871c523ea7379a4f3c1408207c1c9895c8
results cached, skipping 04d2629b0b6a9d2345621b91b5ec509d1b2734033ceefa540479e88340e0cdd7
results cached, skipping 522afcf3db34c52d20272e05157cd2b882ec9c1f5eb30e2c8f0ef59e51c3bcfc
results cached, skipping eab38fa4e20840ad2707664fa8f7c690b30469bb526045eb8a309133341f527f
results ca

In [None]:
export_model_outputs(output_root=OUTPUT_DIR, plot_lob_snapshot=plot_lob_snapshot)
plot_all_model_metrics(output_root=OUTPUT_DIR)

### fm_only

In [44]:
GRID_TYPE = 'fm_only' #Choose "baseline", "fm_only", "mbd_only", "ls_only", "all"
FILE_PATH = Path('BTCUSDT-lob.parq')
OUTPUT_DIR = Path('output/fm_only')
OUTPUT_DIR.mkdir(exist_ok=True)

raw_data = pd.read_parquet(FILE_PATH, engine='pyarrow')
configs = generate_configs(GRID_TYPE, size=20, seed=1324)

for config in configs:
    train_model(output_dir=OUTPUT_DIR, raw_data=raw_data, config=config)

results cached, skipping 9833c1a47b0ccd239c9f26154fe69491523a822c61ffe329f99661807c8382dc
results cached, skipping 17683ba6902ae563d678fc32dbc3430a464842109b4f4b74bab53c74a79b6336
results cached, skipping 42771ae4de8cf5906a7144f83d8801cf1c86ee8e8b84ccf57af7c5b519799905
results cached, skipping 96fb37450ef6b572802c8eceecb9104b9da6447609d6e7ed08c772d188bab442
results cached, skipping 93d3df10599567331d4db94c1640d16c4d8390497d13bc8542cbf626be3635f4
results cached, skipping 0728f000c320fe9df8084d38e56818879753419fb3b7e88eb9398bf3e598f884
results cached, skipping 122a592c27c07f3982b54face802f01e8b5fdb4c02f84ae553a204adc0c5e0e6
results cached, skipping 7dc96d8e907dac1a0ac2b57118690a9313216a04b8c736fc231f1a010df484ce
results cached, skipping d3399ec976044057d512c65c7a688877aa6fa1930d9aaed6fa99271e120f6746
results cached, skipping a7ae6b79d6b6eb57505b3ea8cad8ce30bb08b53bfa26befc980e9a578b8ee0a0
results cached, skipping f4b80fd023e96acc8a066bf8c1b3569fabccdc5802bb973d0c9ad330bab0b993
results ca

In [None]:
export_model_outputs(output_root=OUTPUT_DIR, plot_lob_snapshot=plot_lob_snapshot)
plot_all_model_metrics(output_root=OUTPUT_DIR)

### mbd_only

In [45]:
GRID_TYPE = 'mbd_only' #Choose "baseline", "fm_only", "mbd_only", "ls_only", "all"
FILE_PATH = Path('BTCUSDT-lob.parq')
OUTPUT_DIR = Path('output/mbd_only')
OUTPUT_DIR.mkdir(exist_ok=True)

raw_data = pd.read_parquet(FILE_PATH, engine='pyarrow')
configs = generate_configs(GRID_TYPE, size=20, seed=1324)

for config in configs:
    train_model(output_dir=OUTPUT_DIR, raw_data=raw_data, config=config)

results cached, skipping 53cf051728a7959012f26df249a67fd8867a61eee064d4b9f8196437625057a8
results cached, skipping ba18a2d63c6c28d537e76037ba5247ff88f0213acd0f7b6a79b8749bba2945da
results cached, skipping d2d0ee2df0c9a11b4167c463ab64f2cd204a75c9e08a3049f4ecfb4fd3ab1171
results cached, skipping 84b633bfd3b6db0ddbc290d866ef6184dc0a7796f8b60102df1d3b611b28afa4
results cached, skipping 63b7a8c663d5701a1ceaa095421ee77536dd1869f6f45e29f1ba646e2faa11aa
results cached, skipping 7051d002354243f9ee7392dde8c7bd7587d85d135043cb5b75c550248d8141e8
results cached, skipping 93470e159db8232dc5e09e203355c20aa14116546a3a3247f0ca37911b8fa0d0
results cached, skipping ca5bc7b845bafa968304e5d3a63b68dcfcf1ba56027c3e6e0f85d922aa185974
results cached, skipping a94f89dbfdb60ce3f5f69efd78d78af3a5fb655451aa6200c867793b349efd9d
results cached, skipping dfa741d924b54b11b286f733e23e094b8c807ee416ea698be7d9679e091abdd6
results cached, skipping f2fed4f9d93e0471736f83038820036d86dec77fd814b34bbc8503eb6e66f740
results ca

In [None]:
export_model_outputs(output_root=OUTPUT_DIR, plot_lob_snapshot=plot_lob_snapshot)
plot_all_model_metrics(output_root=OUTPUT_DIR)

### ls_only

In [46]:
GRID_TYPE = 'ls_only' #Choose "baseline", "fm_only", "mbd_only", "ls_only", "all"
FILE_PATH = Path('BTCUSDT-lob.parq')
OUTPUT_DIR = Path('output/ls_only')
OUTPUT_DIR.mkdir(exist_ok=True)

raw_data = pd.read_parquet(FILE_PATH, engine='pyarrow')
configs = generate_configs(GRID_TYPE, size=20, seed=1324)

for config in configs:
    train_model(output_dir=OUTPUT_DIR, raw_data=raw_data, config=config)

results cached, skipping 5df3f79dc559a2341e86b8c1e5d0df8def5835400edccbd6b58db6f8d32d6e0c
results cached, skipping b98b7563d6527434396204470cf01b1df6acb79c7a13d0095ec65cc8c5b758ca
results cached, skipping b7b8af653ba8cfb1bce11dc384732cc93e8b2076caf7681cbbf878d145f32bae
results cached, skipping e7bc4df88db9183aa842e9814cfc4b78a34be97c8d0be74a8439f3926bd4307f
results cached, skipping 85f9ed577044b6c854d6be12bf6338d051a9c0ed764e0a787cb1ac61d49635b5
results cached, skipping 376c5a336cb5a52e2d0d03403add2063db73f183e8ec37e20b7777bd5c61e940
results cached, skipping 135648a51c041e913cd6337b71888b77db60fb5c01e1f518aa305dca14b9a3df
results cached, skipping 5a453140f7914e8e47cbd596236bd54498c24e1742017cb8d413e68a7a17f518
results cached, skipping 6c37abe80194e7d1f175515d44a712965c216ee01138c1f9c6aaf960a04995f0
results cached, skipping effdf509f58b9c2c0d83230f9c6ffea3d8339f0611172a5fa5319acec92aa143
results cached, skipping d045b91d945771d7b96b21c9481c92cfa670fdbadc2625834664ade83c6ba544
results ca

In [None]:
export_model_outputs(output_root=OUTPUT_DIR, plot_lob_snapshot=plot_lob_snapshot)
plot_all_model_metrics(output_root=OUTPUT_DIR)

### All modifications <Optional> 

In [49]:
GRID_TYPE = 'all' #Choose "baseline", "fm_only", "mbd_only", "ls_only", "all"
FILE_PATH = Path('BTCUSDT-lob.parq')
OUTPUT_DIR = Path('output/all')
OUTPUT_DIR.mkdir(exist_ok=True)

raw_data = pd.read_parquet(FILE_PATH, engine='pyarrow')
configs = generate_configs(GRID_TYPE, size=20, seed=1324)

for config in configs:
    train_model(output_dir=OUTPUT_DIR, raw_data=raw_data, config=config)

results cached, skipping a0b6545791e353d0ed3cc1af59cf28930ae1cf7c0ef129517bb2faaa8e18832f
results cached, skipping 8c93e40e1e0ae76d77bc66fa4d5cbbf38845515abd72e50e8c9df0275c4364c1
results cached, skipping 7b77c328115122b476937e54b59d899761cfecb9e786ca7191545989c5994b5f
results cached, skipping 31be9b6cc0383479b8acac5d2327b8a34dfbcda4a2230d6ff77c5a48905012f8
results cached, skipping a8aa596e37e2f737c5d11c83ed330a472dab336ad730e3ec1b97ec85ddf26538
results cached, skipping 843a6c06313abc3f19b1f2ddf66638ee2071ab2a6c66b9ca5d9179a703628f44
results cached, skipping d1226fbed854405821825de848358a7374ec1de93a02bc3fa00a7e7c78e5fab5


100%|██████████| 64/64 [00:11<00:00,  5.49it/s]


Training completed in 0.20 minutes
finished trainining c93f3050319e7259a7e4a6d045c5d667077f9b2392f4a050429915aaa1c82c09
results cached, skipping fbaeced835153c98577173762807d2ad7b3999a5ed5f0e80237c59b1b5734e16
results cached, skipping 1716510d4ff36a702349678198b77b3ca911138bbaa8f697d946b2c4241b7232
results cached, skipping d4d8501ea76aed63962c2e41153aed45c8d6201223078fab167fa952aaf5f559
results cached, skipping a5ad7a0eb24e878a15d505105d1460da7326b5e76178024356b17595f6122b0b
results cached, skipping 9d5357ff8f13992327ab1b03ac9f1781c0953d035ee87a68387a3829c4b156a9
results cached, skipping 052608318b21fd4c44f649431f558c5aff17d4e5475ad2d523a7894b5255c379
results cached, skipping c1d0bc955c7fcd4c7d933ce33c29ce104e0a51ed1e5b3c2b82255ad1480ab97b
results cached, skipping 0a12bfd0bac5261953048c7c4ed08d93856f1f3861ca01b606f06514ea9911cf
results cached, skipping 428d7bd109dd83e29f999002738dd53fb13c38ff8d7947c07ebe0e111db91daf
results cached, skipping bfc7fa621ae7f122a66fd4f7086820a7afcbd4e4b5a86

In [50]:
export_model_outputs(output_root=OUTPUT_DIR, plot_lob_snapshot=plot_lob_snapshot)
plot_all_model_metrics(output_root=OUTPUT_DIR)

  ax.axhline(0, color='k', alpha=0.2)



Processing 052608318b21fd4c44f649431f558c5aff17d4e5475ad2d523a7894b5255c379...
Exported config
Exported metrics
Exported 65 LOB snapshots to images

Processing 0a12bfd0bac5261953048c7c4ed08d93856f1f3861ca01b606f06514ea9911cf...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing 1716510d4ff36a702349678198b77b3ca911138bbaa8f697d946b2c4241b7232...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing 31be9b6cc0383479b8acac5d2327b8a34dfbcda4a2230d6ff77c5a48905012f8...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing 428d7bd109dd83e29f999002738dd53fb13c38ff8d7947c07ebe0e111db91daf...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing 7b77c328115122b476937e54b59d899761cfecb9e786ca7191545989c5994b5f...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing 843a6c06313abc3f19b1f2ddf66638ee2071ab2a6c66b9ca5d9179a703628f44...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing 8c93e40e1e0ae76d77bc66fa4d5cbbf38845515abd72e50e8c9df0275c4364c1...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing 93f3050319e7259a7e4a6d045c5d667077f9b2392f4a050429915aaa1c82c09...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing 9d5357ff8f13992327ab1b03ac9f1781c0953d035ee87a68387a3829c4b156a9...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing a0b6545791e353d0ed3cc1af59cf28930ae1cf7c0ef129517bb2faaa8e18832f...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing a5ad7a0eb24e878a15d505105d1460da7326b5e76178024356b17595f6122b0b...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing a8aa596e37e2f737c5d11c83ed330a472dab336ad730e3ec1b97ec85ddf26538...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing bfc7fa621ae7f122a66fd4f7086820a7afcbd4e4b5a8665b493a6ac9c0c2674b...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing c1d0bc955c7fcd4c7d933ce33c29ce104e0a51ed1e5b3c2b82255ad1480ab97b...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing c93f3050319e7259a7e4a6d045c5d667077f9b2392f4a050429915aaa1c82c09...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing d1226fbed854405821825de848358a7374ec1de93a02bc3fa00a7e7c78e5fab5...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing d4d8501ea76aed63962c2e41153aed45c8d6201223078fab167fa952aaf5f559...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing dcb97459c40c7f55c09d82d429a3842985852feffb0f4f5a5f8be046f934c9a7...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing e32c48909bff9571d2a17a9e379136386f25a825ba6e6a39b106b7e245a7f6a8...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing fbaeced835153c98577173762807d2ad7b3999a5ed5f0e80237c59b1b5734e16...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing image_Old...
Skipping: Missing config.json or metrics.parq

All model outputs processed.


### Best FM vs All turn on

In [59]:
d_fm = {
  "disc_lr": 0.008164257207259593,
  "epochs": 64,
  "fm_weight_e": 0.0010744630994051387,
  "fm_weight_h": 0.6169157082064088,
  "gen_lr": 0.006280181923387721,
  "label_smoothing": 0.0,
  "n_batches": 8,
  "sample_size": 20000,
  "seed": 1324,
  "use_minibatch_discrimination": False,
  "z_dim": 256
}

d_all = {
  "disc_lr": 0.008164257207259593,
  "epochs": 64,
  "fm_weight_e": 0.0010744630994051387,
  "fm_weight_h": 0.6169157082064088,
  "gen_lr": 0.006280181923387721,
  "label_smoothing": 0.0015746702525331782,
  "n_batches": 8,
  "sample_size": 20000,
  "seed": 1324,
  "use_minibatch_discrimination": True,
  "z_dim": 256
}

configs = [d_fm, d_all]
OUTPUT_DIR = Path('output/Best FM model vs All modifications')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

for config_dict in configs:
    config = Config(**config_dict)
    key_output_dir = OUTPUT_DIR / config.get_key()

    train_model(output_dir=OUTPUT_DIR, raw_data=raw_data, config=config)

 70%|███████   | 45/64 [00:07<00:03,  6.22it/s]


KeyboardInterrupt: 

In [None]:
export_model_outputs(output_root=OUTPUT_DIR, plot_lob_snapshot=plot_lob_snapshot)
plot_all_model_metrics(output_root=OUTPUT_DIR)
compare_models(output_root=OUTPUT_DIR)


Processing All Modifications...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

Processing Best FM...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

All model outputs processed.


### Best fm_only vs baseline. 

In [61]:
d_fm = {
  "disc_lr": 0.008164257207259593,
  "epochs": 64,
  "fm_weight_e": 0.0010744630994051387,
  "fm_weight_h": 0.6169157082064088,
  "gen_lr": 0.006280181923387721,
  "label_smoothing": 0.0,
  "n_batches": 8,
  "sample_size": 20000,
  "seed": 1324,
  "use_minibatch_discrimination": False,
  "z_dim": 256
}

d_baseline = {
  "disc_lr": 0.003497033397477862,
  "epochs": 64,
  "fm_weight_e": 0.0,
  "fm_weight_h": 0.0,
  "gen_lr": 0.0031274082971189366,
  "label_smoothing": 0.0,
  "n_batches": 16,
  "sample_size": 20000,
  "seed": 1324,
  "use_minibatch_discrimination": False,
  "z_dim": 256
}

configs = [d_fm, d_baseline ]
OUTPUT_DIR = Path('output/Best FM vs Best Baseline')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

for config_dict in configs:
    config = Config(**config_dict)
    key_output_dir = OUTPUT_DIR / config.get_key()

    train_model(output_dir=OUTPUT_DIR, raw_data=raw_data, config=config)

100%|██████████| 64/64 [00:09<00:00,  6.41it/s]


Training completed in 0.17 minutes
finished trainining 0c00f834e65025416dd3773c1303d1d99cab405757234944780960c762cb0023


100%|██████████| 64/64 [00:08<00:00,  7.79it/s]

Training completed in 0.14 minutes
finished trainining 1d9ba17ba53ddec69599ee323f75092822f1c0724126e4dfaa20b65fc8d261e2





In [62]:
export_model_outputs(output_root=OUTPUT_DIR, plot_lob_snapshot=plot_lob_snapshot)
plot_all_model_metrics(output_root=OUTPUT_DIR)
compare_models(output_root=OUTPUT_DIR)

  ax.axhline(0, color='k', alpha=0.2)



Processing 0c00f834e65025416dd3773c1303d1d99cab405757234944780960c762cb0023...
Exported config
Exported metrics
Exported 65 LOB snapshots to images

Processing 1d9ba17ba53ddec69599ee323f75092822f1c0724126e4dfaa20b65fc8d261e2...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


Exported 65 LOB snapshots to images

All model outputs processed.
Exported all comparison plots to 'data_compare'


### Best ls_only vs baseline

In [64]:
d_ls = {
  "disc_lr": 0.0010006778193875844,
  "epochs": 64,
  "fm_weight_e": 0.0,
  "fm_weight_h": 0.0,
  "gen_lr": 0.006341521084444159,
  "label_smoothing": 0.0015746702525331782,
  "n_batches": 32,
  "sample_size": 20000,
  "seed": 1324,
  "use_minibatch_discrimination": False,
  "z_dim": 128
}

d_baseline = {
  "disc_lr": 0.003497033397477862,
  "epochs": 64,
  "fm_weight_e": 0.0,
  "fm_weight_h": 0.0,
  "gen_lr": 0.0031274082971189366,
  "label_smoothing": 0.0,
  "n_batches": 16,
  "sample_size": 20000,
  "seed": 1324,
  "use_minibatch_discrimination": False,
  "z_dim": 256
}

configs = [d_ls, d_baseline ]
OUTPUT_DIR = Path('output/Best LS vs Best Baseline')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

for config_dict in configs:
    config = Config(**config_dict)
    key_output_dir = OUTPUT_DIR / config.get_key()

    train_model(output_dir=OUTPUT_DIR, raw_data=raw_data, config=config)

100%|██████████| 64/64 [00:11<00:00,  5.59it/s]


Training completed in 0.19 minutes
finished trainining 376c5a336cb5a52e2d0d03403add2063db73f183e8ec37e20b7777bd5c61e940


100%|██████████| 64/64 [00:08<00:00,  7.43it/s]

Training completed in 0.15 minutes
finished trainining 1d9ba17ba53ddec69599ee323f75092822f1c0724126e4dfaa20b65fc8d261e2





In [None]:
export_model_outputs(output_root=OUTPUT_DIR, plot_lob_snapshot=plot_lob_snapshot)
plot_all_model_metrics(output_root=OUTPUT_DIR)
compare_models(output_root=OUTPUT_DIR)

  ax.axhline(0, color='k', alpha=0.2)



Processing 1d9ba17ba53ddec69599ee323f75092822f1c0724126e4dfaa20b65fc8d261e2...
Exported config
Exported metrics
Exported 65 LOB snapshots to images

Processing 376c5a336cb5a52e2d0d03403add2063db73f183e8ec37e20b7777bd5c61e940...
Exported config
Exported metrics


  ax.axhline(0, color='k', alpha=0.2)


### Best mbd only vs baseline

In [None]:
d_mbd = {
  "disc_lr": 0.009104253947742358,
  "epochs": 64,
  "fm_weight_e": 0.0,
  "fm_weight_h": 0.0,
  "gen_lr": 0.0016711948483970965,
  "label_smoothing": 0.0,
  "n_batches": 8,
  "sample_size": 20000,
  "seed": 1324,
  "use_minibatch_discrimination": true,
  "z_dim": 64
}

d_baseline = {
  "disc_lr": 0.003497033397477862,
  "epochs": 64,
  "fm_weight_e": 0.0,
  "fm_weight_h": 0.0,
  "gen_lr": 0.0031274082971189366,
  "label_smoothing": 0.0,
  "n_batches": 16,
  "sample_size": 20000,
  "seed": 1324,
  "use_minibatch_discrimination": False,
  "z_dim": 256
}

configs = [d_mbd, d_baseline ]
OUTPUT_DIR = Path('output/Best MBD vs Best Baseline')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

for config_dict in configs:
    config = Config(**config_dict)
    key_output_dir = OUTPUT_DIR / config.get_key()

    train_model(output_dir=OUTPUT_DIR, raw_data=raw_data, config=config)

In [None]:
export_model_outputs(output_root=OUTPUT_DIR, plot_lob_snapshot=plot_lob_snapshot)
plot_all_model_metrics(output_root=OUTPUT_DIR)
compare_models(output_root=OUTPUT_DIR)

In [None]:
FILE_PATH = Path('BTCUSDT-lob.parq')
raw_data = pd.read_parquet(FILE_PATH)

output_folder = Path("raw_lob_snapshots")
output_folder.mkdir(exist_ok=True)

for i in range(10):
    fig, ax = plt.subplots(figsize=(4, 4))
    plot_raw_lob_snapshot(raw_data.iloc[i], full_df=raw_data, ax=ax)
    plt.tight_layout()
    plt.savefig(output_folder / f"raw_lob_{i:03}.png")
    plt.close()