In [1]:
import torch
import torch.optim as optim
import pickle
import numpy as np
import pandas as pd
from imblearn.over_sampling import SMOTE, ADASYN
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier

def save_model(path, model, optimizer):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state': optimizer.state_dict()
    }, path)


def load_model(path, model, device, mode, optimizer=None, lr=None):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    if optimizer is not None:
        optimizer_state = checkpoint['optimizer_state']
        optimizer.load_state_dict(optimizer_state)
        if lr is not None:
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        print(optimizer)
    
    if mode == "train":
        model.train()
    elif mode == "eval":
        model.eval()


ModuleNotFoundError: No module named 'torch'

In [4]:
refine_flag =True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
class MLP(torch.nn.Module):
        def __init__(self, input_size):
            super(MLP, self).__init__()
            n_cores =512
            
            self.linear =torch. nn.Sequential(
                torch.nn.Linear(input_size, n_cores),
                torch.nn.Dropout(0.5),
                torch.nn.ReLU(inplace=True),
                torch.nn.Linear(n_cores, 2),
            )
 
        def forward(self, x):
            out = self.linear(x)
            return out

In [52]:

with open("extra_0",'rb') as f:
    datasets=pickle.load(f)

for name in datasets.keys():
    
    dataset=datasets[name]
    dataset['target']=np.array(dataset['target']).reshape(-1)
    print(name,dataset['data'].shape)
    print('-'*50)
    
    data,target = dataset['data'],dataset['target']
    net = None
    if name in ['log', 'square_root', 'square', 'frequency', 'round', 'tanh', 'sigmoid', 'isotonic_regression', 'zscore','normalize']:
        net =MLP(400)
    elif name in ['sum', 'subtract', 'multiply', 'divide']:
        net =MLP(800)
    
    criterion = torch.nn.CrossEntropyLoss().cuda()
    #optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9,weight_decay=0.0001)
    optimizer = torch.optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0005)

    if refine_flag is True:
        load_model("{}.pkl".format(name),net,device,mode="train",optimizer=optimizer,lr=0.0001)
    else:
        net.to(device)
    
    dataset = torch.utils.data.TensorDataset(torch.tensor(data).float(),torch.tensor(target).long())
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=128,
        shuffle=True,
        drop_last=False
    )


    for epoch in range(201):  # loop over the dataset multiple times
        loss_accum = 0.0
        for i, (batch_x,batch_y) in enumerate(loader):
            batch_x,batch_y = batch_x.to(device),batch_y.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()

            pred_y = net(batch_x)
            #print(torch.max(pred_y,1))
            loss = criterion(pred_y,batch_y)
            loss_accum = loss_accum + loss
            loss.backward()
            optimizer.step()
        print('epoch: {} loss avg:{} accum:{}'.format(epoch, loss_accum / len(loader), loss_accum))
        print('-'*40)
        if epoch%10==0:
            save_model("{}.pkl".format(name),net,optimizer)


print('Finished Training')

sum (22770, 800)
--------------------------------------------------
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0001
    weight_decay: 0.0005
)
epoch: 0 loss avg:0.21710479259490967 accum:38.6446533203125
----------------------------------------
epoch: 1 loss avg:0.2136761099100113 accum:38.03434753417969
----------------------------------------
epoch: 2 loss avg:0.21763454377651215 accum:38.738948822021484
----------------------------------------
epoch: 3 loss avg:0.21137942373752594 accum:37.62553787231445
----------------------------------------
epoch: 4 loss avg:0.21372833847999573 accum:38.043643951416016
----------------------------------------
epoch: 5 loss avg:0.2156635969877243 accum:38.388118743896484
----------------------------------------
epoch: 6 loss avg:0.218818798661232 accum:38.949745178222656
----------------------------------------
epoch: 7 loss avg:0.21638815104961395 accum:38.51708984375
----------------------------

epoch: 155 loss avg:0.21042098104953766 accum:37.454933166503906
----------------------------------------
epoch: 156 loss avg:0.213125079870224 accum:37.93626403808594
----------------------------------------
epoch: 157 loss avg:0.21021217107772827 accum:37.41776657104492
----------------------------------------
epoch: 158 loss avg:0.20498667657375336 accum:36.48762893676758
----------------------------------------
epoch: 159 loss avg:0.20884576439857483 accum:37.17454528808594
----------------------------------------
epoch: 160 loss avg:0.21132799983024597 accum:37.61638259887695
----------------------------------------
epoch: 161 loss avg:0.2051868885755539 accum:36.52326583862305
----------------------------------------
epoch: 162 loss avg:0.2034592181444168 accum:36.21574020385742
----------------------------------------
epoch: 163 loss avg:0.21159270405769348 accum:37.66350173950195
----------------------------------------
epoch: 164 loss avg:0.20622043311595917 accum:36.707237243

