In [1]:
import sys
sys.path.append('.')
from dataset import *
from loss import create_criterion

from model import get_pose_net
import argparse
import glob
import json
import multiprocessing
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import random
import re
import platform
from importlib import import_module
from pathlib import Path
# from torch.utils.tensorboard import SummaryWriter
import wandb

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim.lr_scheduler import StepLR,LambdaLR
from torch.utils.data import DataLoader
from typing import Optional, Dict, Union

# from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import smplx

from torchvision.transforms.functional import to_pil_image
from collections import OrderedDict
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [3]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [4]:
import easydict
args=easydict.EasyDict({
    # Data and model checkpoints directories
    'name':'exp',
    'seed':42,
    'epochs':5,
    'dataset':'temp_dataset',
    'augmentation':'BaseAugmentation', 
    'resize':[512,512], 
    'batch_size':20, 
    'valid_batch_size':20, 
    'model':'TempModel', 
    'optimizer':'Adam', 
    'log_interval':5,
    'lr':0.00025, 
    'val_ratio':0.2,
    'criterion_2':'depth_criterion',
    'criterion_3':'projection_criterion',
    'criterion_4':'cam_criterion',
    'criterion_5':'joint_3d_criterion',
    'criterion_6':'heatmap_criterion',
    'criterion_7':'heatmap_proj_criterion',
    'lr_decay_step':6000, 
    'data_dir':'/dataset/egodataset', 
    'model_dir':'/workspace/2d_to_3d/apps',
    'smpl_dir':'/workspace/2d_to_3d/model/smpl',
    'model_pretrained_path':'/workspace/2d_to_3d/apps/exp71\last.pth'
})

In [5]:
config_={
    'criterion_weight' : 1,
    'criterion_weight' : 1,
    'criterion_weight' : 1,
    'criterion_weight' : 1,
    'criterion_weight' : 1,
    'criterion_weight' : 1,
    'lr_decay_step' : 20,
    'lr' : 0.001,
    'val_ratio' : 0.2,
}

In [6]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
        
def increment_path(path, exist_ok=False):
    """ Automatically increment path, i.e. runs/exp --> runs/exp0, runs/exp1 etc.

    Args:
        path (str or pathlib.Path): f"{model_dir}/{args.name}".
        exist_ok (bool): whether increment path (increment if False).
    """
    path = Path(path)
    if (path.exists() and exist_ok) or (not path.exists()):
        return str(path)
    else:
        dirs = glob.glob(f"{path}*")
        matches = [re.search(rf"%s(\d+)" % path.stem, d) for d in dirs]
        i = [int(m.groups()[0]) for m in matches if m]
        n = max(i) + 1 if i else 2
        return f"{path}{n}"

def nan_detect_hook(module,input,output,label_info,label):
    if torch.isnan(output).any():
        print(f'nan : in {module}')
        sys.exit(1)
        

In [7]:
split_token = '/' if 'Linux' in platform.platform() else '\\'
# model_folder = r'C:\Users\user\Documents\GitHub\smplx'
# model_type = 'smpl'
# plot_joints = 'true'
# use_face_contour = False
# gender = 'female'
# ext = 'npz'
# num_betas = 10
# num_expression_coeffs = 10

ktree_pred = [-1,  0,  0,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  9,  9, 12, 13,
		14, 16, 17, 18, 19, 20, 21, 15, 20, 25, 26, 20, 28, 29, 20, 31, 32,
		20, 34, 35, 20, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49,
		50, 21, 52, 53]

ktree_label = [-1,0,0,0,1,2,3,4,5,6,7,8,9,12,12,13,14,15,16,17,18,12]
xR_2_SMPL=[2,31,61,62,27,57,63,4,34,64,29,59,0,28,58,1,3,33,5,35,6,36,11,41]
skip_num = []

"""
def joint_2d_viewer(input_images, joints, labels, infos):
	temp_input_images = input_images.clone().detach().cpu()
	temp_joints = joints.clone().detach().cpu()
	temp_labels = labels.clone().detach().cpu()
	temp_infos = infos
	fig,ax = plt.subplots(len(labels),3,figsize=(10, 70))
   
	for i,(input_image,pred_joint,label,info) in enumerate(zip(temp_input_images,temp_joints,temp_labels,temp_infos)):
		# ax = fig.add_subplot(len(input_images),1,i)
		# ax.scatter(joints[:22, 0], joints[:22, 1], -joints[:22, 2], color='r')

		pred_joint = pred_joint.numpy()
		label = label.squeeze(0).numpy()
		# # SMPL joints line plot
		for j in reversed(range(22)):
			if not j:break
			if j in skip_num : continue
			pred_joint_line_x=[pred_joint[j,0],pred_joint[ktree_pred[j],0]]
			pred_joint_line_y=[pred_joint[j,1],pred_joint[ktree_pred[j],1]]
			label_joint_line_x=[label[j,0],label[ktree_label[j],0]]
			label_joint_line_y=[label[j,1],label[ktree_label[j],1]]
			ax[i][1].plot(pred_joint_line_x,pred_joint_line_y)
			ax[i][2].plot(label_joint_line_x,label_joint_line_y)

		ax[i][0].set_aspect('equal')
		ax[i][1].set_aspect('equal')
		ax[i][2].set_aspect('equal')
		ax[i][1].view_init(-30,60,180)
		ax[i][2].view_init(-30,60,180)
		# ax[i][1].set_xlabel('x')
		# ax[i][1].set_ylabel('y')
		


		input_image=to_pil_image(input_image)
		ax[i][0].imshow(input_image)

		ax[i][0].set_title(info.split(split_token)[-1])
		ax[i][1].set_title('pred_2d_joint')
		ax[i][2].set_title('joint_GT')
	# plt.show()

	return fig
"""

