Skip to content

Commit

Permalink
fixes #7, update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
timkpaine committed Jun 1, 2019
1 parent 6f58e03 commit ca3196b
Show file tree
Hide file tree
Showing 21 changed files with 208 additions and 141 deletions.
4 changes: 2 additions & 2 deletions aat/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .callback import Callback
from .structs import TradeRequest, TradeResponse
from .enums import TickType
from .exceptions import CallbackException


class DataSource(metaclass=ABCMeta):
Expand Down Expand Up @@ -108,8 +109,7 @@ def onContinue(self, callback: Callback) -> None:

def registerCallback(self, callback: Callback) -> None:
if not isinstance(callback, Callback):
raise Exception('%s is not an instance of class '
'Callback' % callback)
raise CallbackException(f'{callback} is not an instance of class Callback')
for att in ['onTrade',
'onOpen',
'onFill',
Expand Down
8 changes: 8 additions & 0 deletions aat/enums.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import lru_cache
from enum import Enum


Expand Down Expand Up @@ -53,6 +54,9 @@ class CurrencyType(BaseEnum):
NONE = 'NONE' # special, dont use

USD = 'USD'
EUR = 'EUR'
GBP = 'GBP'

USDC = 'USDC'
BAT = 'BAT'
BCH = 'BCH'
Expand Down Expand Up @@ -86,7 +90,11 @@ class _PairType(BaseEnum):
def __str__(self):
return str(self.value[0].value) + '/' + str(self.value[1].value)

def __hash__(self):
return hash(str(self))

@staticmethod
@lru_cache(None)
def from_string(first, second=''):
if second:
c1 = CurrencyType(first)
Expand Down
14 changes: 13 additions & 1 deletion aat/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
class ConfigException(Exception):
class AATException(Exception):
pass


class CallbackException(AATException):
pass


class ConfigException(AATException):
pass


class QueryException(AATException):
pass
14 changes: 1 addition & 13 deletions aat/market_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import abstractmethod
from .data_source import StreamingDataSource
from .define import EXCHANGE_MARKET_DATA_ENDPOINT
from .enums import OrderType, PairType, CurrencyType, TickType
from .enums import OrderType, PairType, TickType
from .structs import MarketData
from .logging import LOG as log

Expand Down Expand Up @@ -86,18 +86,6 @@ def strToTradeType(self, s: str) -> TickType:
def tradeReqToParams(self, req) -> dict:
pass

def currencyToString(self, cur: CurrencyType) -> str:
if cur == CurrencyType.BTC:
return 'BTC'
if cur == CurrencyType.ETH:
return 'ETH'
if cur == CurrencyType.LTC:
return 'LTC'
if cur == CurrencyType.BCH:
return 'BCH'
else:
raise Exception('Pair not recognized: %s' % str(cur))

@abstractmethod
def currencyPairToString(self, cur: PairType) -> str:
pass
Expand Down
65 changes: 51 additions & 14 deletions aat/order_entry.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import ccxt
import pandas as pd
from datetime import datetime
from functools import lru_cache
from typing import List
from .data_source import RestAPIDataSource
from .enums import PairType, TradingType, ExchangeType
from .structs import TradeRequest, TradeResponse, Account
from .enums import PairType, TradingType, CurrencyType
from .exceptions import AATException
from .structs import TradeRequest, TradeResponse, Account, Instrument
from .utils import (get_keys_from_environment, str_to_currency_type,
exchange_type_to_ccxt_client, tradereq_to_ccxt_order)
from .utils import elog as log
exchange_type_to_ccxt_client, tradereq_to_ccxt_order,
findpath)
# from .utils import elog as log


class OrderEntry(RestAPIDataSource):
Expand Down Expand Up @@ -56,15 +58,50 @@ def accounts(self):
self._accounts = accounts
return accounts

