In [1]:
import sys
import os
import warnings
import glob
import re
from argparse import ArgumentParser
from tqdm.notebook import tqdm
sys.path.append(r'/home/datamake94/秒级高频策略/ML_Project')
from model_training_20240814 import *

os.environ['CUDA_VISIBLE_DEVICES'] = '8'

In [3]:
class UMPDataset_out(torch.utils.data.Dataset):  #读日内全部的数据
    def __init__(self, date, valid_dict) -> None:
        self.date = date
        self.valid_dict=valid_dict
        self.sec_list = params.sec_list_dict[future_ret]
        factor_stock_list=get_default_stock_list(date)   #所有因子固定的个股
        self.stock_list=sorted(list(set(factor_stock_list)))

        pre_date = get_pre_date(date)
        industry_slice = Dataset.industry_table.loc[pre_date]
        self.industry_dummy = torch.from_numpy(pd.get_dummies(industry_slice).reindex(self.stock_list).fillna(False).values).float()
    
    def __getitem__(self, index):
        sec = self.sec_list[index]
        if sec<100000:
            period=93000
        else:
            period=100000
        min_se,max_se,valid_ind=self.valid_dict[period]
        factor_data = load_all_data(self.date,self.stock_list, sec, model_training=True,valid_ind=valid_ind,max_se=max_se,min_se=min_se,import_se=True)
        return [factor_data,self.industry_dummy],sec

    def __len__(self):
        return len(self.sec_list)

def get_best_model(args,store_date):
    model_path=os.path.join(params._model_path,get_model_name(args, store_date))
    version=len(os.listdir(model_path))
    ckpt_all=glob.glob(model_path+'/version_%s/**/*[0-9].ckpt'%(version-1),recursive=True) #最新的模型

    result_df=pd.DataFrame([[int(re.findall('epoch=(\d+)',i)[0]),float(re.findall('val_rankic=(-?[0-9].\d*)',i)[0]),
                             float(re.findall('val_ls=(-?[0-9].\d*)',i)[0])]  for i in ckpt_all],columns=['epoch','ic','ls'])
    best_ind=(result_df['ls']*0.6+result_df['ic']*0.4).idxmax()
    #model_p='/epoch=%s-val_rankic=%.4f-val_ls=%.4f.ckpt'%(result_df.loc[best_ind,'epoch'],result_df.loc[best_ind,'ic'],result_df.loc[best_ind,'ls'])
    return ckpt_all[best_ind], model_path

def calc_args_model_dict(store_date, period_list, future_ret_list):
    args_dict = {}
    info_dict = {}
    model_dict = {}
    for future_ret in future_ret_list:
        for period in period_list:
            params.period = period
            params.future_ret = future_ret
            args = parse_args()
            best_model_path, model_file_path = get_best_model(args,store_date)
            print(future_ret, period, best_model_path)

            load_file_path = os.path.join(model_file_path, 'version_0')
            min_se = torch.load(rf"{load_file_path}/min_tensor.pt")
            max_se = torch.load(rf"{load_file_path}/max_tensor.pt")
            all_factor_list = torch.load(rf"{load_file_path}/all_factor_list.pt")
            valid_ind = torch.load(rf"{load_file_path}/valid_ind.pt")

            factor_num = valid_ind.int().sum()
            model = UMPLitModule.load_from_checkpoint(best_model_path, args=args, factor_num=factor_num)
            model.cpu().eval()
            args_dict[(future_ret, period)] = args
            info_dict[(future_ret, period)] = (min_se, max_se, valid_ind, all_factor_list, factor_num)
            model_dict[(future_ret, period)] = model

    return args_dict, info_dict, model_dict