def fisheye_joint_2d_viewer(input_images, joints, labels, infos, feature_size = (512,512)):
	temp_input_images = input_images.clone().detach().cpu()
	temp_joints = joints.clone().detach().cpu()
	temp_labels = labels.clone().detach().cpu()

	fig,ax = plt.subplots(len(labels),3,figsize=(10, 70))
   
	for i,(input_image,pred_joint,label,info) in enumerate(zip(temp_input_images,temp_joints,temp_labels,infos)):
		# ax = fig.add_subplot(len(input_images),1,i)
		# ax.scatter(joints[:22, 0], joints[:22, 1], -joints[:22, 2], color='r')
		label = label.squeeze(0)
		# SMPL_form_pred=[]
		# SMPL_form_GT=[]
		# for xR_idx in xR_2_SMPL:
		#     SMPL_form_pred.append(pred_joint[xR_idx])
		#     SMPL_form_GT.append(label[xR_idx])
		# SMPL_form_pred=torch.stack(SMPL_form_pred)    
		# SMPL_form_GT=torch.stack(SMPL_form_GT)
		# pred_joint = SMPL_form_pred
		# label = SMPL_form_GT
		# # SMPL joints line plot
		for j in reversed(range(len(pred_joint))):
			if not j:break
			if j in skip_num : continue
			pred_joint_line_x=[pred_joint[j,0],pred_joint[ktree_label[j],0]]
			pred_joint_line_y=[pred_joint[j,1],pred_joint[ktree_label[j],1]]
			label_joint_line_x=[label[j,0],label[ktree_label[j],0]]
			label_joint_line_y=[label[j,1],label[ktree_label[j],1]]
			ax[i][1].plot(pred_joint_line_x,pred_joint_line_y)
			ax[i][2].plot(label_joint_line_x,label_joint_line_y)

		ax[i][0].set_aspect('equal')
		ax[i][1].set_aspect('equal')
		ax[i][2].set_aspect('equal')
  
		ax[i][1].set_xlim(0,feature_size[1])
		ax[i][1].set_ylim(0,feature_size[0])
		ax[i][2].set_xlim(0,feature_size[1])
		ax[i][2].set_ylim(0,feature_size[0])
		ax[i][1].invert_yaxis()
		ax[i][2].invert_yaxis()
		# ax[i][1].set_xlabel('x')
		# ax[i][1].set_ylabel('y')
		


		input_image=to_pil_image(input_image)
		ax[i][0].imshow(input_image)

		ax[i][0].set_title(info.split(split_token)[-1])
		ax[i][1].set_title('pred_fisheye_joint')
		ax[i][2].set_title('fisheye_GT')



	# plt.show()

	return fig


def joint_3d_viewer(input_images, joints, labels, infos):
	temp_input_images = input_images.clone().detach().cpu()
	temp_labels = labels.clone().detach().cpu()
	temp_joints = joints.clone().detach().cpu()
	fig,ax = plt.subplots(len(labels),3,figsize=(10,70),subplot_kw={"projection":"3d"})
	for i in range(len(labels)):
		rows, cols, start, stop = ax[i][0].get_subplotspec().get_geometry()
		ax[i][0].remove()
		ax[i][0] = fig.add_subplot(rows,cols,start+1)
	
	for i,(input_image, joint, label, info) in enumerate(zip(temp_input_images, temp_joints, temp_labels, infos)):
		input_image=to_pil_image(input_image)

		# SMPL joints line plot
		for j in reversed(range(len(joint))):
			if not j:break
			if j in skip_num : continue
			pred_joint_line_x=[joint[j,0],joint[ktree_label[j],0]]
			pred_joint_line_y=[joint[j,1],joint[ktree_label[j],1]]
			pred_joint_line_z=[joint[j,2],joint[ktree_label[j],2]]
			ax[i][1].plot(pred_joint_line_x, pred_joint_line_y, pred_joint_line_z)
			label_joint_line_x=[label[j,0],label[ktree_label[j],0]]
			label_joint_line_y=[label[j,1],label[ktree_label[j],1]]
			label_joint_line_z=[label[j,2],label[ktree_label[j],2]] 
			ax[i][2].plot(label_joint_line_x,label_joint_line_y, label_joint_line_z)

		# for j in reversed(range(len(label))):
		# 	if not j:break
		# 	if j in skip_num : continue

		# 	label_joint_line_x=[label[j,0],label[ktree_label[j],0]]
		# 	label_joint_line_y=[label[j,1],label[ktree_label[j],1]]
		# 	label_joint_line_z=[label[j,2],label[ktree_label[j],2]] 

		# 	ax[i][2].plot(label_joint_line_x,label_joint_line_y, label_joint_line_z)
		
		ax[i][0].set_aspect('equal')
		ax[i][1].set_aspect('equal')
		ax[i][2].set_aspect('equal')
		ax[i][1].view_init(-30,60,180)
		ax[i][2].view_init(-30,60,180)
		ax[i][1].set_xlabel('x')
		ax[i][1].set_ylabel('y')
		ax[i][1].set_zlabel('z')
		ax[i][2].set_ylabel('y')
		ax[i][2].set_xlabel('x')
		ax[i][2].set_zlabel('z')
		
		ax[i][0].imshow(input_image)

		ax[i][0].set_title(info.split(split_token)[-1])
		ax[i][1].set_title('pred_3d_joint')
		ax[i][2].set_title('3d_joint_GT')
	# plt.show()
	return fig

