In [1]:
import os
import sys

pwd = os.path.sep.join(os.getcwd().split(sep=os.path.sep)[:-2])
sys.path.append(pwd)

import dataset
import model
import torch
import torch.nn as nn
import torch.functional as F
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

- 데이터 셋 및 데이터 로더 정의 

In [7]:
mask_dataset = dataset.MaskImageDataset(dataset.csv_path, 
                                        dataset.data_path,
                                       dataset.data_transform)

data_loader = DataLoader(mask_dataset,
                        shuffle=True,
                        batch_size=10, 
                        num_workers=4)

- 모델 정의 및 옵티마이저, 손실 함수 정의

In [8]:
model_state_dict = torch.load('test_model.ckpt')

test_model = model.classificationModel().to(device)
test_model.train()
test_model.load_state_dict(model_state_dict)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(test_model.parameters(), lr=0.0001)

- 학습

In [9]:
for epoch in range():
    
    running_loss = 0.0
    for i, data in enumerate(data_loader, 0):
        
        inputs, labels = data
        
        optimizer.zero_grad()
        
        labels = torch.flatten(torch.argmax(labels, dim=2))

        outputs = test_model(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' %
                 (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')
torch.save(test_model.state_dict(), 'test_model.ckpt')

[1,   100] loss: 0.007
[1,   200] loss: 0.004
[1,   300] loss: 0.004
[1,   400] loss: 0.006
[1,   500] loss: 0.005
[1,   600] loss: 0.008
[1,   700] loss: 0.010
[1,   800] loss: 0.008
[1,   900] loss: 0.007
[1,  1000] loss: 0.007
[1,  1100] loss: 0.005
[1,  1200] loss: 0.006
[1,  1300] loss: 0.006
[1,  1400] loss: 0.004
[1,  1500] loss: 0.006
[1,  1600] loss: 0.005
[1,  1700] loss: 0.006
[1,  1800] loss: 0.006
[2,   100] loss: 0.004
[2,   200] loss: 0.005
[2,   300] loss: 0.004
[2,   400] loss: 0.005
[2,   500] loss: 0.006
[2,   600] loss: 0.005
[2,   700] loss: 0.007
[2,   800] loss: 0.007
[2,   900] loss: 0.006
[2,  1000] loss: 0.006
[2,  1100] loss: 0.006
[2,  1200] loss: 0.005
[2,  1300] loss: 0.006
[2,  1400] loss: 0.006
[2,  1500] loss: 0.006
[2,  1600] loss: 0.006
[2,  1700] loss: 0.006
[2,  1800] loss: 0.005
[3,   100] loss: 0.004
[3,   200] loss: 0.004
[3,   300] loss: 0.005
[3,   400] loss: 0.005
[3,   500] loss: 0.006
[3,   600] loss: 0.004
[3,   700] loss: 0.005
[3,   800] 