In [13]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import mplfinance as mpf
import torch.nn as nn
import torch
import math

from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader

import warnings
warnings.filterwarnings(action='ignore')

import sys
sys.path.append('/root/daily/bit')

from Prescriptor import Prescriptor
from Evolution.crossover import UniformCrossover, WeightedSumCrossover
from Evolution.mutation import MultiplyNormalMutation, MultiplyUniformMutation, AddNormalMutation, AddUniformMutation, ChainMutation, FlipSignMutation
from Evolution.selection import RouletteSelection, TournamentSelection
from Evolution import Evolution

In [14]:
device = 'cuda:0'
group = 30

small_lstm_x = torch.randn(size=(512, 240, 19)).to(device).float()
large_lstm_x = torch.randn(size=(512, 60, 19)).to(device).float()
base_x = torch.randn(size=(1, 6)).to(device).float()
base_x = torch.concat([base_x for i in range(group)])
cate_x = torch.randint(0, 3, size=(group, )).to(device)
step = torch.arange(0, group*3, step=3).to(device)

In [15]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

pres = Prescriptor(basic_block=None, 
                          base_small_input_dim=19, 
                          base_large_input_dim=19,
                          base_hidden_dim=24, 
                          base_output_dim=16, 
                          after_input_dim=22, 
                          after_hidden_dim=32, 
                          after_output_dim=6, 
                          num_blocks=group).to(device)
selection = RouletteSelection(elite_num=200,
                             parents_num=400,
                             minimize=False)

crossover = UniformCrossover()
# mutation = ChainMutation([AddNormalMutation(mut_prob=0.2), 
#                           MultiplyUniformMutation(mut_prob=0.2), 
#                           FlipSignMutation(0.07)])
mutation = AddNormalMutation(mut_prob=0.1)

# *** Evolution instance
evolution = Evolution(
                    prescriptor=pres,
                    selection=selection,
                    crossover=crossover,
                    mutation=mutation,)
get_n_params(pres)

600840

In [16]:
with torch.no_grad():
    for i in range(34):
        lstm_logit = pres.base_forward(small_lstm_x, large_lstm_x).squeeze(dim=2)
        # lstm_logit = torch.concat(lstm_logit, dim=1).to(device)
        break
lstm_logit = lstm_logit.permute([1, 0, 2])

In [17]:
pres.eval()
pres.to(device)
for i in range(17000):
    with torch.no_grad():
        after_input = torch.concat([lstm_logit[0].cuda(), base_x.cuda()], dim=1)
        after_output = pres.after_forward(x=after_input, x_cate=cate_x+step)
        break

In [18]:
# fitness = torch.sum(after_output.squeeze(dim=0), dim=1).cpu()
# evolution.evolve(fitness)
chromosomes, base_ch_shape, after_ch_shape, device = evolution.flatten_chromosomes()
evolution.update_chromosomes(chromosomes, base_ch_shape, after_ch_shape, device)

In [19]:
with torch.no_grad():
    for i in range(34):
        lstm_logit_2 = pres.base_forward(small_lstm_x, large_lstm_x).squeeze(dim=2)
        # lstm_logit = torch.concat(lstm_logit, dim=1).to(device)
        break
lstm_logit_2 = lstm_logit_2.permute([1, 0, 2])

In [20]:
pres.eval()
pres.to(device)
for i in range(17000):
    with torch.no_grad():
        after_input_2 = torch.concat([lstm_logit_2[0].cuda(), base_x.cuda()], dim=1)
        after_output_2 = pres.after_forward(x=after_input_2, x_cate=cate_x+step)
        break

gen	nevals	avg     	std     	min     	max     
0  	50    	0.478552	0.220469	0.243774	0.759626
1  	37    	0.654953	0.176367	0.225834	0.776958
2  	42    	0.677962	0.154839	0.2532  	0.774083
3  	43    	0.711727	0.121391	0.352406	0.78221 
4  	36    	0.751184	0.0610176	0.334909	0.78221 
5  	36    	0.756362	0.0708008	0.266108	0.78221 
6  	44    	0.738317	0.0999812	0.290931	0.78221 
7  	39    	0.746673	0.100134 	0.25953 	0.78221 
8  	41    	0.753659	0.0788583	0.245736	0.78221 
9  	37    	0.771179	0.0126803	0.715897	0.78221 
10 	42    	0.747074	0.107691 	0.253614	0.78221 
11 	41    	0.762735	0.0743486	0.272632	0.78221 
12 	42    	0.766622	0.0724094	0.272632	0.78221 
13 	42    	0.752404	0.104297 	0.25378 	0.78221 
14 	41    	0.757296	0.103593 	0.255826	0.78221 
15 	40    	0.766976	0.073535 	0.260304	0.78221 
16 	36    	0.778491	0.0138444	0.695746	0.78221 
17 	42    	0.770279	0.068797 	0.291152	0.78221 
18 	45    	0.743999	0.124993 	0.250712	0.78221 
19 	40    	0.758999	0.0782086	0.282224	0.7822