In [1]:
import os
import sys

pwd = os.path.sep.join(os.getcwd().split(sep=os.path.sep)[:-1])
sys.path.append(pwd)

import dataset.dataset as dataset
import model.model as model
import torch
import torch.nn as nn
import torch.functional as F
from torch.utils.data import DataLoader
import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

- 데이터 셋 및 데이터 로더 정의 

In [2]:
mask_dataset = dataset.MaskImageDataset(dataset.csv_path, 
                                        dataset.data_path,
                                       dataset.data_transform)

data_loader = DataLoader(mask_dataset,
                        shuffle=True,
                        batch_size=10, 
                        num_workers=4)

- 모델 정의 및 옵티마이저, 손실 함수 정의

In [3]:
test_model = model.classificationModel().to(device)
test_model.train()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(test_model.parameters(), lr=0.0001)

- 학습

In [7]:
for epoch in range(5):
    
    running_loss = 0.0
    for i, data in enumerate(tqdm(data_loader), 0):
        
        inputs, labels = data
        
        optimizer.zero_grad()
        
        labels = torch.flatten(torch.argmax(labels, dim=2))

        outputs = test_model(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' %
                 (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

[1,   100] loss: 0.025
[1,   200] loss: 0.024
[1,   300] loss: 0.023
[1,   400] loss: 0.025
[1,   500] loss: 0.023
[1,   600] loss: 0.026
[1,   700] loss: 0.028
[1,   800] loss: 0.025
[1,   900] loss: 0.028
[1,  1000] loss: 0.022
[1,  1100] loss: 0.022
[1,  1200] loss: 0.027
[1,  1300] loss: 0.021
[1,  1400] loss: 0.024
[1,  1500] loss: 0.022
[1,  1600] loss: 0.023
[1,  1700] loss: 0.024
[1,  1800] loss: 0.024
Finished Training


In [8]:
torch.save(test_model.state_dict(), 'test_model.ckpt')