In [None]:
#新增train_acc,valid_acc
def train_model(model,device, patience, n_epochs):
    
    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the training accuracy per epoch as the model trains
    train_accuracies = [] 
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 
    # to track the validation accuracy per epoch as the model trains
    valid_accuracies = [] 
    # initialize the early_stopping object
    early_stopping = EarlyStopping("val_acc",patience=patience, verbose=True,delta=0)
    
    for epoch in range(1, n_epochs + 1):
 
        ###################
        # train the model #
        ###################
        model.train() # prep model for training
        train_correct=0
        for step, (X, y) in enumerate(train_loader):
            X, y = X.to(device), y.to(device)
            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(X)
            # calculate the loss
            loss = loss_func(output, y)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # record training loss
            train_losses.append(loss.item())
            # update the train_correct label
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            train_correct += pred.eq(y.view_as(pred)).sum().item() 
        ######################    
        # validate the model #
        ######################
        model.eval() # prep model for evaluation
        # to mark the correct label as the model trains
        val_correct = 0
        for step, (X, y) in enumerate(valid_loader):
            X, y = X.to(device), y.to(device)
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(X)
            # calculate the loss
            loss = loss_func(output, y)
            # record validation loss
            valid_losses.append(loss.item())
            # update the val_correct label
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            val_correct += pred.eq(y.view_as(pred)).sum().item() 
        # print training/validation statistics 
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        # calculate valid accuracy over an epoch
        train_acc=100. * train_correct / len(train_loader.dataset)
        valid_acc=100. * val_correct / len(valid_loader.dataset)
        train_accuracies.append(train_acc)
        valid_accuracies.append(valid_acc)
        
        epoch_len = len(str(n_epochs))
        
        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'train_accuracy: {train_correct}/{len(train_loader.dataset)} ({train_acc:.5f})% ' +
                     f'\n    valid_loss: {valid_loss:.5f} ' +
                     f'valid_accuracy: {val_correct}/{len(valid_loader.dataset)} ({valid_acc:.5f})%')
        
        print(print_msg)
        
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        # early_stopping needs the validation acc to check if it has incresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_acc, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
    # load the last checkpoint with the best model
    model.load_state_dict(torch.load('checkpoint.pt'))
 
    return  model, avg_train_losses, avg_valid_losses,train_accuracies,valid_accuracies

In [None]:
lr = 0.01
momentum = 0.5
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)#这个是和设备GPU/CPU相关的系数。
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
loss_func = nn.CrossEntropyLoss()  # the target label is NOT an one-hotted
train_loader=minibatch(X_train,y_train)
valid_loader=minibatch(X_eval,y_eval)
test_loader=minibatch(X_test,y_test)

In [None]:
n_epochs=50
patience = 20
model, train_loss, valid_loss, train_acc, valid_acc = train_model(model,device, patience, n_epochs)

In [None]:
# visualize the loss as the network trained
fig = plt.figure(figsize=(10,8))
plt.plot(range(1,len(train_loss)+1),train_loss, label='Training Loss')
plt.plot(range(1,len(valid_loss)+1),valid_loss,label='Validation Loss')

# find position of lowest validation loss
minposs = valid_loss.index(min(valid_loss))+1 
plt.axvline(minposs, linestyle='--', color='r',label='Early Stopping Checkpoint')

plt.xlabel('epochs')
plt.ylabel('loss')
plt.ylim(0, 0.5) # consistent scale
plt.xlim(0, len(train_loss)+1) # consistent scale
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
fig.savefig('loss_plot.png', bbox_inches='tight')

In [None]:
# visualize the accuracy as the network trained
fig = plt.figure(figsize=(10,8))
plt.plot(range(1,len(train_acc)+1),train_acc, label='Training Accuracy')
plt.plot(range(1,len(valid_acc)+1),valid_acc,label='Validation Accuracy')

# find position of lowest validation loss
maxposs = valid_acc.index(max(valid_acc))+1 
plt.axvline(maxposs, linestyle='--', color='r',label='Early Stopping Checkpoint')

plt.xlabel('epochs')
plt.ylabel('Accuracy')
plt.ylim(0, 100) # consistent scale
plt.xlim(0, len(train_acc)+1) # consistent scale
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
fig.savefig('Accuracy_plot.png', bbox_inches='tight')