In [1]:
# train.py

import os
import torch
import numpy as np
import warnings
import pickle
warnings.filterwarnings(action='ignore')

from data_loader import load_data_1m
from feature_calculations_2 import (
    resample_data, calculate_MA_data, calculate_ema_bollinger_bands, calculate_rsi,
    calculate_macd, calculate_stochastic_oscillator, calculate_adx, calculate_atr, calculate_volume,
    calculate_obv, calculate_williams_r, base_feature_fn, calculate_volatility_features, cyclic_encode_fn, calculate_support_resistance_numba, log_transform
)
from strategies import BB_fitness_fn, BB_MACD_fitness_fn, simple_fitness_fn, BB_MACD_EMA_RSI_fitness_fn
from dataset import make_dataset, replace_nan_with_zero
from train_functions_bi_cul import inference, fitness_fn, generation_valid, generation_test

from Prescriptor import Prescriptor, CryptoModelTCN
from Evolution.crossover import UniformCrossover, WeightedSumCrossover, DifferentialEvolutionOperator, CenDE_DOBLOperator, SkipCrossover
from Evolution.mutation import MultiplyNormalMutation, MultiplyUniformMutation, AddNormalMutation, AddUniformMutation, ChainMutation, FlipSignMutation
from Evolution.mutation import RandomValueMutation
from Evolution.selection import RouletteSelection, TournamentSelection, ParetoLexsortSelection
from Evolution import Evolution

In [2]:
# Load Data
data_1m = load_data_1m('/root/daily/bit/data/1min_ethusdt.pkl')
# # data_1m = data_1m.iloc[:200000]

# # 각 지표별 window 설정 및 계산
# ma_windows = [5, 20, 60]
# bb_windows = [5, 20, 60]
# macd_windows = [(60, 600, 240), (30, 300, 120), (6, 13, 4)]
# rsi_windows = [7, 20, 60]
# stoch_windows = [(240, 60), (120, 30), (9, 3)]
# adx_windows = [60, 20, 7]
# atr_windows = [60, 20, 7]
# williams_windows = [60, 20, 7]
# sr_windows =  [120, 60, 20]

# all_ma_cols, all_ma_cols_rel = [], []
# all_bb_cols, all_bb_cols_rel = [], []
# all_macd_cols = []
# all_rsi_cols = []
# all_stoch_cols = []
# all_adx_cols = []
# all_atr_cols = []
# all_will_cols = []
# all_sr_cols = []
 
# for ws in ma_windows:
#     data_1m, ma_cols, ma_cols_rel = calculate_MA_data(data_1m, ws, 'MA')
#     all_ma_cols.extend(ma_cols)
#     all_ma_cols_rel.extend(ma_cols_rel)
 
# data_1m, _, __ = calculate_MA_data(data_1m, 180, 'MA')
# for ws in bb_windows:
#     data_1m, bb_cols, bb_cols_rel = calculate_ema_bollinger_bands(data_1m, ws)
#     all_bb_cols.extend(bb_cols)
#     all_bb_cols_rel.extend(bb_cols_rel)
 
# for short_period, long_period, signal_period in macd_windows:
#     data_1m, macd_cols = calculate_macd(data_1m, short_period, long_period, signal_period)
#     all_macd_cols.extend(macd_cols)
 
# for ws in rsi_windows:
#     data_1m, rsi_cols = calculate_rsi(data_1m, window=ws)
#     all_rsi_cols.extend(rsi_cols)
 
# for k_period, d_period in stoch_windows:
#     data_1m, stoch_cols = calculate_stochastic_oscillator(data_1m, k_period, d_period)
#     all_stoch_cols.extend(stoch_cols)
 
# for ws in adx_windows:
#     data_1m, adx_cols = calculate_adx(data_1m, ws)
#     all_adx_cols.extend(adx_cols)
 
# for ws in atr_windows:
#     data_1m, atr_cols = calculate_atr(data_1m, ws)
#     all_atr_cols.extend(atr_cols)
 
# for ws in williams_windows:
#     data_1m, will_cols = calculate_williams_r(data_1m, ws)
#     all_will_cols.extend(will_cols)
 
