In [1]:
# train.py

import os
import torch
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings(action='ignore')

from data_loader import load_data_1m
from feature_calculations import (
    resample_data, calculate_MA_data, calculate_ema_bollinger_bands, calculate_rsi,
    calculate_macd, calculate_stochastic_oscillator, calculate_adx, calculate_atr,
    calculate_obv, calculate_williams_r, base_feature_fn, cyclic_encode_fn, log_transform
)
from strategies import BB_fitness_fn, BB_MACD_fitness_fn
from dataset import make_dataset, replace_nan_with_zero
from train_functions import inference, fitness_fn, generation_valid, generation_test

from Prescriptor import Prescriptor, ChromosomeSelectorModel
from Evolution.crossover import UniformCrossover, WeightedSumCrossover, DifferentialEvolutionOperator
from Evolution.mutation import MultiplyNormalMutation, MultiplyUniformMutation, AddNormalMutation, AddUniformMutation, ChainMutation, FlipSignMutation
from Evolution.mutation import RandomValueMutation
from Evolution.selection import RouletteSelection, TournamentSelection, ParetoLexsortSelection
from Evolution import Evolution


import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn


In [3]:
sample_step=100
size=8
epochs=100
max_selected = 5  # Number of chromosomes to select
accumulation_steps = 8  # Number of steps to accumulate gradients
# random_index = make_sample(len(dataset_1m), sample_step=sample_step, batch_size=size)
device = 'cuda:1'
window_size=240

In [4]:
# Load Data
data_1m = load_data_1m('/root/daily/bit/data/1min_bitusdt.pkl')
data_1m = data_1m.iloc[:100000]

# Resample data to 1D
data_1d = resample_data(data_1m, '1D')
data_1d['Close time'] = data_1d.index
data_1d = data_1d.reset_index(drop=True)

# Apply Feature Calculations
# For 1D Data
data_1d, ma_cols_1d, ma_cols_rel_1d = calculate_MA_data(data_1d, 60, 'EMA', '_1d')
data_1d, bb_cols_1d, bb_cols_rel_1d = calculate_ema_bollinger_bands(data_1d, 60, extra_str='_1d')
data_1d, rsi_cols_1d = calculate_rsi(data_1d, window=20, extra_str='_1d')
data_1d, macd_cols_1d = calculate_macd(data_1d, 20, 120, 60, extra_str='_1d')
data_1d, stoch_cols_1d = calculate_stochastic_oscillator(data_1d, 60, 20, extra_str='_1d')
data_1d, adx_cols_1d = calculate_adx(data_1d, 60, extra_str='_1d')
data_1d, atr_cols_1d = calculate_atr(data_1d, 60, extra_str='_1d')
data_1d, obv_cols_1d = calculate_obv(data_1d, extra_str='_1d')
data_1d, will_cols_1d = calculate_williams_r(data_1d, 60, extra_str='_1d')
data_1d, base_feature_1d = base_feature_fn(data_1d, extra_str='_1d')
data_1d, cyclice_encoding_1d = cyclic_encode_fn(data_1d, 'Close time', 'day_of_year')

# For 1M Data
data_1m, ma_cols, ma_cols_rel = calculate_MA_data(data_1m, 240, 'EMA')
data_1m, bb_cols, bb_cols_rel = calculate_ema_bollinger_bands(data_1m, 240)
data_1m, rsi_cols = calculate_rsi(data_1m, window=60)
data_1m, macd_cols = calculate_macd(data_1m, 60, 600, 240)
data_1m, stoch_cols = calculate_stochastic_oscillator(data_1m, 240, 60)
data_1m, adx_cols = calculate_adx(data_1m, 240)
data_1m, atr_cols = calculate_atr(data_1m, 240)
data_1m, obv_cols = calculate_obv(data_1m)
data_1m, will_cols = calculate_williams_r(data_1m, 240)
data_1m, base_feature = base_feature_fn(data_1m)
data_1m, cyclice_encoding = cyclic_encode_fn(data_1m, 'Open time')

data_1m, short_ma_cols, short_ma_cols_rel = calculate_MA_data(data_1m, 60, 'EMA')
data_1m, long_ma_cols, long_ma_cols_rel = calculate_MA_data(data_1m, 180, 'EMA')

# Prepare Feature Columns
drop_column = [
    'Open time', 'Close time', 'Quote asset volume', 'Ignore',
    'Number of trades', 'Taker buy base asset volume', 'Taker buy quote asset volume'
]
feature_column = (
    ma_cols_rel + bb_cols_rel + rsi_cols + macd_cols + stoch_cols +
    adx_cols + will_cols + base_feature + cyclice_encoding  # Excluding obv and atr
)
feature_column_1d = (
    ma_cols_rel_1d + bb_cols_rel_1d + rsi_cols_1d + macd_cols_1d + stoch_cols_1d +
    adx_cols_1d + will_cols_1d + base_feature_1d + cyclice_encoding_1d
)


# Apply Log Transform
for feature in feature_column:
    data_1m[feature] = log_transform(data_1m[feature])

for feature in feature_column_1d:
    data_1d[feature] = log_transform(data_1d[feature])

# bb_entry_pos_list, patience_list, bb_entry_index_list = BB_fitness_fn(data_1m)
bb_macd_entry_pos_list, patience_list, bb_macd_entry_index_list = BB_MACD_fitness_fn(data_1m, 240, 60, 180)

# Prepare Dataset
data_tensor = make_dataset(
    data_1m, data_1d,
    using_column=feature_column, using_column_1d=feature_column_1d,
    window_size=240, window_size_1d=60,
    entry_pos_list=bb_macd_entry_pos_list, patience_list=patience_list,
    use_1d_data=True
)
entry_pos_list = np.array(bb_macd_entry_pos_list)[np.array(bb_macd_entry_pos_list) != 'hold']

