Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Alpaca pricing source and update documentation
- Loading branch information
Showing
18 changed files
with
373 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
''' | ||
Duplicate builtin factor classes in zipline with IEX's USEquityPricing | ||
''' | ||
|
||
from zipline.pipeline.data import USEquityPricing as z_pricing | ||
from zipline.pipeline import factors as z_factors | ||
|
||
from .pricing import USEquityPricing as alpaca_pricing | ||
|
||
|
||
def _replace_inputs(inputs): | ||
map = { | ||
z_pricing.open: alpaca_pricing.open, | ||
z_pricing.high: alpaca_pricing.high, | ||
z_pricing.low: alpaca_pricing.low, | ||
z_pricing.close: alpaca_pricing.close, | ||
z_pricing.volume: alpaca_pricing.volume, | ||
} | ||
|
||
if type(inputs) not in (list, tuple, set): | ||
return inputs | ||
return tuple([ | ||
map.get(inp, inp) for inp in inputs | ||
]) | ||
|
||
|
||
for name in dir(z_factors): | ||
factor = getattr(z_factors, name) | ||
if factor != z_factors.Factor and hasattr( | ||
factor, 'inputs') and issubclass( | ||
factor, z_factors.Factor): | ||
new_factor = type(factor.__name__, (factor,), { | ||
'inputs': _replace_inputs(factor.inputs) | ||
}) | ||
locals()[factor.__name__] = new_factor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from zipline.pipeline.data.dataset import Column, DataSet | ||
from zipline.utils.numpy_utils import float64_dtype | ||
|
||
from .pricing_loader import USEquityPricingLoader | ||
|
||
|
||
# In order to use it as a cache key, we have to make it singleton | ||
_loader = USEquityPricingLoader() | ||
|
||
|
||
class USEquityPricing(DataSet): | ||
""" | ||
Dataset representing daily trading prices and volumes. | ||
""" | ||
open = Column(float64_dtype) | ||
high = Column(float64_dtype) | ||
low = Column(float64_dtype) | ||
close = Column(float64_dtype) | ||
volume = Column(float64_dtype) | ||
|
||
@staticmethod | ||
def get_loader(): | ||
return _loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import numpy as np | ||
import logbook | ||
import pandas as pd | ||
|
||
from zipline.lib.adjusted_array import AdjustedArray | ||
from zipline.pipeline.loaders.base import PipelineLoader | ||
from zipline.utils.calendars import get_calendar | ||
from zipline.errors import NoFurtherDataError | ||
|
||
from pipeline_live.data.sources import alpaca | ||
|
||
|
||
log = logbook.Logger(__name__) | ||
|
||
|
||
class USEquityPricingLoader(PipelineLoader): | ||
""" | ||
PipelineLoader for US Equity Pricing data | ||
""" | ||
|
||
def __init__(self): | ||
cal = get_calendar('NYSE') | ||
|
||
self._all_sessions = cal.all_sessions | ||
|
||
def load_adjusted_array(self, columns, dates, symbols, mask): | ||
# load_adjusted_array is called with dates on which the user's algo | ||
# will be shown data, which means we need to return the data that would | ||
# be known at the start of each date. We assume that the latest data | ||
# known on day N is the data from day (N - 1), so we shift all query | ||
# dates back by a day. | ||
start_date, end_date = _shift_dates( | ||
self._all_sessions, dates[0], dates[-1], shift=1, | ||
) | ||
|
||
sessions = self._all_sessions | ||
sessions = sessions[(sessions >= start_date) & (sessions <= end_date)] | ||
|
||
timedelta = pd.Timestamp.utcnow() - start_date | ||
chart_range = timedelta.days + 1 | ||
log.info('chart_range={}'.format(chart_range)) | ||
prices = alpaca.get_stockprices(chart_range) | ||
|
||
dfs = [] | ||
for symbol in symbols: | ||
if symbol not in prices: | ||
df = pd.DataFrame( | ||
{c.name: c.missing_value for c in columns}, | ||
index=sessions | ||
) | ||
else: | ||
df = prices[symbol] | ||
df = df.reindex(sessions, method='ffill') | ||
dfs.append(df) | ||
|
||
raw_arrays = {} | ||
for c in columns: | ||
colname = c.name | ||
raw_arrays[colname] = np.stack([ | ||
df[colname].values for df in dfs | ||
], axis=-1) | ||
out = {} | ||
for c in columns: | ||
c_raw = raw_arrays[c.name] | ||
out[c] = AdjustedArray( | ||
c_raw.astype(c.dtype), | ||
{}, | ||
c.missing_value | ||
) | ||
return out | ||
|
||
|
||
def _shift_dates(dates, start_date, end_date, shift): | ||
try: | ||
start = dates.get_loc(start_date) | ||
except KeyError: | ||
if start_date < dates[0]: | ||
raise NoFurtherDataError( | ||
msg=( | ||
"Pipeline Query requested data starting on {query_start}, " | ||
"but first known date is {calendar_start}" | ||
).format( | ||
query_start=str(start_date), | ||
calendar_start=str(dates[0]), | ||
) | ||
) | ||
else: | ||
raise ValueError("Query start %s not in calendar" % start_date) | ||
|
||
# Make sure that shifting doesn't push us out of the calendar. | ||
if start < shift: | ||
raise NoFurtherDataError( | ||
msg=( | ||
"Pipeline Query requested data from {shift}" | ||
" days before {query_start}, but first known date is only " | ||
"{start} days earlier." | ||
).format(shift=shift, query_start=start_date, start=start), | ||
) | ||
|
||
try: | ||
end = dates.get_loc(end_date) | ||
except KeyError: | ||
if end_date > dates[-1]: | ||
raise NoFurtherDataError( | ||
msg=( | ||
"Pipeline Query requesting data up to {query_end}, " | ||
"but last known date is {calendar_end}" | ||
).format( | ||
query_end=end_date, | ||
calendar_end=dates[-1], | ||
) | ||
) | ||
else: | ||
raise ValueError("Query end %s not in calendar" % end_date) | ||
return dates[start - shift], dates[end - shift] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import alpaca_trade_api as tradeapi | ||
|
||
from .util import ( | ||
daily_cache, parallelize | ||
) | ||
|
||
|
||
def list_symbols(): | ||
return [ | ||
a.symbol for a in tradeapi.REST().list_assets() | ||
if a.tradable and a.status == 'active' | ||
] | ||
|
||
|
||
def get_stockprices(limit=365, timespan='day'): | ||
all_symbols = list_symbols() | ||
|
||
@daily_cache(filename='alpaca_chart_{}'.format(limit)) | ||
def get_stockprices_cached(all_symbols): | ||
return _get_stockprices(all_symbols, limit, timespan) | ||
|
||
return get_stockprices_cached(all_symbols) | ||
|
||
|
||
def _get_stockprices(symbols, limit=365, timespan='day'): | ||
'''Get stock data (key stats and previous) from Alpaca. | ||
Just deal with Alpaca's 200 stocks per request limit. | ||
''' | ||
|
||
def fetch(symbols): | ||
barset = tradeapi.REST().get_barset(symbols, timespan, limit) | ||
data = {} | ||
for symbol in barset: | ||
df = barset[symbol].df | ||
# Update the index format for comparison with the trading calendar | ||
df.index = df.index.tz_convert('UTC').normalize() | ||
data[symbol] = df.asfreq('C') | ||
|
||
return data | ||
|
||
return parallelize(fetch, splitlen=199)(symbols) |
Oops, something went wrong.