Skip to content

Commit

Permalink
add buy and hold strat, fix backtest bug, add portfolio calculations #9
Browse files Browse the repository at this point in the history
  • Loading branch information
timkpaine committed May 29, 2019
1 parent 19c7f01 commit 52cf350
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 79 deletions.
8 changes: 4 additions & 4 deletions aat/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

def line_to_data(record):
data = MarketData(time=record.name[0],
price=record.volume,
volume=record.close,
volume=record.volume,
price=record.close,
type=TickType.TRADE,
instrument=Instrument(underlying=PairType.from_string(record.name[1])),
exchange=ExchangeType(record.exchange),
Expand All @@ -31,15 +31,15 @@ def run(self, engine) -> None:
for index, row in data.iterrows():
self.receive(line_to_data(row))
log.info('Backtest done, running analysis.')
self.callback(TickType.ANALYZE, None)

self.callback(TickType.ANALYZE, engine.portfolio_value(), engine.query().query_tradereqs(), engine.query().query_traderesps())
log.info('Analysis completed.')

def receive(self, data: MarketData) -> None:
# TODO allow if market data for bid/ask
if data.type == TickType.TRADE:
self.callback(TickType.TRADE, data)
dlog.info(data)

else:
self.callback(TickType.ERROR, data)

Expand Down
4 changes: 2 additions & 2 deletions aat/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def seqnum(self, number: int):
def receive(self):
'''receive data and call callbacks'''

def callback(self, field: str, data) -> None:
def callback(self, field: str, data, *args, **kwargs) -> None:
for cb in self._callbacks[field]:
cb(data)
cb(data, *args, **kwargs)

# Data functions
@abstractmethod
Expand Down
21 changes: 9 additions & 12 deletions aat/order_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,9 @@
from .data_source import RestAPIDataSource
from .enums import PairType, TradingType, ExchangeType
from .structs import TradeRequest, TradeResponse, Account
from .utils import get_keys_from_environment, str_to_currency_type


def exchange_type_to_ccxt_client(exchange_type):
if exchange_type == ExchangeType.COINBASE:
return ccxt.coinbasepro
elif exchange_type == ExchangeType.GEMINI:
return ccxt.gemini
elif exchange_type == ExchangeType.KRAKEN:
return ccxt.kraken
elif exchange_type == ExchangeType.POLONIEX:
return ccxt.poloniex
from .utils import (get_keys_from_environment, str_to_currency_type,
exchange_type_to_ccxt_client, tradereq_to_ccxt_order)
from .utils import elog as log


class OrderEntry(RestAPIDataSource):
Expand Down Expand Up @@ -93,14 +84,20 @@ def orderBook(self, level=1):

def buy(self, req: TradeRequest) -> TradeResponse:
'''execute a buy order'''
params = tradereq_to_ccxt_order(req)
raise NotImplementedError()
self.oe_client().create_order(**params)

def sell(self, req: TradeRequest) -> TradeResponse:
'''execute a sell order'''
params = tradereq_to_ccxt_order(req)
raise NotImplementedError()
self.oe_client().create_order(**params)

def cancel(self, resp: TradeResponse):
params = tradereq_to_ccxt_order(req)
raise NotImplementedError()
self.oe_client().create_order(**params)

def cancelAll(self, resp: TradeResponse):
return self.oe_client().cancel_all_orders()
23 changes: 19 additions & 4 deletions aat/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self,
self._instruments = instruments
self._exchanges = exchanges

self._last_price_by_exchange = {}
self._last_price_by_asset_and_exchange = {}

self._trade_reqs = []
self._trade_resps = []
Expand All @@ -42,6 +42,19 @@ def _paginate(self, instrument: Instrument, lst: list, lst_sub: list, page: int
return lst[from_:to_] \
if page > 1 else lst[from_:]

def query_lastprice(self,
instrument: Instrument,
exchange: ExchangeType = None) -> MarketData:
if instrument not in self._last_price_by_asset_and_exchange:
raise Exception('Not found!')
if exchange:
if exchange not in self._last_price_by_asset_and_exchange[instrument]:
raise Exception('Not found!')
return self._last_price_by_asset_and_exchange[instrument][exchange]
if "ANY" not in self._last_price_by_asset_and_exchange[instrument]:
raise Exception('Not found!')
return self._last_price_by_asset_and_exchange[instrument]["ANY"]

def query_trades(self,
instrument: Instrument = None,
page: int = 1) -> List[MarketData]:
Expand Down Expand Up @@ -74,9 +87,11 @@ def push(self, data: MarketData) -> None:
if data.instrument not in self._trades_by_instrument:
self._trades_by_instrument[data.instrument] = []
self._trades_by_instrument[data.instrument].append(data)
if data.exchange not in self._last_price_by_exchange:
self._last_price_by_exchange[data.exchange] = []
self._last_price_by_exchange[data.exchange].append(data)
if data.instrument not in self._last_price_by_asset_and_exchange:
self._last_price_by_asset_and_exchange[data.instrument] = {}
self._last_price_by_asset_and_exchange[data.instrument][data.exchange] = data
self._last_price_by_asset_and_exchange[data.instrument]['ANY'] = data
print("here", self._last_price_by_asset_and_exchange[data.instrument][data.exchange])

def push_tradereq(self, req: TradeRequest) -> None:
self._trade_reqs.append(req)
Expand Down
3 changes: 0 additions & 3 deletions aat/strategies/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ def __init__(self, size: int) -> None:
self.bought_qty = 0.0
self.profits = 0.0

self._intitialvalue = None
self._portfolio_value = []

def onBuy(self, res: TradeResponse) -> None:
if self._intitialvalue is None:
date = res.time
Expand Down
98 changes: 98 additions & 0 deletions aat/strategies/buy_and_hold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from ..strategy import TradingStrategy
from ..structs import MarketData, TradeRequest, TradeResponse
from ..enums import Side, TradeResult, OrderType
from ..logging import STRAT as slog, ERROR as elog


class BuyAndHoldStrategy(TradingStrategy):
def __init__(self) -> None:
super(BuyAndHoldStrategy, self).__init__()
self.bought = None

def onBuy(self, res: TradeResponse) -> None:
self.bought = res
slog.info('d->g:bought %.2f @ %.2f' % (res.volume, res.price))

def onSell(self, res: TradeResponse) -> None:
pass

def onTrade(self, data: MarketData) -> bool:
# add data to arrays
if self.bought is None:
req = TradeRequest(side=Side.BUY,
volume=1.0,
instrument=data.instrument,
order_type=OrderType.MARKET,
exchange=data.exchange,
price=data.price)
slog.info("requesting buy : %s", req)
self.requestBuy(self.onBuy, req)
return True
return False

def onError(self, e) -> None:
elog.critical(e)

def onAnalyze(self, portfolio_value, requests, responses) -> None:
import pandas
import matplotlib.pyplot as plt
import seaborn as sns

pd = pandas.DataFrame(portfolio_value, columns=['time', 'value'])
pd.set_index(['time'], inplace=True)

sns.set_style('darkgrid')
fig, ax1 = plt.subplots()

plt.title('BTC algo 1 performance')
ax1.plot(pd)

ax1.set_ylabel('Portfolio value($)')
ax1.set_xlabel('Date')
for xy in [portfolio_value[0]] + [portfolio_value[-1]]:
ax1.annotate('$%s' % xy[1], xy=xy, textcoords='data')
plt.show()
print(requests)
print(responses)

def onChange(self, data: MarketData) -> None:
pass

def onContinue(self, data: MarketData) -> None:
pass

def onFill(self, data: MarketData) -> None:
pass

def onCancel(self, data: MarketData) -> None:
pass

def onHalt(self, data: MarketData) -> None:
pass

def onOpen(self, data: MarketData) -> None:
pass

def slippage(self, resp: TradeResponse) -> TradeResponse:
slippage = resp.price * .0001 # .01% price impact
if resp.side == Side.BUY:
# price moves against (up)
resp.slippage = slippage
resp.price += slippage
else:
# price moves against (down)
resp.slippage = -slippage
resp.price -= slippage
return resp

def transactionCost(self, resp: TradeResponse) -> TradeResponse:
txncost = resp.price * resp.volume * .0025 # gdax is 0.0025 max fee
if resp.side == Side.BUY:
# price moves against (up)
resp.transaction_cost = txncost
resp.price += txncost
else:
# price moves against (down)
resp.transaction_cost = -txncost
resp.price -= txncost
return resp
36 changes: 5 additions & 31 deletions aat/strategies/sma_crosses_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,11 @@ def __init__(self, size_short: int, size_long: int) -> None:
self.bought_qty = 0.0
self.profits = 0.0

self._intitialvalue = None
self._portfolio_value = []

def onBuy(self, res: TradeResponse) -> None:
if not res.status == TradeResult.FILLED:
slog.info('order failure: %s' % res)
return

if self._intitialvalue is None:
date = res.time
self._intitialvalue = (date, res.volume*res.price)
self._portfolio_value.append(self._intitialvalue)

self.bought = res.volume*res.price
self.bought_qty = res.volume
slog.info('d->g:bought %.2f @ %.2f for %.2f ---- %.2f %.2f' % (res.volume, res.price, self.bought, self.short_av, self.long_av))
Expand All @@ -51,11 +43,6 @@ def onSell(self, res: TradeResponse) -> None:
self.bought = 0.0
self.bought_qty = 0.0

date = res.time
self._portfolio_value.append((
date,
self._portfolio_value[-1][1] + profit))

def onTrade(self, data: MarketData) -> bool:
# add data to arrays
self.shorts.append(data.price)
Expand Down Expand Up @@ -120,25 +107,15 @@ def onTrade(self, data: MarketData) -> bool:
def onError(self, e) -> None:
elog.critical(e)

def onAnalyze(self, _) -> None:
def onAnalyze(self, portfolio_value, requests, responses) -> None:
import pandas
import matplotlib.pyplot as plt
import seaborn as sns

# pd = pandas.DataFrame(self._actions,
# columns=['time', 'action', 'price'])

pd = pandas.DataFrame(self._portfolio_value, columns=['time', 'value'])
pd = pandas.DataFrame(portfolio_value, columns=['time', 'value'])
pd.set_index(['time'], inplace=True)

print(self.short, self.long, pd.iloc[1].value, pd.iloc[-1].value)
# sp500 = pandas.DataFrame()
# tmp = pandas.read_csv('./data/sp/sp500_v_kraken.csv')
# sp500['Date'] = pandas.to_datetime(tmp['Date'])
# sp500['Close'] = tmp['Close']
# sp500.set_index(['Date'], inplace=True)
# print(sp500)

sns.set_style('darkgrid')
fig, ax1 = plt.subplots()

Expand All @@ -147,14 +124,11 @@ def onAnalyze(self, _) -> None:

ax1.set_ylabel('Portfolio value($)')
ax1.set_xlabel('Date')
for xy in [self._portfolio_value[0]] + [self._portfolio_value[-1]]:
for xy in [portfolio_value[0]] + [portfolio_value[-1]]:
ax1.annotate('$%s' % xy[1], xy=xy, textcoords='data')

# ax2 = ax1.twinx()
# ax2.plot(sp500, 'r')
# ax2.set_ylabel('S&P500 ($)')

plt.show()
print(requests)
print(responses)

def onChange(self, data: MarketData) -> None:
pass
Expand Down
24 changes: 5 additions & 19 deletions aat/strategies/test_strat.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,38 +98,24 @@ def slippage(self, resp: TradeResponse) -> TradeResponse:
def transactionCost(self, resp: TradeResponse) -> TradeResponse:
return resp

def onAnalyze(self, _) -> None:
def onAnalyze(self, portfolio_value, requests, responses) -> None:
import pandas
import matplotlib.pyplot as plt
import seaborn as sns

# pd = pandas.DataFrame(self._actions,
# columns=['time', 'action', 'price'])

pd = pandas.DataFrame(self._portfolio_value, columns=['time', 'value'])
pd = pandas.DataFrame(portfolio_value, columns=['time', 'value'])
pd.set_index(['time'], inplace=True)

print(self.short, self.long, pd.iloc[1].value, pd.iloc[-1].value)
# sp500 = pandas.DataFrame()
# tmp = pandas.read_csv('./data/sp/sp500_v_kraken.csv')
# sp500['Date'] = pandas.to_datetime(tmp['Date'])
# sp500['Close'] = tmp['Close']
# sp500.set_index(['Date'], inplace=True)
# print(sp500)
print(pd.iloc[1].value, pd.iloc[-1].value)

sns.set_style('darkgrid')
fig, ax1 = plt.subplots()

plt.title('BTC algo 1 performance - %d-%d Momentum ' % (self.short, self.long))
plt.title('BTC algo 1 performance')
ax1.plot(pd)

ax1.set_ylabel('Portfolio value($)')
ax1.set_xlabel('Date')
for xy in [self._portfolio_value[0]] + [self._portfolio_value[-1]]:
for xy in [portfolio_value[0]] + [portfolio_value[-1]]:
ax1.annotate('$%s' % xy[1], xy=xy, textcoords='data')

# ax2 = ax1.twinx()
# ax2.plot(sp500, 'r')
# ax2.set_ylabel('S&P500 ($)')

plt.show()

0 comments on commit 52cf350

Please sign in to comment.