In [1]:
import torch 
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline

##### 处理数据集

In [2]:
data = pd.read_csv("dataset/HR.csv")

In [3]:
data = data.join(pd.get_dummies(data.salary))
del  data['salary']
data = data.join(pd.get_dummies(data.part))
del  data['part']

In [4]:
X_data = data[[c for c in data.columns if c != 'left']].values
X      = torch.from_numpy(X_data).type(torch.FloatTensor)

In [5]:
Y_data = data[[c for c in data.columns if c == 'left']].values
Y      = torch.from_numpy(Y_data).type(torch.FloatTensor)

##### 创建模型

In [6]:
from torch import nn

In [7]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(20,64)
        self.linear_2 = nn.Linear(64,64)
        self.linear_3 = nn.Linear(64,1)
        self.relu     = nn.ReLU()
        self.sigmoid  = nn.Sigmoid()
    
    def forward(self,input):
        x = self.linear_1(input)
        x = self.relu(x)
        x = self.linear_2(x)
        x = self.relu(x)
        x = self.linear_3(x)
        x = self.sigmoid(x)
        return x

##### 训练模型

In [8]:
def get_model():
    model = Model()
    opt   = torch.optim.Adam(model.parameters(),lr=0.0001)
    return model,opt

In [9]:
model,optim      = get_model()
loss_fn          = nn.BCELoss()

batch_size       = 64
number_of_batches= len(data)//batch_size
epochs           = 100

In [10]:
from torch.utils.data import TensorDataset

In [11]:
HRdataset = TensorDataset(X,Y)

In [12]:
for epoch in range(epochs):
    for i in range(number_of_batches):
        x,y    = HRdataset[i*batch_size:i*batch_size+batch_size]
        y_pred = model(x)
        loss   = loss_fn(y_pred,y)
        optim.zero_grad()
        loss.backward()
        optim.step()
    with torch.no_grad():
        print("epoch:",epoch,"\t\t\t","loss:",loss_fn(model(X),Y).data.item())

epoch: 0 			 loss: 0.7335248589515686
epoch: 1 			 loss: 0.7064924240112305
epoch: 2 			 loss: 0.6973234415054321
epoch: 3 			 loss: 0.6871408820152283
epoch: 4 			 loss: 0.6662421226501465
epoch: 5 			 loss: 0.6511802077293396
epoch: 6 			 loss: 0.6426291465759277
epoch: 7 			 loss: 0.6867966055870056
epoch: 8 			 loss: 0.6579023003578186
epoch: 9 			 loss: 0.6410834193229675
epoch: 10 			 loss: 0.6280080080032349
epoch: 11 			 loss: 0.6102624535560608
epoch: 12 			 loss: 0.6004673838615417
epoch: 13 			 loss: 0.592176616191864
epoch: 14 			 loss: 0.5910249948501587
epoch: 15 			 loss: 0.5791337490081787
epoch: 16 			 loss: 0.5749177932739258
epoch: 17 			 loss: 0.5698032975196838
epoch: 18 			 loss: 0.566704273223877
epoch: 19 			 loss: 0.5640634894371033
epoch: 20 			 loss: 0.5621867179870605
epoch: 21 			 loss: 0.5609436631202698
epoch: 22 			 loss: 0.5601813197135925
epoch: 23 			 loss: 0.5597779154777527
epoch: 24 			 loss: 0.5640770792961121
epoch: 25 			 loss: 0.560239732265472