epoch: 109 loss avg:0.2252580225467682 accum:39.870670318603516
----------------------------------------
epoch: 110 loss avg:0.21782264113426208 accum:38.55460739135742
----------------------------------------
epoch: 111 loss avg:0.21765057742595673 accum:38.52415084838867
----------------------------------------
epoch: 112 loss avg:0.21377429366111755 accum:37.838050842285156
----------------------------------------
epoch: 113 loss avg:0.21856458485126495 accum:38.68593215942383
----------------------------------------
epoch: 114 loss avg:0.2243329882621765 accum:39.706939697265625
----------------------------------------
epoch: 115 loss avg:0.2199409008026123 accum:38.92953872680664
----------------------------------------
epoch: 116 loss avg:0.21955667436122894 accum:38.86153030395508
----------------------------------------
epoch: 117 loss avg:0.21702134609222412 accum:38.412776947021484
----------------------------------------
epoch: 118 loss avg:0.21657182276248932 accum:38.33321

----------------------------------------
epoch: 63 loss avg:0.18801359832286835 accum:33.4664192199707
----------------------------------------
epoch: 64 loss avg:0.19710594415664673 accum:35.08485794067383
----------------------------------------
epoch: 65 loss avg:0.19667766988277435 accum:35.00862503051758
----------------------------------------
epoch: 66 loss avg:0.1955651044845581 accum:34.81058883666992
----------------------------------------
epoch: 67 loss avg:0.19574595987796783 accum:34.84278106689453
----------------------------------------
epoch: 68 loss avg:0.19132989645004272 accum:34.05672073364258
----------------------------------------
epoch: 69 loss avg:0.19160956144332886 accum:34.106502532958984
----------------------------------------
epoch: 70 loss avg:0.19404421746730804 accum:34.53987121582031
----------------------------------------
epoch: 71 loss avg:0.19944514334201813 accum:35.50123596191406
----------------------------------------
epoch: 72 loss avg:0.192

epoch: 16 loss avg:0.1415305733680725 accum:16.41754722595215
----------------------------------------
epoch: 17 loss avg:0.14350362122058868 accum:16.646419525146484
----------------------------------------
epoch: 18 loss avg:0.13800352811813354 accum:16.00840950012207
----------------------------------------
epoch: 19 loss avg:0.1357932835817337 accum:15.752021789550781
----------------------------------------
epoch: 20 loss avg:0.1371433436870575 accum:15.9086275100708
----------------------------------------
epoch: 21 loss avg:0.13807249069213867 accum:16.016408920288086
----------------------------------------
epoch: 22 loss avg:0.12841302156448364 accum:14.89591121673584
----------------------------------------
epoch: 23 loss avg:0.13504788279533386 accum:15.665555000305176
----------------------------------------
epoch: 24 loss avg:0.1318768411874771 accum:15.297714233398438
----------------------------------------
epoch: 25 loss avg:0.1375596970319748 accum:15.956924438476562
-

----------------------------------------
epoch: 172 loss avg:0.12542618811130524 accum:14.549437522888184
----------------------------------------
epoch: 173 loss avg:0.13025692105293274 accum:15.109803199768066
----------------------------------------
epoch: 174 loss avg:0.13078942894935608 accum:15.171573638916016
----------------------------------------
epoch: 175 loss avg:0.1333627998828888 accum:15.470085144042969
----------------------------------------
epoch: 176 loss avg:0.13259413838386536 accum:15.38092041015625
----------------------------------------
epoch: 177 loss avg:0.12772111594676971 accum:14.81564998626709
----------------------------------------
epoch: 178 loss avg:0.12946383655071259 accum:15.017805099487305
----------------------------------------
epoch: 179 loss avg:0.1273552030324936 accum:14.77320384979248
----------------------------------------
epoch: 180 loss avg:0.1290569007396698 accum:14.970600128173828
----------------------------------------
epoch: 181 

