In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn

from sklearn import datasets
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

In [2]:
def NNtrain(X, y, model, task, lossf, optimizer, epochs, train_size, device = 'CPU'):
    """
    X, 特征向量矩阵nxd
    y, 标签
    model, 神经网络模型
    task, 回归'R'或者分类'C'
    lossf, 损失函数回归nn.MSELoss()或者分类nn.CrossEntropyLoss()
    optimizer, 优化器'Adam' 'RMSprop' 'Adadelta' 'Adagrad' 'SGD'
    epochs, 迭代次数
    train_size, 训练集大小
    device = 'CPU', 使用CPU还是GPU
    """
    model.train()
    
    x_train, x_test, y_train, y_test = train_test_split(X, y, random_state=1, train_size=train_size)
    
    if task == 'C':
        x_train_tr = torch.from_numpy(x_train).float()
        y_train_tr = torch.from_numpy(y_train).long()
        x_test_tr = torch.from_numpy(x_test).float()
        y_test_tr = torch.from_numpy(y_test).long()
    if task == 'R':
        x_train_tr = torch.from_numpy(x_train).float()
        y_train_tr = torch.from_numpy(y_train).float()
        x_test_tr = torch.from_numpy(x_test).float()
        y_test_tr = torch.from_numpy(y_test).float()
        
    if device =='GPU':
        model_used = model.cuda()
        x_train_tr_used = x_train_tr.cuda()
        y_train_tr_used = y_train_tr.cuda()
        x_test_tr_used = x_test_tr.cuda()
        y_test_tr_used = y_test_tr.cuda()
    if device == 'CPU':
        model_used = model
        x_train_tr_used = x_train_tr
        y_train_tr_used = y_train_tr
        x_test_tr_used = x_test_tr
        y_test_tr_used = y_test_tr
    
    optimizer = optimizer
    loss_func = lossf
    loss_val_past = 9999
    loss_past = 9999
        
    for epoch in range(epochs):
        predict = model_used(x_train_tr_used)
        loss = loss_func(predict, y_train_tr_used)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        out_val = model_used(x_test_tr_used)
        loss_val = loss_func(out_val, y_test_tr_used)
        print('Epoch: ', epoch+1, '| Train Accurancy: ', 1-float(loss), '| Validation Accurancy: ', 1-float(loss_val))

        if (float(loss_val) > float(loss_val_past) and epoch>99) or ((float(loss_past)-float(loss))<1e-7 and epoch>99):
            break
        else:
            loss_val_past = loss_val
            loss_past = loss

In [3]:
class Net(torch.nn.Module):  # 继承 torch 的 Module
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()     # 继承 __init__ 功能
        # 定义每层用什么样的形式
        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # 隐藏层线性输出
        self.predict = torch.nn.Linear(n_hidden, n_output)   # 输出层线性输出

    def forward(self, x):   # 这同时也是 Module 中的 forward 功能
        # 正向传播输入值, 神经网络分析出输出值
        x = F.relu(self.hidden(x))      # 激励函数(隐藏层的线性值)
        x = self.predict(x)             # 输出值
        return x

net = Net(n_feature=4, n_hidden=9, n_output=4)

iris = datasets.load_iris()

NNtrain(
    X=iris.data, 
    y=iris.target, 
    model=net, 
    task='C', 
    lossf=nn.CrossEntropyLoss(), 
    optimizer=torch.optim.Adam(net.parameters(), lr=0.01), 
    epochs=30000, 
    train_size=0.9, 
    device = 'CPU')



Epoch:  1 | Train Accurancy:  -0.35689783096313477 | Validation Accurancy:  -0.27798378467559814
Epoch:  2 | Train Accurancy:  -0.2722439765930176 | Validation Accurancy:  -0.21257615089416504
Epoch:  3 | Train Accurancy:  -0.20994722843170166 | Validation Accurancy:  -0.1617361307144165
Epoch:  4 | Train Accurancy:  -0.15970396995544434 | Validation Accurancy:  -0.11875653266906738
Epoch:  5 | Train Accurancy:  -0.11512839794158936 | Validation Accurancy:  -0.0822974443435669
Epoch:  6 | Train Accurancy:  -0.07320058345794678 | Validation Accurancy:  -0.04879426956176758
Epoch:  7 | Train Accurancy:  -0.03544914722442627 | Validation Accurancy:  -0.017242074012756348
Epoch:  8 | Train Accurancy:  -0.001459956169128418 | Validation Accurancy:  0.012257039546966553
Epoch:  9 | Train Accurancy:  0.0308380126953125 | Validation Accurancy:  0.03996950387954712
Epoch:  10 | Train Accurancy:  0.062419772148132324 | Validation Accurancy:  0.06660968065261841
Epoch:  11 | Train Accurancy:  0.0

