In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn

from sklearn import datasets
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

In [3]:
def NNtrain(X, y, model, task, lossf, optimizer, epochs, train_size, device = 'CPU'):
    """
    X, 特征向量矩阵nxd
    y, 标签
    model, 神经网络模型
    task, 回归'R'或者分类'C'
    lossf, 损失函数回归nn.MSELoss()或者分类nn.CrossEntropyLoss()
    optimizer, 优化器'Adam' 'RMSprop' 'Adadelta' 'Adagrad' 'SGD'
    epochs, 迭代次数
    train_size, 训练集大小
    device = 'CPU', 使用CPU还是GPU
    """
    model.train()
    
    x_train, x_test, y_train, y_test = train_test_split(X, y, random_state=1, train_size=train_size)
    
    if task == 'C':
        x_train_tr = torch.from_numpy(x_train).float()
        y_train_tr = torch.from_numpy(y_train).long()
        x_test_tr = torch.from_numpy(x_test).float()
        y_test_tr = torch.from_numpy(y_test).long()
    if task == 'R':
        x_train_tr = torch.from_numpy(x_train).float()
        y_train_tr = torch.from_numpy(y_train).float()
        x_test_tr = torch.from_numpy(x_test).float()
        y_test_tr = torch.from_numpy(y_test).float()
        
    if device =='GPU':
        model_used = model.cuda()
        x_train_tr_used = x_train_tr.cuda()
        y_train_tr_used = y_train_tr.cuda()
        x_test_tr_used = x_test_tr.cuda()
        y_test_tr_used = y_test_tr.cuda()
    if device == 'CPU':
        model_used = model
        x_train_tr_used = x_train_tr
        y_train_tr_used = y_train_tr
        x_test_tr_used = x_test_tr
        y_test_tr_used = y_test_tr
    
    optimizer = optimizer
    loss_func = lossf
    loss_val_past = 9999
    loss_past = 9999
        
    for epoch in range(epochs):
        predict = model_used(x_train_tr_used)
        loss = loss_func(predict, y_train_tr_used)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        out_val = model_used(x_test_tr_used)
        loss_val = loss_func(out_val, y_test_tr_used)
        
        pred_y = torch.max(predict, 1)[1].data.squeeze()
        accuracy = sum(pred_y.numpy()==y_train_tr_used.numpy())/len(y_train_tr_used.numpy())
        
        pred_y_val = torch.max(out_val, 1)[1].data.squeeze()
        accuracy_val = sum(pred_y_val.numpy()==y_test_tr_used.numpy())/len(y_test_tr_used.numpy())
        
        print('Epoch: ', epoch+1, '| Train Accurancy: ', accuracy, '| Validation Accurancy: ', accuracy_val)

        if (float(loss_val) > float(loss_val_past) and epoch>99) or ((float(loss_past)-float(loss))<1e-7 and epoch>99):
            break
        else:
            loss_val_past = loss_val
            loss_past = loss

In [4]:
class Net(torch.nn.Module):  # 继承 torch 的 Module
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()     # 继承 __init__ 功能
        # 定义每层用什么样的形式
        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # 隐藏层线性输出
        self.predict = torch.nn.Linear(n_hidden, n_output)   # 输出层线性输出

    def forward(self, x):   # 这同时也是 Module 中的 forward 功能
        # 正向传播输入值, 神经网络分析出输出值
        x = F.relu(self.hidden(x))      # 激励函数(隐藏层的线性值)
        x = self.predict(x)             # 输出值
        return x

net = Net(n_feature=4, n_hidden=9, n_output=4)

iris = datasets.load_iris()

NNtrain(
    X=iris.data, 
    y=iris.target, 
    model=net, 
    task='C', 
    lossf=nn.CrossEntropyLoss(), 
    optimizer=torch.optim.Adam(net.parameters(), lr=0.01), 
    epochs=30000, 
    train_size=0.9, 
    device = 'CPU')



