In [14]:
import pandas as pd
import orjson
import ujson
from glob import glob
from aiohttp import ClientSession

In [15]:

import logging
import pandas as pd
from functools import partial
from aiohttp import ClientSession

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

class LimitExceededError(Exception):
    ...

class DailyLimitExceededError(Exception):
    ...

class ArgumentError(Exception):
    ...

class TushareClient:

    __token = ''
    __http_url = 'http://api.waditu.com'

    def __init__(self, token: str, http_session: ClientSession, timeout=30):
        """
        Parameters
        ----------
        token: str
            API接口TOKEN，用于用户认证
        """
        self.__token = token
        self.__http_session = http_session
        self.__timeout = timeout
        self.__token_param = {"token": token}
        
    def set_token(self, token: str, how: str = "param",
                  token_key: str = 'token', header_field: str = 'Authorization'):
        if how == 'header':
            if header_field == 'Authorization':
                self.__header[header_field] = "Bearer {token}"
            else:
                self.__header[header_field] = token
        elif how == 'param':
            self.__token_param[token_key] = token
        else:
            raise ValueError(f"Unknown authorization method: {how}")

    async def query(self, api: str, **kwargs):
        req_params = {
            'api_name': api,
            'params': kwargs.get("params", dict()),
            'fields': "".join(kwargs.get("use_columns", [])),
            **self.__token_param,
        }

        # async with self.__http_session as session:
        try:
            async with self.__http_session.post(self.__http_url, json=req_params,
                                    timeout=self.__timeout) as response:
                result = await response.json()
                if result:
                    if result['code'] != 0:
                        if "每分钟最多访问" in result['msg']:
                            raise LimitExceededError(result['msg'])
                        elif '每天最多访问' in result['msg']:
                            raise DailyLimitExceededError(f"{api}, {result['msg']}")
                        elif '服务器错误' in result['msg']:
                            raise ArgumentError(f"Error occurred while requesting api:{api}, param: {kwargs.get('params', dict())}, message: {result['msg']}")
                        else:
                            raise Exception(f"Error occurred while requesting {api}, param: {kwargs.get('params', dict())}, message: {result['msg']}")
                    data = result['data']
                    columns = data['fields']
                    items = data['items']
                    return pd.DataFrame(items, columns=columns)
                else:
                    return pd.DataFrame()
        except LimitExceededError as e:
            print(f"api: {api}, message: {result['msg']}")
            raise e
        except DailyLimitExceededError as e:
            print(f"{api}, message: {result['msg']}")
            raise e
        except Exception as e:
            print(f"Exception occured while requesting {api}, {e}")
            raise e

    def __getattr__(self, name):
        return partial(self.query, name)


In [16]:
http_client = ClientSession(json_serialize=ujson.dumps)
ts_client = TushareClient(http_session=http_client, token='1c95a2ba46ee2f296c8d206a9e576fa1dc62d4ecfdb87e74b656674e')

2022-07-02 14:09:58,019 - asyncio - ERROR: Unclosed client session
client_session: <aiohttp.client.ClientSession object at 0x0000028C507A38E0>


In [17]:
from itertools import islice

def window(seq, n=2):
    "Returns a sliding window (of width n) over data from the iterable"
    "   s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ...                   "
    it = iter(seq)
    result = tuple(islice(it, n))
    if len(result) == n:
        yield result
    for elem in it:
        result = result[1:] + (elem,)
        yield result

In [18]:
from concurrent.futures import ThreadPoolExecutor

threadpool = ThreadPoolExecutor(max_workers=8)

In [19]:
import asyncio

loop = asyncio.get_running_loop()

In [20]:
with open('./mock_db/sync_config.json') as f:
    sync_settings = orjson.loads(f.read())
    
ts_config = sync_settings[0]

In [21]:
from glob import glob

for path in glob('./bundles/**/*2022062*.parquet'):
    date = path.split("_")[-1][:-8]
    if int(date) > 20220626:
        os.remove(path)

