In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torch.utils.data as torchdata
import torchvision.transforms as transforms
import torchvision.datasets as vdatasets
import torchvision.utils as vutils
torch.manual_seed(1)

<torch._C.Generator at 0x7fc67838c050>

### Baseline Instruction

1. MNIST 데이터를 Train, Test 둘 다 로딩하고 데이터 로더를 만든다. (데이터의 path는 "../../data/MNIST")
2. FFN(Feed forward network)를 모델링하는데
  - 데이터의 Input 차원은 784(28*28), Output 차원은 10
  - 히든 레이어를 2개를 사용하는데 각각의 차원은 512, 512
  - Activation function은 tanh를 사용
  - Xavier_normal을 이용해서 weight를 초기화하고, bias는 0.1로 초기화
  - Dropout을 마지막 아웃풋을 제외하고 적용한다. (drop probability는 0.3)
3. Optimizer는 SGD을 사용하고 learning rate는 각자가 적절한 값으로 사용, 또한 0.00001만큼의 weight decay를 준다
4. EPOCH = 10, BATCH_SIZE = 32
4. test 데이터셋을 사용해서 accuracy를 측정해본다
5. 5번에서 찍어본 accuracy 보다 더 좋은 accuracy를 얻기 위해 위의 설정들을 바꿔본다(더 높은 accuracy를 얻는다면 성공)

### TODO

In [4]:
BATCH_SIZE = 32

In [5]:
train_dataset = vdatasets.MNIST(root='../../data/MNIST/',
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)


train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE, 
                                           shuffle=True,
                                           num_workers=2,
                                           drop_last=True) # 이동평균이 튀는걸 방지

test_dataset = vdatasets.MNIST(root='../../data/MNIST/',
                               train=False, 
                               transform=transforms.ToTensor(),
                               download=True)


test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                           batch_size=BATCH_SIZE, 
                                           shuffle=True,
                                           num_workers=2)

In [15]:
class FFN(nn.Module):
    
    def __init__(self,hidden_size):
        super(FFN,self).__init__()
        
        self.l1 = nn.Linear(784,hidden_size)
        self.l2 = nn.Linear(hidden_size,hidden_size)
        self.l3 = nn.Linear(hidden_size,10)
        self.activation = nn.Tanh()
        self.dropout = nn.Dropout(0.3)
    
    def init_weight(self):
        for name, param in self.named_parameters():
            if 'weight' in name:
                param.data = nn.init.xavier_normal(param.data)
            else:
                param.data = nn.init.constant(param.data,0.1)
                
    def forward(self,inputs):
        outputs = self.activation(self.l1(inputs))
        outputs = self.dropout(outputs)
        outputs = self.activation(self.l2(outputs))
        outputs = self.dropout(outputs)
        
        return self.l3(outputs)

In [17]:
EPOCH = 10
LR=0.001
LAMDA = 0.00001

model = FFN(512)
model.init_weight()

loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=LR,weight_decay=LAMDA)

In [23]:
model.train()
for epoch in range(EPOCH):
    losses=[]
    for i,(inputs,targets) in enumerate(train_loader):
        inputs,targets = Variable(inputs.view(-1,784)), Variable(targets)
        
        model.zero_grad()
        preds = model(inputs)
        loss = loss_function(preds,targets)
        losses.append(loss.data[0])
        
        loss.backward()
        optimizer.step()
    
    print("[%d/%d] mean_loss : %.3f" % (epoch,EPOCH,np.mean(losses)))

[0/10] mean_loss : 1.323
[1/10] mean_loss : 0.673
[2/10] mean_loss : 0.536
[3/10] mean_loss : 0.473
[4/10] mean_loss : 0.436
[5/10] mean_loss : 0.413
[6/10] mean_loss : 0.396
[7/10] mean_loss : 0.382
[8/10] mean_loss : 0.373
[9/10] mean_loss : 0.365


### Accuracy 측정 

In [24]:
model.eval()
num_hit=0
for i,(inputs,targets) in enumerate(test_loader):
    inputs, targets = Variable(inputs).view(-1,784), Variable(targets)
    model.zero_grad()
    preds = model(inputs)
    preds = preds.max(1)[1]
    num_hit += torch.eq(preds,targets).sum().data[0]
    
print("accuracy : ",num_hit / len(test_dataset))

accuracy :  0.9101
