In [15]:

import torch
import torch.nn as nn
import torch.optim as optim

from util import AverageMeter

from cnn import CNN

from create_lmdb_dataset import createDataset

from MakeCSV import MakeCSV, CalcMeanStdFromCSV

from torch.autograd import Variable

import dataset


In [2]:
BATCH_SIZE = 600                
EPOCH = 100
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(777)
total_rate = 0.0
filename = ''

In [7]:
# MakeCSV('./dataset_val/number', 'dataset_val_number.csv')
# MakeCSV('./dataset_val/alphabet', 'dataset_val_alphabet.csv')

# CalcMeanStdFromCSV(dataset_val_number)
# CalcMeanStdFromCSV(dataset_val_alphabet)

# createDataset('./dataset_train_number.csv', './lmdb_data/number/train')
# createDataset('./dataset_val_number.csv', './lmdb_data/number/val')
# createDataset('./dataset_train_alphabet.csv', './lmdb_data/alphabet/train')
# createDataset('./dataset_val_alphabet.csv', './lmdb_data/alphabet/val')

In [3]:
def data_loader():

    train_dataset = dataset.lmdbDataset(root='./lmdb_data/number/train')
    # train_dataset = dataset.lmdbDataset(root='./lmdb_data/alphabet/train')
    assert train_dataset

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE,
    shuffle=True, num_workers=int(0), drop_last=True)
    
    val_dataset = dataset.lmdbDataset(root='./lmdb_data/number/val')
    # val_dataset = dataset.lmdbDataset(root='./lmdb_data/alphabet/val')
    assert val_dataset

    val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=True, batch_size=10, 
    num_workers=int(0), drop_last=True)
    
    return train_loader, val_loader



In [11]:
def train(model, train_loader, optm, criterion):
    model.train()
    for iter, (X, Y) in enumerate(train_loader): 
        X = Variable(X).unsqueeze(1)
        Y = Variable(Y)   
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)

        optm.zero_grad()
        hypothesis = model(X)
        loss = criterion(hypothesis, Y)
        loss.backward()
        optm.step()

        pred_label = torch.argmax(hypothesis, 1)
        acc = (pred_label == Y).sum().item() / len(X)
        
        train_loss = loss.item()
        train_acc = acc
        print('train_loss : ', train_loss, ', ', 'train_acc : ', train_acc)
    

def validation(model, val_loader, criterion, epoch):
    model.eval()
    valid_loss, valid_acc = AverageMeter(), AverageMeter()
    
    for iter, (X, Y) in enumerate(val_loader):
        X = Variable(X).unsqueeze(1)
        Y = Variable(Y)   
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)

        with torch.no_grad():
            pred_logit = model(X)

        loss = criterion(pred_logit, Y)
        pred_label = torch.argmax(pred_logit, 1)
        acc = (pred_label == Y).sum().item() / len(X)

        valid_loss.update(loss.item(), len(X))
        valid_acc.update(acc, len(X))

    valid_loss = valid_loss.avg
    valid_acc = valid_acc.avg
    global total_rate
    total_rate = valid_acc - valid_loss
    global filename
    filename = './model_result' + '/' + str(epoch) + '_' + str(iter) + '_' + 'model' + '_' + str(total_rate) + '.pt'
    traced = torch.jit.trace(model, torch.randn(1, 1, 28, 28))
    traced.save(filename)
    print(filename, ', ', 'accuracy : ', valid_acc, ', ', 'loss : ', valid_loss)
    

In [12]:
train_loader, val_loader = data_loader()
model = CNN(10).to(device=DEVICE)
# model = CNN(26).to(device=DEVICE)
criterion = nn.CrossEntropyLoss().to(DEVICE)
optm = optim.Adam(model.parameters(), lr=0.0005)                                             

In [None]:
total_rate = 0.0
for epoch in range(80):
    print('epoch : ', epoch)
    train(model, train_loader, optm, criterion)
    validation(model, val_loader, criterion, epoch)
    print('total_rate : ', total_rate, 'filename : ', filename)


In [None]:
# quantized_model = torch.quantization.quantize_dynamic(model, dtype=torch.qint8, inplace=True)
# filename = './model_result_quant' + '/' + 'quant' + '_' + str(epoch) + '_' + str(iter) + '_' + 'model' + '_' + str(total_rate) + '.pt'
# traced = torch.jit.trace(quantized_model, torch.randn(1, 1, 28, 28))
# traced.save(filename)