In [22]:
from typing import Set, Iterable, Any

path = os.path.join(os.getcwd(), "bundles")

def filter_unsynced(complete_val_range: Iterable[Any], bundle_root: str, api: str, file_ext: str) -> Set[str]:
    data_paths = glob(os.path.join(bundle_root, api, f"*.{file_ext}"))
    synced = set([p.split('\\')[-1][len(api)+1:-(len(file_ext)+1)] for p in data_paths])
    complete = set(complete_val_range)
    diff = complete - synced
    return diff

In [23]:
from itertools import product
import re
from datetime import datetime

start_date = '20050101'
end_date = '20220701'
bundle_root = os.path.join(os.getcwd(), 'bundles')

api_params = {}
arg_vals = pd.read_excel('./tushare_api_args.xlsx')
stock_code = pd.read_json('./shared_data/stock_code.json').ts_code.values
index_code = pd.read_csv('./shared_data/tushare_index_basic.csv').ts_code.values
trade_date = [str(d) for d in pd.read_json('./shared_data/trade_cal.json').cal_date.values if d >= int(start_date) and d <= int(end_date)]
item_code = [str(i) for i in range(1,66)]
samples = arg_vals.api_name.unique()

months = pd.date_range(start_date, end_date, freq='M').format(formatter=lambda x: x.strftime('%Y%m%d'))
month_ym = pd.date_range(start_date, end_date, freq='M').format(formatter=lambda x: x.strftime('%Y%m'))
month_intervals = list(window(months))

quarters = pd.date_range(start_date, end_date, freq='Q').format(formatter=lambda x: x.strftime('%Y%m%d'))
weeks = pd.date_range(start_date, end_date, freq='W').format(formatter=lambda x: x.strftime('%Y%m%d'))
dates = pd.date_range(start_date, end_date,).format(formatter=lambda x: x.strftime('%Y%m%d'))
# date_interval = list(window(dates))
days = pd.date_range(start_date, end_date,).format(formatter=lambda x: x.strftime('%Y-%m-%d %H:%M:%S'))
dt_interval = list(window(days))


for api in samples:
    args = arg_vals.loc[arg_vals.api_name.isin([api])]
    required_args = args.loc[args.required]
    dynamic_args = args.loc[args.is_dynamic]
    
    api_params[api] = []
    use_data = None
    use_interval = False
    
    if dynamic_args.from_shared_data.str.contains("start_dt").any() and dynamic_args.from_shared_data.str.contains("end_dt").any():
        unsynced_days = filter_unsynced(["_".join(interval) for interval in dt_interval], bundle_root, api, 'parquet')
        api_params[api] = [{'start_date': d[0], 'end_date': d[1]} for d in dt_interval if "_".join(d) in unsynced_days]
    elif dynamic_args.arg_name.str.contains("start_date").any() and dynamic_args.arg_name.str.contains("end_date").any():
        unsynced_months = filter_unsynced(["_".join(interval) for interval in month_intervals], bundle_root, api, 'parquet')
        api_params[api] = [{'start_date': w[0], 'end_date': w[1]} for w in month_intervals if "_".join(w) in unsynced_months]
    elif dynamic_args.from_shared_data.str.contains("quarters|quarter").any():
        use_data = quarters
    elif dynamic_args.from_shared_data.str.contains("months").any():
        use_data = months
    elif dynamic_args.from_shared_data.str.contains("month_ym").any():
        use_data = month_ym
    elif dynamic_args.from_shared_data.str.contains("weeks").any():
        use_data = weeks
    elif dynamic_args.from_shared_data.str.contains("dates").any():
        use_data = dates
    elif dynamic_args.from_shared_data.str.contains("days").any():
        use_data = days
    elif dynamic_args.from_shared_data.str.contains("trade_date").any():
        use_data = trade_date
    elif dynamic_args.from_shared_data.str.contains("ts_code").any():
        use_data = stock_code
    elif dynamic_args.from_shared_data.str.contains("index_code").any():
        use_data = index_code
    elif dynamic_args.from_shared_data.str.contains("item_code").any():
        use_data = item_code
        
    if use_data is not None:
        arg_name = dynamic_args.arg_name.values[0]
        unsynced_value_range = filter_unsynced(use_data, bundle_root, api, 'parquet')
        api_params[api] = [{arg_name: q} for q in unsynced_value_range]

    if len(required_args):
        arg_names = required_args.arg_name.values
        value_ranges = [re.findall("'(\w+)'", vrange) for vrange in required_args.value_range.values]

        for arg_combination in product(arg_names, *value_ranges):
            api_params[api].append({arg_combination[0]: arg_combination[1]})

