In [None]:
import torch
import torch.nn as nn
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.optim as optim

In [None]:
batch_size = 16
learning_rate = 0.002
num_epoch = 200


In [None]:
class myCNN(nn.Module):
  def __init__(self):
    super(myCNN, self).__init__()
    # 1단계 : CNN layer
    self.cnn_layer = nn.Sequential(
        # conv + relu -> 1, 16, 5 (5 x 5 filter), padding = 2
        nn.Conv2d(1,16,5,padding=2), # 28 x 28 x 1 --> 28 28 16
        nn.ReLU(),

        # conv + relu
        nn.Conv2d(16,32,5,padding=2), # 28 x 28 x 16 --> 28 28 32
        nn.ReLU(),
        # pool : 28x28 --> 14x14
        nn.MaxPool2d(2,2),
        # conv + relu
        nn.Conv2d(32,64,5,padding=2), # 14 x 14 x 32 --> 14 14 64
        nn.ReLU(),
        # pool : 14x14 --> 7x7
        nn.MaxPool2d(2,2)
    )   # cnn_layer의 출력 : 7x7x64
    # 2단계 : FC layer (fully connected)
    self.fc_layer = nn.Sequential (
        nn.Linear(64*7*7,100),
        nn.ReLU(),
        nn.Linear(100,10)
    )
  def forward(self, x):
    out = self.cnn_layer(x)       # out : batch_size x 7x7x64 4d tensor
    out = out.view(batch_size,-1) # out : batch_size x 7*7*64 2d tensor
    out = self.fc_layer(out)      # fc_layer의 input : 7*7*64x1 1d tensor
    return out

In [None]:
mnist_train = dset.MNIST("../", train=True, transform=transforms.ToTensor(), target_transform=None,download=True)
mnist_test = dset.MNIST("../", train=False, transform=transforms.ToTensor(), target_transform=None,download=True)

In [None]:
train_loader = torch.utils.data.DataLoader(list(mnist_train)[:batch_size*100],batch_size = batch_size,shuffle=True,num_workers=2,drop_last=True)

test_loader = torch.utils.data.DataLoader((mnist_test),batch_size = batch_size,shuffle=False,num_workers=2,drop_last=True)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = myCNN().to(device)

loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)

In [None]:
def EstimateAccuracy (dloader, imodel):
  correct = 0
  total = 0

  for image, label in dloader:
    x = Variable(image, volatile=True).to(device)
    y = Variable(label).to(device)

    y_hat = imodel.forward(x)
    _, y_hat_index = torch.max(y_hat,1)

    total += label.size(0)
    correct += (y_hat_index == y).sum().float()

  print("Accuracy: {}" .format(100*correct/total))
  return 100*correct/total

In [None]:
loss_arr = []
accu_arr = []

for i in range(num_epoch):
  for image, label in train_loader:
    x = Variable(image).to(device)
    y = Variable(label).to(device)

    optimizer.zero_grad()
    y_hat = model.forward(x)
    loss = loss_func(y_hat,y)
    loss.backward()
    optimizer.step()

  if i%5 == 0:
      print(i,loss)
      accu = EstimateAccuracy(test_loader, model)
      loss_arr.append(loss)
      accu_arr.append(accu)


0 tensor(6.7234e-05, device='cuda:0', grad_fn=<NllLossBackward0>)


  x = Variable(image, volatile=True).to(device)


Accuracy: 94.23999786376953
5 tensor(0.0005, device='cuda:0', grad_fn=<NllLossBackward0>)
Accuracy: 94.20999908447266
10 tensor(0.0006, device='cuda:0', grad_fn=<NllLossBackward0>)
Accuracy: 94.22999572753906
15 tensor(0.0004, device='cuda:0', grad_fn=<NllLossBackward0>)
Accuracy: 94.22999572753906
20 tensor(0.0003, device='cuda:0', grad_fn=<NllLossBackward0>)
Accuracy: 94.22000122070312
25 tensor(0.0009, device='cuda:0', grad_fn=<NllLossBackward0>)
Accuracy: 94.22999572753906
30 tensor(0.0005, device='cuda:0', grad_fn=<NllLossBackward0>)
Accuracy: 94.22000122070312
35 tensor(0.0003, device='cuda:0', grad_fn=<NllLossBackward0>)
Accuracy: 94.22999572753906
40 tensor(0.0001, device='cuda:0', grad_fn=<NllLossBackward0>)
Accuracy: 94.23999786376953
45 tensor(4.1608e-05, device='cuda:0', grad_fn=<NllLossBackward0>)
Accuracy: 94.23999786376953
50 tensor(0.0004, device='cuda:0', grad_fn=<NllLossBackward0>)
Accuracy: 94.22999572753906
55 tensor(0.0002, device='cuda:0', grad_fn=<NllLossBackward

KeyboardInterrupt: 