----------------------------------------
epoch: 126 loss avg:0.2083653211593628 accum:5.4174981117248535
----------------------------------------
epoch: 127 loss avg:0.20781558752059937 accum:5.403204917907715
----------------------------------------
epoch: 128 loss avg:0.20401819050312042 accum:5.304472923278809
----------------------------------------
epoch: 129 loss avg:0.21434669196605682 accum:5.573013782501221
----------------------------------------
epoch: 130 loss avg:0.2126951962709427 accum:5.5300750732421875
----------------------------------------
epoch: 131 loss avg:0.20555773377418518 accum:5.34450101852417
----------------------------------------
epoch: 132 loss avg:0.2038154900074005 accum:5.2992024421691895
----------------------------------------
epoch: 133 loss avg:0.20746496319770813 accum:5.3940887451171875
----------------------------------------
epoch: 134 loss avg:0.2138417363166809 accum:5.559885025024414
----------------------------------------
epoch: 135 loss

epoch: 79 loss avg:0.26274123787879944 accum:16.0272159576416
----------------------------------------
epoch: 80 loss avg:0.2564358413219452 accum:15.642587661743164
----------------------------------------
epoch: 81 loss avg:0.26963287591934204 accum:16.447607040405273
----------------------------------------
epoch: 82 loss avg:0.2660695016384125 accum:16.230239868164062
----------------------------------------
epoch: 83 loss avg:0.27227962017059326 accum:16.609058380126953
----------------------------------------
epoch: 84 loss avg:0.2729398310184479 accum:16.649330139160156
----------------------------------------
epoch: 85 loss avg:0.2544872760772705 accum:15.523724555969238
----------------------------------------
epoch: 86 loss avg:0.251035213470459 accum:15.313149452209473
----------------------------------------
epoch: 87 loss avg:0.2552468478679657 accum:15.57005786895752
----------------------------------------
epoch: 88 loss avg:0.26351699233055115 accum:16.07453727722168
--

epoch: 32 loss avg:0.3010217547416687 accum:18.362327575683594
----------------------------------------
epoch: 33 loss avg:0.30261239409446716 accum:18.4593563079834
----------------------------------------
epoch: 34 loss avg:0.2907116115093231 accum:17.733409881591797
----------------------------------------
epoch: 35 loss avg:0.28875720500946045 accum:17.61419105529785
----------------------------------------
epoch: 36 loss avg:0.29322972893714905 accum:17.887014389038086
----------------------------------------
epoch: 37 loss avg:0.2968899607658386 accum:18.110288619995117
----------------------------------------
epoch: 38 loss avg:0.3038932681083679 accum:18.537490844726562
----------------------------------------
epoch: 39 loss avg:0.2880212366580963 accum:17.56929588317871
----------------------------------------
epoch: 40 loss avg:0.29122287034988403 accum:17.764596939086914
----------------------------------------
epoch: 41 loss avg:0.286226749420166 accum:17.4598331451416
----

----------------------------------------
epoch: 189 loss avg:0.2984920144081116 accum:18.2080135345459
----------------------------------------
epoch: 190 loss avg:0.2849973440170288 accum:17.384838104248047
----------------------------------------
epoch: 191 loss avg:0.28872036933898926 accum:17.6119441986084
----------------------------------------
epoch: 192 loss avg:0.30554094910621643 accum:18.637998580932617
----------------------------------------
epoch: 193 loss avg:0.28927260637283325 accum:17.6456298828125
----------------------------------------
epoch: 194 loss avg:0.2990575432777405 accum:18.242511749267578
----------------------------------------
epoch: 195 loss avg:0.3002774715423584 accum:18.316926956176758
----------------------------------------
epoch: 196 loss avg:0.29490119218826294 accum:17.98897361755371
----------------------------------------
epoch: 197 loss avg:0.2922423481941223 accum:17.826784133911133
----------------------------------------
epoch: 198 loss a

epoch: 142 loss avg:0.26069751381874084 accum:16.163246154785156
----------------------------------------
epoch: 143 loss avg:0.26182034611701965 accum:16.23286247253418
----------------------------------------
epoch: 144 loss avg:0.255527138710022 accum:15.842683792114258
----------------------------------------
epoch: 145 loss avg:0.271962434053421 accum:16.861671447753906
----------------------------------------
epoch: 146 loss avg:0.26904943585395813 accum:16.68106460571289
----------------------------------------
epoch: 147 loss avg:0.2634642720222473 accum:16.33478546142578
----------------------------------------
epoch: 148 loss avg:0.25808772444725037 accum:16.001440048217773
----------------------------------------
epoch: 149 loss avg:0.2688906490802765 accum:16.671220779418945
----------------------------------------
epoch: 150 loss avg:0.26885804533958435 accum:16.669198989868164
----------------------------------------
epoch: 151 loss avg:0.26258984208106995 accum:16.280570