def depth_viewer(input_images, features, labels, infos):
	temp_input_images=input_images.clone().detach().cpu()
	temp_features=features.clone().detach().cpu()
	temp_labels=labels.clone().detach().cpu()
	fig,ax = plt.subplots(len(labels),3,figsize=(10, 70))

	for i,(input_image, feature, label,info) in enumerate(zip(temp_input_images, temp_features, temp_labels,infos)):

		input_image=to_pil_image(input_image)
		feature=to_pil_image(feature)
		label=to_pil_image(label)
		ax[i][0].set_title(info.split(split_token)[-1])
		ax[i][0].imshow(input_image )
		ax[i][1].set_title('pred_depth_feature')
		ax[i][1].imshow(feature)
		ax[i][2].set_title('depth_GT')
		ax[i][2].imshow(label)


	# plt.show()

	return fig  

def heatmap_viewer(input_images, heatmaps, labels, infos):
	temp_input_images = input_images.clone().detach()
	temp_heatmaps = heatmaps.clone().detach().cpu()
	temp_labels = labels.clone().detach().cpu()
	fig, ax = plt.subplots(len(labels),3,figsize=(10,70))

	for i,(input_image, heatmap, label, info) in enumerate(zip(temp_input_images, temp_heatmaps, temp_labels, infos)):


		input_image = to_pil_image(input_image)
		total_heatmap = torch.zeros(heatmap.shape[1:])
		for j in range(len(heatmap)):
			total_heatmap += heatmap[j,:,:]

		total_label = torch.zeros(heatmap.shape[1:])
		for j in range(len(label)):
			total_label += label[j,:,:]

		ax[i][0].set_title(info.split(split_token)[-1])
		ax[i][0].imshow(input_image)
		ax[i][1].set_title('pred_heatmap')
		ax[i][1].imshow(total_heatmap)
		ax[i][2].set_title('heatmap_GT')
		ax[i][2].imshow(total_label)

	# plt.show()

	return fig 