dataset_1m = []
dataset_1d = []
skip_data_cnt = 0
for data in data_tensor:
    if len(data[0]) == 240 and len(data[1]) == 60:
        dataset_1m.append(torch.from_numpy(data[0]).unsqueeze(dim=0))
        dataset_1d.append(torch.from_numpy(data[1]).unsqueeze(dim=0))
    else:
        skip_data_cnt += 1
dataset_1m = torch.cat(dataset_1m, dim=0)
dataset_1d = torch.cat(dataset_1d, dim=0)
dataset_1m = replace_nan_with_zero(dataset_1m)
dataset_1d = replace_nan_with_zero(dataset_1d)



100%|██████████| 100000/100000 [00:11<00:00, 8811.00it/s]


In [5]:
state_dict_path = '/root/daily/bit/generation_data_DE/generation_19_5.pt'
best_index = np.load('/root/daily/bit/generation_data_DE/performance_folder/best_of_best_index.npy')
if os.path.exists(state_dict_path):
    state_dict = torch.load(state_dict_path)
    start_gen = state_dict['generation'] + 1
    best_profit = state_dict['best_profit']
    best_chromosomes = state_dict['best_chromosomes']
    best_chromosomes = best_chromosomes[best_index]
    # prescriptor.load_state_dict(state_dict['prescriptor_state_dict'],strict=True)

group = len(best_index)
prescriptor = Prescriptor(
    basic_block=None, 
    base_small_input_dim=19, 
    base_large_input_dim=19,
    base_hidden_dim=24, 
    base_output_dim=16, 
    after_input_dim=19, 
    after_hidden_dim=32, 
    after_output_dim=6, 
    num_blocks=group,
).to(device)

inference_model = Prescriptor(
    basic_block=None, 
    base_small_input_dim=19, 
    base_large_input_dim=19,
    base_hidden_dim=24, 
    base_output_dim=16, 
    after_input_dim=19, 
    after_hidden_dim=32, 
    after_output_dim=6, 
    num_blocks=size*max_selected,
).to(device)

total_param = sum(p.numel() for p in prescriptor.parameters())
print(f"Total parameters: {total_param}")

selection = RouletteSelection(elite_num=1000, parents_num=1000, minimize=False)
crossover = DifferentialEvolutionOperator()
mutation = RandomValueMutation(mut_prob=0.02)
evolution = Evolution(
    prescriptor=prescriptor,
    selection=selection,
    crossover=crossover,
    mutation=mutation
)
inference_evolution = Evolution(
    prescriptor=inference_model,
    selection=selection,
    crossover=crossover,
    mutation=mutation
)

init_chromosomes, base_ch_shape, after_ch_shape, device = evolution.flatten_chromosomes()
device = 'cuda:1'
evolution.update_chromosomes(best_chromosomes[:group], base_ch_shape, after_ch_shape, device)

logits = inference(dataset_1m, dataset_1d, prescriptor, device)
probs = []
for logit in logits:
    logit = torch.stack(logit, dim=0)
    probs.append(logit)
probs = torch.concat(probs, dim=1)
probs = probs.squeeze(dim=2)

Total parameters: 498300


In [6]:
rl_model = prescriptor = ChromosomeSelectorModel(
    small_input_dim=19, 
    large_input_dim=19,
    hidden_dim=512,
    num_chromosomes=group,
    num_layers=6
)

# rl_model = nn.DataParallel(rl_model, device_ids = [0,1])   # 2개의 GPU를 이용할 경우
rl_model = rl_model.to(device)
optimizer = optim.Adam(rl_model.parameters(), lr=1e-4)
total_param = sum(p.numel() for p in rl_model.parameters())
print(f"Total parameters: {total_param}")

Total parameters: 23221273


In [7]:
def make_sample(len_sample, sample_step=500, batch_size=1, ):
    start_index = torch.randint(low=0, high=len_sample-sample_step, size=(batch_size, )).unsqueeze(dim=1)
    end_index = start_index + sample_step
    random_index = torch.concat([start_index, end_index], dim=1)
    return random_index

In [8]:
def loss_cut_fn(pos_list, price_list, leverage_ratio, enter_ratio, profit, curr_low, curr_high, additional_count, alpha=1., cut_percent=80.):
    
    # Positions: 'short' -> 1, 'long' -> 2, 'hold' -> 0
    short_index = torch.where(pos_list == 1)[0]
    long_index = torch.where(pos_list == 2)[0]

    # Calculate profit or loss
    short_profit = -((curr_high[short_index] - price_list[short_index]) / price_list[short_index] * 100.) * leverage_ratio[short_index]
    long_profit = ((curr_low[long_index] - price_list[long_index]) / price_list[long_index] * 100.) * leverage_ratio[long_index]
    
    # Determine positions to cut
    short_cut_index = torch.where(short_profit <= -cut_percent)[0]
    long_cut_index = torch.where(long_profit <= -cut_percent)[0]

    # Update state for short positions to be cut
    short_index = short_index[short_cut_index]
    profit[short_index] = profit[short_index] - (enter_ratio[short_index] * cut_percent * alpha) - 0.1 * leverage_ratio[short_index] * enter_ratio[short_index]
    pos_list[short_index] = 0
    price_list[short_index] = -1.
    leverage_ratio[short_index] = -1.
    enter_ratio[short_index] = -1.
    additional_count[short_index] = 0

    # Update state for long positions to be cut
    long_index = long_index[long_cut_index]
    profit[long_index] = profit[long_index] - (enter_ratio[long_index] * cut_percent * alpha) - 0.1 * leverage_ratio[long_index] * enter_ratio[long_index]
    pos_list[long_index] = 0
    price_list[long_index] = -1.
    leverage_ratio[long_index] = -1.
    enter_ratio[long_index] = -1.
    additional_count[long_index] = 0

    return pos_list, price_list, leverage_ratio, enter_ratio, additional_count, profit


