## Base definition

In [60]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [61]:
import guides
import wave_utils
import waves
import rules

In [62]:
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)

In [63]:
import basic_types
def random_walk_between_points(start_point, end_point):
    time_range = end_point.time_offset - start_point.time_offset
    if time_range == 1:
        return [start_point, end_point]
    
    new_point_list = [start_point]
    for i in range(time_range-1):
        time_left = time_range - i
        step = (end_point.price - new_point_list[-1].price) / time_left
        next_price = new_point_list[-1].price + step*random.randint(5, 15)/10
        next_point = basic_types.Point(time_offset = start_point.time_offset + i +1, price=next_price)
        new_point_list.append(next_point)
    new_point_list.append(end_point)
    return new_point_list

def fill_points_with_random_walk(point_list):
    new_point_list = []
    for index, point in enumerate(point_list[:-1]):
        target_point = point_list[index+1]
        new_point_list += random_walk_between_points(point, target_point)[:-1]
    new_point_list.append(point_list[-1])
    return new_point_list

In [64]:
import datetime
import random


def convert_point_list_to_candle(point_list, start_date = datetime.datetime(2020, 1, 1)):
    candle_dict_list = [ ]
    for index, point in enumerate(point_list):
        prev_price = point_list[index-1].price if index > 0 else point.price
        price_list = [random.randint(-1500, 1500)/1000 * (prev_price - point.price) + point.price for i in range(3)]
        candle_dict_list.append({
            "date": start_date + datetime.timedelta(days=point.time_offset),
            "open": price_list[0],
            "close": point.price,
            "high" : max(point.price, max(price_list)),
            "low": min(point.price, min(price_list))
        })
    return candle_dict_list

base_wave = wave_utils.find_wave_for_scale(wave_type=waves.ImpluseWave, 
                                           min_time=0, 
                                           max_time=60, 
                                           first_price=1000, 
                                           last_price=3000, 
                                           try_times=200, 
                                           show_chart=False)
base_wave.show_line_chart()

wave_utils.expand_sub_wave_to_points(base_wave)
base_wave.show_line_chart()

filled_points = fill_points_with_random_walk(base_wave.get_all_points())
candle_dict_list = convert_point_list_to_candle(filled_points)

import plotly.graph_objects as go
import pandas as pd
from datetime import datetime

df = pd.DataFrame(candle_dict_list)
fig = go.Figure(data=[go.Candlestick(x=df['date'],
                open=df['open'],
                high=df['high'],
                low=df['low'],
                close=df['close'])])

fig.show()

In [56]:
from brute_force_search_wave import brute_force_search_all_point_comb

result_wave_list = brute_force_search_all_point_comb(filled_points)


INFO:brute_force_search_wave:Finished search <class 'waves.ExtendImpluseWave'>, found 10129
INFO:brute_force_search_wave:Finished search <class 'waves.FlatWave'>, found 513
INFO:brute_force_search_wave:Finished search <class 'waves.ImpluseWave'>, found 10129
INFO:brute_force_search_wave:Finished search <class 'waves.ZigZagWave'>, found 4597
INFO:brute_force_search_wave:Finished search <class 'waves.ExpandingTriangleWave'>, found 1
INFO:brute_force_search_wave:Finished search <class 'waves.EndingDiagonalWave'>, found 163
INFO:brute_force_search_wave:Finished search <class 'waves.BarrierTriangleWave'>, found 257
INFO:brute_force_search_wave:Finished search <class 'waves.LeadingDiagonalWave'>, found 161
INFO:brute_force_search_wave:Finished search <class 'waves.ContractingTriangleWave'>, found 1198
INFO:brute_force_search_wave:Total search count: 288089 search time: 8.458106279373169


In [57]:
import heapq

start_time = time.time()

optimum_point_list = wave_search_utils.get_all_local_optimum(filled_points)
search_queue = []
search_count = 0
# 1. 生成所有2点的组合所对应的 index number
for init_comb in itertools.combinations(range(len(optimum_point_list)), 2):
    for wave_type in wave_utils.get_all_concrete_subclass(waves.Wave):
        if issubclass(wave_type, waves.CombinationWave):
            continue
        # 2. 将已有的组合作为根据已生成的点，初始化不同 wave type 对应的 points limit list
        heapq.heappush(search_queue, PointComb(init_comb, wave_type, optimum_point_list))
        search_count += 1

        
# 3. 更新 point limit
for comb in search_queue:
    comb.update_next_point_limit()