In [8]:
def train(data_dir, model_dir, args, logging=True):
	seed_everything(args.seed)
	if logging:wandb.init(project="2d to 3d", entity="vhehduatks")

	save_dir = increment_path(os.path.join(model_dir, args.name))
	print(save_dir)
	os.makedirs(save_dir)
	# -- settings


	# -- dataset
	dataset_module = getattr(import_module("dataset"), args.dataset) 
	dataset = dataset_module(
		dataroot=data_dir,
	)
	# num_classes = dataset.num_classes  # 18

	# -- augmentation
	# transform_module = getattr(import_module("dataset"), args.augmentation)  # default: BaseAugmentation
	# transform = transform_module(
	#     resize=args.resize,
	#     # mean=dataset.mean,
	#     # std=dataset.std,
	# )
	# dataset.set_transform(transform)

	# -- data_loader
	train_set, val_set = dataset.split_dataset()

	train_loader = DataLoader(
		train_set,
		batch_size=args.batch_size,
		num_workers=multiprocessing.cpu_count() // 2,
		# num_workers= 0,
		shuffle=True,
		pin_memory=use_cuda,
		drop_last=True,
	)

	val_loader = DataLoader(
		val_set,
		batch_size=args.valid_batch_size,
		num_workers=multiprocessing.cpu_count() // 2,
		# num_workers= 0,
		shuffle=False,
		pin_memory=use_cuda,
		drop_last=True,
	)

	# # -- feature_model
	# feature_model = get_pose_net(True)
	# feature_model = torch.nn.DataParallel(feature_model)

	# -- reg_model
	model_module = getattr(import_module("model"), args.model)  # default: BaseModel
	model = model_module().to(device)
	model = torch.nn.DataParallel(model,device_ids=[0,1])


	# -- loss & metric
	# smpl_criterion = create_criterion(args.criterion_1)
	depth_criterion = create_criterion(args.criterion_2)   # MSE
	projection_criterion = create_criterion(args.criterion_3) # L1
	cam_criterion = create_criterion(args.criterion_4) # MSE
	joint_3d_criterion = create_criterion(args.criterion_5) # MSE
	heatmap_criterion = create_criterion(args.criterion_6) # MSE
	heatmap_proj_criterion = create_criterion(args.criterion_7)

	# opt_module = getattr(import_module("torch.optim"), args.optimizer)  # default: SGD
	optimizer = torch.optim.Adam(
		params=model.parameters(),
		lr=args.lr,# 0.001
		weight_decay=5e-4
	)
	scheduler = StepLR(optimizer, args.lr_decay_step, gamma=0.5)
	# scheduler = LambdaLR(optimizer,lambda epoch: 0.65 ** epoch)
	# -- logging
	# logger = SummaryWriter(log_dir=save_dir)
	if logging:wandb.config=vars(args)
	# with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
	#     json.dump(vars(args), f, ensure_ascii=False, indent=4)

	# best_val_acc = 0
	best_val_loss = np.inf

	for epoch in range(args.epochs):
		# train loop
		model.train()
		
		loss_value = 0
		matches = 0
		for idx, train_batch in enumerate(train_loader):
			total_loss={}
			ret_dict_train = train_batch
			inputs = {
				'image' : ret_dict_train['image'].cuda(),
				'depth' : ret_dict_train['depth'].cuda(),
				'heatmap' : ret_dict_train['heatmap'].cuda(),
				'camera_info' : ret_dict_train['camera_info']
			}
			# joint_2d_labels = ret_dict_train['joints_2d'].to(device)
			joint_3d_labels = ret_dict_train['joints_3d_cam'].to(device)
			depth_labels = ret_dict_train['depth'].to(device)
			cam_labels_trans,cam_labels_rot = ret_dict_train['camera_info']
			infos = ret_dict_train['info']
			fisheye_labels = ret_dict_train['fisheye_joints_2d'].to(device)
			heatmap_labels = ret_dict_train['heatmap'].to(device)
			heatmap_proj_labels = ret_dict_train['heatmap_1'].to(device)

			pred_dict = model(inputs,is_train=True,epoch=epoch)
			
			depth_loss = depth_criterion(pred_dict['depth_feature'], depth_labels)
			# depth_criterion.register_forward_hook(partial(nan_detect_hook, label_info=infos, label=depth_labels))
			total_loss['depth_loss']= depth_loss * 100
			heatmap_loss = heatmap_criterion(pred_dict['heatmap'],heatmap_labels)
			# heatmap_criterion.register_forward_hook(partial(nan_detect_hook,label_info=infos, label = heatmap_labels))
			total_loss['heatmap_loss'] = heatmap_loss * 1000
			
			cam_loss_trans = cam_criterion(pred_dict['regressor2_dict']['pred_trans'],cam_labels_trans.to(device))
			cam_loss_rot = cam_criterion(pred_dict['regressor2_dict']['pred_rot'],cam_labels_rot.to(device))
			# cam_criterion.register_forward_hook(partial(nan_detect_hook,label_info=infos, label = cam_labels_trans))
			total_loss['cam_loss'] = ((cam_loss_trans * 0.01) + (cam_loss_rot * 1))/2

			fisheye_projection_2d_loss = projection_criterion(pred_dict['regressor2_dict']['fisheye_kp_2d'],fisheye_labels)
			# projection_criterion.register_forward_hook(partial(nan_detect_hook,label_info=infos, label = fisheye_labels))
			total_loss['projection_2d_loss'] = fisheye_projection_2d_loss * 0.1

			heatmap_projection_loss = heatmap_proj_criterion(pred_dict['regressor2_dict']['pred_heatmap_smpl'],heatmap_proj_labels)
			total_loss['heatmap_projection_loss'] = heatmap_projection_loss * 1000

			joint_3d_loss = joint_3d_criterion(pred_dict['regressor2_dict']['kp_3d_cam'],joint_3d_labels)
			# joint_3d_criterion.register_forward_hook(partial(nan_detect_hook,label_info=infos, label = joint_3d_labels))
			total_loss['joint_3d_loss'] = joint_3d_loss * 0.01
			


			# fisheye_cam_trans_loss = cam_criterion(pred_dict['regressor1_dict']['pred_trans'],cam_labels_trans.to(device))
			# fisheye_cam_rot_loss = cam_criterion(pred_dict['regressor1_dict']['pred_rot'],cam_labels_rot.to(device))
			# total_loss['fisheye_cam_loss'] = (fisheye_cam_trans_loss * 0.01 + fisheye_cam_rot_loss * 1)/2


			loss = torch.stack(list(total_loss.values())).sum()

			optimizer.zero_grad()
			loss.backward()
			optimizer.step()
			scheduler.step()
			
			# loss_value += loss.item()
			# matches += (preds == labels).sum().item()
			if (idx + 1) % args.log_interval == 0:
				# train_loss = loss_value / args.log_interval
				# train_acc = matches / args.batch_size / args.log_interval
				current_lr = get_lr(optimizer)
				print("=======================================================================")
				for loss_name,val in total_loss.items():
					print(
						f"Epoch[{epoch}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
						f"training loss : {loss_name} : {val:4.4} || lr {current_lr}"
						# f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
					)
					if logging:
						wandb.log({
							'train/'+loss_name : val,
							"train/lr" : current_lr,
							"train/Epoch" : epoch    
							})
			# viewer(outs)
			if (idx + 1) % 300 == 0:
				# fig2 = viewer(outs,joint_labels,infos)
				fig_dict = {
				'depth_feature_fig' : depth_viewer(inputs['image'],pred_dict['depth_feature'],depth_labels,infos),
				# 'joint_2d_fig' : joint_2d_viewer(inputs['image'], pred_dict['regressor2_res_dict']['kp_2d'],joint_2d_labels,infos),
				'fisheye_2d_fig' : fisheye_joint_2d_viewer(inputs['image'],pred_dict['regressor2_dict']['fisheye_kp_2d'],fisheye_labels,infos),
				'heatmap_fig' : heatmap_viewer(inputs['image'],pred_dict['heatmap'],heatmap_labels,infos),
				'joint_3d_fig' : joint_3d_viewer(inputs['image'],pred_dict['regressor2_dict']['kp_3d_cam'],joint_3d_labels,infos),
				'heatmap_smpl_fig' : heatmap_viewer(inputs['image'],pred_dict['regressor2_dict']['pred_heatmap_smpl'],heatmap_proj_labels,infos)
				
				}
				if logging:
					#step=len(train_loader)*epoch+idx
					wandb.log({
						"train/depth_feature_fig":[wandb.Image(fig_dict['depth_feature_fig'])],
						# "joint_2d_fig_img":[wandb.Image(joint_2d_fig)],
						"train/fisheye_2d_fig":[wandb.Image(fig_dict['fisheye_2d_fig'])],
						"train/heatmap_fig":[wandb.Image(fig_dict['heatmap_fig'])],
						"train/joint_3d_fig":[wandb.Image(fig_dict['joint_3d_fig'])],
						"train/heatmap_smpl_fig":[wandb.Image(fig_dict['heatmap_smpl_fig'])],
					})
				for key,val in fig_dict.items():
					plt.close(val)
			# logger.add_scalar("Train/loss", train_loss, epoch * len(train_loader) + idx)
			# logger.add_scalar("Train/accuracy", train_acc, epoch * len(train_loader) + idx)

		
		
		# val loop
		with torch.no_grad():
			print("Calculating validation results...")
			model.eval()
			for val_idx,val_batch in enumerate(val_loader):
				total_loss={}
				ret_dict_train = val_batch
				inputs = {
				'image':ret_dict_train['image'].cuda(),
				'depth':ret_dict_train['depth'].cuda(),
				'heatmap':ret_dict_train['heatmap'].cuda(),
				'camera_info':ret_dict_train['camera_info']
				}
				# joint_2d_labels = ret_dict_train['joints_2d'].to(device)
				joint_3d_labels = ret_dict_train['joints_3d_cam'].to(device)
				depth_labels = ret_dict_train['depth'].to(device)
				cam_labels_trans,cam_labels_rot = ret_dict_train['camera_info']
				infos = ret_dict_train['info']
				fisheye_labels = ret_dict_train['fisheye_joints_2d'].to(device)
				heatmap_labels = ret_dict_train['heatmap'].to(device)
				heatmap_proj_labels = ret_dict_train['heatmap_1'].to(device)

				pred_dict = model(inputs,is_train=False,epoch=epoch)

				depth_loss = depth_criterion(pred_dict['depth_feature'], depth_labels)
				# depth_criterion.register_forward_hook(partial(nan_detect_hook, label_info=infos, label=depth_labels))
				total_loss['depth_loss']= depth_loss * 100
				heatmap_loss = heatmap_criterion(pred_dict['heatmap'],heatmap_labels)
				# heatmap_criterion.register_forward_hook(partial(nan_detect_hook,label_info=infos, label = heatmap_labels))
				total_loss['heatmap_loss'] = heatmap_loss * 1000
				
				cam_loss_trans = cam_criterion(pred_dict['regressor2_dict']['pred_trans'],cam_labels_trans.to(device))
				cam_loss_rot = cam_criterion(pred_dict['regressor2_dict']['pred_rot'],cam_labels_rot.to(device))
				# cam_criterion.register_forward_hook(partial(nan_detect_hook,label_info=infos, label = cam_labels_trans))
				total_loss['cam_loss'] = ((cam_loss_trans * 0.01) + (cam_loss_rot * 1))/2

				fisheye_projection_2d_loss = projection_criterion(pred_dict['regressor2_dict']['fisheye_kp_2d'],fisheye_labels)
				# projection_criterion.register_forward_hook(partial(nan_detect_hook,label_info=infos, label = fisheye_labels))
				total_loss['projection_2d_loss'] = fisheye_projection_2d_loss * 0.1

				heatmap_projection_loss = heatmap_proj_criterion(pred_dict['regressor2_dict']['pred_heatmap_smpl'],heatmap_proj_labels)
				total_loss['heatmap_projection_loss'] = heatmap_projection_loss * 1000

				joint_3d_loss = joint_3d_criterion(pred_dict['regressor2_dict']['kp_3d_cam'],joint_3d_labels)
				# joint_3d_criterion.register_forward_hook(partial(nan_detect_hook,label_info=infos, label = joint_3d_labels))
				total_loss['joint_3d_loss'] = joint_3d_loss * 0.01


				# fisheye_cam_trans_loss = cam_criterion(pred_dict['regressor1_dict']['pred_trans'],cam_labels_trans.to(device))
				# fisheye_cam_rot_loss = cam_criterion(pred_dict['regressor1_dict']['pred_rot'],cam_labels_rot.to(device))
				# total_loss['fisheye_cam_loss'] = (fisheye_cam_trans_loss * 0.01 + fisheye_cam_rot_loss * 1)/2

		  

				loss = torch.stack(list(total_loss.values())).sum()

				if (val_idx + 1) % args.log_interval == 0:
		 
					current_lr = get_lr(optimizer)
					print("=======================================================================")
					for loss_name,value in total_loss.items():
						loss_name = 'val_'+loss_name
						print(
							f"val_Epoch[{epoch}/{args.epochs}]({val_idx + 1}/{len(val_loader)}) || "
							f"val_training loss : {loss_name} : {value:4.4} || lr {current_lr}"
							# f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
						)
						if logging:
							wandb.log({
								'val/'+loss_name : value,
								"val/lr" : current_lr,
								"val/Epoch" : epoch    
								})
						 
				if (val_idx + 1) % 300 == 0:
					# fig2 = viewer(outs,joint_labels,infos)
					val_fig_dict = {
					'depth_feature_fig' : depth_viewer(inputs['image'],pred_dict['depth_feature'],depth_labels,infos),
					# 'joint_2d_fig' : joint_2d_viewer(inputs['image'], pred_dict['regressor2_dict']['kp_2d'],joint_2d_labels,infos),
					'fisheye_2d_fig' : fisheye_joint_2d_viewer(inputs['image'],pred_dict['regressor2_dict']['fisheye_kp_2d'],fisheye_labels,infos),
					'heatmap_fig' : heatmap_viewer(inputs['image'],pred_dict['heatmap'],heatmap_labels,infos),
					'joint_3d_fig' : joint_3d_viewer(inputs['image'],pred_dict['regressor2_dict']['kp_3d_cam'],joint_3d_labels,infos),
					'heatmap_smpl_fig' : heatmap_viewer(inputs['image'],pred_dict['regressor2_dict']['pred_heatmap_smpl'],heatmap_proj_labels,infos),
					}
					if logging:
						wandb.log({
							"val/depth_feature_fig":[wandb.Image(val_fig_dict['depth_feature_fig'])],
							# "joint_2d_fig_img":[wandb.Image(joint_2d_fig)],
							"val/fisheye_2d_fig":[wandb.Image(fig_dict['fisheye_2d_fig'])],
							"val/heatmap_fig":[wandb.Image(val_fig_dict['heatmap_fig'])],
							"val/joint_3d_fig":[wandb.Image(val_fig_dict['joint_3d_fig'])],
							"val/heatmap_smpl_fig":[wandb.Image(fig_dict['heatmap_smpl_fig'])],
						})
					for key,val in val_fig_dict.items():
						plt.close(val)
					if best_val_loss>loss:
						torch.save(model.module.state_dict(), f"{save_dir}{split_token}best.pth")
						best_val_loss = loss
					torch.save(model.module.state_dict(), f"{save_dir}{split_token}last.pth")
	wandb.finish()