def calculate_now_profit(pos_list, price_list, leverage_ratio, enter_ratio, curr_price):
    now_profit = torch.zeros_like(pos_list, dtype=torch.float)
    short_index = torch.where(pos_list == 1)[0]  # 'short' -> 1
    long_index = torch.where(pos_list == 2)[0]  # 'long' -> 2

    short_profit = (-((curr_price[short_index] - price_list[short_index]) / price_list[short_index] * 100.) * leverage_ratio[short_index]) - 0.1 * leverage_ratio[short_index] * enter_ratio[short_index]
    long_profit = (((curr_price[long_index] - price_list[long_index]) / price_list[long_index] * 100.) * leverage_ratio[long_index]) - 0.1 * leverage_ratio[long_index] * enter_ratio[long_index]
    short_profit = short_profit * enter_ratio[short_index]
    long_profit = long_profit * enter_ratio[long_index]
    now_profit[short_index] = short_profit
    now_profit[long_index] = long_profit

    return now_profit

def after_forward(model, prob, now_profit, leverage_ratio, enter_ratio, pos_list, max_selected, device):
    ch_size = len(now_profit)
    now_profit_tensor = now_profit.unsqueeze(dim=1)
    leverage_ratio_tensor = leverage_ratio.unsqueeze(dim=1)
    enter_ratio_tensor = enter_ratio.unsqueeze(dim=1)
    mapping = {0: 0, 1: 1, 2: 2}  # Adjusted mapping
    mapped_array = pos_list
    mapped_array = mapped_array.repeat([1, max_selected]).flatten()
    step = torch.arange(0, ch_size * 3 * max_selected, step=3, device='cpu')
    now_profit_tensor = now_profit_tensor.repeat([1, max_selected]).flatten().unsqueeze(dim=1)
    leverage_ratio_tensor = leverage_ratio_tensor.repeat([1, max_selected]).flatten().unsqueeze(dim=1)
    enter_ratio_tensor = enter_ratio_tensor.repeat([1, max_selected]).flatten().unsqueeze(dim=1)

    x = torch.cat([prob, now_profit_tensor, leverage_ratio_tensor, enter_ratio_tensor], dim=1)
    cate_x = mapped_array + step

    x = x.to(device).float()
    cate_x = cate_x.to(device).long()

    inference_model.eval()
    with torch.no_grad():
        after_output = inference_model.after_forward(x=x.squeeze(dim=0), x_cate=cate_x)
    after_output = after_output.squeeze(dim=0)
    after_output = after_output.reshape(max_selected, -1, 6)
    
    probs = torch.zeros(size=after_output.shape[1:])
    
    prob_input = after_output[:, :, :4]
    leverate_input = after_output[:, :, 4]
    enter_input = after_output[:, :, 5]

    prob_input = torch.mean(torch.softmax(prob_input, dim=-1), dim=0)
    leverate_input = torch.sigmoid(leverate_input)
    leverate_input = torch.mean(leverate_input, dim=0)
    enter_input = torch.sigmoid(enter_input)
    enter_input = torch.mean(enter_input, dim=0)
        
    probs[:, :4] = prob_input
    probs[:, 4] = leverate_input
    probs[:, 5] = enter_input
    
    return probs

def calculate_same(same_prob, pos_list, price_list, leverage_ratio, enter_ratio, profit, entry_pos, curr_close, additional_count, limit=2, cut_value=1.):
    index = torch.tensor([0, 1, 3])
    logit = torch.argmax(same_prob[:, index], dim=1)
    hold_index = torch.where(logit == 0)[0]
    enter_index = torch.where((logit == 1) & (additional_count < limit))[0]
    loss_index = torch.where(logit == 2)[0]

    # loss
    pos_list[loss_index] = 0  # 'hold' -> 0
    loss_profit = (price_list[loss_index] - curr_close[loss_index]) / price_list[loss_index] * 100
    loss_profit = loss_profit * leverage_ratio[loss_index] * enter_ratio[loss_index]

    # enter
    before_price_list = price_list[enter_index]
    before_enter_list = enter_ratio[enter_index]
    cut_enter = cut_value - before_enter_list
    
    enter_enter_ratio = same_prob[enter_index][:, 5]
    enter_enter_ratio = torch.minimum(cut_enter, enter_enter_ratio)
    after_price_list = before_price_list * (before_enter_list / (before_enter_list + enter_enter_ratio)) \
                       + curr_close[enter_index] * (enter_enter_ratio / (before_enter_list + enter_enter_ratio))
    after_enter_ratio = before_enter_list + enter_enter_ratio

    loss_entry_pos = entry_pos[loss_index]
    long_loss_entry_pos = np.where(loss_entry_pos==2)[0]
    short_loss_entry_pos = np.where(loss_entry_pos==1)[0]

    long_loss = np.intersect1d(loss_index, np.where(entry_pos==2)[0])
    short_loss = np.intersect1d(loss_index, np.where(entry_pos==1)[0])

    profit[long_loss] = profit[long_loss] - loss_profit[long_loss_entry_pos] - 0.1 * leverage_ratio[long_loss] * enter_ratio[long_loss]
    profit[short_loss] = profit[short_loss] + loss_profit[short_loss_entry_pos] - 0.1 * leverage_ratio[short_loss] * enter_ratio[short_loss]

    price_list[loss_index] = -1.
    leverage_ratio[loss_index] = -1.
    enter_ratio[loss_index] = -1.

    # Increment additional_count for allowed entries
    additional_count[enter_index] += 1
    price_list[enter_index] = after_price_list
    enter_ratio[enter_index] = after_enter_ratio

    return pos_list, price_list, leverage_ratio, enter_ratio, additional_count, profit