def lastPrice(self, cur: PairType):
try:
return self.oe_client().fetchTicker(self.currencyPairToStringCCXT(cur))
except (ccxt.ExchangeError, ValueError):
return {'last': -1.0}
@lru_cache(None)
def currencies(self) -> List[CurrencyType]:
return [CurrencyType(x) for x in self.oe_client().fetch_curencies()]

@lru_cache(None)
def currencyPairToStringCCXT(self, cur: PairType) -> str:
return cur.value[0].value + '/' + cur.value[1].value
def markets(self) -> List[Instrument]:
# TODO derivatives
return [Instrument(underlying=PairType.from_string(m['symbol'])) for m in self.oe_client().fetch_markets()]

def ticker(self,
instrument: Instrument = None,
currency: CurrencyType = None):
if instrument:
return self.oe_client().fetchTicker(str(instrument.underlying))
elif currency:
inst = Instrument(underlying=PairType.from_string(currency.value + '/USD'))

if inst in self.markets():
return self.oe_client().fetchTicker(str(inst.underlying))
else:
try:
inst1, inst2, i1_inverted, i2_inverted = findpath(inst, self.markets())
except AATException:
try:
inst = Instrument(underlying=PairType.from_string(currency.value + '/USDC'))
inst1, inst2, i1_inverted, i2_inverted = findpath(inst, self.markets())
except AATException:
return {'last': 0.0}
inst1_t = self.oe_client().fetchTicker(str(inst1.underlying))
inst2_t = self.oe_client().fetchTicker(str(inst2.underlying))
if i1_inverted:
inst1_t['last'] = 1.0/inst1_t['last']
if i2_inverted:
inst2_t['last'] = 1.0/inst2_t['last']
px = inst1_t['last'] * inst2_t['last']
ret = inst1_t
for key in ret:
if key == 'info':
ret[key] = {}
elif key == 'last':
ret[key] = px
else:
ret[key] = None
return ret

def historical(self, timeframe='1m', since=None, limit=None):
'''get historical data (for backtesting)'''
Expand Down Expand Up @@ -95,9 +132,9 @@ def sell(self, req: TradeRequest) -> TradeResponse:
self.oe_client().create_order(**params)

def cancel(self, resp: TradeResponse):
params = tradereq_to_ccxt_order(req)
params = tradereq_to_ccxt_order(resp)
raise NotImplementedError()
self.oe_client().create_order(**params)
self.oe_client().cancel_order(**params)

def cancelAll(self, resp: TradeResponse):
return self.oe_client().cancel_all_orders()
6 changes: 3 additions & 3 deletions aat/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _parse_exchange(exchange, config) -> None:
def _parse_strategy(strategy, config) -> None:
strat_configs = []
if 'strategies' not in strategy:
raise Exception('No Strategies specified')
raise ConfigException('No Strategies specified')

for strat in strategy['strategies'].split('\n'):
if strat == '':
Expand Down Expand Up @@ -134,7 +134,7 @@ def _parse_options(argv, config: TradingEngineConfig) -> None:
if argv.get('exchanges'):
config.exchange_options.exchange_types = [str_to_exchange(x) for x in argv['exchanges'].split() if x]
else:
raise Exception('No exchange set!')
raise ConfigException('No exchange set!')

if argv.get('currency_pairs'):
config.exchange_options.currency_pairs = _parse_currencies(argv.get('currency_pairs'))
Expand Down Expand Up @@ -170,7 +170,7 @@ def _parse_backtest_options(argv, config) -> None:
if argv.get('exchanges'):
config.exchange_options.exchange_types = [str_to_exchange(x) for x in argv['exchanges'].split() if x]
else:
raise Exception('No exchange set!')
raise ConfigException('No exchange set!')

