Skip to content

Commit

Permalink
[Databases] allow multi exchange metadatabases
Browse files Browse the repository at this point in the history
  • Loading branch information
Guillaume De Saint Martin authored and GuillaumeDSM committed Feb 20, 2023
1 parent 27f7598 commit 281ce08
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 91 deletions.
139 changes: 139 additions & 0 deletions octobot_commons/databases/implementations/_exchange_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# pylint: disable=R0902,C0103
# Drakkar-Software OctoBot-Commons
# Copyright (c) Drakkar-Software, All rights reserved.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3.0 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library.
import os

import octobot_commons.databases.implementations.db_writer_reader as db_writer_reader


class ExchangeDatabase:
def __init__(self, meta_database, exchange):
self.meta_database = meta_database
self.run_dbs_identifier = self.meta_database.run_dbs_identifier
self.exchange = exchange
self.orders_db: db_writer_reader.DBWriterReader = None
self.trades_db: db_writer_reader.DBWriterReader = None
self.transactions_db: db_writer_reader.DBWriterReader = None
self.historical_portfolio_value_db: db_writer_reader.DBWriterReader = None
self.symbol_dbs: dict = {}

def get_orders_db(self, account_type):
"""
:return: the orders database. Opens it if not open already
"""
if self.orders_db is None:
self.orders_db = self.meta_database.get_db(
self.run_dbs_identifier.get_orders_db_identifier(
account_type,
self.exchange,
)
)
return self.orders_db

def get_trades_db(self, account_type):
"""
:return: the trades database. Opens it if not open already
"""
if self.trades_db is None:
self.trades_db = self.meta_database.get_db(
self.run_dbs_identifier.get_trades_db_identifier(
account_type,
self.exchange,
)
)
return self.trades_db

def get_transactions_db(self, account_type):
"""
:return: the transactions database. Opens it if not open already
"""
if self.transactions_db is None:
self.transactions_db = self.meta_database.get_db(
self.run_dbs_identifier.get_transactions_db_identifier(
account_type,
self.exchange,
)
)
return self.transactions_db

def get_historical_portfolio_value_db(self, account_type):
"""
:return: the historical portfolio database. Opens it if not open already
"""
if self.historical_portfolio_value_db is None:
self.historical_portfolio_value_db = self.meta_database.get_db(
self.run_dbs_identifier.get_historical_portfolio_value_db_identifier(
account_type, self.exchange
)
)
return self.historical_portfolio_value_db

def get_symbol_db(self, symbol):
"""
:return: the symbol database. Opens it if not open already
"""
key = self._get_symbol_db_key(self.exchange, symbol)
if key not in self.symbol_dbs:
self.symbol_dbs[key] = self.meta_database.get_db(
self.run_dbs_identifier.get_symbol_db_identifier(self.exchange, symbol)
)
return self.symbol_dbs[key]

async def get_all_symbol_dbs(self):
"""
:return: an iterable over each symbol database for the given exchange
"""
if self.run_dbs_identifier.database_adaptor.is_file_system_based():
return [
self.get_symbol_db(self.run_dbs_identifier.get_symbol_db_name(db.name))
for db in os.scandir(
self.run_dbs_identifier.get_exchange_based_identifier(self.exchange)
)
if self.run_dbs_identifier.is_symbol_database(db.name)
]
raise NotImplementedError(
"get_all_symbol_dbs is not implemented for non is_file_system_based databases"
)

def all_basic_run_db(self, account_type):
"""
yields the run, orders, trades and transactions databases
"""
yield self.get_orders_db(account_type)
yield self.get_trades_db(account_type)
yield self.get_transactions_db(account_type)

@staticmethod
def _get_symbol_db_key(exchange, symbol):
return f"{exchange}{symbol}"

async def close(self):
"""
Closes all the open databases
"""
# avoid asyncio.gather here as it is producing unexplained side effects (frozen thread preventing stop)
for coro in (
db.close()
for db in (
self.orders_db,
self.trades_db,
self.transactions_db,
self.historical_portfolio_value_db,
*self.symbol_dbs.values(),
)
if db is not None
):
await coro
145 changes: 54 additions & 91 deletions octobot_commons/databases/implementations/meta_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# You should have received a copy of the GNU Lesser General Public
# License along with this library.
import contextlib
import os

import octobot_commons.databases.implementations.db_writer_reader as db_writer_reader
import octobot_commons.databases.implementations._exchange_database as _exchange_database
import octobot_commons.enums as enums


Expand All @@ -28,80 +28,25 @@ def __init__(self, run_dbs_identifier, with_lock=False, cache_size=None):
self.cache_size = cache_size
self.database_adaptor = self.run_dbs_identifier.database_adaptor
self.run_db: db_writer_reader.DBWriterReader = None
self.orders_db: db_writer_reader.DBWriterReader = None
self.trades_db: db_writer_reader.DBWriterReader = None
self.transactions_db: db_writer_reader.DBWriterReader = None
self.historical_portfolio_value_db: db_writer_reader.DBWriterReader = None
self.backtesting_metadata_db: db_writer_reader.DBWriterReader = None
self.symbol_dbs: dict = {}
self.exchange_dbs = {}