----------------------------------------
epoch: 96 loss avg:0.276285856962204 accum:16.853437423706055
----------------------------------------
epoch: 97 loss avg:0.2811005413532257 accum:17.14713478088379
----------------------------------------
epoch: 98 loss avg:0.2777811884880066 accum:16.944652557373047
----------------------------------------
epoch: 99 loss avg:0.2836730480194092 accum:17.30405616760254
----------------------------------------
epoch: 100 loss avg:0.27876269817352295 accum:17.004526138305664
----------------------------------------
epoch: 101 loss avg:0.2727026343345642 accum:16.63486099243164
----------------------------------------
epoch: 102 loss avg:0.2749827206134796 accum:16.77394676208496
----------------------------------------
epoch: 103 loss avg:0.2776035666465759 accum:16.933818817138672
----------------------------------------
epoch: 104 loss avg:0.27810361981391907 accum:16.96432113647461
----------------------------------------
epoch: 105 loss avg:0.

epoch: 49 loss avg:0.27995583415031433 accum:16.79734992980957
----------------------------------------
epoch: 50 loss avg:0.267860472202301 accum:16.071626663208008
----------------------------------------
epoch: 51 loss avg:0.26982778310775757 accum:16.189666748046875
----------------------------------------
epoch: 52 loss avg:0.2599453926086426 accum:15.596723556518555
----------------------------------------
epoch: 53 loss avg:0.26654714345932007 accum:15.992827415466309
----------------------------------------
epoch: 54 loss avg:0.2696613371372223 accum:16.17967987060547
----------------------------------------
epoch: 55 loss avg:0.2737243175506592 accum:16.423458099365234
----------------------------------------
epoch: 56 loss avg:0.2741749882698059 accum:16.450498580932617
----------------------------------------
epoch: 57 loss avg:0.2756825387477875 accum:16.540950775146484
----------------------------------------
epoch: 58 loss avg:0.28050410747528076 accum:16.830245971679688


epoch: 2 loss avg:0.264287531375885 accum:16.121540069580078
----------------------------------------
epoch: 3 loss avg:0.2601504325866699 accum:15.869176864624023
----------------------------------------
epoch: 4 loss avg:0.26174837350845337 accum:15.96665096282959
----------------------------------------
epoch: 5 loss avg:0.26392948627471924 accum:16.099699020385742
----------------------------------------
epoch: 6 loss avg:0.2611284852027893 accum:15.928837776184082
----------------------------------------
epoch: 7 loss avg:0.25915008783340454 accum:15.808156967163086
----------------------------------------
epoch: 8 loss avg:0.26348650455474854 accum:16.072677612304688
----------------------------------------
epoch: 9 loss avg:0.26097992062568665 accum:15.91977596282959
----------------------------------------
epoch: 10 loss avg:0.2562441825866699 accum:15.630895614624023
----------------------------------------
epoch: 11 loss avg:0.24636758863925934 accum:15.028423309326172
------

----------------------------------------
epoch: 159 loss avg:0.2617719769477844 accum:15.96809196472168
----------------------------------------
epoch: 160 loss avg:0.2555026113986969 accum:15.585660934448242
----------------------------------------
epoch: 161 loss avg:0.25503313541412354 accum:15.557022094726562
----------------------------------------
epoch: 162 loss avg:0.2572946548461914 accum:15.694974899291992
----------------------------------------
epoch: 163 loss avg:0.25730881094932556 accum:15.695838928222656
----------------------------------------
epoch: 164 loss avg:0.25533822178840637 accum:15.57563304901123
----------------------------------------
epoch: 165 loss avg:0.24986453354358673 accum:15.241737365722656
----------------------------------------
epoch: 166 loss avg:0.2527601420879364 accum:15.418370246887207
----------------------------------------
epoch: 167 loss avg:0.2586359977722168 accum:15.7767972946167
----------------------------------------
epoch: 168 los