if argv.get('currency_pairs'):
config.exchange_options.currency_pairs = _parse_currencies(argv.get('currency_pairs'))
Expand Down
9 changes: 5 additions & 4 deletions aat/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List
from concurrent.futures import ThreadPoolExecutor
from .enums import TickType, Side, ExchangeType, CurrencyType, PairType # noqa: F401
from .exceptions import QueryException
from .structs import Instrument, MarketData, TradeRequest, TradeResponse


Expand Down Expand Up @@ -32,7 +33,7 @@ def query_instruments(self) -> List[PairType]:
def query_exchanges(self) -> List[ExchangeType]:
return self._exchanges

def _paginate(self, instrument: Instrument, lst: list, lst_sub: list, page: int = 1)-> list:
def _paginate(self, instrument: Instrument, lst: list, lst_sub: list, page: int = 1) -> list:
if page is not None:
from_ = -1*page*100
to_ = -1*(page-1)*100
Expand All @@ -56,13 +57,13 @@ def query_lastprice(self,
instrument: Instrument,
exchange: ExchangeType = None) -> MarketData:
if instrument not in self._last_price_by_asset_and_exchange:
raise Exception('Not found!')
raise QueryException('Not found!')
if exchange:
if exchange not in self._last_price_by_asset_and_exchange[instrument]:
raise Exception('Not found!')
raise QueryException('Not found!')
return self._last_price_by_asset_and_exchange[instrument][exchange]
if "ANY" not in self._last_price_by_asset_and_exchange[instrument]:
raise Exception('Not found!')
raise QueryException('Not found!')
return self._last_price_by_asset_and_exchange[instrument]["ANY"]

def query_trades(self,
Expand Down
17 changes: 6 additions & 11 deletions aat/strategies/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ def __init__(self, size: int) -> None:
self.profits = 0.0

def onBuy(self, res: TradeResponse) -> None:
if self._intitialvalue is None:
date = res.time
self._intitialvalue = (date, res.volume*res.price)
self._portfolio_value.append(self._intitialvalue)

self.bought = res.volume*res.price
self.bought_qty = res.volume
slog.info('d->g:bought %.2f @ %.2f for %.2f', res.volume, res.price, self.bought)
Expand All @@ -39,11 +34,6 @@ def onSell(self, res: TradeResponse) -> None:
self.bought = 0.0
self.bought_qty = 0.0

date = res.time
self._portfolio_value.append((
date,
self._portfolio_value[-1][1] + profit))

def onTrade(self, data: MarketData):
# add data to arrays
self.ticks.append(data.price)
Expand Down Expand Up @@ -104,11 +94,16 @@ def onTrade(self, data: MarketData):
def onError(self, e: MarketData):
elog.critical(e)

def onAnalyze(self, portfolio_value, requests, responses) -> None:
def onAnalyze(self, engine) -> None:
import pandas
import matplotlib.pyplot as plt
import seaborn as sns

portfolio_value = engine.portfolio_value()
requests = engine.query().query_tradereqs()
trades = pandas.DataFrame([{'time': x.time, 'price': x.price} for x in engine.query().query_trades(instrument=requests[0].instrument, page=None)])
trades.set_index(['time'], inplace=True)

pd = pandas.DataFrame(portfolio_value, columns=['time', 'value'])
pd.set_index(['time'], inplace=True)

Expand Down
2 changes: 1 addition & 1 deletion aat/strategies/buy_and_hold.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ..strategy import TradingStrategy
from ..structs import MarketData, TradeRequest, TradeResponse
from ..enums import Side, TradeResult, OrderType
from ..enums import Side, OrderType
from ..logging import STRAT as slog, ERROR as elog


Expand Down
2 changes: 1 addition & 1 deletion aat/strategies/data_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ def slippage(self, resp: TradeResponse) -> TradeResponse:
def transactionCost(self, resp: TradeResponse) -> TradeResponse:
return resp

def onAnalyze(self, _) -> None:
def onAnalyze(self, engine) -> None:
pass
44 changes: 27 additions & 17 deletions aat/strategies/sma_crosses_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def onTrade(self, data: MarketData) -> bool:
instrument=data.instrument,
order_type=OrderType.MARKET,
exchange=data.exchange,
price=data.price)
price=data.price,
time=data.time)
# slog.info("requesting buy : %s", req)
self.requestBuy(self.onBuy, req)
return True
Expand All @@ -97,7 +98,8 @@ def onTrade(self, data: MarketData) -> bool:
instrument=data.instrument,
order_type=OrderType.MARKET,
exchange=data.exchange,
price=data.price)
price=data.price,
time=data.time)
# slog.info("requesting sell : %s", req)
self.requestSell(self.onSell, req)
return True
Expand All @@ -107,28 +109,36 @@ def onTrade(self, data: MarketData) -> bool:
def onError(self, e) -> None:
elog.critical(e)