Epoch:  102 | Train Accurancy:  0.8525571078062057 | Validation Accurancy:  0.8593935072422028
Epoch:  103 | Train Accurancy:  0.8545165956020355 | Validation Accurancy:  0.8617157936096191
Epoch:  104 | Train Accurancy:  0.8564192354679108 | Validation Accurancy:  0.8639838248491287
Epoch:  105 | Train Accurancy:  0.8582667261362076 | Validation Accurancy:  0.8662068247795105
Epoch:  106 | Train Accurancy:  0.8600612580776215 | Validation Accurancy:  0.8683827221393585
Epoch:  107 | Train Accurancy:  0.8618043065071106 | Validation Accurancy:  0.8705040663480759
Epoch:  108 | Train Accurancy:  0.8634971976280212 | Validation Accurancy:  0.8725622445344925
Epoch:  109 | Train Accurancy:  0.865141972899437 | Validation Accurancy:  0.8745572119951248
Epoch:  110 | Train Accurancy:  0.8667402565479279 | Validation Accurancy:  0.8764985576272011
Epoch:  111 | Train Accurancy:  0.8682932555675507 | Validation Accurancy:  0.8783988505601883
Epoch:  112 | Train Accurancy:  0.8698026090860367 

Epoch:  207 | Train Accurancy:  0.9256387203931808 | Validation Accurancy:  0.9554774984717369
Epoch:  208 | Train Accurancy:  0.9258509576320648 | Validation Accurancy:  0.9558053240180016
Epoch:  209 | Train Accurancy:  0.9260601699352264 | Validation Accurancy:  0.956129115074873
Epoch:  210 | Train Accurancy:  0.9262664020061493 | Validation Accurancy:  0.9564486667513847
Epoch:  211 | Train Accurancy:  0.9264698699116707 | Validation Accurancy:  0.9567640163004398
Epoch:  212 | Train Accurancy:  0.9266705587506294 | Validation Accurancy:  0.9570759534835815
Epoch:  213 | Train Accurancy:  0.9268685728311539 | Validation Accurancy:  0.9573843330144882
Epoch:  214 | Train Accurancy:  0.927063837647438 | Validation Accurancy:  0.9576894901692867
Epoch:  215 | Train Accurancy:  0.9272564798593521 | Validation Accurancy:  0.9579911455512047
Epoch:  216 | Train Accurancy:  0.927446648478508 | Validation Accurancy:  0.958288948982954
Epoch:  217 | Train Accurancy:  0.9276341274380684 | V

Epoch:  318 | Train Accurancy:  0.9388853423297405 | Validation Accurancy:  0.9773823581635952
Epoch:  319 | Train Accurancy:  0.9389519579708576 | Validation Accurancy:  0.9775006286799908
Epoch:  320 | Train Accurancy:  0.939018078148365 | Validation Accurancy:  0.9776182491332293
Epoch:  321 | Train Accurancy:  0.9390836842358112 | Validation Accurancy:  0.9777349624782801
Epoch:  322 | Train Accurancy:  0.939148798584938 | Validation Accurancy:  0.9778508823364973
Epoch:  323 | Train Accurancy:  0.9392134472727776 | Validation Accurancy:  0.977965958416462
Epoch:  324 | Train Accurancy:  0.9392775222659111 | Validation Accurancy:  0.9780804309993982
Epoch:  325 | Train Accurancy:  0.9393412135541439 | Validation Accurancy:  0.9781939666718245
Epoch:  326 | Train Accurancy:  0.9394043534994125 | Validation Accurancy:  0.9783066119998693
Epoch:  327 | Train Accurancy:  0.9394670687615871 | Validation Accurancy:  0.9784187320619822
Epoch:  328 | Train Accurancy:  0.9395292922854424 | 