def get_out_factor_list(factor_list):
    '''实盘命名规则'''
    factor_prefix_dict= {
            'ysw_orderbook1':'ob1',
            'ysw_pv_a':'sec2',
            'ysw_pv_b':'sec2',
            'ysw_graph':'graph1',
            'ysw_orderbook2':'ob2',
            'ysw_pv2':'sec1',
            'yy_order_basic':'order_basic',
            'yy_order_ls1':'order_ls1',
            'yy_trans_basic':'trans_basic',
            'yy_trans_ls1':'trans_ls1',
            'yy_orderbook3':'ob3',
            'yy_pv4':'sec4',
            }
    static_factor=[]
    with open('/home/intern1/hft_factor_comb/backup_test/factor_list_update20231204.txt', 'r', encoding='utf-8') as file:
        # 逐行读取文件
        lines = file.readlines()  # 这会将所有行读入一个列表
        # 遍历每一行
        for line in lines:
            if int(line.strip().split('=')[1])>=764:
                static_factor.append(line.strip().split('=')[0])

    factor_list_daily = factor_list[:-len(static_factor)]
    factor_list_daily = [x.lower() for x in factor_list_daily]
    for i in range(len(factor_list_daily)):
        for key in factor_prefix_dict.keys():
            if factor_list_daily[i].startswith(key):
                factor_list_daily[i]=factor_prefix_dict[key]+factor_list_daily[i][len(key):]
    return factor_list_daily+[x.upper() for x in static_factor]

def output_jit_model(period_list, future_ret_list, store_date, save_score_name, info_dict, model_dict):
    print('正在输出jit模型和基础数据')
    save_file_path = os.path.join(jit_model_output_path, rf"{store_date}:{save_score_name}")
    if not os.path.exists(save_file_path):
        os.makedirs(save_file_path)
    for period in period_list:
        for future_ret in future_ret_list:
            min_se, max_se, valid_ind, all_factor_list, factor_num = info_dict[(future_ret, period)]
            test_data = torch.randn([5000, factor_num])
            #生成一个随机的0-1矩阵，shape为5000*200
            industry_dummy = torch.randint(0, 1, [5000, 200])
            model = model_dict[(future_ret, period)].model.eval()
            jit_model = torch.jit.trace(model, (test_data, industry_dummy))

            #测试模型结果，生成三个随机的结果要求都能对上
            for i in range(3):
                test_data = torch.randn([5000, factor_num])
                industry_dummy = torch.randint(0, 1, [5000, 200])
                score = model(test_data, industry_dummy)
                score_jit = jit_model(test_data, industry_dummy)
                assert (score == score_jit).float().mean() == 1.

            model_idx = {'1m':1, '5m':2, '15s':3}[future_ret]
            torch.jit.save(jit_model, rf"{save_file_path}/model{model_idx}_{period}.pt")

            #输出基础数据
            IT_all_factor = get_out_factor_list(all_factor_list)
            min_se_output = pd.Series(min_se.numpy(), index=IT_all_factor)
            max_se_output = pd.Series(max_se.numpy(), index=IT_all_factor)
            min_se_output.to_csv(rf"{save_file_path}/min_tensor.csv")
            max_se_output.to_csv(rf"{save_file_path}/max_tensor.csv")

            valid_factor_list = [IT_all_factor[i] for i in range(len(IT_all_factor)) if valid_ind[i] == True]
            file = open(rf'{save_file_path}/factor_list_{period}_model{model_idx}.txt','w')
            for i in range(len(valid_factor_list)):
                file.write(valid_factor_list[i]+'='+str(i)+'\n')
            file.close()

    return

In [4]:
#先生成jit模型
period_list=[93000,100000]
future_ret_list=['1m','5m','15s']
save_score_name = 'yy_industry_94calc_2'

score_output_path = r'/home/datamake94/data_nb7/sec_score_output_final'
jit_model_output_path = r'/home/datamake94/data_nb7/jit_model_output_final'
year_month_list=get_year_month(params.date_list_all)
month_list= year_month_list[year_month_list.index('202405'):year_month_list.index('202406')][::2]

