batch normalization

In [99]:
import torch
import torchvision

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(1)
if device == 'cuda':
    torch.cuda.manual_seed_all(1)

In [100]:
learning_rate = 0.01
training_epochs = 10
batch_size = 32

In [101]:
mnist_train = torchvision.datasets.MNIST(root = 'MNIST_data/', train=True, transform=torchvision.transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST(root = 'MNIST_data/', train=False, transform=torchvision.transforms.ToTensor(), download=True)


In [102]:
train_loader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)

In [103]:
linear1 = torch.nn.Linear(784, 32, bias=True)
linear2 = torch.nn.Linear(32, 32, bias=True)
linear3 = torch.nn.Linear(32, 10, bias=True)
relu = torch.nn.ReLU()
bn1 = torch.nn.BatchNorm1d(32)
bn2 = torch.nn.BatchNorm1d(32)

nn_linear1 = torch.nn.Linear(784, 32, bias=True)
nn_linear2 = torch.nn.Linear(32, 32, bias=True)
nn_linear3 = torch.nn.Linear(32, 10, bias=True)

bn_model = torch.nn.Sequential(linear1, bn1, relu,
                            linear2, bn2, relu,
                            linear3).to(device)
nn_model = torch.nn.Sequential(nn_linear1, relu,
                               nn_linear2, relu,
                               nn_linear3).to(device)

criterion = torch.nn.CrossEntropyLoss().to(device)
bn_optimizer = torch.optim.Adam(bn_model.parameters(), lr=learning_rate)
nn_optimizer = torch.optim.Adam(nn_model.parameters(), lr=learning_rate)


In [104]:
train_total_batch = len(train_loader)
for epoch in range(training_epochs):
    bn_model.train()
    bn_accuracy_train = 0
    nn_accuracy_train = 0
    bn_accuracy_test = 0
    nn_accuracy_test = 0
    bn_avg_cost = 0
    nn_avg_cost = 0

    for X, Y in train_loader:
        X = X.view(-1, 28*28).to(device)
        Y = Y.to(device)

        bn_optimizer.zero_grad()
        bn_prediction = bn_model(X)
        bn_loss = criterion(bn_prediction, Y)
        bn_loss.backward()
        bn_optimizer.step()

        nn_optimizer.zero_grad()
        nn_prediction = nn_model(X)
        nn_loss = criterion(nn_prediction, Y)
        nn_loss.backward()
        nn_optimizer.step()

        bn_avg_cost += bn_loss.item() / train_total_batch
        nn_avg_cost += nn_loss.item() / train_total_batch
    
    print('Epoch:', '%04d' % (epoch + 1), 'cost = bn : {:.9f}, nn : {:.9f}'.format(bn_avg_cost, nn_avg_cost))

    with torch.no_grad():
        bn_model.eval()
        X_test = (mnist_test.data.view(-1, 28 * 28)/255).float().to(device)
        Y_test = mnist_test.targets.to(device)
        bn_prediction = torch.argmax(bn_model(X_test), 1)
        bn_correct_prediction = bn_prediction == Y_test
        nn_prediction = torch.argmax(nn_model(X_test), 1)
        nn_correct_prediction = nn_prediction == Y_test

        bn_accuracy = bn_correct_prediction.float().mean()
        print('eval bn Accuracy :', bn_accuracy.item())
        nn_accuracy = nn_correct_prediction.float().mean()
        print('eval nn Accuracy :', nn_accuracy.item())

        X_train = (mnist_train.data.view(-1, 28 * 28)/255).float().to(device)
        Y_train = mnist_train.targets.to(device)

        bn_prediction = bn_model(X_train)
        bn_correct_prediction = torch.argmax(bn_prediction, 1) == Y_train
        nn_prediction = nn_model(X_train)
        nn_correct_prediction = torch.argmax(nn_prediction, 1) == Y_train
        bn_accuracy = bn_correct_prediction.float().mean()

        print('train bn Accuracy :', bn_accuracy.item())
        nn_accuracy = nn_correct_prediction.float().mean()
        print('train nn Accuracy :', nn_accuracy.item())

print('Learning finished')

Epoch: 0001 cost = bn : 0.288282830, nn : 0.298617792
eval bn Accuracy : 0.9541999697685242
eval nn Accuracy : 0.9386999607086182
train bn Accuracy : 0.9593333601951599
train nn Accuracy : 0.9438499808311462
Epoch: 0002 cost = bn : 0.176860546, nn : 0.196754268
eval bn Accuracy : 0.9639999866485596
eval nn Accuracy : 0.9287999868392944
train bn Accuracy : 0.970550000667572
train nn Accuracy : 0.934249997138977
Epoch: 0003 cost = bn : 0.143006775, nn : 0.177613062
eval bn Accuracy : 0.9657999873161316
eval nn Accuracy : 0.9490999579429626
train bn Accuracy : 0.9712666869163513
train nn Accuracy : 0.958216667175293
Epoch: 0004 cost = bn : 0.129210275, nn : 0.171262613
eval bn Accuracy : 0.9668999910354614
eval nn Accuracy : 0.9519999623298645
train bn Accuracy : 0.9757833480834961
train nn Accuracy : 0.9574833512306213
Epoch: 0005 cost = bn : 0.120301018, nn : 0.155702868
eval bn Accuracy : 0.9706999659538269
eval nn Accuracy : 0.957099974155426
train bn Accuracy : 0.9800500273704529
tra

In [105]:
# transform은 dataloader로 불러올 때 진행되기 때문에 mnist_test 나 mnist_train을 직접 사용해 accuracy를 계산하는 이번 상황에서는 직접 /255를 함