Epoch:  418 | Train Accurancy:  0.9437599405646324 | Validation Accurancy:  0.9861217821016908
Epoch:  419 | Train Accurancy:  0.9437959603965282 | Validation Accurancy:  0.9861854715272784
Epoch:  420 | Train Accurancy:  0.9438317976891994 | Validation Accurancy:  0.9862488266080618
Epoch:  421 | Train Accurancy:  0.9438674710690975 | Validation Accurancy:  0.9863117691129446
Epoch:  422 | Train Accurancy:  0.9439030028879642 | Validation Accurancy:  0.9863741397857666
Epoch:  423 | Train Accurancy:  0.9439382739365101 | Validation Accurancy:  0.9864363670349121
Epoch:  424 | Train Accurancy:  0.9439734555780888 | Validation Accurancy:  0.9864984191954136
Epoch:  425 | Train Accurancy:  0.9440084062516689 | Validation Accurancy:  0.9865599470213056
Epoch:  426 | Train Accurancy:  0.9440432600677013 | Validation Accurancy:  0.9866211572661996
Epoch:  427 | Train Accurancy:  0.9440779574215412 | Validation Accurancy:  0.9866817947477102
Epoch:  428 | Train Accurancy:  0.944112416356802 

Epoch:  575 | Train Accurancy:  0.9479745142161846 | Validation Accurancy:  0.9929293156601489
Epoch:  576 | Train Accurancy:  0.9479951076209545 | Validation Accurancy:  0.9929577829316258
Epoch:  577 | Train Accurancy:  0.9480156525969505 | Validation Accurancy:  0.9929856937378645
Epoch:  578 | Train Accurancy:  0.9480360299348831 | Validation Accurancy:  0.9930136362090707
Epoch:  579 | Train Accurancy:  0.9480564258992672 | Validation Accurancy:  0.9930414040572941
Epoch:  580 | Train Accurancy:  0.9480768218636513 | Validation Accurancy:  0.9930691877380013
Epoch:  581 | Train Accurancy:  0.9480970725417137 | Validation Accurancy:  0.993096764665097
Epoch:  582 | Train Accurancy:  0.9481172822415829 | Validation Accurancy:  0.9931241669692099
Epoch:  583 | Train Accurancy:  0.9481375031173229 | Validation Accurancy:  0.993151553440839
Epoch:  584 | Train Accurancy:  0.9481576345860958 | Validation Accurancy:  0.9931785743683577
Epoch:  585 | Train Accurancy:  0.9481776915490627 |

Epoch:  686 | Train Accurancy:  0.9499750547111034 | Validation Accurancy:  0.9953545569442213
Epoch:  687 | Train Accurancy:  0.9499908909201622 | Validation Accurancy:  0.9953709444962442
Epoch:  688 | Train Accurancy:  0.9500066414475441 | Validation Accurancy:  0.9953873474150896
Epoch:  689 | Train Accurancy:  0.9500224366784096 | Validation Accurancy:  0.9954034807160497
Epoch:  690 | Train Accurancy:  0.9500381350517273 | Validation Accurancy:  0.9954198994673789
Epoch:  691 | Train Accurancy:  0.9500537887215614 | Validation Accurancy:  0.9954360006377101
Epoch:  692 | Train Accurancy:  0.9500693753361702 | Validation Accurancy:  0.9954522610642016
Epoch:  693 | Train Accurancy:  0.9500849768519402 | Validation Accurancy:  0.9954679650254548
Epoch:  694 | Train Accurancy:  0.9501006416976452 | Validation Accurancy:  0.9954838277772069
Epoch:  695 | Train Accurancy:  0.9501161873340607 | Validation Accurancy:  0.9954996746964753
Epoch:  696 | Train Accurancy:  0.9501315727829933

