In [40]:
## 라이브러리 추가하기
import os
import argparse
import random
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

# pytorch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

# sci-kit learn
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold

# self-class
from model import Net
from dataset import *
from util import *

In [41]:
## 랜덤시드 고정하기
# seed 값을 고정해야 hyper parameter 바꿀 때마다 결과를 비교할 수 있습니다.
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed = 0
seed_everything(seed)

In [47]:
## Parser 생성하기
parser = argparse.ArgumentParser(description="Train the Net",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument("--lr", default=1e-3, type=float, dest="lr")
parser.add_argument("--batch_size", default=8, type=int, dest="batch_size")
parser.add_argument("--num_epoch", default=40, type=int, dest="num_epoch")

parser.add_argument("--data_dir", default="./datasets", type=str, dest="data_dir")
parser.add_argument("--ckpt_dir", default="./checkpoint", type=str, dest="ckpt_dir")
parser.add_argument("--log_dir", default="./log", type=str, dest="log_dir")
parser.add_argument("--result_dir", default="./result", type=str, dest="result_dir")

parser.add_argument("--train/test_mode", default="test", type=str, dest="mode")
parser.add_argument("--train_continue", default="off", type=str, dest="train_continue")

args, unknown = parser.parse_known_args()

In [48]:
## 트레이닝 파라메터 설정하기
lr = args.lr
batch_size = args.batch_size
num_epoch = args.num_epoch

'''
data_dir = './datasets'
ckpt_dir = './checkpoint'
log_dir = './log'
result_dir = './result'
'''

data_dir = args.data_dir
ckpt_dir = args.ckpt_dir
log_dir = args.log_dir
result_dir = args.result_dir

mode = args.mode
train_continue = args.train_continue

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("learning rate: %.4e" % lr)
print("batch size: %d" % batch_size)
print("number of epoch: %d" % num_epoch)
print("data dir: %s" % data_dir)
print("ckpt dir: %s" % ckpt_dir)
print("log dir: %s" % log_dir)
print("result dir: %s" % result_dir)
print("train/test_mode: %s" % mode)

learning rate: 1.0000e-03
batch size: 8
number of epoch: 40
data dir: ./datasets
ckpt dir: ./checkpoint
log dir: ./log
result dir: ./result
train/test_mode: test


In [49]:
## 네트워크 학습하기
if mode == 'train':
    transform_train = transforms.Compose([ToPILImage(), RandomRotation(degree=30), RandomAffine(degree=30),
                                          ToNumpy(), Normalization(mean=0.5, std=0.5), ToTensor()])
    transform_val = transforms.Compose([Normalization(mean=0.5, std=0.5), ToTensor()])

    # train/valid dataset으로 나누어 준다.
    # train_test_split의 output 순서
    # 4분류 : train_input, valid_input, train_label,valid_label
    # 2분류 : train, valid
    load_data = pd.read_csv(os.path.join(data_dir, 'train.csv'))
    train, val = train_test_split(load_data, test_size=0.125, random_state=seed)

    dataset_train = Dataset(train, mode, transform=transform_train)
    loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)

    dataset_val = Dataset(val, mode, transform=transform_val)
    loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=8)

    # 그밖에 부수적인 variables 설정하기
    num_data_train = len(dataset_train)
    num_data_val = len(dataset_val)

    num_batch_train = np.ceil(num_data_train / batch_size)  # np.ceil은 올림 함수이다. Ex> 4.2 → 5 로 변환
    num_batch_val = np.ceil(num_data_val / batch_size)

else:
    transform = transforms.Compose([Normalization(mean=0.5, std=0.5), ToTensor()])

    load_data = pd.read_csv(os.path.join(data_dir, 'test.csv'))
    dataset_test = Dataset(load_data, mode, transform=transform)
    loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=8)

    # 그밖에 부수적인 variables 설정하기
    num_data_test = len(dataset_test)

    num_batch_test = np.ceil(num_data_test / batch_size)