def calculate_diff(diff_prob, pos_list, price_list, leverage_ratio, enter_ratio, profit, entry_pos, curr_close, additional_count):
    index = torch.tensor([0, 1, 2])
    logit = torch.argmax(diff_prob[:, index], dim=1)
    hold_index = torch.where(logit == 0)[0]
    switch_index = torch.where(logit == 1)[0]
    take_index = torch.where(logit == 2)[0]

    # switch
    switch_profit = (price_list[switch_index] - curr_close[switch_index]) / price_list[switch_index] * 100
    switch_profit = switch_profit * leverage_ratio[switch_index] * enter_ratio[switch_index]
    switch_leverage = diff_prob[switch_index][:, 4] * 100.
    switch_enter_ratio = diff_prob[switch_index][:, 5]

    # take
    pos_list[take_index] = 0  # 'hold' -> 0
    take_profit = (price_list[take_index] - curr_close[take_index]) / price_list[take_index] * 100
    take_profit = take_profit * leverage_ratio[take_index] * enter_ratio[take_index]

    switch_entry_pos = entry_pos[switch_index]
    switch_long_entry_pos = np.where(switch_entry_pos == 2)[0]
    switch_short_entry_pos = np.where(switch_entry_pos == 1)[0]
    take_entry_pos = entry_pos[take_index]
    take_long_entry_pos = np.where(take_entry_pos == 2)[0]
    take_short_entry_pos = np.where(take_entry_pos == 1)[0]

    switch_long = np.intersect1d(switch_index, np.where(entry_pos==2)[0])
    switch_short = np.intersect1d(switch_index, np.where(entry_pos==1)[0])
    
    take_long = np.intersect1d(take_index, np.where(entry_pos==2)[0])
    take_short = np.intersect1d(take_index, np.where(entry_pos==1)[0])


    # take
    profit[take_long] = profit[take_long] + take_profit[take_long_entry_pos] - leverage_ratio[take_long] * 0.1 * enter_ratio[take_long]
    profit[take_short] = profit[take_short] - take_profit[take_short_entry_pos] - leverage_ratio[take_short] * 0.1 * enter_ratio[take_short]
    
    # switch
    profit[switch_long] = profit[switch_long] + switch_profit[switch_long_entry_pos] - leverage_ratio[switch_long] * 0.1 * enter_ratio[switch_long]
    pos_list[switch_long] = 2  # 'long' -> 2
    profit[switch_short] = profit[switch_short] - switch_profit[switch_short_entry_pos] - leverage_ratio[switch_short] * 0.1 * enter_ratio[switch_short]
    pos_list[switch_short] = 1  # 'short' -> 1
    

    price_list[switch_index] = curr_close[switch_index]
    leverage_ratio[switch_index] = switch_leverage
    enter_ratio[switch_index] = switch_enter_ratio

    price_list[take_index] = -1.
    leverage_ratio[take_index] = -1.
    enter_ratio[take_index] = -1.

    # Reset additional_count for switched and taken positions
    additional_count[switch_index] = 0
    additional_count[take_index] = 0

    return pos_list, price_list, leverage_ratio, enter_ratio, additional_count, profit


def calculate_hold(hold_prob, pos_list, price_list, leverage_ratio, enter_ratio, profit, entry_pos, curr_close, additional_count):
    index = torch.tensor([0, 1])
    logit = torch.argmax(hold_prob[:, index], dim=1)
    hold_index = torch.where(logit == 0)[0]
    enter_index = torch.where(logit == 1)[0]

    # enter
    enter_leverage = hold_prob[enter_index][:, 4] * 100.
    enter_enter_ratio = hold_prob[enter_index][:, 5]
    price_list[enter_index] = curr_close[enter_index]
    leverage_ratio[enter_index] = enter_leverage
    enter_ratio[enter_index] = enter_enter_ratio

    enter_long = np.intersect1d(enter_index, np.where(entry_pos==2)[0])
    enter_short = np.intersect1d(enter_index, np.where(entry_pos==1)[0])

    pos_list[enter_long] = 2  # 'long' -> 2
    pos_list[enter_short] = 1  # 'short' -> 1

    # Initialize additional_count for new positions
    additional_count[enter_index] = 0

    return pos_list, price_list, leverage_ratio, enter_ratio, additional_count, profit



def mapping_chromosome(inference_evolution, best_chromosomes, selected_indices, device):
    window_size=240
    init_chromosomes, base_ch_shape, after_ch_shape, device = inference_evolution.flatten_chromosomes()
    device = 'cuda:1'
    inference_evolution.update_chromosomes(best_chromosomes[selected_indices.flatten().cpu()], base_ch_shape, after_ch_shape, device)
    
def inference_model_infer(model, part_dataset_1m_batch, part_dataset_1d_batch, max_selected):
    with torch.no_grad():
        logits = []
        for part_dataset_1m, part_datset_1d in zip(part_dataset_1m_batch, part_dataset_1d_batch):
            logit = inference_model.base_forward(part_dataset_1m.float(), part_datset_1d.float())
            logits.append(logit)
            
        probs = []
        for index, logit in enumerate(logits):
            logit = torch.stack(logit, dim=0)
            # print(logit[index*max_selected:(index+1)*max_selected].shape)
            probs.append(logit[index*max_selected:(index+1)*max_selected])
        probs = torch.concat(probs, dim=0)
        probs = probs.squeeze(dim=2)
    return probs
    


