In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline

##### 处理数据集

In [2]:
data = pd.read_csv("./dataset/HR.csv")

In [3]:
data = data.join(pd.get_dummies(data.salary))
del  data['salary']
data = data.join(pd.get_dummies(data.part))
del  data['part']

In [4]:
X_data = data[[c for c in data.columns if c != 'left']].values
X      = torch.from_numpy(X_data).type(torch.FloatTensor)
Y_data = data[[c for c in data.columns if c == 'left']].values
Y      = torch.from_numpy(Y_data).type(torch.FloatTensor)

##### 创建模型

In [5]:
from torch import nn
import torch.nn.functional as F

In [6]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(20,64)
        self.linear_2 = nn.Linear(64,64)
        self.linear_3 = nn.Linear(64,1)
        
    def forward(self,input):
        x = self.linear_1(input)
        x = F.relu(x)
        x = self.linear_2(x)
        x = F.relu(x)
        x = self.linear_3(x)
        x = F.sigmoid(x)
        return x

##### 训练模型

In [7]:
def get_model():
    model = Model()
    opt   = torch.optim.Adam(model.parameters(),lr=0.0001)
    return model,opt

In [8]:
model,optim       = get_model()
loss_fn           = nn.BCELoss()

batch_size        = 64
number_of_batches = len(data)//batch_size
epoches           = 100

In [9]:
for epoch in range(epoches):
    for i in range(number_of_batches):
        start = i*batch_size
        end   = i*batch_size + batch_size
        x     = X[start:end]
        y     = Y[start:end]
        y_pred= model(x)
        loss  = loss_fn(y_pred,y)
        optim.zero_grad()
        loss.backward()
        optim.step()
    with torch.no_grad():
        loss_fn(model(X),Y)
        print("epoch:\t",epoch,"\t\t\t","loss:\t",loss_fn(model(X),Y).data.item())



epoch:	 0 			 loss:	 0.7075486779212952
epoch:	 1 			 loss:	 0.7288095355033875
epoch:	 2 			 loss:	 0.7137643694877625
epoch:	 3 			 loss:	 0.7012246251106262
epoch:	 4 			 loss:	 0.6933883428573608
epoch:	 5 			 loss:	 0.6841027736663818
epoch:	 6 			 loss:	 0.6689909100532532
epoch:	 7 			 loss:	 0.6601163744926453
epoch:	 8 			 loss:	 0.6510088443756104
epoch:	 9 			 loss:	 0.6418108344078064
epoch:	 10 			 loss:	 0.6326261162757874
epoch:	 11 			 loss:	 0.6209865212440491
epoch:	 12 			 loss:	 0.608540415763855
epoch:	 13 			 loss:	 0.6012986898422241
epoch:	 14 			 loss:	 0.594216525554657
epoch:	 15 			 loss:	 0.5816237926483154
epoch:	 16 			 loss:	 0.5761025547981262
epoch:	 17 			 loss:	 0.5701773166656494
epoch:	 18 			 loss:	 0.5663467645645142
epoch:	 19 			 loss:	 0.5637797117233276
epoch:	 20 			 loss:	 0.5613133907318115
epoch:	 21 			 loss:	 0.5603255033493042
epoch:	 22 			 loss:	 0.5841894745826721
epoch:	 23 			 loss:	 0.5597661137580872
epoch:	 24 			 loss:	 0.5598