----------------------------------------
epoch: 113 loss avg:0.26589682698249817 accum:17.549190521240234
----------------------------------------
epoch: 114 loss avg:0.2716938555240631 accum:17.931793212890625
----------------------------------------
epoch: 115 loss avg:0.26277080178260803 accum:17.342872619628906
----------------------------------------
epoch: 116 loss avg:0.27485254406929016 accum:18.14026641845703
----------------------------------------
epoch: 117 loss avg:0.28263938426971436 accum:18.654199600219727
----------------------------------------
epoch: 118 loss avg:0.27097126841545105 accum:17.884103775024414
----------------------------------------
epoch: 119 loss avg:0.2771620750427246 accum:18.292695999145508
----------------------------------------
epoch: 120 loss avg:0.2611635625362396 accum:17.23679542541504
----------------------------------------
epoch: 121 loss avg:0.258059024810791 accum:17.03189468383789
----------------------------------------
epoch: 122 lo

epoch: 66 loss avg:0.254054456949234 accum:15.497322082519531
----------------------------------------
epoch: 67 loss avg:0.26091545820236206 accum:15.91584300994873
----------------------------------------
epoch: 68 loss avg:0.26630261540412903 accum:16.244461059570312
----------------------------------------
epoch: 69 loss avg:0.2574194669723511 accum:15.70258903503418
----------------------------------------
epoch: 70 loss avg:0.2588873505592346 accum:15.792128562927246
----------------------------------------
epoch: 71 loss avg:0.25554779171943665 accum:15.58841609954834
----------------------------------------
epoch: 72 loss avg:0.2596791684627533 accum:15.84043025970459
----------------------------------------
epoch: 73 loss avg:0.25930294394493103 accum:15.817480087280273
----------------------------------------
epoch: 74 loss avg:0.26157718896865845 accum:15.956209182739258
----------------------------------------
epoch: 75 loss avg:0.26580294966697693 accum:16.21398162841797
-

epoch: 19 loss avg:0.25203198194503784 accum:15.37395191192627
----------------------------------------
epoch: 20 loss avg:0.25428062677383423 accum:15.51111888885498
----------------------------------------
epoch: 21 loss avg:0.259798139333725 accum:15.847686767578125
----------------------------------------
epoch: 22 loss avg:0.26386165618896484 accum:16.095561981201172
----------------------------------------
epoch: 23 loss avg:0.2581333518028259 accum:15.746134757995605
----------------------------------------
epoch: 24 loss avg:0.258043110370636 accum:15.740631103515625
----------------------------------------
epoch: 25 loss avg:0.25458669662475586 accum:15.529788970947266
----------------------------------------
epoch: 26 loss avg:0.2516957223415375 accum:15.353439331054688
----------------------------------------
epoch: 27 loss avg:0.258566677570343 accum:15.772567749023438
----------------------------------------
epoch: 28 loss avg:0.25030532479286194 accum:15.268625259399414
-

----------------------------------------
epoch: 98 loss avg:0.2525618076324463 accum:15.406271934509277
----------------------------------------
epoch: 99 loss avg:0.25270983576774597 accum:15.415301322937012
----------------------------------------
epoch: 100 loss avg:0.2494983822107315 accum:15.219402313232422
----------------------------------------
epoch: 101 loss avg:0.24732163548469543 accum:15.086620330810547
----------------------------------------
epoch: 102 loss avg:0.26063793897628784 accum:15.898914337158203
----------------------------------------
epoch: 103 loss avg:0.25750958919525146 accum:15.708085060119629
----------------------------------------
epoch: 104 loss avg:0.25613394379615784 accum:15.624171257019043
----------------------------------------
epoch: 105 loss avg:0.2562674283981323 accum:15.63231372833252
----------------------------------------
epoch: 106 loss avg:0.2519638240337372 accum:15.369794845581055
----------------------------------------
epoch: 107 l

In [6]:

import torch
from sklearn.isotonic import IsotonicRegression

'''
in: 一个ndarray的2列(m×1)
out: 一个m×1 的 1 列
description: 2列每个元素对应相加
'''


def f_sum(column_1, column_2):
    return column_1 + column_2


'''
in: 一个ndarray的2列(m×1)
out: 一个m×1 的 1 列
description: 2列每个元素对应相减
'''


def f_subtract(column_1, column_2):
    return column_1 - column_2


'''
in: 一个ndarray的2列(m×1)
out: 一个m×1 的 1 列
description: 2列每个元素对应相减
'''


def f_multiply(column_1, column_2):
    return column_1 * column_2


'''
in:  一个ndarray的2列(m×1)
out: 一个m×1 的 1 列
description: 2列每个元素对应相减
'''


def f_divide(column_1, column_2):
    condition = None
    if torch.is_tensor(column_2):
        condition = torch.all(torch.ne(column_2, 0))
    else:
        condition = np.all(column_2 != 0)
    if condition:
        return column_1 / column_2
    return None