for month in month_list:
    store_date=get_first_date(month,params.date_list_all)
    begin_date=get_first_date(year_month_list[year_month_list.index(month)-24],params.date_list_all)
    begin_date=begin_date if begin_date>='20200430' else '20200430'
    end_date=get_first_date(year_month_list[year_month_list.index(month)+2],params.date_list_all)
    print('训练集数据起始日为{}，模型样本外预测期为{}——{}'.format(begin_date,store_date,end_date))

    date_list = get_date_list(begin_date,store_date)
    test_date_list = get_date_list(store_date, end_date)

    args_dict, info_dict, model_dict = calc_args_model_dict(store_date, period_list, future_ret_list)
    tmp_min_all = [info_dict[(future_ret, period)][0].unsqueeze(dim=1) for future_ret in future_ret_list for period in period_list]
    tmp_max_all = [info_dict[(future_ret, period)][1].unsqueeze(dim=1) for future_ret in future_ret_list for period in period_list]
    min_all, max_all = torch.cat(tmp_min_all, dim=1), torch.cat(tmp_max_all, dim=1)
    #检查min_all和max_all各行的值是否都相等
    assert (min_all.T == min_all.T[0]).float().mean() == 1.

    #1.输出jit模型和基础数据
    output_jit_model(period_list, future_ret_list, store_date, save_score_name, info_dict, model_dict)

    #2.输出因子打分
    for future_ret in future_ret_list:
        for period in period_list:
            model_dict[(future_ret, period)].to('cuda:0')
    for date in tqdm(test_date_list[:], desc=month):
        for future_ret in future_ret_list:
            save_score_path = os.path.join(score_output_path, rf"{future_ret}:{save_score_name}")
            if not os.path.exists(save_score_path):
                os.makedirs(save_score_path)
            out_all=[]
            valid_dict = {period: (info_dict[(future_ret, period)][0], info_dict[(future_ret, period)][1], info_dict[(future_ret, period)][2]) for period in period_list}
            dm = UMPDataset_out(date,valid_dict)
            factor_stock_list=dm.stock_list
            test_dataloader=DataLoaderX(dm,batch_size=1,collate_fn=lambda x:x[0],num_workers=20,shuffle=False,drop_last=False)
            for batch in tqdm(test_dataloader, desc=rf"{date},{future_ret}"):
                data,sec=batch
                if sec<100000:
                    period=93000
                else:
                    period=100000
                factor_data,industry_dummy = data
                factor_data=factor_data.to('cuda:0')
                industry_dummy=industry_dummy.to('cuda:0')
                out_list=model_dict[(future_ret, period)](factor_data,industry_dummy).cpu().detach().numpy()[:,0]
                out_series=pd.Series(out_list,index=factor_stock_list,name=sec)
                out_all.append(out_series)

            all_factor=pd.concat(out_all,axis=1).T
            all_factor.columns=all_factor.columns.astype(str)
            all_factor.index.name='second'
            all_factor.reset_index(drop=False).to_feather(save_score_path+'/%s.fea'%date)

训练集数据起始日为20220505，模型样本外预测期为20240506——20240701
1m 93000 /home/datamake94/data_nb7/sec_model_all_final/20240506-yy_test-update0828-DCN+BERT-ob_ret_1m-93000-minmax-citic_2-NN256x64-pLabelTrue-graphTrue-cross3-thre0.2-decorrTrue-noiseFalse-gpus4-maxepch15-btch1-accu_btch64-drop0-schdstep_lr-losswccc-lr0.001-lr_gamma0.95-lr_stepsz2-wd0.0001/version_0/checkpoints/epoch=13-val_rankic=0.2381-val_ls=0.2482.ckpt
1m 100000 /home/datamake94/data_nb7/sec_model_all_final/20240506-yy_test-update0828-DCN+BERT-ob_ret_1m-100000-minmax-citic_2-NN256x64-pLabelTrue-graphTrue-cross3-thre0.2-decorrTrue-noiseFalse-gpus4-maxepch15-btch1-accu_btch64-drop0-schdstep_lr-losswccc-lr0.001-lr_gamma0.95-lr_stepsz2-wd0.0001/version_0/checkpoints/epoch=11-val_rankic=0.2559-val_ls=0.1230.ckpt
5m 93000 /home/datamake94/data_nb7/sec_model_all_final/20240506-yy_test-update0828-DCN+BERT-ob_ret_5m-93000-minmax-citic_2-NN256x64-pLabelTrue-graphTrue-cross3-thre0.2-decorrTrue-noiseFalse-gpus4-maxepch15-btch1-accu_btch64-drop0-sc

202405:   0%|          | 0/39 [00:00<?, ?it/s]

