In [2]:
import os
import datetime as dt
import time
from typing import Any, Dict, Optional, List

import requests
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import talib
import multiprocessing as mp
from requests.exceptions import ConnectionError, Timeout

%matplotlib inline
plt.style.use("fivethirtyeight")
print("the envirment is ok")

the envirment is ok


In [4]:
## 撰写自定义函数，通过API获取数据

def fetch_trochil(url: str,
                  params: Dict[str, str],
                  attempt: int = 3,
                  timeout: int = 3) -> Dict[str, Any]:
    """装饰requests.get函数"""
    for i in range(attempt):
        try:
            resp = requests.get(url, params, timeout=timeout)
            resp.raise_for_status()
            data = resp.json()["data"]
            if not data:
                raise Exception("empty dataset")
            return data
        except (ConnectionError, Timeout) as e:
            print(e)
            i += 1
            time.sleep(i * 0.5)


def fetch_cnstocks(apikey: str) -> pd.DataFrame:
    """从蜂鸟数据获取A股产品列表"""
    url = "https://api.trochil.cn/v1/cnstock/markets"
    params = {"apikey": apikey}

    res = fetch_trochil(url, params)

    return pd.DataFrame.from_records(res)


def fetch_daily_ohlc(symbol: str,
                     date_from: dt.datetime,
                     date_to: dt.datetime,
                     apikey: str) -> pd.DataFrame:
    """从蜂鸟数据获取A股日图历史K线"""
    url = "https://api.trochil.cn/v1/cnstock/history"
    params = {
        "symbol": symbol,
        "start_date": date_from.strftime("%Y-%m-%d"),
        "end_date": date_to.strftime("%Y-%m-%d"),
        "freq": "daily",
        "apikey": apikey
    }

    res = fetch_trochil(url, params)

    return pd.DataFrame.from_records(res)


def fetch_index_ohlc(symbol: str,
                     date_from: dt.datetime,
                     date_to: dt.datetime,
                     apikey: str) -> pd.DataFrame:
    """获取股指的日图历史数据"""
    url = "https://api.trochil.cn/v1/index/daily"
    params = {
        "symbol": symbol,
        "start_date": date_from.strftime("%Y-%m-%d"),
        "end_date": date_to.strftime("%Y-%m-%d"),
        "apikey": apikey
    }

    res = fetch_trochil(url, params)

    return pd.DataFrame.from_records(res)

print("the struct of the fucs is ok")

the struct of the fucs is ok


In [5]:
apikey = os.getenv("TROCHIL_API")  # use your apikey
cnstocks = fetch_cnstocks(apikey)
cnstocks

HTTPSConnectionPool(host='api.trochil.cn', port=443): Max retries exceeded with url: /v1/cnstock/markets (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x00000193DC643E20>, 'Connection to api.trochil.cn timed out. (connect timeout=3)'))
HTTPSConnectionPool(host='api.trochil.cn', port=443): Max retries exceeded with url: /v1/cnstock/markets (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x00000193DC64D9A0>, 'Connection to api.trochil.cn timed out. (connect timeout=3)'))
HTTPSConnectionPool(host='api.trochil.cn', port=443): Max retries exceeded with url: /v1/cnstock/markets (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x00000193DC64D5E0>, 'Connection to api.trochil.cn timed out. (connect timeout=3)'))


TypeError: object of type 'NoneType' has no len()

In [None]:
cnstocks_shsz = cnstocks.query("symbol.str.startswith('SH')")
cnstocks_shsz

In [None]:
%%time

# 下载2019年至今的历史数据
# 下载时剔除K线少于260个交易日的股票
date_from = dt.datetime(2019, 1, 1)
date_to = dt.datetime.today()
symbols = cnstocks_shsz.symbol.to_list()
min_klines = 260

# 逐个下载，蜂鸟数据的API没有分钟请求限制
# 先把数据存储在列表中，下载完成后再合并和清洗
ohlc_list = []
for symbol in symbols:
    try:
        ohlc = fetch_daily_ohlc(symbol, date_from, date_to, apikey)
        if ohlc is not None and len(ohlc) >= min_klines:
            ohlc.set_index("datetime", inplace=True)
            ohlc_list.append(ohlc)
    except Exception as e:
        pass

