In [2]:
import import_ipynb 
from torch import optim, nn
import torch
import os.path as osp
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from config import epochs, device, data_folder, epochs, checkpoint_folder
from dataset import create_datasets
from model import resnet18

In [3]:
def train_val(net, trainloader, valloader, criteron, epochs, device, model_name="cls"):
    best_acc = 0.0
    best_loss = 1e9
    writer = SummaryWriter("log")

    
    # 載入之前的訓練
    if osp.exists(osp.join(checkpoint_folder, model_name + ".pth")):
        net.load_state_dict(torch.load(osp.join(checkpoint_folder, model_name + ".pth")))
        print("模型已載入")
        
    for n, (num_epochs, lr) in enumerate(epochs):
        optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=5e-4, momentum=0.9)
        for epoch in range(num_epochs):
            net.train()
            epoch_loss = 0.0
            epoch_acc = 0.0
            for i, (img, label) in tqdm(enumerate(trainloader), total=len(trainloader)):
                # GPU
                img, label = img.to(device), label.to(device)
                output = net(img)
                # 清空梯度
                optimizer.zero_grad()
                # 損失計算
                loss= criteron(output, label)
                # 反向傳播
                loss.backward()
                # 更新參數
                optimizer.step()
                
                # 分類問題使用acc rate來衡量模型
                # 但回歸模型無法使用此方法計算
                if model_name == "cls":
                    pred = torch.argmax(output, dim=1)
                    acc = torch.sum(pred == label)
                    # 累計acc
                    epoch_acc += acc.item()
                epoch_loss += loss.item() * img.shape[0]
            # 計算這個epoch的avg_loss
                epoch_loss /= len(trainloader.dataset)
            if model_name == "cls":
                # 計算這個epoch的avg_acc
                epoch_acc /= len(trainloader.dataset)
                print("epoch loss: {:.8f} epoch accuracy : {:.8f}".format(epoch_loss, epoch_acc))
            
                # 將loss增加到TensorBoard中
                writer.add_scalar(
                                  "epoch_loss_{}".format(model_name), 
                                  epoch_loss, 
                                  sum([e[0] for e in epochs[:n]]) + epoch,
                                 )
                # 將acc增加到TensorBoard中
                writer.add_scalar(
                                  "epoch_acc_{}".format(model_name), 
                                  epoch_acc, 
                                  sum([e[0] for e in epochs[:n]]) + epoch,
                                 )
            else:
                print("epoch loss: {:.8f}".format(epoch_loss))
                writer.add_scalar(
                                  "epoch_loss_{}".format(model_name), 
                                  epoch_loss, 
                                  sum([e[0] for e in epochs[:n]]) + epoch,
                                 )
            # 在無梯度模式下快速驗證
            with torch.no_grad():
                # 將net設為驗證模式
                net.eval()
                val_loss = 0.0
                val_acc = 0.0
                for i, (img,label) in tqdm(
                enumerate(valloader), total=len(valloader)
                ):
                    img, label = img.to(device), label.to(device)
                    output = net(img)
                    loss = criteron(output, label)
                    if model_name == "cls":
                        pred = torch.argmax(output, dim=1)
                        acc = torch.sum(pred == label)
                        val_acc += acc.item()
                    val_loss += loss.item() * img.shape[0]
                val_loss /= len(valloader.dataset)
                val_acc /= len(valloader.dataset)
                if model_name == "cls":
                    # 驗證後的模型超越先前最好的模型
                    if val_acc > best_acc:
                        # update best_acc
                        best_acc = val_acc
                        # save model
                        torch.save(
                            net.state_dict(),
                            osp.join(checkpoint_folder, model_name + ".pth"),
                        )
                    print("validation loss: {:.8f} validation accuracy : {:.8f}".format(val_loss, val_acc))
                    # 將validation_loss加入TB中
                    writer.add_scalar(
                                  "validation_loss_{}".format(model_name), 
                                  val_loss, 
                                  sum([e[0] for e in epochs[:n]]) + epoch,
                                 )
                    writer.add_scalar(
                                  "validation_acc_{}".format(model_name), 
                                  val_acc, 
                                  sum([e[0] for e in epochs[:n]]) + epoch,
                                 )
                else:
                    # 如果得到的loss比當前的好
                    if val_loss < best_loss:
                        # update best_loss
                        best_loss = val_loss
                        # save model
                        torch.save(
                            net.state_dict(),
                            osp.join(checkpoint_folder, model_name + ".pth"),
                        )
                    print("validation loss: {:.8f}".format(val_loss))
                    writer.add_scalar(
                                  "validation_loss_{}".format(model_name), 
                                  val_loss, 
                                  sum([e[0] for e in epochs[:n]]) + epoch,
                                 )
        writer.close()

In [5]:
if __name__ == "__main__":
    trainloader, valloader = create_datasets(data_folder)
    net = resnet18().to(device)
    criteron = nn.CrossEntropyLoss()
    train_val(net, trainloader, valloader, criteron, epochs, device)

Files already downloaded and verified
Files already downloaded and verified
模型已載入


100%|████████████████████████████████████████████████████████████████████████████████| 782/782 [00:34<00:00, 22.63it/s]


epoch loss: 0.00000555 epoch accuracy : 0.98398000


100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:02<00:00, 73.39it/s]


validation loss: 0.37488687 validation accuracy : 0.90850000


100%|████████████████████████████████████████████████████████████████████████████████| 782/782 [00:31<00:00, 24.75it/s]


epoch loss: 0.00000898 epoch accuracy : 0.98226000


100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 79.48it/s]


validation loss: 0.36882058 validation accuracy : 0.90970000


100%|████████████████████████████████████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.32it/s]


epoch loss: 0.00002874 epoch accuracy : 0.98102000


 25%|████████████████████                                                             | 39/157 [00:00<00:01, 75.42it/s]


KeyboardInterrupt: 