In [1]:
import sys
sys.path.append('.')
from dataset import *
from loss import create_criterion

import argparse
import glob
import json
import multiprocessing
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import random
import re
from importlib import import_module
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter


import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from typing import Optional, Dict, Union

# from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import smplx


In [2]:
import easydict
args=easydict.EasyDict({
    # Data and model checkpoints directories
    'name':'exp',
    'seed':42,
    'epochs':5,
    'dataset':'temp_dataset',
    'augmentation':'BaseAugmentation', 
    'resize':[512,512], 
    'batch_size':10, 
    'valid_batch_size':64, 
    'model':'Resnet', 
    'optimizer':'Adam', 
    'log_interval':5,
    'lr':0.001, 
    'val_ratio':0.2,
    'criterion':'CustomLoss_joint',
    'lr_decay_step':20, 
    'data_dir':r'F:\ego_cam_dataset', 
    'model_dir':r'C:\Users\user\Documents\GitHub\2D_to_3D_ver3\apps',
    'smpl_dir':r'D:\SMPL\SMPL_python_v.1.0.0\smpl\models'
})

In [3]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
        
def increment_path(path, exist_ok=False):
    """ Automatically increment path, i.e. runs/exp --> runs/exp0, runs/exp1 etc.

    Args:
        path (str or pathlib.Path): f"{model_dir}/{args.name}".
        exist_ok (bool): whether increment path (increment if False).
    """
    path = Path(path)
    if (path.exists() and exist_ok) or (not path.exists()):
        return str(path)
    else:
        dirs = glob.glob(f"{path}*")
        matches = [re.search(rf"%s(\d+)" % path.stem, d) for d in dirs]
        i = [int(m.groups()[0]) for m in matches if m]
        n = max(i) + 1 if i else 2
        return f"{path}{n}"

In [4]:

model_folder = r'C:\Users\user\Documents\GitHub\smplx'
model_type = 'smpl'
plot_joints = 'true'
use_face_contour = False
gender = 'female'
ext = 'npz'
num_betas = 10
num_expression_coeffs = 10

def viewer(vector_82s,image=None,plot_joints=True):
    temp_smpl=smplx.create(model_folder, model_type=model_type,
                         gender=gender, use_face_contour=use_face_contour,
                         num_betas=num_betas,
                         num_expression_coeffs=num_expression_coeffs,
                         ext=ext)
    fig = plt.figure(figsize=(6, 70))
    for i,vector_82 in enumerate(vector_82s,start=1):
        go=vector_82[:3].unsqueeze(0).float().to('cpu')
        pose=vector_82[3:72].unsqueeze(0).float().to('cpu')
        shape=vector_82[72:].unsqueeze(0).float().to('cpu')
        output = temp_smpl(betas=shape,global_orient=go,body_pose=pose,return_verts=True)
        
        joints=output.joints.detach().cpu().numpy().squeeze()
        vertices = output.vertices.detach().cpu().numpy().squeeze()
        
        ax = fig.add_subplot(len(vector_82s),1,i, projection='3d')

        mesh = Poly3DCollection(vertices[temp_smpl.faces], alpha=0.1)
        face_color = (1.0, 1.0, 0.9)
        edge_color = (0, 0, 0)
        mesh.set_edgecolor(edge_color)
        mesh.set_facecolor(face_color)
        ax.add_collection3d(mesh)
        ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='r')

        if plot_joints:
            ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], alpha=0.1)

    plt.show()

    return fig



def grid_image(np_images, gts, preds, n=16, shuffle=False):
    batch_size = np_images.shape[0]
    assert n <= batch_size

    choices = random.choices(range(batch_size), k=n) if shuffle else list(range(n))
    figure = plt.figure(figsize=(12, 18 + 2))  # cautions: hardcoded, 이미지 크기에 따라 figsize 를 조정해야 할 수 있습니다. T.T
    plt.subplots_adjust(top=0.8)  # cautions: hardcoded, 이미지 크기에 따라 top 를 조정해야 할 수 있습니다. T.T
    n_grid = int(np.ceil(n ** 0.5))
    tasks = ["mask", "gender", "age"]
    for idx, choice in enumerate(choices):
        gt = gts[choice].item()
        pred = preds[choice].item()
        image = np_images[choice]
        gt_decoded_labels = MaskBaseDataset.decode_multi_class(gt)
        pred_decoded_labels = MaskBaseDataset.decode_multi_class(pred)
        title = "\n".join([
            f"{task} - gt: {gt_label}, pred: {pred_label}"
            for gt_label, pred_label, task
            in zip(gt_decoded_labels, pred_decoded_labels, tasks)
        ])

        plt.subplot(n_grid, n_grid, idx + 1, title=title)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(image, cmap=plt.cm.binary)

    return figure