In [50]:
## 네트워크 생성하기
net = Net().to(device)

## 손실함수 정의하기
fn_loss = nn.CrossEntropyLoss().to(device)

## Optimizer 설정하기
optim = torch.optim.Adam(net.parameters(), lr=lr)

## 그밖에 부수적인 functions 설정하기
# 텐서를 넘파이로 바꾸어 줄 때는 CPU로 옮겨야 한다.
fn_tonumpy = lambda x: x.to('cpu').detach().numpy()
fn_denorm = lambda x, mean, std: (x * std) + mean
fn_class = lambda x: 1.0 * (x > 0.5)

In [51]:
## 네트워크 학습시키기
st_epoch = 0

# TRAIN MODE
if mode == 'train':
    if train_continue == "on":
        net, optim, st_epoch = load(ckpt_dir=ckpt_dir, net=net, optim=optim)
    
    # plot을 그리기 위해 빈 리스트를 추가한다.
    train_loss, val_loss, train_acc, val_acc = [], [], [], []
    
    
    for epoch in range(st_epoch + 1, num_epoch + 1):
        net.train()
        loss_arr = []
        
        # 1 batch의 loss, acc를 모두 더한 값
        loss_batch_sum = 0
        acc_batch_sum = 0
        
        # enumerate(~, 1) 에서 1은 start value를 의미한다
        # 열거하다라는 뜻, 1을 안쓰면 0부터 시작하므로 카운트가 어렵다.
        # 여기서 batch는 counting index이고 data는 loader_val인듯
        for batch, data in enumerate(loader_train, 1):
            # forward pass (net에 input을 입력함으로써 forward가 시작됨)
            label = data['label'].to(device)
            input = data['input'].to(device)

            # this output is probability
            output = net(input)
                     
            # backward pass
            optim.zero_grad()

            loss = fn_loss(output, label)
            loss.backward()

            optim.step()
            
            # 손실함수 계산
            loss_arr += [loss.item()]
            loss = np.mean(loss_arr)
            loss_batch_sum += loss
            
            
            # 정확도 계산
            # this output is digit (numpy_output, number_output)
            np_output = fn_tonumpy(output)
            lst_output = []
            for i in range(batch_size):
                nb_output = np.argmax(np_output[i, :])
                lst_output.append(nb_output)
            
            # np.sum(a==b)로 카운팅하기 위해서 np로 바꾸어 준다.
            # 리스트형 lst_output는 바로 numpy로 바꿀 수 없어서 tensor로 바꾼뒤 numpy로 바꿈
            # 텐서 label은 cpu로 보낸뒤 numpy로 바꿈
            lst_output = fn_tonumpy(torch.FloatTensor(lst_output))
            label = fn_tonumpy(data['label'].to('cpu'))
                     
            acc = np.sum(lst_output == label) / len(label)
            acc_batch_sum += acc

#             print("TRAIN: EPOCH %04d / %04d | BATCH %04d / %04d | LOSS %.4f | ACC %.4f" %
#                    (epoch, num_epoch, batch, num_batch_train, loss, acc))

            
        train_loss.append(loss_batch_sum/num_batch_train)
        train_acc.append(acc_batch_sum/num_batch_train)
        

        # with torch.no_grad()는 autograd를 멈추게 한다. val을 계산해야 하기 때문
        with torch.no_grad():
            net.eval()
            loss_arr = []
            
            # 1 batch의 loss, acc를 모두 더한 값
            loss_batch_sum = 0
            acc_batch_sum = 0

            for batch, data in enumerate(loader_val, 1):
                # forward pass
                label = data['label'].to(device)
                input = data['input'].to(device)

                output = net(input)

                # 손실함수 계산하기
                loss = fn_loss(output, label)

                loss_arr += [loss.item()]
                loss = np.mean(loss_arr)
                loss_batch_sum += loss
 
                # 정확도 계산
                # this output is digit (numpy_output, number_output)
                np_output = fn_tonumpy(output)
                lst_output = []
                for i in range(batch_size):
                    nb_output = np.argmax(np_output[i, :])
                    lst_output.append(nb_output)
            
                # np.sum(a==b)로 카운팅하기 위해서 np로 바꾸어 준다.
                lst_output = fn_tonumpy(torch.FloatTensor(lst_output))
                label = fn_tonumpy(data['label'].to('cpu'))
                
                acc = np.sum(lst_output == label) / len(label)
                acc_batch_sum += acc