Epoch:  794 | Train Accurancy:  0.9514995999634266 | Validation Accurancy:  0.9967456341255456
Epoch:  795 | Train Accurancy:  0.9515120647847652 | Validation Accurancy:  0.9967554728500545
Epoch:  796 | Train Accurancy:  0.9515246264636517 | Validation Accurancy:  0.9967652796767652
Epoch:  797 | Train Accurancy:  0.951537013053894 | Validation Accurancy:  0.9967751184012741
Epoch:  798 | Train Accurancy:  0.9515494108200073 | Validation Accurancy:  0.9967848618980497
Epoch:  799 | Train Accurancy:  0.9515618719160557 | Validation Accurancy:  0.9967945734970272
Epoch:  800 | Train Accurancy:  0.9515743106603622 | Validation Accurancy:  0.9968043963890523
Epoch:  801 | Train Accurancy:  0.9515866376459599 | Validation Accurancy:  0.9968140283599496
Epoch:  802 | Train Accurancy:  0.9515989758074284 | Validation Accurancy:  0.9968235492706299
Epoch:  803 | Train Accurancy:  0.9516112431883812 | Validation Accurancy:  0.996833038283512
Epoch:  804 | Train Accurancy:  0.9516234584152699 |

Epoch:  907 | Train Accurancy:  0.9527492858469486 | Validation Accurancy:  0.9976161003578454
Epoch:  908 | Train Accurancy:  0.9527588859200478 | Validation Accurancy:  0.9976223628036678
Epoch:  909 | Train Accurancy:  0.952768474817276 | Validation Accurancy:  0.9976281800772995
Epoch:  910 | Train Accurancy:  0.9527781121432781 | Validation Accurancy:  0.9976342201698571
Epoch:  911 | Train Accurancy:  0.9527876749634743 | Validation Accurancy:  0.997640069341287
Epoch:  912 | Train Accurancy:  0.9527972005307674 | Validation Accurancy:  0.9976459504105151
Epoch:  913 | Train Accurancy:  0.9528067335486412 | Validation Accurancy:  0.997651704121381
Epoch:  914 | Train Accurancy:  0.9528163149952888 | Validation Accurancy:  0.9976576804183424
Epoch:  915 | Train Accurancy:  0.9528257623314857 | Validation Accurancy:  0.997663434362039
Epoch:  916 | Train Accurancy:  0.9528352469205856 | Validation Accurancy:  0.997669315431267
Epoch:  917 | Train Accurancy:  0.9528447203338146 | Va

Epoch:  1016 | Train Accurancy:  0.9536688402295113 | Validation Accurancy:  0.998136043548584
Epoch:  1017 | Train Accurancy:  0.9536761306226254 | Validation Accurancy:  0.998139985371381
Epoch:  1018 | Train Accurancy:  0.9536834359169006 | Validation Accurancy:  0.9981435775989667
Epoch:  1019 | Train Accurancy:  0.9536906033754349 | Validation Accurancy:  0.9981474876403809
Epoch:  1020 | Train Accurancy:  0.9536978676915169 | Validation Accurancy:  0.9981512069934979
Epoch:  1021 | Train Accurancy:  0.9537049829959869 | Validation Accurancy:  0.9981551170349121
Epoch:  1022 | Train Accurancy:  0.9537121914327145 | Validation Accurancy:  0.9981586455833167
Epoch:  1023 | Train Accurancy:  0.9537193216383457 | Validation Accurancy:  0.9981623649364337
Epoch:  1024 | Train Accurancy:  0.9537264369428158 | Validation Accurancy:  0.9981661160709336
Epoch:  1025 | Train Accurancy:  0.9537335149943829 | Validation Accurancy:  0.9981696765171364
Epoch:  1026 | Train Accurancy:  0.9537406

Epoch:  1174 | Train Accurancy:  0.9545936733484268 | Validation Accurancy:  0.9986010233405977
Epoch:  1175 | Train Accurancy:  0.9545982182025909 | Validation Accurancy:  0.9986035346519202
Epoch:  1176 | Train Accurancy:  0.954602800309658 | Validation Accurancy:  0.9986056328052655
Epoch:  1177 | Train Accurancy:  0.9546073526144028 | Validation Accurancy:  0.998607826186344
Epoch:  1178 | Train Accurancy:  0.9546118192374706 | Validation Accurancy:  0.998610210372135
Epoch:  1179 | Train Accurancy:  0.9546163901686668 | Validation Accurancy:  0.9986123403068632
Epoch:  1180 | Train Accurancy:  0.954620786011219 | Validation Accurancy:  0.9986147244926542
Epoch:  1181 | Train Accurancy:  0.9546253494918346 | Validation Accurancy:  0.9986170451156795
Epoch:  1182 | Train Accurancy:  0.9546298198401928 | Validation Accurancy:  0.9986191431526095
Epoch:  1183 | Train Accurancy:  0.9546343050897121 | Validation Accurancy:  0.9986213366501033
Epoch:  1184 | Train Accurancy:  0.954638719

