<a href="https://colab.research.google.com/github/KokiNiimura/study/blob/master/Training_openpose_small.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%cd /content/drive/My Drive/study/PyTorch_Advanced/04

/content/drive/My Drive/study/PyTorch_Advanced/04


In [None]:
import random
import math
import time
import pandas as pd
import numpy as np
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)

In [None]:
from utils.dataloader import make_datapath_list, DataTransform, COCOkeypointsDataset

train_img_list, train_mask_list, val_img_list, val_mask_list, train_meta_list, val_meta_list = \
    make_datapath_list(rootpath='./data/')

train_dataset = COCOkeypointsDataset(
    val_img_list, val_mask_list, val_meta_list, phase="train", transform=DataTransform())

# val_dataset = CocokeypointsDataset(
#     val_img_list, val_mask_list, val_meta_list, phase="val", transform=DataTransform())

batch_size = 32

train_dataloader = data.DataLoader(
    train_dataset, batch_size, shuffle=True)

# val_dataloader = data.DataLoader(
#     val_dataset, batch_size=batch_size, shuffle=False)

# dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}
dataloaders_dict = {"train": train_dataloader, "val": None}

In [None]:
from utils.openpose_net import OpenPoseNet
net = OpenPoseNet()

In [None]:
class OpenPoseLoss(nn.Module):
    def __init__(self):
        super(OpenPoseLoss, self).__init__()
    
    def forward(self, saved_for_loss, heatmap_target, heat_mask, paf_target, paf_mask):
        total_loss = 0
        for j in range(6):
            # PAFs
            pred1 = saved_for_loss[2 * j] * paf_mask
            gt1 = paf_target.float() * paf_mask

            # heatmaps
            pred2 = saved_for_loss[2 * j + 1] * heat_mask
            gt2 = heatmap_target.float()*heat_mask

            total_loss += F.mse_loss(pred1, gt1, reduction='mean') + \
                F.mse_loss(pred2, gt2, reduction='mean')
    
        return total_loss

criterion = OpenPoseLoss()

In [None]:
optimizer = optim.SGD(net.parameters(), lr=1e-2, 
                      momentum=0.9, weight_decay=0.0001)

In [None]:
def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("device:", device)

    net.to(device)

    torch.backends.cudnn.benchmark = True

    num_train_imgs = len(dataloaders_dict["train"].dataset)
    batch_size = dataloaders_dict["train"].batch_size

    iteration = 1

    for epoch in range(num_epochs):
        t_epoch_start = time.time()
        t_iter_start = time.time()
        epoch_train_loss = 0.0
        epoch_val_loss = 0.0

        print("---------------")
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print("---------------")

        for phase in ["train", "val"]:
            if phase == "train":
                net.train()
                optimizer.zero_grad()
                print("(train)")

            else:
                continue
                # net.eval()
                # print("---------------")
                # print("(val)")

            for images, heatmap_target, heat_mask, paf_target, paf_mask in dataloaders_dict[phase]:
                if images.size()[0] == 1:
                    continue

                images = images.to(device)
                heatmap_target = heatmap_target.to(device)
                heat_mask = heat_mask.to(device)
                paf_target = paf_target.to(device)
                paf_mask = paf_mask.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    _, saved_for_loss = net(images)
                    
                    loss = criterion(saved_for_loss, heatmap_target, 
                                    heat_mask, paf_target, paf_mask)
                    
                    del saved_for_loss

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                        if (iteration % 10 == 0):
                            t_iter_finish = time.time()
                            duration = t_iter_finish - t_iter_start
                            print('iteration {} || Loss: {:.4f} || 10iter: {:.4f} sec.'.format(
                                iteration, loss.item()/batch_size, duration))
                            
                        epoch_train_loss += loss.item()
                        iteration += 1

                    # else:
                    #     epoch_val_loss += loss.item()
        
        t_epoch_finish = time.time()
        print('--------------')
        print('epoch {} || Epoch_TRAIN_Loss: {:.4f} || Epoch_VAL_Loss: {:.4f}'.format(
            epoch+1, epoch_train_loss/num_train_imgs, 0))
        print('timer: {:.4f} sec.'.format(t_epoch_finish - t_epoch_start))
        t_epoch_start = time.time()

        log_epoch = {'epoch': epoch+1, 'train_loss': epoch_train_loss/num_train_imgs, 
                'val_loss': 0}
        logs.append(log_epoch)
        df = pd.DataFrame(logs)
        df.to_csv("log_output_openpose.csv")

        torch.save(net.state_dict(), 'weights/openpose_' + str(epoch+1) + '.pth')

In [None]:
num_epochs = 2
train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)

device: cuda:0
---------------
Epoch 1/2
---------------
(train)
iteration 10 || Loss: 0.0092 || 10iter: 446.4787 sec.
iteration 20 || Loss: 0.0083 || 10iter: 826.9023 sec.
iteration 30 || Loss: 0.0068 || 10iter: 1186.9212 sec.
iteration 40 || Loss: 0.0057 || 10iter: 1532.4195 sec.
iteration 50 || Loss: 0.0052 || 10iter: 1839.6095 sec.
iteration 60 || Loss: 0.0042 || 10iter: 2138.0834 sec.
iteration 70 || Loss: 0.0036 || 10iter: 2398.5133 sec.
iteration 80 || Loss: 0.0034 || 10iter: 2652.2420 sec.
iteration 90 || Loss: 0.0028 || 10iter: 2931.7767 sec.
iteration 100 || Loss: 0.0025 || 10iter: 3169.8675 sec.
iteration 110 || Loss: 0.0023 || 10iter: 3423.6697 sec.
iteration 120 || Loss: 0.0020 || 10iter: 3646.1087 sec.
iteration 130 || Loss: 0.0017 || 10iter: 3880.0846 sec.
iteration 140 || Loss: 0.0018 || 10iter: 4102.8560 sec.
iteration 150 || Loss: 0.0017 || 10iter: 4328.1095 sec.
--------------


NameError: ignored