def get_run_db(self):
"""
:return: the run database. Opens it if not open already
"""
if self.run_db is None:
self.run_db = self._get_db(
self.run_db = self.get_db(
self.run_dbs_identifier.get_run_data_db_identifier()
)
return self.run_db

def get_orders_db(self, account_type, exchange=None):
"""
:return: the orders database. Opens it if not open already
"""
if self.orders_db is None:
self.orders_db = self._get_db(
self.run_dbs_identifier.get_orders_db_identifier(
account_type,
exchange or self.run_dbs_identifier.context.exchange_name,
)
)
return self.orders_db

def get_trades_db(self, account_type, exchange=None):
"""
:return: the trades database. Opens it if not open already
"""
if self.trades_db is None:
self.trades_db = self._get_db(
self.run_dbs_identifier.get_trades_db_identifier(
account_type,
exchange or self.run_dbs_identifier.context.exchange_name,
)
)
return self.trades_db

def get_transactions_db(self, account_type, exchange=None):
"""
:return: the transactions database. Opens it if not open already
"""
if self.transactions_db is None:
self.transactions_db = self._get_db(
self.run_dbs_identifier.get_transactions_db_identifier(
account_type,
exchange or self.run_dbs_identifier.context.exchange_name,
)
)
return self.transactions_db

def get_historical_portfolio_value_db(self, account_type, exchange):
"""
:return: the historical portfolio database. Opens it if not open already
"""
if self.historical_portfolio_value_db is None:
self.historical_portfolio_value_db = self._get_db(
self.run_dbs_identifier.get_historical_portfolio_value_db_identifier(
account_type, exchange
)
)
return self.historical_portfolio_value_db

def get_backtesting_metadata_db(self):
"""
:return: the backtesting metadata database. Opens it if not open already
"""
if self.backtesting_metadata_db is None:
self.backtesting_metadata_db = self._get_db(
self.backtesting_metadata_db = self.get_db(
self.run_dbs_identifier.get_backtesting_metadata_identifier()
)
return self.backtesting_metadata_db
Expand All @@ -118,49 +63,70 @@ async def get_backtesting_metadata_from_run(self):
)
)[0]

def _get_exchange_db(self, exchange=None):
"""
:return: the ExchangeDatabase associated to the given exchange
"""
exchange = exchange or self.run_dbs_identifier.context.exchange_name
try:
return self.exchange_dbs[exchange]
except KeyError:
self.exchange_dbs[exchange] = _exchange_database.ExchangeDatabase(
self, exchange
)
return self.exchange_dbs[exchange]

def get_orders_db(self, account_type, exchange=None):
"""
:return: the orders database. Opens it if not open already
"""
return self._get_exchange_db(exchange).get_orders_db(account_type)

def get_trades_db(self, account_type, exchange=None):
"""
:return: the trades database. Opens it if not open already
"""
return self._get_exchange_db(exchange).get_trades_db(account_type)

def get_transactions_db(self, account_type, exchange=None):
"""
:return: the transactions database. Opens it if not open already
"""
return self._get_exchange_db(exchange).get_transactions_db(account_type)

def get_historical_portfolio_value_db(self, account_type, exchange):
"""
:return: the historical portfolio database. Opens it if not open already
"""
return self._get_exchange_db(exchange).get_historical_portfolio_value_db(
account_type
)

def get_symbol_db(self, exchange, symbol):
"""
:return: the symbol database. Opens it if not open already
"""
key = self._get_symbol_db_key(exchange, symbol)
if key not in self.symbol_dbs:
self.symbol_dbs[key] = self._get_db(
self.run_dbs_identifier.get_symbol_db_identifier(exchange, symbol)
)
return self.symbol_dbs[key]
return self._get_exchange_db(exchange).get_symbol_db(symbol)

async def get_all_symbol_dbs(self, exchange):
"""
:return: an iterable over each symbol database for the given exchange
"""
if self.run_dbs_identifier.database_adaptor.is_file_system_based():
return [
self.get_symbol_db(
exchange, self.run_dbs_identifier.get_symbol_db_name(db.name)
)
for db in os.scandir(
self.run_dbs_identifier.get_exchange_based_identifier(exchange)
)
if self.run_dbs_identifier.is_symbol_database(db.name)
]
raise NotImplementedError(
"get_all_symbol_dbs is not implemented for non is_file_system_based databases"
)
return self._get_exchange_db(exchange).get_all_symbol_dbs()

def all_basic_run_db(self, account_type, exchange=None):
"""
yields the run, orders, trades and transactions databases
"""
yield self.get_run_db()
yield self.get_orders_db(account_type, exchange)
yield self.get_trades_db(account_type, exchange)
yield self.get_transactions_db(account_type, exchange)
exchange = exchange or self.run_dbs_identifier.context.exchange_name
for db in self.exchange_dbs[exchange].all_basic_run_db(account_type):
yield db

@staticmethod
def _get_symbol_db_key(exchange, symbol):
return f"{exchange}{symbol}"

def _get_db(self, db_identifier):
def get_db(self, db_identifier):
"""
:return: the database associated to the given identifier
"""
return db_writer_reader.DBWriterReader(
db_identifier,
with_lock=self.with_lock,
Expand All @@ -178,16 +144,13 @@ async def close(self):
db.close()
for db in (
self.run_db,
self.orders_db,
self.trades_db,
self.transactions_db,
self.historical_portfolio_value_db,
self.backtesting_metadata_db,
*self.symbol_dbs.values(),
)
if db is not None
):
await coro
for exchange_db in self.exchange_dbs.values():
await exchange_db.close()

@classmethod
@contextlib.asynccontextmanager
Expand Down

0 comments on commit 281ce08

Please sign in to comment.