#                 print("VALID: EPOCH %04d / %04d | BATCH %04d / %04d | LOSS %.4f | ACC %.4f" %
#                        (epoch, num_epoch, batch, num_batch_val, loss, acc))


            val_loss.append(loss_batch_sum/num_batch_val)
            val_acc.append(acc_batch_sum/num_batch_val)

        
        # Epoch마다 결과값을 표기한다.
        print("EPOCH: {}/{} | ".format(epoch, num_epoch), "TRAIN_LOSS: {:4f} | ".format(train_loss[-1]),
              "TRAIN_ACC: {:4f} | ".format(train_acc[-1]), "VAL_LOSS: {:4f} | ".format(val_loss[-1]), "VAL_ACC: {:4f}".format(val_acc[-1]))

        if epoch % 40 == 0:
            save_model(ckpt_dir=ckpt_dir, net=net, optim=optim, epoch=num_epoch, batch=batch_size)
            
# TEST MODE
else:
    net, optim, st_epoch = load(ckpt_dir=ckpt_dir, net=net, optim=optim)

    with torch.no_grad():
        net.eval()
        loss_arr = []
        pred = []

        for batch, data in enumerate(loader_test, 1):
            # forward pass
            input = data['input'].to(device)
            
            output = net(input)         
            
            # 배치사이즈에 해당하는 output digit를 만든다.
            np_output = fn_tonumpy(output)
            lst_output = []
            for i in range(batch_size):
                nb_output = np.argmax(np_output[i, :])
                lst_output.append(nb_output)
            
            # submission에 제출할 전체 test data digit를 만든다.
            pred.append(lst_output)
            
            print("TEST: BATCH %04d / %04d" %
                  (batch, num_batch_test))
         
        # submission
        if batch % 2560 == 0:
            save_submission(result_dir=result_dir, prediction=pred, epoch=num_epoch, batch=batch_size)

    print("AVERAGE TEST: BATCH %04d / %04d" %
          (batch, num_batch_test))

TEST: BATCH 0001 / 2560
TEST: BATCH 0002 / 2560
TEST: BATCH 0003 / 2560
TEST: BATCH 0004 / 2560
TEST: BATCH 0005 / 2560
TEST: BATCH 0006 / 2560
TEST: BATCH 0007 / 2560
TEST: BATCH 0008 / 2560
TEST: BATCH 0009 / 2560
TEST: BATCH 0010 / 2560
TEST: BATCH 0011 / 2560
TEST: BATCH 0012 / 2560
TEST: BATCH 0013 / 2560
TEST: BATCH 0014 / 2560
TEST: BATCH 0015 / 2560
TEST: BATCH 0016 / 2560
TEST: BATCH 0017 / 2560
TEST: BATCH 0018 / 2560
TEST: BATCH 0019 / 2560
TEST: BATCH 0020 / 2560
TEST: BATCH 0021 / 2560
TEST: BATCH 0022 / 2560
TEST: BATCH 0023 / 2560
TEST: BATCH 0024 / 2560
TEST: BATCH 0025 / 2560
TEST: BATCH 0026 / 2560
TEST: BATCH 0027 / 2560
TEST: BATCH 0028 / 2560
TEST: BATCH 0029 / 2560
TEST: BATCH 0030 / 2560
TEST: BATCH 0031 / 2560
TEST: BATCH 0032 / 2560
TEST: BATCH 0033 / 2560
TEST: BATCH 0034 / 2560
TEST: BATCH 0035 / 2560
TEST: BATCH 0036 / 2560
TEST: BATCH 0037 / 2560
TEST: BATCH 0038 / 2560
TEST: BATCH 0039 / 2560
TEST: BATCH 0040 / 2560
TEST: BATCH 0041 / 2560
TEST: BATCH 0042