In [9]:
def calculate_rewards(inference_evolution, prescriptor, best_chromosomes, selected_indices, part_dataset_1m_batch, part_dataset_1d_batch, max_selected,
                      size, entry_pos_list, random_index, entry_index_list, data_1m, device):
    alpha=1.
    cut_percent=90.
    limit = 4
    
    mapping_chromosome(inference_evolution, best_chromosomes, selected_indices, device)
    probs = inference_model_infer(inference_model, part_dataset_1m_batch, part_dataset_1d_batch, max_selected)
        
    pos_list = torch.zeros(size, dtype=torch.long, device='cpu')  # 0: 'hold'
    price_list = torch.full((size,), -1., device='cpu').float()
    leverage_ratio = torch.full((size,), -1., device='cpu').float()
    enter_ratio = torch.full((size,), -1., device='cpu').float()
    profit = torch.zeros(size, device='cpu').float()
    patience_cnt = torch.zeros(size, dtype=torch.long, device='cpu')
    additional_count = torch.zeros(size, dtype=torch.long, device='cpu')
    returns_list = []
    before_index = np.zeros(shape=(size, )).astype(int)

    # Map entry positions
    entry_pos_mapping = {'hold': 0, 'short': 1, 'long': 2}
    entry_pos_list_int = [entry_pos_mapping[ep] for ep in entry_pos_list]

    data_range = np.array([np.arange(random_index[i][0], random_index[i][1]) for i in range(len(random_index))])
    data_range = data_range.T + skip_data_cnt

    # entry_index_list = bb_macd_entry_index_list
        
    for data_cnt in range(len(data_range)):
        entry_index = np.array(entry_index_list)[data_range[data_cnt]]
        entry_pos = np.array(entry_pos_list_int)[data_range[data_cnt]]
        entry_pos = torch.tensor(entry_pos).long().to('cpu')
        
        x = data_1m.iloc[entry_index]
        curr_close = torch.tensor(x['Close'].values, device='cpu').float()

        history_x = [
            data_1m.iloc[before:entry + 1] 
            for before, entry in zip(before_index, entry_index)
        ]
        history_high = []
        history_low = []
        for df in history_x:
            history_high.append(df['High'].max())
            history_low.append(df['Low'].min())
        history_high = torch.tensor(history_high, device='cpu').float()
        history_low = torch.tensor(history_low, device='cpu').float()

        pos_list, price_list, leverage_ratio, enter_ratio, additional_count, profit = loss_cut_fn(
            pos_list, price_list, leverage_ratio,
            enter_ratio, profit, history_low, history_high,
            additional_count, alpha, cut_percent
        )

        prob = torch.tensor(probs[:, data_cnt]).to('cpu')
        hold_pos = torch.where(pos_list == 0)[0]
        same_pos = torch.where(pos_list == entry_pos)[0]
        diff_pos = torch.where((pos_list != entry_pos) & (pos_list != 0))[0]
        now_profit = calculate_now_profit(pos_list, price_list, leverage_ratio, enter_ratio, curr_close)
        prob = after_forward(prescriptor, prob, now_profit, leverage_ratio, enter_ratio, pos_list, max_selected, device=device)
        prob = prob.cpu()
        
        same_prob = prob[same_pos]
        diff_prob = prob[diff_pos]
        hold_prob = prob[hold_pos]
        
        pos_list[same_pos], price_list[same_pos], leverage_ratio[same_pos], enter_ratio[same_pos], additional_count[same_pos], profit[same_pos] = calculate_same(
            same_prob, pos_list[same_pos], price_list[same_pos], leverage_ratio[same_pos], enter_ratio[same_pos], profit[same_pos],
            entry_pos[same_pos], curr_close[same_pos], additional_count[same_pos], limit
        )
                
        pos_list[diff_pos], price_list[diff_pos], leverage_ratio[diff_pos], enter_ratio[diff_pos], additional_count[diff_pos], profit[diff_pos] = calculate_diff(
            diff_prob, pos_list[diff_pos], price_list[diff_pos], leverage_ratio[diff_pos], enter_ratio[diff_pos], profit[diff_pos],
            entry_pos[diff_pos], curr_close[diff_pos], additional_count[diff_pos]
        )
                
        pos_list[hold_pos], price_list[hold_pos], leverage_ratio[hold_pos], enter_ratio[hold_pos], additional_count[hold_pos], profit[hold_pos] = calculate_hold(
            hold_prob, pos_list[hold_pos], price_list[hold_pos], leverage_ratio[hold_pos], enter_ratio[hold_pos], profit[hold_pos],
            entry_pos[hold_pos], curr_close[hold_pos], additional_count[hold_pos]
        )
        
        before_index = entry_index
        returns_list.append(profit.clone().cpu().detach().numpy())
        profit = torch.zeros(size, device='cpu').float()
    returns_list = np.stack(returns_list).T

    rewards = []
    for i in range(len(returns_list)):
        returns  = returns_list[i]
        returns = returns[returns!= 0.]
        
        total_profit = 1.
        for profit_value in returns:
            total_profit = total_profit * ((100+profit_value) / 100)
        rewards.append(total_profit)
    rewards = torch.tensor(rewards).to(device)
    
    return rewards

