Skip to content

Commit

Permalink
trim code
Browse files Browse the repository at this point in the history
  • Loading branch information
timkpaine committed Jul 2, 2019
1 parent 835495e commit bb9ffab
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 70 deletions.
4 changes: 2 additions & 2 deletions aat/exchanges/coinbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ..enums import PairType, TickType, TickType_from_string
from ..exchange import Exchange
from ..structs import MarketData, Instrument
from ..utils import parse_date, str_to_currency_pair_type, str_to_side, str_to_order_type
from ..utils import parse_date, str_to_side, str_to_order_type


class CoinbaseExchange(Exchange):
Expand Down Expand Up @@ -44,7 +44,7 @@ def tickToData(self, jsn: dict) -> MarketData:
volume = float(jsn.get('size', 'nan'))
price = float(jsn.get('price', 'nan'))

currency_pair = str_to_currency_pair_type(jsn.get('product_id')) if typ != TickType.ERROR else PairType.NONE
currency_pair = PairType.from_string(jsn.get('product_id')) if typ != TickType.ERROR else PairType.NONE

instrument = Instrument(underlying=currency_pair)

Expand Down
4 changes: 2 additions & 2 deletions aat/exchanges/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..exchange import Exchange
from ..logging import log
from ..structs import MarketData, Instrument
from ..utils import str_to_side, str_to_currency_pair_type
from ..utils import str_to_side


class GeminiExchange(Exchange):
Expand Down Expand Up @@ -143,7 +143,7 @@ def tickToData(self, jsn: dict) -> MarketData:
if 'symbol' not in jsn:
return

currency_pair = str_to_currency_pair_type(jsn.get('symbol'))
currency_pair = PairType.from_string(jsn.get('symbol'))
instrument = Instrument(underlying=currency_pair)