# for ws in sr_windows:
#     data_1m, sr_col = calculate_support_resistance_numba(data_1m, window=ws)
#     if isinstance(sr_col, list):
#         all_sr_cols.extend(sr_col)
#     else:
#         all_sr_cols.append(sr_col)

# test_column = ['Quote asset volume', 'Number of trades', 'Taker buy base asset volume',
#                'Taker buy quote asset volume']

# # 기본 피처와 사이클릭 인코딩 계산 (window size와 무관한 경우)
# data_1m, base_feature = base_feature_fn(data_1m, alpha=100)
# data_1m, volume_feature = calculate_volume(data_1m, window_size=240, volume_column_list=test_column)
# data_1m, volatility_cols = calculate_volatility_features(data_1m, window=240, alpha=100)
# data_1m, cyclic_encoding = cyclic_encode_fn(data_1m, 'Open time')


# # 예시로 일부 test용 컬럼 정의
# drop_column = [
#     'Open time', 'Close time', 'Quote asset volume', 'Ignore',
#     'Number of trades', 'Taker buy base asset volume', 'Taker buy quote asset volume'
# ]


# # for cloumn in test_column:
# #     data_1m[cloumn] = log_transform(data_1m[cloumn])

# # 최종 feature 컬럼을 정리합니다.
# feature_column = (
#     test_column +
#     cyclic_encoding +
#     all_ma_cols_rel +
#     all_bb_cols_rel +
#     all_rsi_cols +
#     all_macd_cols +
#     all_stoch_cols +
#     all_adx_cols +
#     all_will_cols +
#     all_sr_cols +
#     volatility_cols + 
#     volume_feature + 
#     base_feature
# )
 
 
# # bb_entry_pos_list, patience_list, bb_entry_index_list = BB_fitness_fn(data_1m)
# # bb_macd_entry_pos_list, patience_list, bb_macd_entry_index_list = BB_MACD_fitness_fn(data_1m, 60, 20, 60)
# # bb_macd_entry_pos_list, patience_list, bb_macd_entry_index_list = simple_fitness_fn(data_1m, 240, 60, 180)
# # bb_macd_entry_pos_list, patience_list, bb_macd_entry_index_list = BB_MACD_fitness_fn(data_1m, 60, 20, 60)
# bb_macd_entry_pos_list, patience_list, bb_macd_entry_index_list = BB_MACD_EMA_RSI_fitness_fn(data_1m, 60, 20, 60, 180)

# # Prepare Dataset
# data_tensor = make_dataset(
#     data_1m,
#     using_column=feature_column,
#     window_size=1,
#     entry_pos_list=bb_macd_entry_pos_list,
#     patience_list=patience_list,

# )
# entry_pos_list = np.array(bb_macd_entry_pos_list)[np.array(bb_macd_entry_pos_list) != 'hold']

# dataset_1m = []
# skip_data_cnt = 0
# for data in data_tensor:
#     if type(data[0]) == np.ndarray:
#         dataset_1m.append(torch.from_numpy(data[0]).unsqueeze(dim=0))
 
#     else:
#         skip_data_cnt += 1
# dataset_1m = torch.cat(dataset_1m, dim=0)

# # # Avoid division by zero by replacing zero denominators with a small epsilon value
# # epsilon = 1e-10
# # dataset_1m[:, :, :4] = dataset_1m[:, :, :4] / (torch.mean(dataset_1m[:, :, :4], dim=0).unsqueeze(dim=1) + epsilon)

# dataset_1m = replace_nan_with_zero(dataset_1m)

# import pickle

# data_to_save = {
#     'dataset_1m': dataset_1m.squeeze(dim=1),
#     'skip_data_cnt': skip_data_cnt,
#     'entry_pos_list': entry_pos_list,
#     'bb_macd_entry_pos_list': bb_macd_entry_pos_list,
#     'bb_macd_entry_index_list': bb_macd_entry_index_list
# }

# with open('/root/daily/bit_5/backup_feature_data/data.pkl', 'wb') as f:
#     pickle.dump(data_to_save, f)

In [3]:
# Load data from the file into separate variables
with open('/root/daily/bit_5/backup_feature_data/data.pkl', 'rb') as f:
    loaded_data = pickle.load(f)