In [5]:
def train(data_dir, model_dir, args):
    seed_everything(args.seed)

    save_dir = increment_path(os.path.join(model_dir, args.name))
    print(save_dir)
    # -- settings
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # -- dataset
    dataset_module = getattr(import_module("dataset"), args.dataset)  # default: MaskBaseDataset
    dataset = dataset_module(
        dataroot=data_dir,
    )
    # num_classes = dataset.num_classes  # 18

    # -- augmentation
    transform_module = getattr(import_module("dataset"), args.augmentation)  # default: BaseAugmentation
    transform = transform_module(
        resize=args.resize,
        # mean=dataset.mean,
        # std=dataset.std,
    )
    dataset.set_transform(transform)

    # -- data_loader
    train_set, val_set = dataset.split_dataset()

    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        # num_workers=multiprocessing.cpu_count() // 2,
        num_workers= 0,
        shuffle=True,
        pin_memory=use_cuda,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_set,
        batch_size=args.valid_batch_size,
        # num_workers=multiprocessing.cpu_count() // 2,
        num_workers= 0,
        shuffle=False,
        pin_memory=use_cuda,
        drop_last=True,
    )

    # -- model
    model_module = getattr(import_module("model"), args.model)  # default: BaseModel
    model = model_module(
        num_classes=82
    ).to(device)
    model = torch.nn.DataParallel(model)

    # -- loss & metric
    criterion = create_criterion(args.criterion)  # default: cross_entropy
    opt_module = getattr(import_module("torch.optim"), args.optimizer)  # default: SGD
    optimizer = opt_module(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=args.lr,
        weight_decay=5e-4
    )
    scheduler = StepLR(optimizer, args.lr_decay_step, gamma=0.5)

    # -- logging
    logger = SummaryWriter(log_dir=save_dir)
    with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
        json.dump(vars(args), f, ensure_ascii=False, indent=4)

    best_val_acc = 0
    best_val_loss = np.inf
    for epoch in range(args.epochs):
        # train loop
        model.train()
        loss_value = 0
        matches = 0
        for idx, train_batch in enumerate(train_loader):
            ret_dict_train = train_batch
            inputs, labels = ret_dict_train['image'],ret_dict_train['joints']
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outs = model(inputs)
            preds = torch.argmax(outs, dim=-1)
            loss = criterion(outs, labels)
   
            loss.backward()
            optimizer.step()

            loss_value += loss.item()
            # matches += (preds == labels).sum().item()
            if (idx + 1) % args.log_interval == 0:
                train_loss = loss_value / args.log_interval
                # train_acc = matches / args.batch_size / args.log_interval
                current_lr = get_lr(optimizer)
                print(
                    f"Epoch[{epoch}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"training loss {train_loss:4.4} || lr {current_lr}"
                    # f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
                )
                # viewer(outs)
                if (idx + 1) % 200 == 0:
                    viewer(outs)
                    torch.save(model.module.state_dict(), f"{save_dir}/last.pth")
                # logger.add_scalar("Train/loss", train_loss, epoch * len(train_loader) + idx)
                # logger.add_scalar("Train/accuracy", train_acc, epoch * len(train_loader) + idx)

                loss_value = 0
                matches = 0

        scheduler.step()

        # val loop
        with torch.no_grad():
            print("Calculating validation results...")
            model.eval()
            val_loss_items = []
            val_acc_items = []
            figure = None
            for val_batch in val_loader:
                ret_dict_val = val_batch
                inputs, labels = ret_dict_val['image'],ret_dict_val['joints']
                inputs = inputs.to(device)
                labels = labels.to(device)

                outs = model(inputs)
                preds = torch.argmax(outs, dim=-1)

                loss_item = criterion(outs, labels).item()
                # acc_item = (labels == preds).sum().item()
                val_loss_items.append(loss_item)
                # val_acc_items.append(acc_item)

                # if figure is None:
                #     inputs_np = torch.clone(inputs).detach().cpu().permute(0, 2, 3, 1).numpy()
                #     inputs_np = dataset_module.denormalize_image(inputs_np, dataset.mean, dataset.std)
                #     figure = grid_image(
                #         inputs_np, labels, preds, n=16, shuffle=args.dataset != "MaskSplitByProfileDataset"
                #     )

            val_loss = np.sum(val_loss_items) / len(val_loader)
            val_acc = np.sum(val_acc_items) / len(val_set)
            best_val_loss = min(best_val_loss, val_loss)
            if val_acc > best_val_acc:
                print(f"New best model for val accuracy : {val_acc:4.2%}! saving the best model..")
                torch.save(model.module.state_dict(), f"{save_dir}/best.pth")
                best_val_acc = val_acc
            torch.save(model.module.state_dict(), f"{save_dir}/last.pth")
            print(
                f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.2} || "
                f"best acc : {best_val_acc:4.2%}, best loss: {best_val_loss:4.2}"
            )
            # logger.add_scalar("Val/loss", val_loss, epoch)
            # logger.add_scalar("Val/accuracy", val_acc, epoch)
            # logger.add_figure("results", figure, epoch)
            print()