20240506,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240506,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240506,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240507,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240507,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240507,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240508,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240508,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240508,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240509,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240509,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240509,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240510,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240510,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240510,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240513,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240513,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240513,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240514,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240514,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240514,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240515,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240515,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240515,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240516,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240516,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240516,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240517,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240517,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240517,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240520,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240520,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240520,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240521,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240521,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240521,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240522,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240522,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240522,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240523,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240523,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240523,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240524,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240524,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240524,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240527,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240527,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240527,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240528,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240528,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240528,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240529,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240529,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240529,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240530,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240530,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240530,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240531,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240531,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240531,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240603,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240603,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240603,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240604,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240604,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240604,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

20240605,1m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240605,5m:   0%|          | 0/2736 [00:00<?, ?it/s]

20240605,15s:   0%|          | 0/2736 [00:00<?, ?it/s]

In [5]:
date = '20240506'
#         if date<'20240510':continue
future_ret = '5m'
valid_dict = {period: (info_dict[(future_ret, period)][0], info_dict[(future_ret, period)][1], info_dict[(future_ret, period)][2]) for period in period_list}
out_all=[]
factor_stock_list=get_default_stock_list(date) 
test_dataloader=UMPDataset_out(date,valid_dict)
test = test_dataloader[0]

In [6]:
test[0][0] - 0.5

tensor([[-0.0085, -0.0881, -0.0552,  ..., -0.5000,  0.3966, -0.1000],
        [-0.4988,  0.0726,  0.0608,  ..., -0.5000,  0.3621, -0.1000],
        [ 0.0099,  0.1244,  0.1722,  ..., -0.5000,  0.1897, -0.1000],
        ...,
        [-0.2046, -0.0056,  0.0082,  ..., -0.5000,  0.2241,  0.5000],
        [ 0.1572, -0.0461, -0.0273,  ..., -0.5000,  0.1552,  0.5000],
        [ 0.1071, -0.0707, -0.0373,  ..., -0.5000, -0.3621,  0.5000]])

In [20]:
sec = 93000
if sec<100000:
    period=93000
else:
    period=100000
    
valid_ind=test_dataloader.valid_dict[period]
factor_data = load_all_data(date,test_dataloader.stock_list, sec, model_training=True,valid_ind=valid_ind)