dataset_1m = loaded_data['dataset_1m']
# dataset_1d = loaded_data['dataset_1d']
skip_data_cnt = loaded_data['skip_data_cnt']
entry_pos_list = loaded_data['entry_pos_list']
bb_macd_entry_pos_list = loaded_data['bb_macd_entry_pos_list']
bb_macd_entry_index_list = loaded_data['bb_macd_entry_index_list']

In [4]:
dataset_1m.shape

torch.Size([90609, 54])

In [5]:
valid_skip_data_cnt = int(len(dataset_1m)*0.6) + skip_data_cnt
test_skip_data_cnt = int(len(dataset_1m)*0.8) + skip_data_cnt


In [6]:
# Evolution Setup
# 전역적으로 기울기 계산 비활성화
torch.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True

chromosomes_size=30000
window_size=240
EPOCH = 5
gen_loop=50
best_size=30000
elite_size=6000
profit_init=1
device = 'cuda:1'
group = 30000
start_gen = 0
best_profit = None
best_chromosomes = None

prescriptor = Prescriptor(input_dim=54, 
                fc_hidden_size=16, 
                output_dim=8, 
                after_input_dim=11, 
                after_hidden_dim=16, 
                after_output_dim=5, 
                num_blocks=group).to(device).eval()

# if i == 1:
#     start_gen=0

# if i == 0:
#     prescriptor.load_state_dict(state_dict['prescriptor_state_dict'],strict=True)

total_param = sum(p.numel() for p in prescriptor.parameters())
print(f"Total parameters: {total_param}")

selection = RouletteSelection(elite_num=6000, parents_num=12000, minimize=False)
# selection = ParetoLexsortSelection(elite_num=2000, parents_num=4000,
#                                     priority=[], prior_ratio= [],
#                                     prob_method= 'softmax',minimize=False)
# crossover = DifferentialEvolutionOperator()
# crossover = UniformCrossover(num_parents=4)
# crossover = CenDE_DOBLOperator()
mutation = ChainMutation([RandomValueMutation(mut_prob=0.05), AddUniformMutation(mut_prob=0.1)])
# crossover = UniformCrossover(num_parents=1)
crossover = DifferentialEvolutionOperator()
# mutation = AddNormalMutation(mut_prob=0.1)
evolution = Evolution(
    prescriptor=prescriptor,
    selection=selection,
    crossover=crossover,
    mutation=mutation,
    group_size=group
)

best_chromosomes, best_profit = generation_valid(
    data_1m=data_1m,
    dataset_1m=dataset_1m,
    # dataset_1d=dataset_1d,
    prescriptor=prescriptor,
    evolution=evolution,
    skip_data_cnt=skip_data_cnt,
    valid_skip_data_cnt=valid_skip_data_cnt,
    test_skip_data_cnt=test_skip_data_cnt,
    chromosomes_size=chromosomes_size,
    window_size=window_size,
    gen_loop=gen_loop,
    best_size=best_size,
    elite_size=elite_size,
    profit_init=profit_init,
    entry_index_list=bb_macd_entry_index_list,
    entry_pos_list=entry_pos_list,
    best_profit=best_profit,
    best_chromosomes=best_chromosomes,
    start_gen=start_gen,
    device=device
)

Total parameters: 197370000
generation  0: 


Inference Progress: 100%|██████████| 708/708 [02:42<00:00,  4.36it/s]
 60%|██████    | 54472/90716 [12:25<08:16, 73.07it/s]


generation  1: 


Inference Progress: 100%|██████████| 708/708 [02:18<00:00,  5.10it/s]
 60%|██████    | 54472/90716 [12:24<08:15, 73.13it/s] 


generation  2: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.09it/s]
 60%|██████    | 54472/90716 [12:09<08:05, 74.65it/s]


generation  3: 


Inference Progress: 100%|██████████| 708/708 [02:18<00:00,  5.10it/s]
 60%|██████    | 54472/90716 [12:21<08:13, 73.49it/s]


generation  4: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.05it/s]
 60%|██████    | 54472/90716 [12:15<08:09, 74.04it/s]  


generation  5: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [12:02<08:00, 75.42it/s]  
 80%|████████  | 72594/90716 [03:01<00:45, 400.20it/s]   


generation  6: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.05it/s]
 60%|██████    | 54472/90716 [11:50<07:52, 76.66it/s]  
 80%|████████  | 72594/90716 [02:51<00:42, 423.76it/s]   