TEST: BATCH 0441 / 2560
TEST: BATCH 0442 / 2560
TEST: BATCH 0443 / 2560
TEST: BATCH 0444 / 2560
TEST: BATCH 0445 / 2560
TEST: BATCH 0446 / 2560
TEST: BATCH 0447 / 2560
TEST: BATCH 0448 / 2560
TEST: BATCH 0449 / 2560
TEST: BATCH 0450 / 2560
TEST: BATCH 0451 / 2560
TEST: BATCH 0452 / 2560
TEST: BATCH 0453 / 2560
TEST: BATCH 0454 / 2560
TEST: BATCH 0455 / 2560
TEST: BATCH 0456 / 2560
TEST: BATCH 0457 / 2560
TEST: BATCH 0458 / 2560
TEST: BATCH 0459 / 2560
TEST: BATCH 0460 / 2560
TEST: BATCH 0461 / 2560
TEST: BATCH 0462 / 2560
TEST: BATCH 0463 / 2560
TEST: BATCH 0464 / 2560
TEST: BATCH 0465 / 2560
TEST: BATCH 0466 / 2560
TEST: BATCH 0467 / 2560
TEST: BATCH 0468 / 2560
TEST: BATCH 0469 / 2560
TEST: BATCH 0470 / 2560
TEST: BATCH 0471 / 2560
TEST: BATCH 0472 / 2560
TEST: BATCH 0473 / 2560
TEST: BATCH 0474 / 2560
TEST: BATCH 0475 / 2560
TEST: BATCH 0476 / 2560
TEST: BATCH 0477 / 2560
TEST: BATCH 0478 / 2560
TEST: BATCH 0479 / 2560
TEST: BATCH 0480 / 2560
TEST: BATCH 0481 / 2560
TEST: BATCH 0482

TEST: BATCH 0853 / 2560
TEST: BATCH 0854 / 2560
TEST: BATCH 0855 / 2560
TEST: BATCH 0856 / 2560
TEST: BATCH 0857 / 2560
TEST: BATCH 0858 / 2560
TEST: BATCH 0859 / 2560
TEST: BATCH 0860 / 2560
TEST: BATCH 0861 / 2560
TEST: BATCH 0862 / 2560
TEST: BATCH 0863 / 2560
TEST: BATCH 0864 / 2560
TEST: BATCH 0865 / 2560
TEST: BATCH 0866 / 2560
TEST: BATCH 0867 / 2560
TEST: BATCH 0868 / 2560
TEST: BATCH 0869 / 2560
TEST: BATCH 0870 / 2560
TEST: BATCH 0871 / 2560
TEST: BATCH 0872 / 2560
TEST: BATCH 0873 / 2560
TEST: BATCH 0874 / 2560
TEST: BATCH 0875 / 2560
TEST: BATCH 0876 / 2560
TEST: BATCH 0877 / 2560
TEST: BATCH 0878 / 2560
TEST: BATCH 0879 / 2560
TEST: BATCH 0880 / 2560
TEST: BATCH 0881 / 2560
TEST: BATCH 0882 / 2560
TEST: BATCH 0883 / 2560
TEST: BATCH 0884 / 2560
TEST: BATCH 0885 / 2560
TEST: BATCH 0886 / 2560
TEST: BATCH 0887 / 2560
TEST: BATCH 0888 / 2560
TEST: BATCH 0889 / 2560
TEST: BATCH 0890 / 2560
TEST: BATCH 0891 / 2560
TEST: BATCH 0892 / 2560
TEST: BATCH 0893 / 2560
TEST: BATCH 0894