# 4. 根据 points limit，选取下一个随机点
valid_wave_map = {}
valid_wave_list = []
while len(search_queue) != 0:
    search_comb = heapq.heappop(search_queue)
    if len(search_comb.point_index_list) == search_comb.target_wave_type.min_point_num:
        wave = search_comb.get_wave()
        if wave.is_valid():
            valid_wave_map[search_comb.target_wave_type] = valid_wave_map.get(search_comb.target_wave_type, 0) + 1
            valid_wave_list.append(wave)
        else:
            logger.debug(f"Invalid wave: {wave}")
            for rule in wave.get_not_valid_rule():
                logger.debug("{0} {1}".format(rule, rule.desp))
        continue
        
    for next_index in search_comb.get_next_available_points():
        index_list = search_comb.point_index_list + (next_index,)
        new_comb = PointComb(index_list, 
                             search_comb.target_wave_type, 
                             search_comb.original_points)
        new_comb.update_next_point_limit()
        heapq.heappush(search_queue, new_comb)
        search_count += 1
    if search_count % 10000 == 0:
        logger.debug(f"Already searched {search_count} combs")
logger.info(f"Total searched {search_count} combs, used time:" + str(time.time() - start_time))
logger.info(f"Found valid wave {len(valid_wave_list)}")

INFO:__main__:Total searched 55091 combs, used time:1.3962891101837158
INFO:__main__:Found valid wave 20400


In [65]:
from collections import defaultdict
from itertools import product
import copy
sub_wave_match_map = defaultdict(list)

def get_subwave_key(start_point, end_point, wave_type):
    # Wavetype - startpoint - endpoint - wave list
    key = "{0}-{1}-{2}".format(start_point.time_offset, end_point.time_offset, wave_type.__name__)
    return key
    
total_wave_num = len(valid_wave_list)
logger.info(f"Process valid wave num: {total_wave_num}")
for wave in valid_wave_list:
    wave_key = get_subwave_key(wave.point_list[0], wave.point_list[-1], type(wave))
    sub_wave_match_map[wave_key].append(wave)
logger.info("Added all wave to sub_wave_match_map")
    
for wave in valid_wave_list:
    final_wave_list = []
    # search possible subwave
    sub_wave_limit = type(wave).get_sub_wave_type_limit()
    # A two dimension list, first dim is subwave num, second dim is valid subwave
    subwave_selection = []
    for subwave_num in range(len(sub_wave_limit)):
        subwave_selection.append([])
        for abstract_type in sub_wave_limit[subwave_num]:
            concrete_subwave_type = wave_utils.get_all_concrete_subclass(abstract_type)
            for subwave_type in concrete_subwave_type:
                subwave_key = get_subwave_key(wave.point_list[subwave_num], 
                                              wave.point_list[subwave_num+1],
                                              subwave_type)
                if subwave_key not in sub_wave_match_map:
                    continue
                subwave_selection[subwave_num] += sub_wave_match_map[subwave_key]
    if issubclass(type(wave), waves.CombinationWave) and not all(subwave_selection):
            continue
    
    if any(subwave_selection):
        # Fill the empty selection with None
        for sub_wave_list in subwave_selection:
            if len(sub_wave_list) == 0:
                sub_wave_list.append(None)
            
    sub_wave_combinations = list(product(*subwave_selection))
    for sub_wave in sub_wave_combinations:
        wave.sub_wave = sub_wave
        if wave.is_valid():
            final_wave_list.append(copy.copy(wave))
        else:
            print(wave.get_not_valid_rule())
            assert(False)
    if len(final_wave_list) == 0:
        continue
        
    logger.debug(f"Pick from {len(final_wave_list)} final list {type(wave)}")
    
    # pick highest score wave
    max_score = -1
    max_score_wave = None
    for final_wave in final_wave_list:
        score = final_wave.get_score()/len(final_wave.guide_dict)
        score += sum([subwave.get_score()/len(subwave.guide_dict) for subwave in final_wave.sub_wave if subwave is not None])
        if score > max_score:
            max_score = score
            max_score_wave = final_wave
    wave.sub_wave = max_score_wave.sub_wave

logger.info("Matched all subwave")

INFO:__main__:Process valid wave num: 20400
INFO:__main__:Added all wave to sub_wave_match_map


[<class 'rules.Rule33'>]


AssertionError: 

In [None]:
def find_wave_in_wave_list(target_wave, wave_list):
    target_points = target_wave.to_dict()["points"]
    for check_wave in wave_list:
        if type(check_wave) != type(target_wave):
            continue
        if check_wave.to_dict()["points"] != target_points:
            continue
        return check_wave

def get_max_score_wave_in_list(wave_list):
    max_score = -1
    max_score_wave = None
    for check_wave in wave_list:
        score = check_wave.get_score()/len(check_wave.guide_dict)
        score += sum([subwave.get_score()/len(subwave.guide_dict) for subwave in check_wave.sub_wave if subwave is not None])
        if score > max_score:
            max_score = score
            max_score_wave = check_wave
    print(max_score)
    return max_score_wave

max_score_wave = get_max_score_wave_in_list(valid_wave_list)
max_score_wave.to_dict()

In [None]:
find_wave_in_wave_list(base_wave, valid_wave_list).to_dict()

In [None]:
max_score_wave.show_line_chart()

In [None]:
base_wave.to_dict()

In [None]:
base_wave.show_line_chart()

In [59]:
waves.ZigZagTripleCombinationWave.rule_list

[rules.PointNumberRule, rules.TimeDifferentRule, rules.Rule0, rules.Rule26]