In [1]:
import import_ipynb
from data import FaceData
from model import UNet
import config
import utils
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms
from imutils import paths
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import time
import os

importing Jupyter notebook from data.ipynb
importing Jupyter notebook from utils.ipynb
importing Jupyter notebook from config.ipynb
importing Jupyter notebook from model.ipynb


In [2]:
plt.style.use('ggplot')

In [3]:
imagePaths, maskPaths = utils.get_path_list('dataset')

# 依照設定的比例分割測試集和驗證集
(trainImages, testImages) = train_test_split(imagePaths, test_size=config.TEST_SPLIT, random_state=42)
(trainMasks, testMasks) = train_test_split(maskPaths, test_size=config.TEST_SPLIT, random_state=42)

In [4]:
trns = transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(config.IMAGE_RESIZE)])

In [5]:
trainDS = FaceData(trainImages, trainMasks, trns)
testDS = FaceData(testImages, testMasks, trns)
print(f"[INFO] found {len(trainDS)} examples in the training set...")
print(f"[INFO] found {len(testDS)} examples in the test set...")

trainLoader = DataLoader(trainDS, shuffle=True,
    batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY,
    num_workers=0)
testLoader = DataLoader(testDS, shuffle=False,
    batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY,
    num_workers=0)

[INFO] found 24000 examples in the training set...
[INFO] found 6000 examples in the test set...


In [6]:
unet = UNet(config.NUM_CHANNELS, config.NUM_CLASSES).to(config.DEVICE)

# if os.path.exists('./output/last_model.pth'):
#     unet.load_state_dict(torch.load('./output/best_model.pth')['model_state_dict'])

lossFunc = BCEWithLogitsLoss()
opt = Adam(unet.parameters(), lr=config.INIT_LR, betas=(0.9, 0.999), weight_decay=0.1)

trainSteps = len(trainDS) // config.BATCH_SIZE
testSteps = len(testDS) // config.BATCH_SIZE
# 建立字典便於追蹤loss
H = {"train_loss": [], "test_loss": []}

In [None]:
print("[INFO] training the network...")

train_loss_list = []
val_loss_list = []
best_val_loss = 9999

startTime = time.time()
for e in range(config.NUM_EPOCHS):
    # 每次驗證都需要轉換為eval，訓練時要轉換回來
    unet.train()
    
    totalTrainLoss = 0
    totalTestLoss = 0
    
    # training
    for (i, (images, masks)) in tqdm(enumerate(trainLoader), total=len(trainLoader)):
        (images, masks) = (images.to(config.DEVICE), masks.to(config.DEVICE))
        # forward pass
        pred = unet(images)
        # training loss
        loss = lossFunc(pred, masks)
        opt.zero_grad()
        loss.backward()
        opt.step()
        # totalloss
        totalTrainLoss += loss
    # valid
    with torch.no_grad():
        unet.eval()
        
        for _, (images, masks) in tqdm(enumerate(testLoader), total=len(testLoader)):
            (images, masks) = (images.to(config.DEVICE), masks.to(config.DEVICE))
            pred = unet(images)
            totalTestLoss += lossFunc(pred, masks)
            
    avgTrainLoss = totalTrainLoss / trainSteps
    avgTestLoss = totalTestLoss / testSteps
    # update training history
    H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
    H["test_loss"].append(avgTestLoss.cpu().detach().numpy())
    
    print("[INFO] EPOCH: {}/{}".format(e + 1, config.NUM_EPOCHS))
    print("Train loss: {:.6f}, Test loss: {:.4f}".format(
        avgTrainLoss, avgTestLoss))
    
    
    best_val_loss = 9999
    
    if bestValLoss > avgTestLoss:
        best_val_loss = avgTestLoss
        torch.save(unet.state_dict(), 'output/best_model.pth')
    
    torch.save(unet.state_dict(), 'output/last_model.pth')
    
    utils.save_loss_plot(
        config.BASE_OUTPUT, H["train_loss"],  H["test_loss"]
    )
    
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
    endTime - startTime))

[INFO] training the network...


 18%|██████████████▋                                                                  | 68/375 [02:30<11:01,  2.15s/it]