In [6]:
# raise

In [None]:
data_dir = args.data_dir
model_dir = args.model_dir
print(data_dir)
train(data_dir, model_dir, args)

In [None]:
raise

RuntimeError: No active exception to reraise

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

data_dir=r'F:\ego_cam_dataset'

# -- dataset
dataset_module = getattr(import_module("dataset"), 'temp_dataset')  # default: MaskBaseDataset
dataset = dataset_module(
	dataroot=data_dir,
)
# num_classes = dataset.num_classes  # 18

# -- augmentation
transform_module = getattr(import_module("dataset"), 'BaseAugmentation')  # default: BaseAugmentation
transform = transform_module(
	resize=[512,512],
	# mean=dataset.mean,
	# std=dataset.std,
)
dataset.set_transform(transform)

# -- data_loader
train_set, val_set = dataset.split_dataset()



train_loader = DataLoader(
	train_set,
	batch_size=args.batch_size,
	# num_workers=multiprocessing.cpu_count() // 2,
	num_workers= 0,
	shuffle=True,
	pin_memory=use_cuda,
	drop_last=True,
)

for batch in train_loader:
	print(batch)
	break



{'image': tensor([[[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],

         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],

         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]],


        [[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,

In [None]:
temp_path=r'F:\ego_cam_dataset\female_001_a_a\env_001\cam_down\json\female_001_a_a_000001.json'

import json
temp_json=None
with open(temp_path,'r') as f:
	temp_json = json.loads(f.read())

torch.tensor(temp_json['pts3d_fisheye'])

tensor([[ -0.6680,  -0.6680,  -0.6098,  15.6957,  17.7588,  36.6055,  60.1680,
          71.7993,  74.8752,  77.6928,  80.2030,  72.7545,  75.6986,  78.3872,
          80.7347,  70.5293,  72.9097,  74.8153,  76.4923,  72.8526,  75.2624,
          77.4694,  79.3901,  64.4243,  67.4582,  69.8666,  71.7331,  13.2310,
           4.8682,  16.8883,  16.0242,   7.0379,  -0.6601, -17.0028, -19.9913,
         -37.9556, -61.5204, -73.6135, -76.4736, -79.1507, -81.4400, -73.3672,
         -76.4634, -79.3759, -81.9084, -73.2427, -75.1171, -76.5951, -77.8073,
         -73.2462, -75.9175, -78.4076, -80.6400, -65.6441, -68.6585, -71.3106,
         -73.5754, -15.0040,  -6.1841, -19.2568, -18.4758,  -8.2603,  -0.6197,
          -0.6315,  -0.6449],
        [-18.7340, -12.2619, -19.9708, -19.7429, -23.3717, -19.6090, -19.4217,
         -17.3664, -17.3415, -17.3188, -17.2985, -19.3198, -19.2961, -19.2743,
         -19.2553, -22.6989, -22.6797, -22.6643, -22.6508, -21.1558, -21.1363,
         -21.1185, -21