TEST: BATCH 1227 / 2560
TEST: BATCH 1228 / 2560
TEST: BATCH 1229 / 2560
TEST: BATCH 1230 / 2560
TEST: BATCH 1231 / 2560
TEST: BATCH 1232 / 2560
TEST: BATCH 1233 / 2560
TEST: BATCH 1234 / 2560
TEST: BATCH 1235 / 2560
TEST: BATCH 1236 / 2560
TEST: BATCH 1237 / 2560
TEST: BATCH 1238 / 2560
TEST: BATCH 1239 / 2560
TEST: BATCH 1240 / 2560
TEST: BATCH 1241 / 2560
TEST: BATCH 1242 / 2560
TEST: BATCH 1243 / 2560
TEST: BATCH 1244 / 2560
TEST: BATCH 1245 / 2560
TEST: BATCH 1246 / 2560
TEST: BATCH 1247 / 2560
TEST: BATCH 1248 / 2560
TEST: BATCH 1249 / 2560
TEST: BATCH 1250 / 2560
TEST: BATCH 1251 / 2560
TEST: BATCH 1252 / 2560
TEST: BATCH 1253 / 2560
TEST: BATCH 1254 / 2560
TEST: BATCH 1255 / 2560
TEST: BATCH 1256 / 2560
TEST: BATCH 1257 / 2560
TEST: BATCH 1258 / 2560
TEST: BATCH 1259 / 2560
TEST: BATCH 1260 / 2560
TEST: BATCH 1261 / 2560
TEST: BATCH 1262 / 2560
TEST: BATCH 1263 / 2560
TEST: BATCH 1264 / 2560
TEST: BATCH 1265 / 2560
TEST: BATCH 1266 / 2560
TEST: BATCH 1267 / 2560
TEST: BATCH 1268

TEST: BATCH 1617 / 2560
TEST: BATCH 1618 / 2560
TEST: BATCH 1619 / 2560
TEST: BATCH 1620 / 2560
TEST: BATCH 1621 / 2560
TEST: BATCH 1622 / 2560
TEST: BATCH 1623 / 2560
TEST: BATCH 1624 / 2560
TEST: BATCH 1625 / 2560
TEST: BATCH 1626 / 2560
TEST: BATCH 1627 / 2560
TEST: BATCH 1628 / 2560
TEST: BATCH 1629 / 2560
TEST: BATCH 1630 / 2560
TEST: BATCH 1631 / 2560
TEST: BATCH 1632 / 2560
TEST: BATCH 1633 / 2560
TEST: BATCH 1634 / 2560
TEST: BATCH 1635 / 2560
TEST: BATCH 1636 / 2560
TEST: BATCH 1637 / 2560
TEST: BATCH 1638 / 2560
TEST: BATCH 1639 / 2560
TEST: BATCH 1640 / 2560
TEST: BATCH 1641 / 2560
TEST: BATCH 1642 / 2560
TEST: BATCH 1643 / 2560
TEST: BATCH 1644 / 2560
TEST: BATCH 1645 / 2560
TEST: BATCH 1646 / 2560
TEST: BATCH 1647 / 2560
TEST: BATCH 1648 / 2560
TEST: BATCH 1649 / 2560
TEST: BATCH 1650 / 2560
TEST: BATCH 1651 / 2560
TEST: BATCH 1652 / 2560
TEST: BATCH 1653 / 2560
TEST: BATCH 1654 / 2560
TEST: BATCH 1655 / 2560
TEST: BATCH 1656 / 2560
TEST: BATCH 1657 / 2560
TEST: BATCH 1658

