In [1]:
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch

# device setting
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)

# Data download
mnist_train = dsets.MNIST(root = "MNIST_data/", train = True, transform = transforms.ToTensor(), download= True)
mnist_test = dsets.MNIST(root = "MNIST_data/", train = False, transform = transforms.ToTensor(), download= True)


In [2]:
# parameters
learning_rate = 0.001
training_epochs = 15
batch_size = 100

In [3]:
# Dataloader
from torch.utils.data import DataLoader

data_loader = DataLoader(dataset= mnist_train, batch_size= batch_size, shuffle= True, drop_last= True)

In [4]:
# model
linear1 = torch.nn.Linear(28*28, 256, bias = True).to(device)
linear2 = torch.nn.Linear(256, 256, bias = True).to(device)
linear3 = torch.nn.Linear(256, 10, bias = True).to(device)


# wegith를 초기화 (더 성능이 좋아짐)
torch.nn.init.normal_(linear1.weight) # weight가 지금 10*784의 크기인데, 이걸 dim = 1을 기준으로 mean을 구하면, (10,)이 되고 standardization이 잘 들어간것을 볼 수 있다.
torch.nn.init.normal_(linear2.weight)
torch.nn.init.normal_(linear3.weight)
relu = torch.nn.ReLU()

# model
model = torch.nn.Sequential(linear1,relu, linear2, relu, linear3).to(device)


# define cost/loss & optimizer
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [6]:
# train
total_batch = len(data_loader)
correct = 0

for epoch in range(training_epochs):
    avg_cost = 0
    correct = 0
    
    for X,Y in data_loader:
        X = X.view(-1,28*28).to(device)
        hypothesis = model(X)
        cost = criterion(hypothesis, Y)
        
        #backward
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()
        
        # calculate
        avg_cost += cost/total_batch
        correct += (torch.argmax(hypothesis,1).float() == Y).sum().item()
        
    print("Epoch: {:04d} cost = {:.4f} Acc = {:.4f} %".format(epoch,avg_cost,correct/(total_batch*batch_size)*100))
        

Epoch: 0000 cost = 129.7920 Acc = 74.7150 %
Epoch: 0001 cost = 36.0880 Acc = 88.8100 %
Epoch: 0002 cost = 23.1344 Acc = 91.5000 %
Epoch: 0003 cost = 16.0726 Acc = 93.1300 %
Epoch: 0004 cost = 11.8583 Acc = 94.2667 %
Epoch: 0005 cost = 8.6889 Acc = 95.1683 %
Epoch: 0006 cost = 6.4800 Acc = 95.8317 %
Epoch: 0007 cost = 4.7156 Acc = 96.6033 %
Epoch: 0008 cost = 3.6496 Acc = 97.0367 %
Epoch: 0009 cost = 2.8156 Acc = 97.4167 %
Epoch: 0010 cost = 2.1127 Acc = 97.8483 %
Epoch: 0011 cost = 1.6560 Acc = 98.1117 %
Epoch: 0012 cost = 1.1724 Acc = 98.4917 %
Epoch: 0013 cost = 1.0253 Acc = 98.5450 %
Epoch: 0014 cost = 0.7893 Acc = 98.7967 %