In [10]:
from tqdm import tqdm
step_baseline = 0.
for i in tqdm(range((1000))):
    random_index = make_sample(len(dataset_1m), sample_step=sample_step, batch_size=size)
    part_dataset_1m_batch = torch.stack([dataset_1m[idx[0]:idx[1]] for idx in random_index]).to(device)
    part_dataset_1d_batch = torch.stack([dataset_1d[idx[0]:idx[1]] for idx in random_index]).to(device)
    # part_dataset_1m_batch = torch.stack([dataset_1m[idx[0]:idx[1]] for idx in random_index])
    # part_dataset_1d_batch = torch.stack([dataset_1d[idx[0]:idx[1]] for idx in random_index])
    
    scores = rl_model(part_dataset_1m_batch.float().reshape(-1, 240, 19), part_dataset_1d_batch.float().reshape(-1, 60, 19))

    scores = scores.view(size, sample_step, group)

    # Compute probabilities using softmax over the chromosome dimension
    probabilities = F.softmax(scores, dim=-1)  # Shape: (batch_size, sample_step, num_chromosomes)

    # Average probabilities over the sample_step dimension to get one distribution per batch
    avg_probabilities = probabilities.mean(dim=1)  # Shape: (batch_size, num_chromosomes)

    # Select top-k chromosomes based on average probabilities
    selected_indices = torch.multinomial(avg_probabilities, num_samples=max_selected, replacement=False)  # Shape: (batch_size, max_selected)
    # Select top-k chromosomes based on average probabilities
    # _, selected_indices = avg_probabilities.topk(k=max_selected, dim=-1)  # Shape: (batch_size, max_selected)
    
    # # For each sample in the batch, simulate the selected chromosomes
    # rewards = torch.zeros(size).to(device)
    # for i in range(size):
    #     selected_chromosomes = selected_indices[i]  # Indices of selected chromosomes
    #     # Implement 'simulate_chromosomes' function to simulate performance and compute reward
    #     reward = torch.randn(size=(size, sample_step))
    #     rewards[i] = reward.sum()
        
    rewards = calculate_rewards(inference_evolution, inference_model, best_chromosomes, selected_indices, part_dataset_1m_batch, part_dataset_1m_batch,
                                max_selected, size, entry_pos_list, random_index, bb_macd_entry_index_list, data_1m, device)
        
    baseline = rewards.mean()
    step_baseline += baseline
    # Compute loss using policy gradient (REINFORCE algorithm)
    log_probs = torch.log(avg_probabilities + 1e-8)  # Avoid log(0)
    selected_log_probs = log_probs.gather(1, selected_indices)  # Shape: (batch_size, max_selected)

    # Sum log probabilities over the selected chromosomes
    sum_log_probs = selected_log_probs.sum(dim=1)  # Shape: (batch_size,)

    # Compute the policy gradient loss
    loss = -((rewards - baseline) * sum_log_probs).mean()
            
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
        print(step_baseline)
        step_baseline = 0.


  1%|          | 8/1000 [00:20<42:42,  2.58s/it]

tensor(8.2649, device='cuda:1', dtype=torch.float64)


  2%|▏         | 16/1000 [00:39<40:46,  2.49s/it]

tensor(7.8655, device='cuda:1', dtype=torch.float64)


  2%|▏         | 24/1000 [00:59<40:16,  2.48s/it]

tensor(8.0116, device='cuda:1', dtype=torch.float64)


  3%|▎         | 32/1000 [01:18<38:14,  2.37s/it]

tensor(8.0799, device='cuda:1', dtype=torch.float64)


  4%|▍         | 40/1000 [01:37<38:08,  2.38s/it]

tensor(8.1004, device='cuda:1', dtype=torch.float64)


  5%|▍         | 48/1000 [01:57<38:16,  2.41s/it]

tensor(7.8112, device='cuda:1', dtype=torch.float64)


  6%|▌         | 56/1000 [02:16<37:13,  2.37s/it]

tensor(8.2027, device='cuda:1', dtype=torch.float64)


  6%|▋         | 64/1000 [02:35<37:36,  2.41s/it]

tensor(8.1191, device='cuda:1', dtype=torch.float64)


  7%|▋         | 72/1000 [02:54<36:45,  2.38s/it]

tensor(8.0079, device='cuda:1', dtype=torch.float64)


  8%|▊         | 80/1000 [03:14<37:01,  2.41s/it]

tensor(7.8734, device='cuda:1', dtype=torch.float64)


  9%|▉         | 88/1000 [03:33<36:07,  2.38s/it]

tensor(7.8188, device='cuda:1', dtype=torch.float64)


 10%|▉         | 96/1000 [03:53<37:39,  2.50s/it]

tensor(7.7855, device='cuda:1', dtype=torch.float64)


 10%|█         | 104/1000 [04:12<35:28,  2.38s/it]

tensor(7.9911, device='cuda:1', dtype=torch.float64)


 11%|█         | 112/1000 [04:31<34:57,  2.36s/it]

tensor(7.8632, device='cuda:1', dtype=torch.float64)


 12%|█▏        | 120/1000 [04:51<36:29,  2.49s/it]

tensor(7.8073, device='cuda:1', dtype=torch.float64)


 13%|█▎        | 128/1000 [05:10<34:32,  2.38s/it]

tensor(7.8673, device='cuda:1', dtype=torch.float64)


 14%|█▎        | 136/1000 [05:29<34:16,  2.38s/it]

tensor(7.9540, device='cuda:1', dtype=torch.float64)


 14%|█▍        | 144/1000 [05:48<33:49,  2.37s/it]

tensor(7.9979, device='cuda:1', dtype=torch.float64)


 15%|█▌        | 152/1000 [06:08<34:00,  2.41s/it]