(tensor([-4.0576e-01, -6.1621e-01, -6.4307e-01, -9.3213e-01, -4.4727e+00,
         -4.4727e+00, -2.3450e-01, -3.3813e-01, -5.7520e-01, -8.5498e-01,
         -7.7461e+00, -7.7461e+00, -1.6809e-01, -2.3816e-01, -5.2588e-01,
         -7.6318e-01, -1.0945e+01, -1.0953e+01, -1.3354e-01, -1.8668e-01,
         -5.1416e-01, -6.9336e-01, -1.5000e+01, -1.5133e+01, -9.0137e-01,
         -6.7062e-03, -1.5135e-03, -4.4727e+00, -8.9453e-01, -6.3515e-03,
         -1.9045e-03, -6.0078e+00, -8.8770e-01, -5.9891e-03, -2.2106e-03,
         -6.4961e+00, -8.7646e-01, -5.5618e-03, -2.5101e-03, -7.0938e+00,
         -1.1279e+00, -1.1719e+00, -1.2703e+01, -1.0293e+00, -1.1631e+00,
         -1.3516e+01, -9.2822e-01, -1.1182e+00, -1.4031e+01, -8.6279e-01,
         -1.0859e+00, -1.4336e+01, -1.6055e+00, -1.6553e+00, -3.6912e+00,
         -1.2734e+00, -1.2988e+00, -1.3691e+00, -1.4404e+00, -3.4375e+00,
         -1.2373e+00, -1.2861e+00, -1.1338e+00, -1.2119e+00, -3.1191e+00,
         -1.2070e+00, -1.2656e+00, -9.

In [11]:
score1 = pd.read_feather(r'/home/datamake94/data_nb7/sec_score_output_final/5m:yy_industry_test/20240506.fea').set_index('second')
score2 = pd.read_feather(r'/home/datamake94/data_nb7/sec_score_output_final/5m:test2/20240506.fea').set_index('second')
score1.T.corrwith(score2.T).describe()

count    2736.000000
mean        0.999993
std         0.000004
min         0.999983
25%         0.999990
50%         0.999992
75%         0.999995
max         1.000003
dtype: float64

In [9]:
pd.read_feather(r'/home/datamake94/data_nb7/sec_score_output_final/5m:test2/20240506.fea').set_index('second')

Unnamed: 0_level_0,1,2,4,6,7,8,9,10,11,12,...,688787,688788,688789,688793,688798,688799,688800,688819,688981,689009
second,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
93200,0.232280,0.695141,-0.804502,0.334179,-0.422321,-0.096884,0.753620,0.875685,0.255046,-0.560255,...,-0.817130,-0.404219,-0.641658,-0.817475,0.002067,-1.337756,-0.040318,0.078400,-0.903496,0.093498
93205,0.139143,1.107282,-0.816226,0.409033,0.016141,0.471355,0.517374,1.221308,0.305733,-0.018300,...,-0.740883,0.084808,-0.337551,-0.722563,0.129618,-0.232381,0.431011,0.887340,-0.829795,0.676262
93210,0.179563,0.860346,-1.049079,0.013632,-0.166536,-0.014878,-0.279685,0.886646,0.011624,-0.460517,...,-1.147584,-1.085817,-0.700030,-0.634275,-0.989253,-0.239495,-0.228274,0.825831,-1.140662,-0.009144
93215,0.127347,0.729137,-0.691090,0.872204,0.360820,0.065106,-0.100862,1.063529,0.075046,-0.456464,...,-0.881877,-1.172132,-0.684748,-0.579740,-0.562186,-0.013179,-0.273139,0.546679,-0.698112,-0.140301
93220,0.207856,0.687744,-0.067666,1.438583,0.298061,-0.148229,1.022386,0.621247,0.657221,-0.606849,...,-0.708593,-0.545180,-0.989314,-0.466911,-0.144159,0.434528,-0.733792,0.284400,-0.660327,0.621168
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
144935,-0.133726,-0.134836,-0.091653,-0.123519,-0.062173,-0.092499,-0.128157,-0.077646,-0.138444,-0.137554,...,-0.140876,-0.062219,-0.185193,-0.098493,-0.111295,0.010911,-0.036594,-0.112648,-0.212443,-0.043415
144940,-0.104626,-0.117472,-0.075552,-0.121455,-0.043340,-0.038483,-0.150223,-0.070940,-0.123162,-0.129446,...,-0.249141,-0.051889,-0.253959,-0.090285,-0.050867,0.267559,-0.041019,-0.105701,-0.232709,0.044137
144945,-0.198401,-0.094569,-0.048705,-0.128285,-0.016693,-0.088326,-0.152827,-0.076875,-0.179344,-0.134270,...,-0.119257,-0.046546,-0.321520,-0.087829,-0.075130,0.138740,-0.019911,-0.174609,-0.261208,-0.002931
144950,-0.168417,-0.119674,-0.083158,-0.126767,-0.058644,-0.087461,-0.133529,-0.056690,-0.206757,-0.100528,...,-0.033395,-0.039048,-0.425324,-0.068994,0.085918,0.226589,-0.042723,-0.117599,-0.256661,-0.041448


In [38]:
jit_model = torch.jit.load(r'/home/datamake94/决策库/trade_strategy/秒频_yy实盘20240427_t0/实盘模型_测试/model3_100000.pt')
jit_model(factor_data.cpu(),industry_dummy.cpu()).detach().numpy()[:,0][:10]

array([ 0.25251997,  0.8065849 ,  0.08614143, -1.4189508 ,  0.08776984,
       -0.9006066 ,  0.23149104,  1.6277778 , -1.6405615 ,  5.7395678 ],
      dtype=float32)

In [37]:
model = model_dict[(future_ret, period)]
model.to('cuda:0')
model(factor_data,industry_dummy).cpu().detach().numpy()[:,0][:10]

array([ 0.2524215 ,  0.80681264,  0.08562268, -1.4194442 ,  0.08741993,
       -0.90110993,  0.23194721,  1.626204  , -1.6410867 ,  5.739192  ],
      dtype=float32)

In [22]:
params.factor_num

tensor(533)