ret = MarketData(order_id=order_id,
Expand Down
23 changes: 12 additions & 11 deletions aat/order_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from functools import lru_cache
from typing import List
from .data_source import RestAPIDataSource
from .enums import PairType, TradingType, CurrencyType, ExchangeType, ExchangeType_to_string
from .enums import PairType, TradingType, CurrencyType, ExchangeType, ExchangeType_to_string, TradeResult
from .exceptions import AATException
from .structs import TradeRequest, TradeResponse, Account, Instrument
from .utils import (get_keys_from_environment, str_to_currency_type, str_to_side,
from .utils import (get_keys_from_environment, str_to_side,
exchange_type_to_ccxt_client, tradereq_to_ccxt_order,
str_to_trade_result, parse_date, findpath)
parse_date, findpath)
# from .utils import elog as log


Expand Down Expand Up @@ -45,7 +45,7 @@ def accounts(self):
else:
currency = jsn['currency']

currency = str_to_currency_type(currency)
currency = CurrencyType(currency)
if 'balance' in jsn:
balance = float(jsn['balance'])
elif 'amount' in jsn:
Expand Down Expand Up @@ -142,14 +142,15 @@ def _extract_fields(self, order, exchange):
original = float(order.get('info', {}).get('original_amount', 0.0))
is_cancelled = order.get('info', {}).get('is_cancelled', False)
if is_cancelled:
status = 'REJECTED'

status = TradeResult.REJECTED
if filled == original or remaining <= 0:
status = 'FILLED'
status = TradeResult.FILLED
elif remaining < original and remaining > 0:
status = 'PARTIAL'
status = TradeResult.PARTIAL
elif remaining == original:
status = 'PENDING'
status = TradeResult.PENDING
elif status in ('OPEN',):
status = TradeResult.PENDING
return side, filled, price, datetime, status, cost, remaining

def buy(self, req: TradeRequest) -> TradeResponse:
Expand All @@ -164,7 +165,7 @@ def buy(self, req: TradeRequest) -> TradeResponse:
price=float(price),
instrument=req.instrument,
time=parse_date(datetime),
status=str_to_trade_result(status),
status=status,
order_id=order['id'],
slippage=float(price) - req.price,
transaction_cost=cost,
Expand All @@ -183,7 +184,7 @@ def sell(self, req: TradeRequest) -> TradeResponse:
price=float(price),
instrument=req.instrument,
time=parse_date(datetime),
status=str_to_trade_result(status),
status=status,
order_id=order['id'],
slippage=float(price) - req.price,
transaction_cost=cost,
Expand Down
6 changes: 3 additions & 3 deletions aat/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from configparser import ConfigParser
from pydoc import locate
from .config import TradingEngineConfig, BacktestConfig, StrategyConfig, SyntheticExchangeConfig
from .enums import TradingType, InstrumentType, ExchangeType
from .enums import TradingType, InstrumentType, ExchangeType, PairType
from .exceptions import ConfigException
from .structs import Instrument
from .utils import str_to_exchange, set_verbose, str_to_currency_pair_type
from .utils import str_to_exchange, set_verbose
from .logging import log


Expand Down Expand Up @@ -130,7 +130,7 @@ def _parse_currencies(currencies):
splits = [x.strip().upper().replace('-', '') for x in currencies.split(',')]
ret = []
for s in splits:
ret.append(str_to_currency_pair_type(s))
ret.append(PairType.from_string(s))
return ret


Expand Down
25 changes: 0 additions & 25 deletions aat/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,6 @@ def test_get_keys_from_environment(self):
assert(two == 'test')
assert(three == 'test')

def test_str_to_currency_type(self):
from ..utils import str_to_currency_type
from ..enums import CurrencyType
assert(str_to_currency_type('BTC') == CurrencyType.BTC)
assert(str_to_currency_type('ETH') == CurrencyType.ETH)
assert(str_to_currency_type('LTC') == CurrencyType.LTC)
assert(str_to_currency_type('USD') == CurrencyType.USD)
assert(str_to_currency_type('ZRX') == CurrencyType.ZRX)

def test_str_to_side(self):
from ..utils import str_to_side
from ..enums import Side
Expand All @@ -70,22 +61,6 @@ def test_str_to_exchange(self):
assert(str_to_exchange('kraken') == ExchangeType.KRAKEN)
assert(str_to_exchange('poloniex') == ExchangeType.POLONIEX)

def test_str_to_currency_pair_type(self):
from ..utils import str_to_currency_pair_type
from ..enums import PairType, CurrencyType

for c1, v1 in CurrencyType.__members__.items():
for c2, v2 in CurrencyType.__members__.items():
if c1 == c2 or \
c1 == 'USD' or \
c2 == 'USD' or \
c1 == 'NONE' or \
c2 == 'NONE':
continue
assert str_to_currency_pair_type(c1 + '/' + c2) == PairType.from_string(c1 + '/' + c2)
assert str_to_currency_pair_type(c1 + '-' + c2) == PairType.from_string(c1, c2)
assert str_to_currency_pair_type(c1 + c2) == PairType.from_string(c1, c2)

def test_trade_req_to_params(self):
from ..utils import trade_req_to_params
from ..structs import TradeRequest, Instrument, ExchangeType
Expand Down
37 changes: 10 additions & 27 deletions aat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,31 +54,23 @@ def ex_type_to_ex(ex: ExchangeType):

@lru_cache(None)
def get_keys_from_environment(prefix: str) -> tuple:
'''get exchange keys from environment variables
Args:
prefix (str): exchange name e.g. COINBASE
Returns
key (str): private key
secret (str): private secret
passphrase (str): optional passphrase
'''

prefix = prefix.upper()
key = os.environ[prefix + '_API_KEY']
secret = os.environ[prefix + '_API_SECRET']
passphrase = os.environ[prefix + '_API_PASS']
return key, secret, passphrase


@lru_cache(None)
def str_to_currency_type(s: str) -> CurrencyType:
s = s.upper()
if s not in CurrencyType.members():
raise AATException(f'CurrencyType not recognized {s}')
return CurrencyType(s)


@lru_cache(None)
def str_to_currency_pair_type(s: str) -> PairType:
return PairType.from_string(s)


@lru_cache(None)
def str_currency_to_currency_pair_type(s: str, base: str = 'USD') -> PairType:
return PairType.from_string(s, base)


@lru_cache(None)
def str_to_side(s: str) -> Side:
s = s.upper()
Expand All @@ -99,14 +91,6 @@ def str_to_order_type(s: str) -> OrderType:
return OrderType.NONE


@lru_cache(None)
def str_to_trade_result(s: str) -> TradeResult:
s = s.upper()
if s in ('OPEN',):
s = 'PENDING'
return TradeResult(s)


@lru_cache(None)
def str_to_exchange(exchange: str) -> ExchangeType:
if exchange.upper() not in ExchangeTypes:
Expand All @@ -132,7 +116,6 @@ def trade_req_to_params(req) -> dict:

if ret['type'] == 'limit':
ret['price'] = req.price

return ret


Expand Down

0 comments on commit bb9ffab

Please sign in to comment.