tensor(7.9459, device='cuda:1', dtype=torch.float64)


 16%|█▌        | 160/1000 [06:27<33:12,  2.37s/it]

tensor(7.8187, device='cuda:1', dtype=torch.float64)


 17%|█▋        | 168/1000 [06:46<33:01,  2.38s/it]

tensor(7.8345, device='cuda:1', dtype=torch.float64)


 18%|█▊        | 176/1000 [07:05<32:40,  2.38s/it]

tensor(7.9629, device='cuda:1', dtype=torch.float64)


 18%|█▊        | 184/1000 [07:25<32:56,  2.42s/it]

tensor(7.8431, device='cuda:1', dtype=torch.float64)


 19%|█▉        | 192/1000 [07:44<32:11,  2.39s/it]

tensor(7.9130, device='cuda:1', dtype=torch.float64)


 20%|██        | 200/1000 [08:03<31:46,  2.38s/it]

tensor(7.9152, device='cuda:1', dtype=torch.float64)


 21%|██        | 208/1000 [08:22<31:14,  2.37s/it]

tensor(8.1930, device='cuda:1', dtype=torch.float64)


 22%|██▏       | 216/1000 [08:41<31:01,  2.37s/it]

tensor(7.6092, device='cuda:1', dtype=torch.float64)


 22%|██▏       | 224/1000 [09:02<31:25,  2.43s/it]

tensor(7.9846, device='cuda:1', dtype=torch.float64)


 23%|██▎       | 232/1000 [09:21<30:20,  2.37s/it]

tensor(7.9281, device='cuda:1', dtype=torch.float64)


 24%|██▍       | 240/1000 [09:39<29:55,  2.36s/it]

tensor(7.8941, device='cuda:1', dtype=torch.float64)


 25%|██▍       | 248/1000 [09:58<29:43,  2.37s/it]

tensor(7.9434, device='cuda:1', dtype=torch.float64)


 26%|██▌       | 256/1000 [10:17<29:23,  2.37s/it]

tensor(7.9829, device='cuda:1', dtype=torch.float64)


 26%|██▋       | 264/1000 [10:36<29:04,  2.37s/it]

tensor(7.9463, device='cuda:1', dtype=torch.float64)


 27%|██▋       | 272/1000 [10:57<29:33,  2.44s/it]

tensor(7.7639, device='cuda:1', dtype=torch.float64)


 28%|██▊       | 280/1000 [11:16<28:29,  2.37s/it]

tensor(7.8710, device='cuda:1', dtype=torch.float64)


 29%|██▉       | 288/1000 [11:35<28:06,  2.37s/it]

tensor(8.0009, device='cuda:1', dtype=torch.float64)


 30%|██▉       | 296/1000 [11:54<27:58,  2.38s/it]

tensor(7.9486, device='cuda:1', dtype=torch.float64)


 30%|███       | 304/1000 [12:13<27:40,  2.39s/it]

tensor(8.0671, device='cuda:1', dtype=torch.float64)


 31%|███       | 312/1000 [12:32<27:24,  2.39s/it]

tensor(7.8892, device='cuda:1', dtype=torch.float64)


 32%|███▏      | 320/1000 [12:53<32:59,  2.91s/it]

tensor(8.0628, device='cuda:1', dtype=torch.float64)


 33%|███▎      | 328/1000 [13:12<27:04,  2.42s/it]

tensor(7.6260, device='cuda:1', dtype=torch.float64)


 34%|███▎      | 336/1000 [13:31<26:31,  2.40s/it]

tensor(7.9451, device='cuda:1', dtype=torch.float64)


 34%|███▍      | 344/1000 [13:50<26:11,  2.39s/it]

tensor(7.7334, device='cuda:1', dtype=torch.float64)


 35%|███▌      | 352/1000 [14:09<25:38,  2.37s/it]

tensor(8.0653, device='cuda:1', dtype=torch.float64)


 36%|███▌      | 360/1000 [14:28<25:10,  2.36s/it]

tensor(7.9624, device='cuda:1', dtype=torch.float64)


 37%|███▋      | 368/1000 [14:47<25:05,  2.38s/it]

tensor(7.8557, device='cuda:1', dtype=torch.float64)


 38%|███▊      | 376/1000 [15:06<24:29,  2.36s/it]

tensor(7.9336, device='cuda:1', dtype=torch.float64)


 38%|███▊      | 384/1000 [15:25<24:27,  2.38s/it]

tensor(7.9464, device='cuda:1', dtype=torch.float64)


 39%|███▉      | 392/1000 [15:46<24:46,  2.45s/it]

tensor(7.5741, device='cuda:1', dtype=torch.float64)


 40%|████      | 400/1000 [16:05<23:44,  2.37s/it]

tensor(7.8023, device='cuda:1', dtype=torch.float64)


 41%|████      | 408/1000 [16:24<23:26,  2.38s/it]

tensor(8.0095, device='cuda:1', dtype=torch.float64)


 42%|████▏     | 416/1000 [16:43<23:07,  2.38s/it]

tensor(7.8154, device='cuda:1', dtype=torch.float64)


 42%|████▏     | 424/1000 [17:02<22:40,  2.36s/it]

tensor(8.0488, device='cuda:1')


 43%|████▎     | 432/1000 [17:21<22:19,  2.36s/it]

tensor(7.8399, device='cuda:1', dtype=torch.float64)


 44%|████▍     | 440/1000 [17:40<22:12,  2.38s/it]

tensor(7.8621, device='cuda:1', dtype=torch.float64)


 45%|████▍     | 448/1000 [17:59<21:49,  2.37s/it]

