In [None]:
import pymysql
import numpy as np
import pandas as pd
import datetime as dt
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)


def create_connection():
    connection = pymysql.connect(host='124.220.177.115',
                                 user='hwh',
                                 password='gtja20',
                                 db='factordb', # Add your database name here
                                 port=3306,
                                 cursorclass=pymysql.cursors.DictCursor)
    return connection


def get_all_securities(types, date):
    connection = create_connection()
    query = f"""
    SELECT DISTINCT S_INFO_WINDCODE 
    FROM BASICFACTORS_TABLE
    WHERE TRADE_DT <= '{date}' AND S_INFO_WINDCODE LIKE '%{types[0]}%'  # Assuming types is a list of type filters
    """
    try:
        securities_df = pd.read_sql(query, connection)
    finally:
        connection.close()
    return securities_df


def get_price(stock_list, start_date, end_date, frequency='1d', fields=None):
    if fields is None:
        fields = ['S_DQ_ADJOPEN', 'S_DQ_ADJHIGH', 'S_DQ_ADJLOW', 'S_DQ_ADJCLOSE', 'S_DQ_ADJPRECLOSE', 'S_DQ_ADJPCTCHANGE']
    field_str = ', '.join(fields)
    stocks = "', '".join(stock_list)
    
    connection = create_connection()
    query = f"""
    SELECT TRADE_DT, S_INFO_WINDCODE, {field_str}
    FROM BASICFACTORS_TABLE
    WHERE TRADE_DT BETWEEN '{start_date}' AND '{end_date}' AND S_INFO_WINDCODE IN ('{stocks}')
    ORDER BY TRADE_DT
    """
    try:
        price_df = pd.read_sql(query, connection)
    finally:
        connection.close()
    return price_df


def get_trade_days(start_date, end_date):
    connection = create_connection()
    query = f"""
    SELECT DISTINCT TRADE_DT
    FROM BASICFACTORS_TABLE
    WHERE TRADE_DT BETWEEN '{start_date}' AND '{end_date}'
    ORDER BY TRADE_DT
    """
    try:
        days_df = pd.read_sql(query, connection)
    finally:
        connection.close()
    return days_df['TRADE_DT'].tolist()


def get_bars(stock_list, count, unit='1d', fields=None, end_dt=None, df=True):
    if fields is None:
        fields = ['S_DQ_ADJCLOSE']  # Default to adjusted close if not specified
    field_str = ', '.join(fields)
    stocks = "', '".join(stock_list)
    
    connection = create_connection()
    # Assuming weekly data is aggregated; modify if your database contains explicit weekly data
    query = f"""
    SELECT TRADE_DT, S_INFO_WINDCODE, {field_str}
    FROM BASICFACTORS_TABLE
    WHERE S_INFO_WINDCODE IN ('{stocks}') AND TRADE_DT <= '{end_dt}'
    ORDER BY TRADE_DT DESC
    LIMIT {count}
    """
    try:
        bars_df = pd.read_sql(query, connection)
        if df:
            return bars_df.pivot(index='TRADE_DT', columns='S_INFO_WINDCODE', values=fields[0])
        else:
            return bars_df
    finally:
        connection.close()


# 获取指定日期符合箱体突破要求的个股
def get_feasible_stocks(date):
    
    # 获取所有股票
    stock_list = list(get_all_securities(['stock'], date=date).index)
    
    # 单日成交量是前5日平均成交量的两倍以上
    if stock_list:
        volume = get_price(stock_list, end_date=date, frequency='1d', fields=['volume'], count=6)['volume']
        indice = volume.iloc[-1, :] > 2 * volume[:-1].mean()
        stock_list = indice[indice == 1].index.tolist()
    
    # 单日成交额大于2亿
    if stock_list:
        money = get_price(stock_list, end_date=date, frequency='1d', fields=['money'], count=1)['money'].iloc[-1, :]
        stock_list = money[money > 200000000].index.tolist()
    
    # 收盘价站稳60日均线
    if stock_list:
        close = get_price(stock_list, end_date=date, frequency='1d', fields=['close'], count=60)['close']
        indice = close.iloc[-1, :] > close.mean()
        stock_list = indice[indice == 1].index.tolist()
    
    # 近40日价格波幅超出N日均价的(0.8,1.05)倍低于5次
    if stock_list:
        N = 60
        close = get_price(stock_list, end_date=date, frequency='1d', fields=['close'], count=100)['close']
        MA_N = close[-N:].mean()
        indice = ((close[-40:] > 1.05 * MA_N) | (close[-40:] < 0.8 * MA_N)).sum()
        stock_list = indice[indice < 5].index.tolist()
    
    # 排除阶段性高位股票（最新价格处于近一年价格分位数90%以下）
    if stock_list:
        close = get_price(stock_list, end_date=date, frequency='1d', fields=['close'], count=250)['close']
        rank = close.rank(pct=True).iloc[-1, :]
        stock_list = rank[rank < 0.9].index.tolist()
    
    # 长周期均线有走平并转向上趋势（250日均线方向向上）
    if stock_list:
        close = get_price(stock_list, end_date=date, frequency='1d', fields=['close'], count=300)['close']
        MA_250 = close.rolling(window=250).mean()
        MA_250.iloc[-1, :][MA_250.iloc[-1, :] > 0].index.tolist()
    
    # 时间跨度至少N=60个交易日（横盘区间内最高价距离最低价的幅度小于30%）
    if stock_list:
        N = 60
        close = get_price(stock_list, end_date=date, frequency='1d', fields=['close'], count=N+3)['close'][:-3]
        rate = close.max() / close.min() - 1
        stock_list = rate[rate < 0.3].index.tolist()
        # 记录平台最高价
        high_ser = close[stock_list].max()
        
    # 收盘价高于平台上沿2%以上
    if stock_list:
        close = get_price(stock_list, end_date=date, frequency='1d', fields=['close'], count=1)['close'].iloc[-1, :]
        rate = close / high_ser - 1
        stock_list = rate[rate > 0.02].index.tolist()
    
    # 日线连续几根阳线突破平台
    if stock_list:
        K = 2
        close = get_price(stock_list, end_date=date, frequency='1d', fields=['close'], count=K+1)['close']
        # 连续K根阳线
        con1 = (close.diff().dropna(how='all') > 0).sum()
        # 突破平台最高价
        con2 = (close[1:] > high_ser[stock_list]).sum()
        stock_list = con1[(con1 == K) & (con2 == K)].index.tolist()
    
    # 周级别K线同样给出进场信号
    if stock_list:
        # 周线数据
        close = get_bars(stock_list, count=80, unit='1w',fields=['date','close'], include_now=True, end_dt=date, df=True).reset_index()
        close = pd.pivot_table(close, values='close', index='date', columns='level_0')
        # 周线收盘价站稳60周均线
        indice = close.iloc[-1, :] > close[-60:].mean()
        stock_list = indice[indice == 1].index.tolist()
        
    return stock_list


# 测试区间
start_date = '2024-01-01'
end_date = '2024-05-29'
# 区间内所有交易日
date_list = get_trade_days(start_date=start_date, end_date=end_date)
for date in date_list:
    feasible_stocks = get_feasible_stocks(date)
    print('%s选股: '%(date))
    print(', '.join(feasible_stocks), '\n')
    print('123')