In [24]:
api_params['指数月线行情']

[]

In [25]:
import os

def save_df_as_csv(df, path, filename):
    if not os.path.exists(path):
        os.makedirs(path)
    df.to_csv(os.path.join(path, filename))
    
def save_df_as_parquet(df, path, filename, overwrite = False):
    if not os.path.exists(path):
        os.makedirs(path)
    
    fpath = os.path.join(path, filename)  
    if not os.path.isfile(fpath) or overwrite: 
        df.to_parquet(fpath)

In [26]:
from datetime import datetime, timedelta

index_monthly_paths = glob("./bundles/指数月线行情/*.parquet")
dates = [m.split("_")[-1][:-8] for m in index_monthly_paths]
months = []

for m in index_monthly_paths:
    if os.path.getsize(m) < 7000:
        date = datetime.strptime(m.split("_")[-1][:-8], "%Y%m%d")
        months.append({"trade_date": (date - timedelta(days=1)).strftime("%Y%m%d")})

In [27]:
api_params['指数月线行情'] = months

In [13]:
# api_params.pop("电影日度票房")
# api_params.pop("影院每日票房")
# api_params.pop("新闻联播")

In [28]:
from asyncio import gather, sleep
from datetime import datetime, timedelta
from tqdm import tqdm
import logging

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

apis_to_sync = ['指数月线行情']
addresses = {api: arg_vals.loc[arg_vals.api_name == api].address.values[0] for api in apis_to_sync if len(arg_vals.loc[arg_vals.api_name == api].address.values)}
batch_size = len(apis_to_sync)
test_batch = apis_to_sync

max_request_per_minute = 350