generation  7: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.08it/s]
 60%|██████    | 54472/90716 [11:36<07:43, 78.23it/s] 
 80%|████████  | 72594/90716 [02:45<00:41, 438.70it/s]   


generation  8: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.08it/s]
 60%|██████    | 54472/90716 [11:22<07:34, 79.76it/s] 
 80%|████████  | 72594/90716 [02:43<00:40, 443.77it/s]   


generation  9: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [11:24<07:35, 79.62it/s]  
 80%|████████  | 72594/90716 [02:44<00:40, 442.15it/s]   


generation  10: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.07it/s]
 60%|██████    | 54472/90716 [11:15<07:29, 80.65it/s]
 80%|████████  | 72594/90716 [02:42<00:40, 445.88it/s]   


generation  11: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [11:11<07:26, 81.18it/s]  
 80%|████████  | 72594/90716 [02:40<00:40, 451.46it/s]   


generation  12: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.08it/s]
 60%|██████    | 54472/90716 [10:56<07:16, 82.94it/s] 
 80%|████████  | 72594/90716 [02:40<00:40, 452.88it/s]   


generation  13: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [11:11<07:26, 81.11it/s] 
 80%|████████  | 72594/90716 [02:41<00:40, 448.93it/s]   


generation  14: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [11:12<07:27, 81.01it/s] 
 80%|████████  | 72594/90716 [02:40<00:39, 453.23it/s]   


generation  15: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [11:01<07:20, 82.33it/s] 
 80%|████████  | 72594/90716 [02:39<00:39, 454.38it/s]   


generation  16: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [11:01<07:19, 82.40it/s] 
 80%|████████  | 72594/90716 [02:40<00:40, 452.46it/s]   


generation  17: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.05it/s]
 60%|██████    | 54472/90716 [11:12<07:27, 81.05it/s]  
 80%|████████  | 72594/90716 [02:42<00:40, 447.80it/s]   


generation  18: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.05it/s]
 60%|██████    | 54472/90716 [11:03<07:21, 82.11it/s] 
 80%|████████  | 72594/90716 [02:41<00:40, 449.36it/s]   


generation  19: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.07it/s]
 60%|██████    | 54472/90716 [10:54<07:15, 83.18it/s] 
 80%|████████  | 72594/90716 [02:39<00:39, 456.47it/s]   


generation  20: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [11:06<07:23, 81.74it/s] 
 80%|████████  | 72594/90716 [02:41<00:40, 449.77it/s]   


generation  21: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [11:18<07:31, 80.33it/s]  
 80%|████████  | 72594/90716 [02:42<00:40, 447.69it/s]   


generation  22: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [11:12<07:27, 81.04it/s] 
 80%|████████  | 72594/90716 [02:41<00:40, 449.68it/s]   


generation  23: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.08it/s]
 60%|██████    | 54472/90716 [11:00<07:19, 82.52it/s] 
 80%|████████  | 72594/90716 [02:40<00:39, 453.69it/s]   


generation  24: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [11:03<07:21, 82.10it/s] 
 80%|████████  | 72594/90716 [02:40<00:39, 453.47it/s]   


generation  25: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.05it/s]
 60%|██████    | 54472/90716 [11:03<07:21, 82.11it/s]
 80%|████████  | 72594/90716 [02:39<00:39, 454.39it/s]   


generation  26: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [11:08<07:24, 81.54it/s]  
 80%|████████  | 72594/90716 [02:40<00:39, 453.55it/s]   


generation  27: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.07it/s]
 60%|██████    | 54472/90716 [11:05<07:23, 81.81it/s] 
 80%|████████  | 72594/90716 [02:41<00:40, 450.86it/s]   


generation  28: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.07it/s]
 60%|██████    | 54472/90716 [11:00<07:19, 82.43it/s]
 80%|████████  | 72594/90716 [02:39<00:39, 456.34it/s]   


generation  29: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.05it/s]
 60%|██████    | 54472/90716 [11:02<07:20, 82.19it/s] 
 80%|████████  | 72594/90716 [02:39<00:39, 455.55it/s]   


generation  30: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [11:00<07:19, 82.46it/s]
 80%|████████  | 72594/90716 [02:39<00:39, 455.27it/s]   