tensor(7.9389, device='cuda:1', dtype=torch.float64)


 46%|████▌     | 456/1000 [18:18<21:27,  2.37s/it]

tensor(7.6987, device='cuda:1', dtype=torch.float64)


 46%|████▋     | 464/1000 [18:37<21:10,  2.37s/it]

tensor(7.9687, device='cuda:1')


 47%|████▋     | 472/1000 [18:59<22:03,  2.51s/it]

tensor(8.1454, device='cuda:1', dtype=torch.float64)


 48%|████▊     | 480/1000 [19:18<20:51,  2.41s/it]

tensor(8.2731, device='cuda:1', dtype=torch.float64)


 49%|████▉     | 488/1000 [19:37<20:19,  2.38s/it]

tensor(7.8927, device='cuda:1', dtype=torch.float64)


 50%|████▉     | 496/1000 [19:56<20:01,  2.38s/it]

tensor(7.9917, device='cuda:1', dtype=torch.float64)


 50%|█████     | 504/1000 [20:15<19:48,  2.40s/it]

tensor(7.9526, device='cuda:1', dtype=torch.float64)


 51%|█████     | 512/1000 [20:34<19:24,  2.39s/it]

tensor(8.2122, device='cuda:1', dtype=torch.float64)


 52%|█████▏    | 520/1000 [20:53<18:59,  2.37s/it]

tensor(7.9852, device='cuda:1', dtype=torch.float64)


 53%|█████▎    | 528/1000 [21:12<18:41,  2.38s/it]

tensor(7.9509, device='cuda:1', dtype=torch.float64)


 54%|█████▎    | 536/1000 [21:31<18:15,  2.36s/it]

tensor(7.9137, device='cuda:1', dtype=torch.float64)


 54%|█████▍    | 544/1000 [21:50<18:07,  2.38s/it]

tensor(8.0533, device='cuda:1', dtype=torch.float64)


 55%|█████▌    | 552/1000 [22:09<17:50,  2.39s/it]

tensor(7.6474, device='cuda:1', dtype=torch.float64)


 56%|█████▌    | 560/1000 [22:31<23:30,  3.21s/it]

tensor(7.7266, device='cuda:1', dtype=torch.float64)


 57%|█████▋    | 568/1000 [22:50<17:25,  2.42s/it]

tensor(7.9475, device='cuda:1', dtype=torch.float64)


 58%|█████▊    | 576/1000 [23:09<16:44,  2.37s/it]

tensor(8.1410, device='cuda:1', dtype=torch.float64)


 58%|█████▊    | 584/1000 [23:28<16:30,  2.38s/it]

tensor(7.9698, device='cuda:1')


 59%|█████▉    | 592/1000 [23:47<16:14,  2.39s/it]

tensor(7.7410, device='cuda:1', dtype=torch.float64)


 60%|██████    | 600/1000 [24:06<15:53,  2.38s/it]

tensor(7.9976, device='cuda:1', dtype=torch.float64)


 61%|██████    | 608/1000 [24:25<15:40,  2.40s/it]

tensor(8.0116, device='cuda:1', dtype=torch.float64)


 62%|██████▏   | 616/1000 [24:44<15:10,  2.37s/it]

tensor(7.9351, device='cuda:1', dtype=torch.float64)


 62%|██████▏   | 624/1000 [25:03<14:50,  2.37s/it]

tensor(7.7834, device='cuda:1', dtype=torch.float64)


 63%|██████▎   | 632/1000 [25:22<14:30,  2.37s/it]

tensor(7.7428, device='cuda:1', dtype=torch.float64)


 64%|██████▍   | 640/1000 [25:41<14:10,  2.36s/it]

tensor(7.8215, device='cuda:1', dtype=torch.float64)


 65%|██████▍   | 648/1000 [26:00<13:54,  2.37s/it]

tensor(8.0138, device='cuda:1', dtype=torch.float64)


 66%|██████▌   | 656/1000 [26:18<13:30,  2.36s/it]

tensor(7.9699, device='cuda:1', dtype=torch.float64)


 66%|██████▋   | 664/1000 [26:37<13:15,  2.37s/it]

tensor(7.7373, device='cuda:1', dtype=torch.float64)


 67%|██████▋   | 672/1000 [27:00<16:42,  3.06s/it]

tensor(7.8832, device='cuda:1', dtype=torch.float64)


 68%|██████▊   | 680/1000 [27:19<12:47,  2.40s/it]

tensor(8.1306, device='cuda:1', dtype=torch.float64)


 69%|██████▉   | 688/1000 [27:38<12:22,  2.38s/it]

tensor(7.9650, device='cuda:1', dtype=torch.float64)


 70%|██████▉   | 696/1000 [27:56<12:01,  2.37s/it]

tensor(7.9877, device='cuda:1', dtype=torch.float64)


 70%|███████   | 704/1000 [28:15<11:40,  2.37s/it]

tensor(7.9769, device='cuda:1', dtype=torch.float64)


 71%|███████   | 712/1000 [28:35<11:25,  2.38s/it]

tensor(8.0857, device='cuda:1', dtype=torch.float64)


 72%|███████▏  | 720/1000 [28:53<11:05,  2.38s/it]

tensor(7.9805, device='cuda:1', dtype=torch.float64)


 73%|███████▎  | 728/1000 [29:13<10:50,  2.39s/it]

tensor(7.8677, device='cuda:1', dtype=torch.float64)


 74%|███████▎  | 736/1000 [29:31<10:21,  2.35s/it]

tensor(8.0000, device='cuda:1', dtype=torch.float64)


 74%|███████▍  | 738/1000 [29:36<10:19,  2.37s/it]