Epoch:  1287 | Train Accurancy:  0.9550282582640648 | Validation Accurancy:  0.9988246917491779
Epoch:  1288 | Train Accurancy:  0.9550314135849476 | Validation Accurancy:  0.9988264719722793
Epoch:  1289 | Train Accurancy:  0.9550345875322819 | Validation Accurancy:  0.9988280613906682
Epoch:  1290 | Train Accurancy:  0.9550378322601318 | Validation Accurancy:  0.9988297780510038
Epoch:  1291 | Train Accurancy:  0.9550410024821758 | Validation Accurancy:  0.9988315582741052
Epoch:  1292 | Train Accurancy:  0.9550440832972527 | Validation Accurancy:  0.9988333065994084
Epoch:  1293 | Train Accurancy:  0.9550472423434258 | Validation Accurancy:  0.9988349279155955
Epoch:  1294 | Train Accurancy:  0.9550503678619862 | Validation Accurancy:  0.9988367717014626
Epoch:  1295 | Train Accurancy:  0.9550535306334496 | Validation Accurancy:  0.9988385835895315
Epoch:  1296 | Train Accurancy:  0.9550565779209137 | Validation Accurancy:  0.9988401413429528
Epoch:  1297 | Train Accurancy:  0.95505

Epoch:  1396 | Train Accurancy:  0.9553194455802441 | Validation Accurancy:  0.9989927610149607
Epoch:  1397 | Train Accurancy:  0.9553216099739075 | Validation Accurancy:  0.9989941279636696
Epoch:  1398 | Train Accurancy:  0.9553238116204739 | Validation Accurancy:  0.9989954630145803
Epoch:  1399 | Train Accurancy:  0.9553260281682014 | Validation Accurancy:  0.9989969253074378
Epoch:  1400 | Train Accurancy:  0.9553281851112843 | Validation Accurancy:  0.9989982922561467
Epoch:  1401 | Train Accurancy:  0.955330353230238 | Validation Accurancy:  0.9989995320793241
Epoch:  1402 | Train Accurancy:  0.9553325287997723 | Validation Accurancy:  0.999000899028033
Epoch:  1403 | Train Accurancy:  0.9553346857428551 | Validation Accurancy:  0.9990023295395076
Epoch:  1404 | Train Accurancy:  0.9553368464112282 | Validation Accurancy:  0.9990035692462698
Epoch:  1405 | Train Accurancy:  0.9553389921784401 | Validation Accurancy:  0.9990049997577444
Epoch:  1406 | Train Accurancy:  0.9553412

Epoch:  1509 | Train Accurancy:  0.9555259272456169 | Validation Accurancy:  0.9991319338441826
Epoch:  1510 | Train Accurancy:  0.9555274173617363 | Validation Accurancy:  0.9991329192998819
Epoch:  1511 | Train Accurancy:  0.9555289261043072 | Validation Accurancy:  0.999134095502086
Epoch:  1512 | Train Accurancy:  0.9555304199457169 | Validation Accurancy:  0.9991350809577852
Epoch:  1513 | Train Accurancy:  0.9555318839848042 | Validation Accurancy:  0.9991361935972236
Epoch:  1514 | Train Accurancy:  0.9555333107709885 | Validation Accurancy:  0.9991374651435763
Epoch:  1515 | Train Accurancy:  0.9555347971618176 | Validation Accurancy:  0.9991386095643975
Epoch:  1516 | Train Accurancy:  0.9555362053215504 | Validation Accurancy:  0.9991396268014796
Epoch:  1517 | Train Accurancy:  0.9555376544594765 | Validation Accurancy:  0.9991407712223008
Epoch:  1518 | Train Accurancy:  0.9555390700697899 | Validation Accurancy:  0.9991418838617392
Epoch:  1519 | Train Accurancy:  0.955540