In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline

##### 处理数据集

In [2]:
data = pd.read_csv("dataset/HR.csv")

In [3]:
data = data.join(pd.get_dummies(data.salary))
del data['salary']
data = data.join(pd.get_dummies(data.part))
del data['part']

In [4]:
X_data = data[[c for c in data.columns if c != 'left']].values
X      = torch.from_numpy(X_data).type(torch.FloatTensor)

In [5]:
Y_data = data[[c for c in data.columns if c == 'left']].values
Y      = torch.from_numpy(Y_data).type(torch.FloatTensor)

##### 创建模型

In [6]:
from torch import nn

In [7]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(20,64)
        self.linear_2 = nn.Linear(64,64)
        self.linear_3 = nn.Linear(64,1)
        self.relu     = nn.ReLU()
        self.sigmoid  = nn.Sigmoid()
    
    def forward(self,input):
        x = self.linear_1(input)
        x = self.relu(x)
        x = self.linear_2(x)
        x = self.relu(x)
        x = self.linear_3(x)
        x = self.sigmoid(x)
        return x

##### 训练模型

In [8]:
def get_model():
    model = Model()
    opt   = torch.optim.Adam(model.parameters(),lr=0.0001)
    return model,opt

In [9]:
model,optim        = get_model()
loss_fn            = nn.BCELoss()

batch_size         = 64
numnber_of_batches = len(data)//batch_size
epoches            = 100

In [10]:
for epoch in range(epoches):
    for i in range(numnber_of_batches):
        start = i*batch_size
        end   = i*batch_size + batch_size
        x     = X[start:end]
        y     = Y[start:end]
        y_pred= model(x)
        loss  = loss_fn(y_pred,y)
        optim.zero_grad()
        loss.backward()
        optim.step()
    with torch.no_grad():
        loss_fn(model(X),Y)
        print("epoch:\t",epoch,"\t\t\t","loss:\t",loss_fn(model(X),Y).data.item())

epoch:	 0 			 loss:	 0.6028925776481628
epoch:	 1 			 loss:	 0.6054073572158813
epoch:	 2 			 loss:	 0.6053191423416138
epoch:	 3 			 loss:	 0.6011462807655334
epoch:	 4 			 loss:	 0.5955274701118469
epoch:	 5 			 loss:	 0.5895217061042786
epoch:	 6 			 loss:	 0.5837771892547607
epoch:	 7 			 loss:	 0.5786375403404236
epoch:	 8 			 loss:	 0.5741863250732422
epoch:	 9 			 loss:	 0.5705020427703857
epoch:	 10 			 loss:	 0.5675674676895142
epoch:	 11 			 loss:	 0.5652454495429993
epoch:	 12 			 loss:	 0.5636212825775146
epoch:	 13 			 loss:	 0.5628132820129395
epoch:	 14 			 loss:	 0.5624322295188904
epoch:	 15 			 loss:	 0.5627555251121521
epoch:	 16 			 loss:	 0.5668768286705017
epoch:	 17 			 loss:	 0.5667158961296082
epoch:	 18 			 loss:	 0.5668954849243164
epoch:	 19 			 loss:	 0.5671303272247314
epoch:	 20 			 loss:	 0.5672408938407898
epoch:	 21 			 loss:	 0.5674843788146973
epoch:	 22 			 loss:	 0.5675179362297058
epoch:	 23 			 loss:	 0.5674824118614197
epoch:	 24 			 loss:	 0.56