class Binaries:
    #name = ['sum', 'subtract', 'multiply', 'divide']
    #func = [f_sum, f_subtract, f_multiply, f_divide]
    name = ['multiply', 'divide']
    func = [f_multiply, f_divide]

    def __init__(self):
        pass


def f_log(column):
    if torch.is_tensor(column):
        return torch.log2(column) if torch.all(torch.gt(column, 0)) else None
    else:
        return np.log2(column) if np.all(column > 0) else None


'''
in:  一个ndarray 的1列(m×1)
out: 一个m×1 的 1 列
description: 2列每个元素绝对值对应求平方根
'''


def f_square_root(column):
    return torch.sqrt(torch.abs(column)) if torch.is_tensor(column) else np.sqrt(np.abs(column))


'''
in:  一个ndarray 的1列(m×1)
out: 一个m×1 的 1 列
description: 2列每个元素对应求平方根,负值对绝对值求平方根加符号,例如square_root(-9)=-3
'''


def f_square(column):
    return torch.sqrt(torch.abs(column)) * torch.sign(column) if torch.is_tensor(column) else np.sqrt(
        np.abs(column)) * np.sign(column)


'''
in:  一个ndarray 的1列(m×1)
out: 一个m×1 的 1 列
description: 对应元素替换成该元素在这一列出现的频次,例:[7,7,2,3,3,4] -> [2,2,1,2,2,1]
'''


def f_frequency(column):
    freq = pd.value_counts(np.array(column))
    freq_result = list(map(lambda x: freq[x], np.array(column)))
    return torch.tensor(freq_result).float() if torch.is_tensor(column) else np.array(freq_result)


'''
in:  一个ndarray 的1列(m×1)
out: 一个m×1 的 1 列
description:每个值对应四舍五入
'''


def f_round(column):
    return torch.round(column) if torch.is_tensor(column) else np.round(column).astype('int')


'''
in:  一个ndarray 的1列(m×1)
out: 一个m×1 的 1 列
description:每个值对应双曲正切
'''


def f_tanh(column):
    return torch.tanh(column) if torch.is_tensor(column) else np.tanh(column)


'''
in:  一个ndarray 的1列(m×1)
out: 一个m×1 的 1 列er
description:每个值对应sigmoid,自己查一下sigmoid函数
'''


def f_sigmoid(column):
    return torch.sigmoid(column) if torch.is_tensor(column) else (1 / (1 + np.exp(-column)))


'''
in:  一个ndarray 的1列(m×1),
out: 一个m×1 的 1 列
description:对该列值的分布进行,自己查一下保序回归
'''


def f_isotonic_regression(column):
    inds = range(column.shape[0])
    if torch.is_tensor(column):
        return torch.tensor(IsotonicRegression().fit_transform(inds, column)).float()
    else:
        return IsotonicRegression().fit_transform(inds,column)


'''
in:  一个ndarray 的1列(m×1),
out: 一个m×1 的 1 列
description:对该列值的分布进行z分数,查一下z分数
'''


def f_zscore(column):
    if torch.is_tensor(column):
        mv, stv = torch.mean(column), torch.std(column)
        condition = torch.all(torch.ne(stv, 0))
    else:
        mv, stv = np.mean(column), np.std(column)
        condition = np.all(stv != 0)
    if condition:
        return (column - mv) / stv
    return None


'''
in:  一个ndarray 的1列(m×1),
out: 一个m×1 的 1 列
description:对该列值的分布进行-1到1正则化,查一下normalization
'''


def f_normalize(column):
    if torch.is_tensor(column):
        maxv, minv = torch.max(column), torch.min(column)
        condition = torch.equal(maxv, minv)
    else:
        maxv, minv = np.max(column), np.min(column)
        condition = maxv == minv
    if condition:
        return None
    return -1 + 2 / (maxv - minv) * (column - minv)


class Unaries:
    name = ['log', 'square_root', 'square', 'frequency', 'round', 'tanh', 'sigmoid', 'isotonic_regression', 'zscore']
    func = [f_log, f_square_root, f_square, f_frequency, f_round, f_tanh, f_sigmoid, f_isotonic_regression, f_zscore]

    def __init__(self):
        pass

In [7]:
model_groups = dict()
for name in Unaries.name:
    model_groups[name] = MLP(400)
    load_model("{}.pkl".format(name),model_groups[name],device,mode="eval")