TEST: BATCH 2026 / 2560
TEST: BATCH 2027 / 2560
TEST: BATCH 2028 / 2560
TEST: BATCH 2029 / 2560
TEST: BATCH 2030 / 2560
TEST: BATCH 2031 / 2560
TEST: BATCH 2032 / 2560
TEST: BATCH 2033 / 2560
TEST: BATCH 2034 / 2560
TEST: BATCH 2035 / 2560
TEST: BATCH 2036 / 2560
TEST: BATCH 2037 / 2560
TEST: BATCH 2038 / 2560
TEST: BATCH 2039 / 2560
TEST: BATCH 2040 / 2560
TEST: BATCH 2041 / 2560
TEST: BATCH 2042 / 2560
TEST: BATCH 2043 / 2560
TEST: BATCH 2044 / 2560
TEST: BATCH 2045 / 2560
TEST: BATCH 2046 / 2560
TEST: BATCH 2047 / 2560
TEST: BATCH 2048 / 2560
TEST: BATCH 2049 / 2560
TEST: BATCH 2050 / 2560
TEST: BATCH 2051 / 2560
TEST: BATCH 2052 / 2560
TEST: BATCH 2053 / 2560
TEST: BATCH 2054 / 2560
TEST: BATCH 2055 / 2560
TEST: BATCH 2056 / 2560
TEST: BATCH 2057 / 2560
TEST: BATCH 2058 / 2560
TEST: BATCH 2059 / 2560
TEST: BATCH 2060 / 2560
TEST: BATCH 2061 / 2560
TEST: BATCH 2062 / 2560
TEST: BATCH 2063 / 2560
TEST: BATCH 2064 / 2560
TEST: BATCH 2065 / 2560
TEST: BATCH 2066 / 2560
TEST: BATCH 2067

TEST: BATCH 2441 / 2560
TEST: BATCH 2442 / 2560
TEST: BATCH 2443 / 2560
TEST: BATCH 2444 / 2560
TEST: BATCH 2445 / 2560
TEST: BATCH 2446 / 2560
TEST: BATCH 2447 / 2560
TEST: BATCH 2448 / 2560
TEST: BATCH 2449 / 2560
TEST: BATCH 2450 / 2560
TEST: BATCH 2451 / 2560
TEST: BATCH 2452 / 2560
TEST: BATCH 2453 / 2560
TEST: BATCH 2454 / 2560
TEST: BATCH 2455 / 2560
TEST: BATCH 2456 / 2560
TEST: BATCH 2457 / 2560
TEST: BATCH 2458 / 2560
TEST: BATCH 2459 / 2560
TEST: BATCH 2460 / 2560
TEST: BATCH 2461 / 2560
TEST: BATCH 2462 / 2560
TEST: BATCH 2463 / 2560
TEST: BATCH 2464 / 2560
TEST: BATCH 2465 / 2560
TEST: BATCH 2466 / 2560
TEST: BATCH 2467 / 2560
TEST: BATCH 2468 / 2560
TEST: BATCH 2469 / 2560
TEST: BATCH 2470 / 2560
TEST: BATCH 2471 / 2560
TEST: BATCH 2472 / 2560
TEST: BATCH 2473 / 2560
TEST: BATCH 2474 / 2560
TEST: BATCH 2475 / 2560
TEST: BATCH 2476 / 2560
TEST: BATCH 2477 / 2560
TEST: BATCH 2478 / 2560
TEST: BATCH 2479 / 2560
TEST: BATCH 2480 / 2560
TEST: BATCH 2481 / 2560
TEST: BATCH 2482

In [52]:
if mode == 'train':
    %matplotlib inline
    %config InlineBackend.figure_format = 'retina'
    
    plt.plot(train_loss, label='Training loss')
    plt.plot(val_loss, label='Validation loss')
    plt.legend(frameon=False) 

    plt.figure() # 하나의 윈도우를 나타냄, 생략가능
    plt.plot(train_acc, label='Training accuracy')
    plt.plot(val_acc, label='Validation accuracy') 

    
    plt.legend(frameon=False) 