generation  31: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.07it/s]
 60%|██████    | 54472/90716 [11:09<07:25, 81.33it/s]
 80%|████████  | 72594/90716 [02:39<00:39, 454.72it/s]   


generation  32: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.08it/s]
 60%|██████    | 54472/90716 [11:12<07:27, 81.03it/s]  
 80%|████████  | 72594/90716 [02:39<00:39, 454.67it/s]   


generation  33: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.05it/s]
 60%|██████    | 54472/90716 [11:02<07:20, 82.24it/s] 
 80%|████████  | 72594/90716 [02:39<00:39, 455.76it/s]   


generation  34: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.04it/s]
 60%|██████    | 54472/90716 [11:04<07:22, 81.96it/s] 
 80%|████████  | 72594/90716 [02:39<00:39, 453.82it/s]   


generation  35: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.04it/s]
 60%|██████    | 54472/90716 [11:02<07:20, 82.23it/s]
 80%|████████  | 72594/90716 [02:38<00:39, 456.68it/s]   


generation  36: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.06it/s]
 60%|██████    | 54472/90716 [10:57<07:17, 82.84it/s] 
 80%|████████  | 72594/90716 [02:39<00:39, 454.85it/s]   


generation  37: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.07it/s]
 60%|██████    | 54472/90716 [10:58<07:17, 82.75it/s] 
 80%|████████  | 72594/90716 [02:41<00:40, 450.23it/s]   


generation  38: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.03it/s]
 60%|██████    | 54472/90716 [11:06<07:23, 81.79it/s] 
 80%|████████  | 72594/90716 [02:41<00:40, 449.86it/s]   


generation  39: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.03it/s]
 60%|██████    | 54472/90716 [11:03<07:21, 82.14it/s] 
 80%|████████  | 72594/90716 [02:38<00:39, 457.43it/s]   


generation  40: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.04it/s]
 60%|██████    | 54472/90716 [11:09<07:25, 81.34it/s]  
 80%|████████  | 72594/90716 [02:40<00:39, 453.68it/s]   


generation  41: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.05it/s]
 60%|██████    | 54472/90716 [11:04<07:22, 81.98it/s] 
 80%|████████  | 72594/90716 [02:39<00:39, 455.72it/s]   


generation  42: 


Inference Progress: 100%|██████████| 708/708 [02:19<00:00,  5.07it/s]
 60%|██████    | 54472/90716 [11:05<07:22, 81.88it/s] 
 80%|████████  | 72594/90716 [02:41<00:40, 449.30it/s]   


generation  43: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.04it/s]
 60%|██████    | 54472/90716 [10:59<07:19, 82.56it/s] 
 80%|████████  | 72594/90716 [02:39<00:39, 456.05it/s]   


generation  44: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.04it/s]
 60%|██████    | 54472/90716 [11:05<07:22, 81.85it/s]
 80%|████████  | 72594/90716 [02:40<00:40, 452.58it/s]   


generation  45: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.05it/s]
 60%|██████    | 54472/90716 [11:03<07:21, 82.14it/s] 
 80%|████████  | 72594/90716 [02:39<00:39, 453.77it/s]   


generation  46: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.03it/s]
 60%|██████    | 54472/90716 [10:59<07:19, 82.55it/s] 
 80%|████████  | 72594/90716 [02:39<00:39, 456.38it/s]   


generation  47: 


Inference Progress: 100%|██████████| 708/708 [02:20<00:00,  5.05it/s]
 60%|██████    | 54472/90716 [10:55<07:16, 83.09it/s] 
 80%|████████  | 72594/90716 [02:39<00:39, 453.95it/s]   


generation  48: 


Inference Progress: 100%|██████████| 708/708 [02:21<00:00,  5.02it/s]
 60%|██████    | 54472/90716 [11:06<07:23, 81.78it/s] 
 80%|████████  | 72594/90716 [02:42<00:40, 447.81it/s]   


generation  49: 


Inference Progress: 100%|██████████| 708/708 [02:21<00:00,  5.01it/s]
 60%|██████    | 54472/90716 [11:09<07:25, 81.30it/s]
 80%|████████  | 72594/90716 [02:53<00:43, 417.77it/s]   