for name in Binaries.name:
    model_groups[name] = MLP(800)
    load_model("{}.pkl".format(name),model_groups[name],device,mode="eval")



In [41]:
from itertools import permutations

threshold = 0.88
n_attempt = 1000
perm =True

def is_continuous(column):
    if len(np.unique(column))>3:
        return True

def probability(out):
    out= torch.nn.functional.softmax(out, dim=0)
    result = torch.max(out, 0)
    #print(result)
    return result[0] if result[1] == 1 else 1-result[0]
    

def get_tensor_sketch(n_bins, feature, labels):
    supr, infr = torch.max(feature), torch.min(feature)
    idx0, idx1 = torch.where(labels == 0), torch.where(labels == 1)

    sketch0 = torch.histc(feature[idx0], bins=n_bins, min=float(infr), max=float(supr))
    sketch1 = torch.histc(feature[idx1], bins=n_bins, min=float(infr), max=float(supr))

    sketch0 = -10 + 20 * (sketch0 - torch.min(sketch0)) / (torch.max(sketch0) - torch.min(sketch0))
    sketch1 = -10 + 20 * (sketch1 - torch.min(sketch1)) / (torch.max(sketch1) - torch.min(sketch1))
    quantile_sketch = torch.cat((sketch0, sketch1), 0)
    return quantile_sketch

In [42]:
ef=[93, 18, 68, 35, 4, 55, 69, 3, 31, 23, 5, 70, 65, 41, 54, 100, 45, 28, 101, 24, 83, 84, 44, 17, 53, 103, 64, 22, 33, 94, 39, 36, 63, 10, 77, 86, 82, 30, 75, 25, 99, 34, 89, 2, 11, 16, 85, 13, 21, 72, 95, 7, 90, 29, 51, 20, 40, 6, 87, 43, 96, 42, 81, 32, 37, 9, 98, 61, 14, 46, 8, 0, 19]

In [43]:
class InsuranceDataset:
    def __init__(self):
        self.data = None
        self.target = None
        self.features = None
        self.cont_index = list()
    
    def read_insurance_csv(self,path,ef):
        dataset = pd.read_csv(path)
        for col in dataset.columns.values:
            mask = dataset[col] != np.inf
            dataset.loc[~mask, col] = dataset.loc[mask, col].max()
            
        if ef is None:
            self.data = torch.from_numpy(dataset.values[:,:-1].astype("float")).float()
            self.target = torch.from_numpy(dataset.values[:,-1].astype("int")).long()
            self.features = list(dataset.columns.values[:-1])
        else:
            self.data = torch.from_numpy(dataset.values[:,ef].astype("float")).float()
            self.target = torch.from_numpy(dataset.values[:,-1].astype("int")).long()
            self.features = list(dataset.columns.values[ef])
            
        
        # 获得连续特征的下标
        for col in range(self.data.shape[1]):
            if is_continuous(self.data[:,col]):
                self.cont_index.append(col)

In [44]:
from sklearn.metrics import classification_report, accuracy_score, make_scorer
original_class = list()
predicted_class = list()

def classification_report_with_accuracy_score(y_true, y_pred):
    original_class.extend(y_true)
    predicted_class.extend(y_pred)
    #return classification_report(y_true, y_pred)# print classification report
    return accuracy_score(y_true, y_pred) # return accuracy score

In [45]:
from xgboost import XGBClassifier
from xgboost import plot_importance
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import classification_report


# 读入数据集
insurance_dataset =  InsuranceDataset()

insurance_dataset.read_insurance_csv("fengxian.csv",ef)
# 全排列获得所有特征
binary_candidates = list(permutations(insurance_dataset.cont_index,2))
unary_candidates = list(permutations(insurance_dataset.cont_index,1))

print("binary_candidates:",len(binary_candidates))
print("unary_candidates:",len(unary_candidates))

#scores = list()
score = None
for i in range(1):
#     classifier=XGBClassifier(learning_rate=0.01,
#                       n_estimators=20,           # 树的个数-10棵树建立xgboost
#                       max_depth=8,               # 树的深度
#                       min_child_weight = 1,      # 叶子节点最小权重
#                       gamma=0.,                  # 惩罚项中叶子结点个数前的参数
#                       subsample=1,               # 所有样本建立决策树
#                       colsample_btree=1,         # 所有特征建立决策树
#                       scale_pos_weight=1,        # 解决样本个数不平衡的问题
#                       random_state=27,           # 随机数
#                       slient = 0
#                            )
    classifier = GradientBoostingClassifier(random_state=10)
    score = cross_val_score(classifier, insurance_dataset.data, insurance_dataset.target, cv=10, scoring=make_scorer(classification_report_with_accuracy_score)).mean()
    print('later feature num:', insurance_dataset.data.shape)
    print(classification_report(original_class, predicted_class)) 
    print("benchscore:", score)
    #scores.append(score)