Epoch:  1 | Train Accurancy:  0.32592592592592595 | Validation Accurancy:  0.4
Epoch:  2 | Train Accurancy:  0.32592592592592595 | Validation Accurancy:  0.4
Epoch:  3 | Train Accurancy:  0.32592592592592595 | Validation Accurancy:  0.4
Epoch:  4 | Train Accurancy:  0.32592592592592595 | Validation Accurancy:  0.4
Epoch:  5 | Train Accurancy:  0.4074074074074074 | Validation Accurancy:  0.26666666666666666
Epoch:  6 | Train Accurancy:  0.34074074074074073 | Validation Accurancy:  0.3333333333333333
Epoch:  7 | Train Accurancy:  0.45925925925925926 | Validation Accurancy:  0.4666666666666667
Epoch:  8 | Train Accurancy:  0.6074074074074074 | Validation Accurancy:  0.6
Epoch:  9 | Train Accurancy:  0.6666666666666666 | Validation Accurancy:  0.6
Epoch:  10 | Train Accurancy:  0.674074074074074 | Validation Accurancy:  0.6
Epoch:  11 | Train Accurancy:  0.674074074074074 | Validation Accurancy:  0.6
Epoch:  12 | Train Accurancy:  0.674074074074074 | Validation Accurancy:  0.6
Epoch:  13 |

Epoch:  172 | Train Accurancy:  0.9703703703703703 | Validation Accurancy:  1.0
Epoch:  173 | Train Accurancy:  0.9703703703703703 | Validation Accurancy:  1.0
Epoch:  174 | Train Accurancy:  0.9703703703703703 | Validation Accurancy:  1.0
Epoch:  175 | Train Accurancy:  0.9703703703703703 | Validation Accurancy:  1.0
Epoch:  176 | Train Accurancy:  0.9703703703703703 | Validation Accurancy:  1.0
Epoch:  177 | Train Accurancy:  0.9703703703703703 | Validation Accurancy:  1.0
Epoch:  178 | Train Accurancy:  0.9703703703703703 | Validation Accurancy:  1.0
Epoch:  179 | Train Accurancy:  0.9703703703703703 | Validation Accurancy:  1.0
Epoch:  180 | Train Accurancy:  0.9703703703703703 | Validation Accurancy:  1.0
Epoch:  181 | Train Accurancy:  0.9703703703703703 | Validation Accurancy:  1.0
Epoch:  182 | Train Accurancy:  0.9703703703703703 | Validation Accurancy:  1.0
Epoch:  183 | Train Accurancy:  0.9703703703703703 | Validation Accurancy:  1.0
Epoch:  184 | Train Accurancy:  0.970370

Epoch:  362 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  363 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  364 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  365 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  366 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  367 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  368 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  369 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  370 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  371 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  372 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  373 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  374 | Train Accurancy:  0.977777

Epoch:  557 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  558 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  559 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  560 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  561 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  562 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  563 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  564 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  565 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  566 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  567 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  568 | Train Accurancy:  0.9777777777777777 | Validation Accurancy:  1.0
Epoch:  569 | Train Accurancy:  0.977777

Epoch:  752 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  753 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  754 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  755 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  756 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  757 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  758 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  759 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  760 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  761 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  762 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  763 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  764 | Train Accurancy:  0.985185

Epoch:  936 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  937 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  938 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  939 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  940 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  941 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  942 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  943 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  944 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  945 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  946 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  947 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  948 | Train Accurancy:  0.985185

Epoch:  1125 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1126 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1127 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1128 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1129 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1130 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1131 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1132 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1133 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1134 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1135 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1136 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1137 | Train Accuran

Epoch:  1318 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1319 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1320 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1321 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1322 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1323 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1324 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1325 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1326 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1327 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1328 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1329 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1330 | Train Accuran

Epoch:  1512 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1513 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1514 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1515 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1516 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1517 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1518 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1519 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1520 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1521 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1522 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1523 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1524 | Train Accuran

Epoch:  1701 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1702 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1703 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1704 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1705 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1706 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1707 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1708 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1709 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1710 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1711 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1712 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1713 | Train Accuran

Epoch:  1883 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1884 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1885 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1886 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1887 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1888 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1889 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1890 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1891 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1892 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1893 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1894 | Train Accurancy:  0.9851851851851852 | Validation Accurancy:  1.0
Epoch:  1895 | Train Accuran