In [9]:
raise

RuntimeError: No active exception to reraise

In [10]:
data_dir = args.data_dir
model_dir = args.model_dir
print(data_dir)

train(data_dir, model_dir, args)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


/dataset/egodataset


[34m[1mwandb[0m: Currently logged in as: [33mvhehduatks[0m. Use [1m`wandb login --relogin`[0m to force relogin


/workspace/2d_to_3d/apps/exp370


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Epoch[0/5](5/12058) || training loss : depth_loss : 0.1737 || lr 0.00025
Epoch[0/5](5/12058) || training loss : heatmap_loss : 0.3192 || lr 0.00025
Epoch[0/5](5/12058) || training loss : cam_loss : 3.833 || lr 0.00025
Epoch[0/5](5/12058) || training loss : projection_2d_loss : 1.755 || lr 0.00025
Epoch[0/5](5/12058) || training loss : heatmap_projection_loss : 1.576 || lr 0.00025
Epoch[0/5](5/12058) || training loss : joint_3d_loss : 0.4404 || lr 0.00025
Epoch[0/5](10/12058) || training loss : depth_loss : 0.126 || lr 0.00025
Epoch[0/5](10/12058) || training loss : heatmap_loss : 0.3092 || lr 0.00025
Epoch[0/5](10/12058) || training loss : cam_loss : 4.942 || lr 0.00025
Epoch[0/5](10/12058) || training loss : projection_2d_loss : 1.987 || lr 0.00025
Epoch[0/5](10/12058) || training loss : heatmap_projection_loss : 1.475 || lr 0.00025
Epoch[0/5](10/12058) || training loss : joint_3d_loss : 0.5518 || lr 0.00025
Epoch[0/5](15/12058) || training loss : depth_loss : 0.1781 || lr 0.00025
Epo



val_Epoch[0/5](5/3014) || val_training loss : val_depth_loss : 0.03703 || lr 6.25e-05
val_Epoch[0/5](5/3014) || val_training loss : val_heatmap_loss : 0.1605 || lr 6.25e-05
val_Epoch[0/5](5/3014) || val_training loss : val_cam_loss : 3.214 || lr 6.25e-05
val_Epoch[0/5](5/3014) || val_training loss : val_projection_2d_loss : 1.77 || lr 6.25e-05
val_Epoch[0/5](5/3014) || val_training loss : val_heatmap_projection_loss : 1.03 || lr 6.25e-05
val_Epoch[0/5](5/3014) || val_training loss : val_joint_3d_loss : 0.4884 || lr 6.25e-05
val_Epoch[0/5](10/3014) || val_training loss : val_depth_loss : 0.03639 || lr 6.25e-05
val_Epoch[0/5](10/3014) || val_training loss : val_heatmap_loss : 0.1315 || lr 6.25e-05
val_Epoch[0/5](10/3014) || val_training loss : val_cam_loss : 3.52 || lr 6.25e-05
val_Epoch[0/5](10/3014) || val_training loss : val_projection_2d_loss : 1.69 || lr 6.25e-05
val_Epoch[0/5](10/3014) || val_training loss : val_heatmap_projection_loss : 1.054 || lr 6.25e-05
val_Epoch[0/5](10/3014)



Epoch[1/5](5/12058) || training loss : depth_loss : 0.03047 || lr 6.25e-05
Epoch[1/5](5/12058) || training loss : heatmap_loss : 0.07494 || lr 6.25e-05
Epoch[1/5](5/12058) || training loss : cam_loss : 3.679 || lr 6.25e-05
Epoch[1/5](5/12058) || training loss : projection_2d_loss : 1.47 || lr 6.25e-05
Epoch[1/5](5/12058) || training loss : heatmap_projection_loss : 1.294 || lr 6.25e-05
Epoch[1/5](5/12058) || training loss : joint_3d_loss : 0.3886 || lr 6.25e-05
Epoch[1/5](10/12058) || training loss : depth_loss : 0.04598 || lr 6.25e-05
Epoch[1/5](10/12058) || training loss : heatmap_loss : 0.1884 || lr 6.25e-05
Epoch[1/5](10/12058) || training loss : cam_loss : 5.89 || lr 6.25e-05
Epoch[1/5](10/12058) || training loss : projection_2d_loss : 2.206 || lr 6.25e-05
Epoch[1/5](10/12058) || training loss : heatmap_projection_loss : 1.765 || lr 6.25e-05
Epoch[1/5](10/12058) || training loss : joint_3d_loss : 1.017 || lr 6.25e-05
Epoch[1/5](15/12058) || training loss : depth_loss : 0.03066 || 



val_Epoch[1/5](5/3014) || val_training loss : val_depth_loss : 0.02572 || lr 1.5625e-05
val_Epoch[1/5](5/3014) || val_training loss : val_heatmap_loss : 0.08147 || lr 1.5625e-05
val_Epoch[1/5](5/3014) || val_training loss : val_cam_loss : 1.027 || lr 1.5625e-05
val_Epoch[1/5](5/3014) || val_training loss : val_projection_2d_loss : 1.672 || lr 1.5625e-05
val_Epoch[1/5](5/3014) || val_training loss : val_heatmap_projection_loss : 0.9446 || lr 1.5625e-05
val_Epoch[1/5](5/3014) || val_training loss : val_joint_3d_loss : 0.4805 || lr 1.5625e-05
val_Epoch[1/5](10/3014) || val_training loss : val_depth_loss : 0.02451 || lr 1.5625e-05
val_Epoch[1/5](10/3014) || val_training loss : val_heatmap_loss : 0.07661 || lr 1.5625e-05
val_Epoch[1/5](10/3014) || val_training loss : val_cam_loss : 2.234 || lr 1.5625e-05
val_Epoch[1/5](10/3014) || val_training loss : val_projection_2d_loss : 1.495 || lr 1.5625e-05
val_Epoch[1/5](10/3014) || val_training loss : val_heatmap_projection_loss : 1.053 || lr 1.562



Epoch[2/5](5/12058) || training loss : depth_loss : 0.03093 || lr 1.5625e-05
Epoch[2/5](5/12058) || training loss : heatmap_loss : 0.1182 || lr 1.5625e-05
Epoch[2/5](5/12058) || training loss : cam_loss : 1.825 || lr 1.5625e-05
Epoch[2/5](5/12058) || training loss : projection_2d_loss : 1.932 || lr 1.5625e-05
Epoch[2/5](5/12058) || training loss : heatmap_projection_loss : 1.49 || lr 1.5625e-05
Epoch[2/5](5/12058) || training loss : joint_3d_loss : 0.4632 || lr 1.5625e-05
Epoch[2/5](10/12058) || training loss : depth_loss : 0.02638 || lr 1.5625e-05
Epoch[2/5](10/12058) || training loss : heatmap_loss : 0.1179 || lr 1.5625e-05
Epoch[2/5](10/12058) || training loss : cam_loss : 2.988 || lr 1.5625e-05
Epoch[2/5](10/12058) || training loss : projection_2d_loss : 1.99 || lr 1.5625e-05
Epoch[2/5](10/12058) || training loss : heatmap_projection_loss :  1.7 || lr 1.5625e-05
Epoch[2/5](10/12058) || training loss : joint_3d_loss : 0.9221 || lr 1.5625e-05
Epoch[2/5](15/12058) || training loss : d



val_Epoch[2/5](5/3014) || val_training loss : val_depth_loss : 0.08098 || lr 3.90625e-06
val_Epoch[2/5](5/3014) || val_training loss : val_heatmap_loss : 0.0966 || lr 3.90625e-06
val_Epoch[2/5](5/3014) || val_training loss : val_cam_loss : 0.8614 || lr 3.90625e-06
val_Epoch[2/5](5/3014) || val_training loss : val_projection_2d_loss : 1.724 || lr 3.90625e-06
val_Epoch[2/5](5/3014) || val_training loss : val_heatmap_projection_loss : 0.9582 || lr 3.90625e-06
val_Epoch[2/5](5/3014) || val_training loss : val_joint_3d_loss : 0.662 || lr 3.90625e-06
val_Epoch[2/5](10/3014) || val_training loss : val_depth_loss : 0.06339 || lr 3.90625e-06
val_Epoch[2/5](10/3014) || val_training loss : val_heatmap_loss : 0.1237 || lr 3.90625e-06
val_Epoch[2/5](10/3014) || val_training loss : val_cam_loss : 2.101 || lr 3.90625e-06
val_Epoch[2/5](10/3014) || val_training loss : val_projection_2d_loss :  1.7 || lr 3.90625e-06
val_Epoch[2/5](10/3014) || val_training loss : val_heatmap_projection_loss : 1.079 || l



Epoch[3/5](5/12058) || training loss : depth_loss : 0.02108 || lr 3.90625e-06
Epoch[3/5](5/12058) || training loss : heatmap_loss : 0.05369 || lr 3.90625e-06
Epoch[3/5](5/12058) || training loss : cam_loss : 3.364 || lr 3.90625e-06
Epoch[3/5](5/12058) || training loss : projection_2d_loss : 1.629 || lr 3.90625e-06
Epoch[3/5](5/12058) || training loss : heatmap_projection_loss : 1.628 || lr 3.90625e-06
Epoch[3/5](5/12058) || training loss : joint_3d_loss : 0.6626 || lr 3.90625e-06
Epoch[3/5](10/12058) || training loss : depth_loss : 0.01633 || lr 3.90625e-06
Epoch[3/5](10/12058) || training loss : heatmap_loss : 0.07303 || lr 3.90625e-06
Epoch[3/5](10/12058) || training loss : cam_loss : 2.279 || lr 3.90625e-06
Epoch[3/5](10/12058) || training loss : projection_2d_loss : 1.733 || lr 3.90625e-06
Epoch[3/5](10/12058) || training loss : heatmap_projection_loss : 1.449 || lr 3.90625e-06
Epoch[3/5](10/12058) || training loss : joint_3d_loss : 0.588 || lr 3.90625e-06
Epoch[3/5](15/12058) || t



val_Epoch[3/5](5/3014) || val_training loss : val_depth_loss : 0.02253 || lr 9.765625e-07
val_Epoch[3/5](5/3014) || val_training loss : val_heatmap_loss : 0.05874 || lr 9.765625e-07
val_Epoch[3/5](5/3014) || val_training loss : val_cam_loss : 0.8734 || lr 9.765625e-07
val_Epoch[3/5](5/3014) || val_training loss : val_projection_2d_loss : 1.361 || lr 9.765625e-07
val_Epoch[3/5](5/3014) || val_training loss : val_heatmap_projection_loss : 0.8839 || lr 9.765625e-07
val_Epoch[3/5](5/3014) || val_training loss : val_joint_3d_loss : 0.5744 || lr 9.765625e-07
val_Epoch[3/5](10/3014) || val_training loss : val_depth_loss : 0.01863 || lr 9.765625e-07
val_Epoch[3/5](10/3014) || val_training loss : val_heatmap_loss : 0.06359 || lr 9.765625e-07
val_Epoch[3/5](10/3014) || val_training loss : val_cam_loss : 2.063 || lr 9.765625e-07
val_Epoch[3/5](10/3014) || val_training loss : val_projection_2d_loss : 1.507 || lr 9.765625e-07
val_Epoch[3/5](10/3014) || val_training loss : val_heatmap_projection_los



Epoch[4/5](5/12058) || training loss : depth_loss : 0.01943 || lr 9.765625e-07
Epoch[4/5](5/12058) || training loss : heatmap_loss : 0.1138 || lr 9.765625e-07
Epoch[4/5](5/12058) || training loss : cam_loss : 3.029 || lr 9.765625e-07
Epoch[4/5](5/12058) || training loss : projection_2d_loss : 2.034 || lr 9.765625e-07
Epoch[4/5](5/12058) || training loss : heatmap_projection_loss : 1.81 || lr 9.765625e-07
Epoch[4/5](5/12058) || training loss : joint_3d_loss : 0.625 || lr 9.765625e-07
Epoch[4/5](10/12058) || training loss : depth_loss : 0.02126 || lr 9.765625e-07
Epoch[4/5](10/12058) || training loss : heatmap_loss : 0.09119 || lr 9.765625e-07
Epoch[4/5](10/12058) || training loss : cam_loss : 1.781 || lr 9.765625e-07
Epoch[4/5](10/12058) || training loss : projection_2d_loss : 1.742 || lr 9.765625e-07
Epoch[4/5](10/12058) || training loss : heatmap_projection_loss : 1.383 || lr 9.765625e-07
Epoch[4/5](10/12058) || training loss : joint_3d_loss : 0.4403 || lr 9.765625e-07
Epoch[4/5](15/1

: 

: 