#print('original feature num:', insurance_dataset.data.shape)
#print("benchscore:", score)

for bc in binary_candidates:
    idx1,idx2= bc
    f1,f2 = insurance_dataset.data[:,idx1],insurance_dataset.data[:,idx2]
    sk1 = get_tensor_sketch(200,f1,insurance_dataset.target).to(device)
    sk2 = get_tensor_sketch(200,f2,insurance_dataset.target).to(device)
    qsa = torch.cat((sk1,sk2),0)
    for name,func in zip(Binaries.name,Binaries.func):
        output = model_groups[name](qsa)
        new_feature_name= "{} {} {}".format(insurance_dataset.features[idx1], name,insurance_dataset.features[idx2])
        prob = probability(output)
        if prob > threshold:
            new_feature = func(f1,f2)
            if new_feature is not None:
                print(new_feature_name)
                print(prob)
                print("-"*30)
                insurance_dataset.features.append(new_feature_name)
                new_feature = torch.unsqueeze(new_feature, 1)
                insurance_dataset.data = torch.cat((insurance_dataset.data, new_feature), 1)
                

for uc in unary_candidates:
    idx = uc[0]
    f = insurance_dataset.data[:,idx]
    qsa = get_tensor_sketch(200,f,insurance_dataset.target).to(device)
    for name,func in zip(Unaries.name,Unaries.func):
        output = model_groups[name](qsa)
        new_feature_name="{} {}".format(name,insurance_dataset.features[idx])
        prob = probability(output)
        if prob>threshold:
            new_feature = func(f)
            if new_feature is not None:
                print(new_feature_name)
                print(prob)
                print("-"*30)
                insurance_dataset.features.append(new_feature_name)
                new_feature = torch.unsqueeze(new_feature,1)
                insurance_dataset.data = torch.cat((insurance_dataset.data,new_feature),1)

score = None
for i in range(1):
#     classifier=XGBClassifier(learning_rate=0.01,
#                       n_estimators=20,           # 树的个数-10棵树建立xgboost
#                       max_depth=8,               # 树的深度
#                       min_child_weight = 1,      # 叶子节点最小权重
#                       gamma=0.,                  # 惩罚项中叶子结点个数前的参数
#                       subsample=1,               # 所有样本建立决策树
#                       colsample_btree=1,         # 所有特征建立决策树
#                       scale_pos_weight=1,        # 解决样本个数不平衡的问题
#                       random_state=27,           # 随机数
#                       slient = 0
#                            )
    #score = cross_val_score(classifier, np.array(insurance_dataset.data), np.array(insurance_dataset.target), cv=10, scoring='f1')
    classifier = GradientBoostingClassifier()
    original_class = list()
    predicted_class = list()
    score = cross_val_score(classifier, insurance_dataset.data, insurance_dataset.target, cv=10, scoring=make_scorer(classification_report_with_accuracy_score)).mean()
    #scores.append(score)
    print('later feature num:', insurance_dataset.data.shape)
    print(classification_report(original_class, predicted_class)) 
    print("newscore:", score)

binary_candidates: 600
unary_candidates: 25
later feature num: torch.Size([1510, 73])
              precision    recall  f1-score   support

           0       0.51      0.61      0.55       300
           1       0.90      0.85      0.88      1210

    accuracy                           0.80      1510
   macro avg       0.70      0.73      0.71      1510
weighted avg       0.82      0.80      0.81      1510

benchscore: 0.8046357615894039
beizhixing_jine multiply zongzichan_baochoulv
tensor(0.8804, device='cuda:0', grad_fn=<MaxBackward0>)
------------------------------
shixin_beizhixing_xinxishuliang multiply zongzichan_baochoulv
tensor(0.8806, device='cuda:0', grad_fn=<MaxBackward0>)
------------------------------
yingshouzhangkuan/yingyeshouru multiply zhuce_ziben
tensor(0.9062, device='cuda:0', grad_fn=<MaxBackward0>)
------------------------------
zichanzongji_tongbi multiply zongzichan_baochoulv
tensor(0.8802, device='cuda:0', grad_fn=<MaxBackward0>)
-----------------------------