for i in tqdm(range(0, len(test_batch), batch_size)):
    current_task_batch = test_batch[i:i+batch_size]
    uncompleted_batch = set(current_task_batch)
    
    start_time = datetime.now()
    next_refresh_time = start_time + timedelta(minutes=1) - timedelta(seconds=start_time.second)
    request_cnt = max_request_per_minute
    
    for task in current_task_batch:
        has_dynamic_arg = len(arg_vals.loc[(arg_vals.api_name == task) & (arg_vals.is_dynamic.any())]) > 0
        if has_dynamic_arg and len(api_params[task]) == 0:
            # logger.info(f"Task {task} need no sync.")
            uncompleted_batch.remove(task)
    
    logger.info(f"Syncing: {uncompleted_batch}")
    logger.info(f"next_refresh_time: {next_refresh_time}")
 
    while len(uncompleted_batch) > 0:
        ts_requests = []
        
        current_time = datetime.now()
        
        while current_time <= next_refresh_time and request_cnt <= 0:
            await sleep(1)
            current_time = datetime.now()
            # logger.info(f"sleeping, now: {current_time}, refresh_time: {next_refresh_time}")
            
            if current_time.second % 20 == 0:
                logger.info(f"sleeping, now: {current_time}, refresh_time: {next_refresh_time}")
        
        if current_time > next_refresh_time:
            next_refresh_time = datetime.now() + timedelta(minutes=1)
            request_cnt = max_request_per_minute
        
        for task_name in uncompleted_batch:
            address = addresses[task_name]
            params = {}
            if len(api_params[task_name]) > 0:
                params = api_params[task_name][0]
            ts_requests.append((task_name, ts_client.query(address, params=params)))
        
        try:
            results = await gather(*[task[1] for task in ts_requests])
            request_cnt -= 1
        
            if request_cnt % 100 == 0:
                logger.info(f"request_cnt: {request_cnt}, result size: {len(results)}, request size: {len(ts_requests)}")
            
            # check completed and save result
            save_tasks = []
            for req, result in zip(ts_requests, results):
                current_task_name, _ = req

                if result is not None:
                    params = ''             
                    if len(api_params[current_task_name]) > 0:
                        params = '_'.join(api_params[current_task_name].pop(0).values())
                        
                    if len(api_params[current_task_name]) == 0 and current_task_name in uncompleted_batch:
                        uncompleted_batch.remove(current_task_name)
                    
                    filename = f'{current_task_name}_{params.replace(":", "_")}.parquet' if len(params) > 0 else f'{current_task_name}.parquet'
                    save_path = os.path.join(os.getcwd(), "bundles", current_task_name)
                    save_tasks.append(loop.run_in_executor(threadpool, save_df_as_parquet, result, save_path, filename, True))
            
            await asyncio.wait(save_tasks)
        except ArgumentError as e:
            api_address = re.findall(r'.*?api:(.*),.*', e.args[0])[0].split(',')[0]
            api_name = arg_vals.loc[(arg_vals.address == api_address)].api_name.values[0]
            
            if len(api_params[api]) > 0:
                arg = api_params[api_name].pop(0)
                logger.info(f"去除导致问题的参数: {arg}")
            else:
                uncompleted_batch.remove(api_name)
                logger.info(f"暂时不同步出错的api: {api_name}。请检查数据源。")
        except LimitExceededError:
            logger.info("Limit exceed... Sleep for 10 seconds")
            await sleep(60)
        except DailyLimitExceededError as e:
            api_address = e.args[0].split(',')[0]
            api = arg_vals.loc[(arg_vals.address == api_address)].api_name.values[0]
            uncompleted_batch.remove(api)
            logger.info(f"{api} 已达到日上限。。。明天继续")
        except TimeoutError:
            logger.info("Timeout Error... Retry...")
        except Exception as e:
            logger.info(e)

  0%|          | 0/1 [00:00<?, ?it/s]2022-07-02 14:10:35,974 - __main__ - INFO: Syncing: {'指数月线行情'}
2022-07-02 14:10:35,974 - __main__ - INFO: next_refresh_time: 2022-07-02 14:11:00.973942
2022-07-02 14:10:41,736 - __main__ - INFO: request_cnt: 300, result size: 1, request size: 1
100%|██████████| 1/1 [00:08<00:00,  8.20s/it]


In [15]:
e = DailyLimitExceededError("a,b")

In [16]:
for i in range(0, len(test_batch), batch_size):
    print(i%batch_size)

0
0


In [17]:
stock_list = await ts_client.query('stock_basic', oarams={"limit": 10000, "offset": 4999})

In [18]:
stock_list

Unnamed: 0,ts_code,symbol,name,area,industry,market,list_date
0,000001.SZ,000001,平安银行,深圳,银行,主板,19910403
1,000002.SZ,000002,万科A,深圳,全国地产,主板,19910129
2,000004.SZ,000004,ST国华,深圳,软件服务,主板,19910114
3,000005.SZ,000005,ST星源,深圳,环境保护,主板,19901210
4,000006.SZ,000006,深振业A,深圳,区域地产,主板,19920427
...,...,...,...,...,...,...,...
4816,871981.BJ,871981,晶赛科技,,,北交所,20211115
4817,872925.BJ,872925,锦好医疗,,,北交所,20211025
4818,873169.BJ,873169,七丰精工,,,北交所,20220415
4819,873223.BJ,873223,荣亿精密,,,北交所,20220609
