In [1]:
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim

# data prepare

In [2]:
use_cuda = torch.cuda.is_available()

root = './data'
if not os.path.exists(root):
    os.mkdir(root)

#trans = transforms.Compose([transforms.RandomCrop(32, padding=4),
#                transforms.RandomHorizontalFlip(), transforms.ToTensor(),
#                transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])])
trans = transforms.Compose([transforms.ToTensor()])

# if not exist, download mnist dataset
train_set = dset.CIFAR10(root=root, train=True, transform=trans, download=False)
test_set = dset.CIFAR10(root=root, train=False, transform=trans, download=False)

In [3]:
batch_size = 128

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))

==>>> total trainning batch number: 391
==>>> total testing batch number: 79


In [None]:
inputs, _ = next(iter(train_loader))
img = inputs[4]
img = img.permute(1,2,0)

import matplotlib.pyplot as plt
plt.imshow(img, interpolation='nearest')

# MODEL

In [7]:
class MLPNet(nn.Module):
    def __init__(self):
        super(MLPNet, self).__init__()
        self.fc1 = nn.Linear(32*32*3, 500)
        self.fc2 = nn.Linear(500, 256)
        self.fc3 = nn.Linear(256, 10)
        
    def forward(self, x):
        x = x.view(-1, 32*32*3) #flat시키기
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def name(self):
        return "MLP"

    
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 20, 5)
        self.conv2 = nn.Conv2d(20, 50, 5)
        self.fc1 = nn.Linear(5*5*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5*5*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def name(self):
        return "LeNet"

# Training

In [9]:
model = LeNet()

if use_cuda:
    model = model.cuda()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()

for epoch in range(1, 10):
    # trainning
    ave_loss = 0
    for batch_idx, (x, target) in enumerate(train_loader):
        optimizer.zero_grad()
        if use_cuda:
            x, target = x.cuda(), target.cuda()
        x, target = Variable(x), Variable(target)
        
        out = model(x)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        if (batch_idx+1) % 100 == 0 or (batch_idx+1) == len(train_loader):
            print('==>>> epoch: {}, batch index: {}, train loss: {:.6f}'.format(
                epoch, batch_idx+1, loss.data))
            
    # testing
    correct_cnt = 0
    total_cnt= 0
    for batch_idx, (x, target) in enumerate(test_loader):
        with torch.no_grad():
            if use_cuda:
                x, target = x.cuda(), target.cuda()
            x, target = Variable(x), Variable(target)
            
            out = model(x)
            loss = criterion(out, target)
            _, pred_label = torch.max(out.data, 1)
            total_cnt = total_cnt + batch_size
            count = torch.sum(pred_label == target.data)
            
            correct_cnt = correct_cnt + count.item()
            if(batch_idx+1) % 100 == 0 or (batch_idx+1) == len(test_loader):
                print('==>>> epoch: {}, batch index: {}, test loss: {:.6f}, acc: {:.3f}'.format(
                    epoch, batch_idx+1, loss.data , correct_cnt * 1.0 / total_cnt))
                
print("Finished!")

==>>> epoch: 1, batch index: 100, train loss: 2.098930
==>>> epoch: 1, batch index: 200, train loss: 1.985222
==>>> epoch: 1, batch index: 300, train loss: 1.709409
==>>> epoch: 1, batch index: 391, train loss: 1.655666
==>>> epoch: 1, batch index: 79, test loss: 1.447818, acc: 0.399
==>>> epoch: 2, batch index: 100, train loss: 1.614856
==>>> epoch: 2, batch index: 200, train loss: 1.436003
==>>> epoch: 2, batch index: 300, train loss: 1.353874
==>>> epoch: 2, batch index: 391, train loss: 1.415918
==>>> epoch: 2, batch index: 79, test loss: 1.115165, acc: 0.499
==>>> epoch: 3, batch index: 100, train loss: 1.304854
==>>> epoch: 3, batch index: 200, train loss: 1.253309
==>>> epoch: 3, batch index: 300, train loss: 1.216772
==>>> epoch: 3, batch index: 391, train loss: 1.166130
==>>> epoch: 3, batch index: 79, test loss: 1.404466, acc: 0.548
==>>> epoch: 4, batch index: 100, train loss: 1.117727
==>>> epoch: 4, batch index: 200, train loss: 1.187654
==>>> epoch: 4, batch index: 300, t