In [None]:
ohlc_joined = pd.concat(ohlc_list)
ohlc_joined.info()
ohlc_joined.isnull().sum()
ohlc_joined.to_csv("cnstock_daily_ohlc.csv", index=True)

In [None]:
benchmark = fetch_index_ohlc("shci", date_from, date_to, apikey)
benchmark.tail()
benchmark_ann_ret = benchmark.close.pct_change(252).iloc[-1]

In [None]:
def screen(close: pd.Series, benchmark_ann_ret: float) -> pd.Series:
    """实现MM选股模型的逻辑，评估单只股票是否满足筛选条件

    Args:
        close(pd.Series): 股票收盘价，默认时间序列索引
        benchmark_ann_ret(float): 基准指数1年收益率，用于计算相对强弱
    """
    # 计算50，150，200日均线
    ema_50 = talib.EMA(close, 50).iloc[-1]
    ema_150 = talib.EMA(close, 150).iloc[-1]
    ema_200 = talib.EMA(close, 200).iloc[-1]

    # 200日均线的20日移动平滑，用于判断200日均线是否上升
    ema_200_smooth = talib.EMA(talib.EMA(close, 200), 20).iloc[-1]

    # 收盘价的52周高点和52周低点
    high_52week = close.rolling(52 * 5).max().iloc[-1]
    low_52week = close.rolling(52 * 5).min().iloc[-1]

    # 最新收盘价
    cl = close.iloc[-1]

    # 筛选条件1：收盘价高于150日均线和200日均线
    if cl > ema_150 and cl > ema_200:
        condition_1 = True
    else:
        condition_1 = False

    # 筛选条件2：150日均线高于200日均线
    if ema_150 > ema_200:
        condition_2 = True
    else:
        condition_2 = False

    # 筛选条件3：200日均线上升1个月
    if ema_200 > ema_200_smooth:
        condition_3 = True
    else:
        condition_3 = False

    # 筛选条件4：50日均线高于150日均线和200日均线
    if ema_50 > ema_150 and ema_50 > ema_200:
        condition_4 = True
    else:
        condition_4 = False

    # 筛选条件5：收盘价高于50日均线
    if cl > ema_50:
        condition_5 = True
    else:
        condition_5 = False

    # 筛选条件6：收盘价比52周低点高30%
    if cl >= low_52week * 1.3:
        condition_6 = True
    else:
        condition_6 = False

    # 筛选条件7：收盘价在52周高点的25%以内
    if cl >= high_52week * 0.75 and cl <= high_52week * 1.25:
        condition_7 = True
    else:
        condition_7 = False

    # 筛选条件8：相对强弱指数大于等于70
    rs = close.pct_change(252).iloc[-1] / benchmark_ann_ret * 100
    if rs >= 70:
        condition_8 = True
    else:
        condition_8 = False

    # 判断股票是否符合标准
    if (condition_1 and condition_2 and condition_3 and
        condition_4 and condition_5 and condition_6 and
        condition_7 and condition_8):
        meet_criterion = True
    else:
        meet_criterion = False

    out = {
        "rs": round(rs, 2),
        "close": cl,
        "ema_50": ema_50,
        "ema_150": ema_150,
        "ema_200": ema_200,
        "high_52week": high_52week,
        "low_52week": low_52week,
        "meet_criterion": meet_criterion
    }

    return pd.Series(out)

In [None]:
# 仅仅筛选有足够历史数据的股票
symbols_to_screen = list(ohlc_joined.symbol.unique())

# 将数据框的格式从long-format转化为wide-format
ohlc_joined_wide = ohlc_joined.pivot(columns="symbol", values="close").fillna(method="ffill")

ohlc_joined_wide.head()

In [None]:
%%time

results = ohlc_joined_wide.apply(screen, benchmark_ann_ret=benchmark_ann_ret)
results = results.T

In [None]:
results.query("meet_criterion == True").sort_values("rs", ascending=False)