def onAnalyze(self, portfolio_value, requests, responses) -> None:
def onAnalyze(self, engine) -> None:
import pandas
import matplotlib.pyplot as plt
import seaborn as sns
portfolio_value = engine.portfolio_value()
requests = engine.query().query_tradereqs()
responses = engine.query().query_traderesps()

pd = pandas.DataFrame(portfolio_value, columns=['time', 'value'])
pd.set_index(['time'], inplace=True)

print(self.short, self.long, pd.iloc[1].value, pd.iloc[-1].value)
sns.set_style('darkgrid')
fig, ax1 = plt.subplots()

plt.title('BTC algo 1 performance - %d-%d Momentum ' % (self.short, self.long))
ax1.plot(pd)

ax1.set_ylabel('Portfolio value($)')
ax1.set_xlabel('Date')
for xy in [portfolio_value[0]] + [portfolio_value[-1]]:
ax1.annotate('$%s' % xy[1], xy=xy, textcoords='data')
plt.show()
print(requests)
print(responses)
if len(requests) > 0:
trades = pandas.DataFrame([{'time': x.time, 'price': x.price} for x in engine.query().query_trades(instrument=requests[0].instrument, page=None)])
trades.set_index(['time'], inplace=True)

if pd.size > 0:
print(self.short, self.long, pd.iloc[1].value, pd.iloc[-1].value)
sns.set_style('darkgrid')
fig, ax1 = plt.subplots()

plt.title('BTC algo 1 performance - %d-%d Momentum ' % (self.short, self.long))
ax1.plot(pd)

ax1.set_ylabel('Portfolio value($)')
ax1.set_xlabel('Date')
for xy in [portfolio_value[0]] + [portfolio_value[-1]]:
ax1.annotate('$%s' % xy[1], xy=xy, textcoords='data')
plt.show()
print(requests)
print(responses)

def onChange(self, data: MarketData) -> None:
pass
Expand Down
7 changes: 6 additions & 1 deletion aat/strategies/test_strat.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,16 @@ def slippage(self, resp: TradeResponse) -> TradeResponse:
def transactionCost(self, resp: TradeResponse) -> TradeResponse:
return resp

def onAnalyze(self, portfolio_value, requests, responses) -> None:
def onAnalyze(self, engine) -> None:
import pandas
import matplotlib.pyplot as plt
import seaborn as sns

portfolio_value = engine.portfolio_value()
requests = engine.query().query_tradereqs()
trades = pandas.DataFrame([{'time': x.time, 'price': x.price} for x in engine.query().query_trades(instrument=requests[0].instrument, page=None)])
trades.set_index(['time'], inplace=True)

pd = pandas.DataFrame(portfolio_value, columns=['time', 'value'])
pd.set_index(['time'], inplace=True)

Expand Down
3 changes: 3 additions & 0 deletions aat/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def __eq__(self, other):
def __str__(self):
return str(self.underlying)

def __repr__(self):
return str(self.underlying)

def __hash__(self):
return hash(str(self.underlying))

Expand Down

0 comments on commit ca3196b

Please sign in to comment.