batch normalization

In [19]:
import torch
import torchvision

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(1)
if device == 'cuda':
    torch.cuda.manual_seed_all(1)

In [20]:
learning_rate = 0.01
training_epochs = 10
batch_size = 32

In [21]:
mnist_train = torchvision.datasets.MNIST(root = 'MNIST_data/', train=True, transform=torchvision.transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST(root = 'MNIST_data/', train=False, transform=torchvision.transforms.ToTensor(), download=True)


In [22]:
data_loader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)

In [23]:
linear1 = torch.nn.Linear(784, 32, bias=True)
linear2 = torch.nn.Linear(32, 32, bias=True)
linear3 = torch.nn.Linear(32, 10, bias=True)
relu = torch.nn.ReLU()
bn1 = torch.nn.BatchNorm1d(32)
bn2 = torch.nn.BatchNorm1d(32)

nn_linear1 = torch.nn.Linear(784, 32, bias=True)
nn_linear2 = torch.nn.Linear(32, 32, bias=True)
nn_linear3 = torch.nn.Linear(32, 10, bias=True)

bn_model = torch.nn.Sequential(linear1, bn1, relu,
                            linear2, bn2, relu,
                            linear3).to(device)
nn_model = torch.nn.Sequential(nn_linear1, relu,
                               nn_linear2, relu,
                               nn_linear3).to(device)

criterion = torch.nn.CrossEntropyLoss().to(device)
bn_optimizer = torch.optim.Adam(bn_model.parameters(), lr=learning_rate)
nn_optimizer = torch.optim.Adam(nn_model.parameters(), lr=learning_rate)


In [24]:
total_batch = len(data_loader)
for epoch in range(training_epochs):
    bn_model.train()
    bn_avg_cost = 0
    nn_avg_cost = 0

    for X, Y in data_loader:
        X = X.view(-1, 28*28).to(device)
        Y = Y.to(device)

        bn_optimizer.zero_grad()
        bn_prediction = bn_model(X)
        bn_loss = criterion(bn_prediction, Y)
        bn_loss.backward()
        bn_optimizer.step()

        nn_optimizer.zero_grad()
        nn_prediction = nn_model(X)
        nn_loss = criterion(nn_prediction, Y)
        nn_loss.backward()
        nn_optimizer.step()

        bn_avg_cost += bn_loss.item() / total_batch
        nn_avg_cost += nn_loss.item() / total_batch
    
    print('Epoch:', '%04d' % (epoch + 1), 'cost = bn : {:.9f}, nn : {:.9f}'.format(bn_avg_cost, nn_avg_cost))
    
    with torch.no_grad():
        bn_model.eval()
        X_test = mnist_test.data.view(-1, 28 * 28).float().to(device)
        Y_test = mnist_test.targets.to(device)
        bn_prediction = bn_model(X_test)
        bn_correct_prediction = torch.argmax(bn_prediction, 1) == Y_test
        nn_prediction = nn_model(X_test)
        nn_correct_prediction = torch.argmax(nn_prediction, 1) == Y_test

        bn_accuracy = bn_correct_prediction.float().mean()
        print('eval bn Accuracy :', bn_accuracy.item())
        nn_accuracy = nn_correct_prediction.float().mean()
        print('eval nn Accuracy :', nn_accuracy.item())

        X_train = mnist_train.data.view(-1, 28 * 28).float().to(device)
        Y_train = mnist_train.targets.to(device)
        bn_prediction = bn_model(X_train)
        bn_correct_prediction = torch.argmax(bn_prediction, 1) == Y_train
        nn_prediction = nn_model(X_train)
        nn_correct_prediction = torch.argmax(nn_prediction, 1) == Y_train
        bn_accuracy = bn_correct_prediction.float().mean()
        print('train bn Accuracy :', bn_accuracy.item())
        nn_accuracy = nn_correct_prediction.float().mean()
        print('train nn Accuracy :', nn_accuracy.item())
    
print('Learning finished')

Epoch: 0001 cost = bn : 0.302100966, nn : 0.298617792
eval bn Accuracy : 0.9352999925613403
eval nn Accuracy : 0.9390999674797058
train bn Accuracy : 0.9410666823387146
train nn Accuracy : 0.9447333216667175
Epoch: 0002 cost = bn : 0.211172594, nn : 0.196754268
eval bn Accuracy : 0.9384999871253967
eval nn Accuracy : 0.9213999509811401
train bn Accuracy : 0.9409166574478149
train nn Accuracy : 0.9280499815940857
Epoch: 0003 cost = bn : 0.186399662, nn : 0.177613062
eval bn Accuracy : 0.942799985408783
eval nn Accuracy : 0.9473999738693237
train bn Accuracy : 0.9499666690826416
train nn Accuracy : 0.9551500082015991
Epoch: 0004 cost = bn : 0.178260933, nn : 0.171262613
eval bn Accuracy : 0.9434999823570251
eval nn Accuracy : 0.9458999633789062
train bn Accuracy : 0.9517666697502136
train nn Accuracy : 0.9498500227928162
Epoch: 0005 cost = bn : 0.172060434, nn : 0.155702868
eval bn Accuracy : 0.9407999515533447
eval nn Accuracy : 0.9497999548912048
train bn Accuracy : 0.9499666690826416
