In [None]:
## DEX_YCB
import os
import os.path as osp
import numpy as np
import torch
import cv2
import random
import matplotlib.pyplot as plt
import re
import copy
import torchvision.transforms as transforms
from common.logger import colorlogger
from PIL import Image
from pycocotools.coco import COCO
from common.utils.preprocessing import load_img, get_bbox, process_bbox, generate_patch_image, augmentation
from main.config import cfg
from common.utils.transforms import compute_mpjpe, compute_pa_mpjpe
from common.utils.skeleton_map import skeleton_map_gray
from common.codecs.keypoint_eval import keypoint_pck_accuracy
from mmpose.utils.tensor_utils import to_numpy
from common.utils.vis import ux_hon_result, ux_hon_result_final

class DEX_YCB(torch.utils.data.Dataset):
    def __init__(self, transform, data_split, log_name='cfg_logs.txt'):
        self.transform = transform
        self.data_split = data_split if data_split == 'train' else 'test'
        self.root_dir = osp.join('data', 'DEX_YCB')
        self.annot_path = osp.join(self.root_dir, 'annotations')
        self.hand_type = {'left': 0, 'right': 0}
        self.datalist = self.load_data()
        self.root_joint_idx = 0
        if self.data_split != 'train':
            self.eval_result = [[],[],[],[]] #[mpjpe_list, pa-mpjpe_list]
        
        self.logger = colorlogger(cfg.log_dir, log_name=log_name)
        
        if self.data_split == 'train':
            for i in cfg.__dict__:
                self.logger.info('{0}: {1}'.format(i, cfg.__dict__[i]))
        
        message = []
        message.append(f"DataList len: {len(self.datalist)}")
        message.append('left hand data: {0}, right hand data: {1}'.format(self.hand_type['left'], self.hand_type['right']))
        
        if cfg.simcc and cfg.SET:
            message.append(f'Start the model {cfg.backbone} with SET and with simcc')
        elif cfg.simcc:
            message.append(f'Start the model {cfg.backbone} without SET and with simcc')
        elif cfg.SET:
            message.append(f'Start the model {cfg.backbone} with SET and with regressor')
        else:
            message.append(f'Start the model {cfg.backbone} without SET and with regressor')
        for msg in message:
            self.logger.info(msg)
            
    def load_data(self):
        db = COCO(osp.join(self.annot_path, "DEX_YCB_s0_{}_data.json".format(self.data_split)))
        
        datalist = []
        skip = 1

        if self.data_split == 'train':
            skip_mode = cfg.train_skip
            remainder = cfg.train_remainder
        elif self.data_split == 'test':
            skip_mode = cfg.test_skip
            remainder = cfg.test_remainder

        for aid in db.anns.keys():
            if skip % skip_mode == remainder:
                ann = db.anns[aid]
                image_id = ann['image_id']
                img = db.loadImgs(image_id)[0]
                if osp.exists(osp.join(self.root_dir, img['file_name'])):
                    img_path = osp.join(self.root_dir, img['file_name'])
                    img_shape = (img['height'], img['width'])
                    
                    joints_coord_img = np.array(ann['joints_img'], dtype=np.float32)
                    hand_type = ann['hand_type']

                    bbox = get_bbox(joints_coord_img[:,:2], np.ones_like(joints_coord_img[:,0]), expansion_factor=1.5)
                    bbox = process_bbox(bbox, img['width'], img['height'], expansion_factor=1.0)

                    data = {"img_path": img_path, "img_shape": img_shape, "joints_coord_img": joints_coord_img,
                            "bbox": bbox, "hand_type": hand_type}
                    
                    if all(val is not None for val in data.values()):
                        datalist.append(data)
                        if data['hand_type'] == 'left':
                            self.hand_type['left'] += 1
                        else:
                            self.hand_type['right'] += 1
            skip += 1
        return datalist
    
    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, idx):
        data = copy.deepcopy(self.datalist[idx])
        img_path, img_shape, bbox = data['img_path'], data['img_shape'], data['bbox']
        hand_type = data['hand_type']
        do_flip = False # (hand_type == 'left')

        # img
        img = load_img(img_path)
        orig_img = copy.deepcopy(img)[:,:,::-1]
        img, img2bb_trans, bb2img_trans, rot, scale = augmentation(img, bbox, self.data_split, do_flip=do_flip)
        # Convert numpy array to PIL Image
        # img = np.clip(img, 0, 255).astype(np.uint8)
        # img = Image.fromarray(img)
        save_path = cfg.vis_dir + '/' + 'image'
        save_path = save_path + '/' + str(idx) + '.jpg'
        img = self.transform(img.astype(np.float32))/255.

        if self.data_split == 'train':
            targets = {}
            ## 2D joint coordinate
            joints_img = data['joints_coord_img']
            # if do_flip:
            #     joints_img[:,0] = img_shape[1] - joints_img[:,0] - 1
            joints_img_xy1 = np.concatenate((joints_img[:,:2], np.ones_like(joints_img[:,:1])),1)
            joints_img = np.dot(img2bb_trans, joints_img_xy1.transpose(1,0)).transpose(1,0)[:,:2]
            if not cfg.simcc:
                joints_img_copy = joints_img.copy()
                ## normalize to [0,1]
                joints_img_copy[:,0] /= cfg.input_img_shape[0]
                joints_img_copy[:,1] /= cfg.input_img_shape[1]
                targets['joints_img'] = joints_img_copy
            else:
                targets['joints_img'] = joints_img
            
            skeleton_map = skeleton_map_gray((cfg.input_img_shape[0], cfg.input_img_shape[1]), joints_img)
            # cv2.imshow('test', skeleton_map)
            # cv2.waitKey(0)
            # cv2.destroyAllWindows()
            skeleton_map = self.transform(skeleton_map.astype(np.float32))/255.

            inputs = {'img': img}
            targets['skeleton_map'] = skeleton_map
        else:
            inputs = {'img': img}
            targets = {}

        return inputs, targets
    
    def evaluate(self, outs, cur_sample_idx):
        annots = self.datalist
        sample_num = len(outs)
        for n in range(sample_num):            
            annot = annots[cur_sample_idx + n]
            # cv2.namedWindow(annot['img_path'], 0)
            
            out = outs[n]
            
            # img convert
            img = load_img(annot['img_path'])
            orig_img = copy.deepcopy(img)
            img, img2bb_trans, bb2img_trans, rot, scale = augmentation(img, annot['bbox'], self.data_split, do_flip=False)
        
    #         # GT and out['keypoints]
            gt_joints_coord_img = annot['joints_coord_img']
            joints_img_xy1 = np.concatenate((gt_joints_coord_img[:,:2], np.ones_like(gt_joints_coord_img[:,:1])),1)
            joints_img = np.dot(img2bb_trans, joints_img_xy1.transpose(1,0)).transpose(1,0)[:,:2]
            
            if cfg.backbone == 'unext':
                gt_skeleton_map = skeleton_map_gray((cfg.input_img_shape[0], cfg.input_img_shape[1]), joints_img)
                gt_skeleton_map = gt_skeleton_map/255.
                
                pred_skeleton_map = (out['skeleton_map'].squeeze().cpu().numpy()).astype(float)# > 0.5
                
                ## show result
                cat_imgs = ux_hon_result(orig_img, img, pred_skeleton_map, gt_skeleton_map)
                cat_imgs = ux_hon_result_final(out, bb2img_trans, orig_img, img, cat_imgs)
                
                parts = re.split(r'[\\/]', annot['img_path'])
                path = osp.join(cfg.vis_dir,'_'.join(parts[1:]))
                cv2.imwrite(path, cat_imgs)
                # cv2.imshow(annot['img_path'], cat_imgs)
                # cv2.waitKey(0)
                # cv2.destroyAllWindows()

                pred_skeleton_map = (out['skeleton_map'].squeeze().cpu().numpy()>0.5).astype(float)# > 0.5
 
                num_correct = (pred_skeleton_map==gt_skeleton_map).sum()
                num_pixels = cfg.input_img_shape[0] * cfg.input_img_shape[1]
            else:
                img_uint8 = cv2.resize(orig_img.astype(np.uint8), (cfg.input_img_shape[0], cfg.input_img_shape[1]))
                rgb_img_uint8 = cv2.cvtColor(img_uint8.astype(np.uint8), cv2.COLOR_BGR2RGB)
                rgb_img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
                ori_imgs = np.hstack([rgb_img_uint8, rgb_img])
                cat_imgs = ux_hon_result_final(out, bb2img_trans, orig_img, img, ori_imgs)
                
                parts = re.split(r'[\\/]', annot['img_path'])
                path = osp.join(cfg.vis_dir,'_'.join(parts[1:]))
                cv2.imwrite(path, cat_imgs)
                # cv2.imshow(annot['img_path'], cat_imgs)
                # cv2.waitKey(0)
                # cv2.destroyAllWindows()
                

            if cfg.simcc:
                pred_keypoints = np.expand_dims(out['keypoints'], axis=0)
                pred_keypoint_scores = out['keypoint_scores']
                keypoints_restored = np.dot(bb2img_trans, np.concatenate((pred_keypoints[0], np.ones((pred_keypoints[0].shape[0], 1))), axis=1).transpose(1, 0))
                keypoints_restored = keypoints_restored[:2, :].transpose(1, 0)
            else:
                pred_keypoints = np.expand_dims(out['joints_coord_img'].cpu().numpy(), axis=0)
                pred_keypoints[:,:,0] *= cfg.input_img_shape[1]
                pred_keypoints[:,:,1] *= cfg.input_img_shape[0]
                keypoints_restored = np.dot(bb2img_trans, np.concatenate((pred_keypoints[0], np.ones((pred_keypoints[0].shape[0], 1))), axis=1).transpose(1, 0))
                keypoints_restored = keypoints_restored[:2, :].transpose(1, 0)
                pred_keypoint_scores = np.any(keypoints_restored, axis=1)

    #         # flip back to left hand
    #         if annot['hand_type'] == 'left':
    #             joints_out[:,0] *= -1 
            _, avg_acc, _ = keypoint_pck_accuracy(
                pred=np.expand_dims(keypoints_restored, axis=0),
                gt=np.expand_dims(gt_joints_coord_img[:,:2], axis=0),
                mask=np.expand_dims(pred_keypoint_scores, axis=0) > 0,
                thr=cfg.pck_thr,
                norm_factor=np.expand_dims(annot['img_shape'], axis=0),
            )

            self.eval_result[2].append(compute_mpjpe(keypoints_restored, gt_joints_coord_img[:,:2]))

            if cfg.backbone == 'unext':
                self.eval_result[0].append(num_correct / num_pixels)
                self.eval_result[1].append(avg_acc)
            else:
                self.eval_result[0].append(avg_acc)
                
    def print_eval_result(self, test_epoch):
        message = []
        if cfg.backbone == 'unext':
            message.append('Output: {0}, Model: snapshot_{1}.pth.tar'.format(cfg.output_dir.split('\\')[-1], test_epoch))
            message.append('Correct/Total(One Batch) pixels: {0:.2f}'.format(np.mean(self.eval_result[0]) * 100))
            message.append('PCK@{0}: {1:.2f}'.format(cfg.pck_thr, np.mean(self.eval_result[1]) * 100))
            message.append('MPJPE : %.2f' % np.mean(self.eval_result[2]))
        else:
            message.append('Output: {0}, Model: snapshot_{1}.pth.tar'.format(cfg.output_dir.split('\\')[-1], test_epoch))
            message.append('PCK@{0}: {1:.2f}'.format(cfg.pck_thr, np.mean(self.eval_result[0]) * 100))
            message.append('MPJPE : %.2f' % np.mean(self.eval_result[2]))
        return message

In [None]:
## HO3D 
import os
import os.path as osp
import numpy as np
import torch
import cv2
import random
import json
import re
import copy
from common.logger import colorlogger
from common.utils.skeleton_map import skeleton_map_gray
from pycocotools.coco import COCO
from main.config import cfg
from common.utils.preprocessing import load_img, get_bbox, process_bbox, generate_patch_image, augmentation
from common.utils.transforms import world2cam, cam2pixel, compute_mpjpe, compute_pa_mpjpe
from common.utils.vis import vis_keypoints, vis_mesh, save_obj, vis_keypoints_with_skeleton
from common.codecs.keypoint_eval import keypoint_pck_accuracy

class HO3D(torch.utils.data.Dataset):
    def __init__(self, transform, data_split, log_name='cfg_logs.txt'):
        self.transform = transform
        self.data_split = data_split if data_split == 'evaluation' else 'train'
        self.root_dir = osp.join('data', 'HO3D')
        self.annot_path = osp.join(self.root_dir, 'annotations')
        self.root_joint_idx = 0
        self.hand_type = {'left': 0, 'right': 0}
        self.datalist = self.load_data()
        if self.data_split == 'train':
            self.eval_result = [[],[],[],[]]
        
        self.logger = colorlogger(cfg.log_dir, log_name=log_name)
        
        for i in cfg.__dict__:
            self.logger.info('{0}: {1}'.format(i, cfg.__dict__[i]))
        message = []
        message.append(f"DataList len: {len(self.datalist)}")
        message.append('left hand data: {0}, right hand data: {1}'.format(self.hand_type['left'], self.hand_type['right']))
        
        if cfg.simcc and cfg.SET:
            message.append(f'Start the model {cfg.backbone} with SET and simcc')
        elif cfg.simcc:
            message.append(f'Start the model {cfg.backbone} without SET and with simcc')
        elif cfg.SET:
            message.append(f'Start the model {cfg.backbone} with SET and with regressor')
        else:
            message.append(f'Start the model {cfg.backbone} without SET and simcc')
        for msg in message:
            self.logger.info(msg)
            
    def load_data(self):
        db = COCO(osp.join(self.annot_path, "HO3D_{}_data.json".format(self.data_split)))
        # db = COCO(osp.join(self.annot_path, 'HO3Dv3_partial_test_multiseq_coco.json'))

        datalist = []
        skip = 1
        if self.data_split == 'train':
            skip_mode = cfg.train_skip
            remainder = cfg.train_remainder
        elif self.data_split == 'test':
            skip_mode = cfg.test_skip
            remainder = cfg.test_remainder

        for aid in db.anns.keys():
            if skip % skip_mode == remainder:
                ann = db.anns[aid]
                image_id = ann['image_id']
                img = db.loadImgs(image_id)[0]
                if osp.exists(osp.join(self.root_dir, self.data_split, img['file_name'])):
                    img_path = osp.join(self.root_dir, self.data_split, img['file_name'])
                    # TEMP
                    # img_path = osp.join(self.root_dir, 'train', img['sequence_name'], 'rgb', img['file_name'])

                    img_shape = (img['height'], img['width'])
                    joints_coord_cam = np.array(ann['joints_coord_cam'], dtype=np.float32) # meter
                    cam_param = {k:np.array(v, dtype=np.float32) for k,v in ann['cam_param'].items()}
                    joints_coord_img = cam2pixel(joints_coord_cam, cam_param['focal'], cam_param['princpt'])
                    bbox = get_bbox(joints_coord_img[:,:2], np.ones_like(joints_coord_img[:,0]), expansion_factor=1.5)
                    bbox = process_bbox(bbox, img['width'], img['height'], expansion_factor=1.0)
                    data = {"img_path": img_path, "img_shape": img_shape, "joints_coord_img": joints_coord_img,
                            "bbox": bbox,}
                        
                    if all(val is not None for val in data.values()):
                        datalist.append(data)
                        self.hand_type['right'] += 1
            skip += 1
        return datalist

    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, idx):
        data = copy.deepcopy(self.datalist[idx])
        img_path, img_shape, bbox = data['img_path'], data['img_shape'], data['bbox']

        # img
        img = load_img(img_path)
        img, img2bb_trans, self.bb2img_trans, rot, scale = augmentation(img, bbox, self.data_split, do_flip=False)
        img = self.transform(img.astype(np.float32))/255.

        if self.data_split == 'train':
            targets = {}
            ## 2D joint coordinate
            joints_img = data['joints_coord_img']
            joints_img_xy1 = np.concatenate((joints_img[:,:2], np.ones_like(joints_img[:,:1])),1)
            joints_img = np.dot(img2bb_trans, joints_img_xy1.transpose(1,0)).transpose(1,0)[:,:2]
            if not cfg.simcc:
                joints_img_copy = joints_img.copy()
                ## normalize to [0,1]
                joints_img_copy[:,0] /= cfg.input_img_shape[0]
                joints_img_copy[:,1] /= cfg.input_img_shape[1]
                targets['joints_img'] = joints_img_copy
            else:
                targets['joints_img'] = joints_img
            
            skeleton_map = skeleton_map_gray((cfg.input_img_shape[0], cfg.input_img_shape[1]), joints_img)
            skeleton_map = self.transform(skeleton_map.astype(np.float32))/255.

            inputs = {'img': img}
            targets['skeleton_map'] = skeleton_map
        else:
            inputs = {'img': img}
            targets = {}

        return inputs, targets
                  
    def evaluate(self, outs, cur_sample_idx):
        annots = self.datalist
        sample_num = len(outs)
        for n in range(sample_num):            
            annot = annots[cur_sample_idx + n]
            # cv2.namedWindow(annot['img_path'], 0)
            
            out = outs[n]
            
            # img convert
            img = load_img(annot['img_path'])
            orig_img = copy.deepcopy(img)
            img, img2bb_trans, bb2img_trans, rot, scale = augmentation(img, annot['bbox'], self.data_split, do_flip=False)
        
    #         # GT and out['keypoints]
            gt_joints_coord_img = annot['joints_coord_img']
            joints_img_xy1 = np.concatenate((gt_joints_coord_img[:,:2], np.ones_like(gt_joints_coord_img[:,:1])),1)
            joints_img = np.dot(img2bb_trans, joints_img_xy1.transpose(1,0)).transpose(1,0)[:,:2]
            
            if cfg.backbone == 'unext':
                gt_skeleton_map = skeleton_map_gray((cfg.input_img_shape[0], cfg.input_img_shape[1]), joints_img)
                gt_skeleton_map = gt_skeleton_map/255.
                
                pred_skeleton_map = (out['skeleton_map'].squeeze().cpu().numpy()).astype(float)# > 0.5
                
                ## show result
                cat_imgs = ux_hon_result(orig_img, img, pred_skeleton_map, gt_skeleton_map)
                cat_imgs = ux_hon_result_final(out, bb2img_trans, orig_img, img, cat_imgs)
                
                parts = re.split(r'[\\/]', annot['img_path'])
                path = osp.join(cfg.vis_dir,'_'.join(parts[1:]))
                cv2.imwrite(path, cat_imgs)
                # cv2.imshow(annot['img_path'], cat_imgs)
                # cv2.waitKey(0)
                # cv2.destroyAllWindows()

                pred_skeleton_map = (out['skeleton_map'].squeeze().cpu().numpy()>0.5).astype(float)# > 0.5
 
                num_correct = (pred_skeleton_map==gt_skeleton_map).sum()
                num_pixels = cfg.input_img_shape[0] * cfg.input_img_shape[1]
            else:
                img_uint8 = cv2.resize(orig_img.astype(np.uint8), (cfg.input_img_shape[0], cfg.input_img_shape[1]))
                rgb_img_uint8 = cv2.cvtColor(img_uint8.astype(np.uint8), cv2.COLOR_BGR2RGB)
                rgb_img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
                ori_imgs = np.hstack([rgb_img_uint8, rgb_img])
                cat_imgs = ux_hon_result_final(out, bb2img_trans, orig_img, img, ori_imgs)
                
                parts = re.split(r'[\\/]', annot['img_path'])
                path = osp.join(cfg.vis_dir,'_'.join(parts[1:]))
                cv2.imwrite(path, cat_imgs)
                # cv2.imshow(annot['img_path'], cat_imgs)
                # cv2.waitKey(0)
                # cv2.destroyAllWindows()

            if cfg.simcc:
                pred_keypoints = np.expand_dims(out['keypoints'], axis=0)
                pred_keypoint_scores = out['keypoint_scores']
                keypoints_restored = np.dot(bb2img_trans, np.concatenate((pred_keypoints[0], np.ones((pred_keypoints[0].shape[0], 1))), axis=1).transpose(1, 0))
                keypoints_restored = keypoints_restored[:2, :].transpose(1, 0)
            else:
                pred_keypoints = np.expand_dims(out['joints_coord_img'].cpu().numpy(), axis=0)
                pred_keypoints[:,:,0] *= cfg.input_img_shape[1]
                pred_keypoints[:,:,1] *= cfg.input_img_shape[0]
                keypoints_restored = np.dot(bb2img_trans, np.concatenate((pred_keypoints[0], np.ones((pred_keypoints[0].shape[0], 1))), axis=1).transpose(1, 0))
                keypoints_restored = keypoints_restored[:2, :].transpose(1, 0)
                pred_keypoint_scores = np.any(keypoints_restored, axis=1)

    #         # flip back to left hand
    #         if annot['hand_type'] == 'left':
    #             joints_out[:,0] *= -1 
            _, avg_acc, _ = keypoint_pck_accuracy(
                pred=np.expand_dims(keypoints_restored, axis=0),
                gt=np.expand_dims(gt_joints_coord_img[:,:2], axis=0),
                mask=np.expand_dims(pred_keypoint_scores, axis=0) > 0,
                thr=cfg.pck_thr,
                norm_factor=np.expand_dims(annot['img_shape'], axis=0),
            )

            self.eval_result[2].append(compute_mpjpe(keypoints_restored, gt_joints_coord_img[:,:2]))

            if cfg.backbone == 'unext':
                self.eval_result[0].append(num_correct / num_pixels)
                self.eval_result[1].append(avg_acc)
            else:
                self.eval_result[0].append(avg_acc)
                
    def print_eval_result(self, test_epoch):
        message = []
        if cfg.backbone == 'unext':
            message.append('Output: {0}, Model: snapshot_{1}.pth.tar'.format(cfg.output_dir.split('\\')[-1], test_epoch))
            message.append('Correct/Total(One Batch) pixels: {0:.2f}'.format(np.mean(self.eval_result[0]) * 100))
            message.append('PCK@{0}: {1:.2f}'.format(cfg.pck_thr, np.mean(self.eval_result[1]) * 100))
            message.append('MPJPE : %.2f' % np.mean(self.eval_result[2]))
        else:
            message.append('Output: {0}, Model: snapshot_{1}.pth.tar'.format(cfg.output_dir.split('\\')[-1], test_epoch))
            message.append('PCK@{0}: {1:.2f}'.format(cfg.pck_thr, np.mean(self.eval_result[0]) * 100))
            message.append('MPJPE : %.2f' % np.mean(self.eval_result[2]))
        return message

In [None]:
## Base
import os
import os.path as osp
import math
import time
import glob
import abc
from torch.utils.data import DataLoader
import torch.optim
import torchvision.transforms as transforms
from common.timer import Timer
from common.logger import colorlogger
from torch.nn.parallel.data_parallel import DataParallel
from main.config import cfg
# dynamic model import
if cfg.backbone == 'fpn':
    from main.model import get_model

# dynamic dataset import
# exec('from ' + cfg.trainset + ' import ' + cfg.trainset)
# exec('from ' + cfg.testset + ' import ' + cfg.testset)

class Base(object):
    __metaclass__ = abc.ABCMeta

    def __init__(self, log_name='logs.txt'):
        
        self.cur_epoch = 0

        # timer
        self.tot_timer = Timer()
        self.gpu_timer = Timer()
        self.read_timer = Timer()

        # logger
        self.logger = colorlogger(cfg.log_dir, log_name=log_name)

    @abc.abstractmethod
    def _make_batch_generator(self):
        return

    @abc.abstractmethod
    def _make_model(self):
        return

In [None]:
## HandOccNet Trainer
class Trainer(Base):
    def __init__(self):
        super(Trainer, self).__init__(log_name = 'train_logs.txt')

    def get_optimizer(self, model):
        model_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = torch.optim.Adam(model_params, lr=cfg.lr)
        return optimizer

    def save_model(self, state, epoch):
        file_path = osp.join(cfg.model_dir,'snapshot_{}.pth.tar'.format(str(epoch)))
        torch.save(state, file_path)
        self.logger.info("Write snapshot into {}".format(file_path))

    def load_model(self, model, optimizer):
        model_file_list = glob.glob(osp.join(cfg.model_dir,'*.pth.tar'))
        cur_epoch = max([int(file_name[file_name.find('snapshot_') + 9 : file_name.find('.pth.tar')]) for file_name in model_file_list])
        ckpt_path = osp.join(cfg.model_dir, 'snapshot_' + str(cur_epoch) + '.pth.tar')
        ckpt = torch.load(ckpt_path) 
        start_epoch = ckpt['epoch'] + 1
        model.load_state_dict(ckpt['network'], strict=False)
        #optimizer.load_state_dict(ckpt['optimizer'])

        self.logger.info('Load checkpoint from {}'.format(ckpt_path))
        return start_epoch, model, optimizer

    def set_lr(self, epoch):
        for e in cfg.lr_dec_epoch:
            if epoch < e:
                break
        if epoch < cfg.lr_dec_epoch[-1]:
            idx = cfg.lr_dec_epoch.index(e)
            for g in self.optimizer.param_groups:
                g['lr'] = cfg.lr * (cfg.lr_dec_factor ** idx)
        else:
            for g in self.optimizer.param_groups:
                g['lr'] = cfg.lr * (cfg.lr_dec_factor ** len(cfg.lr_dec_epoch))

    def get_lr(self):
        for g in self.optimizer.param_groups:
            cur_lr = g['lr']
        return cur_lr
    
    def _make_batch_generator(self):
        # data load and construct batch generator
        self.logger.info("Creating dataset...")
        # Augment train data
        # train_transforms = transforms.Compose([
        #     transforms.Resize((256, 192)),
        #     transforms.ToTensor()
        # ])
        train_dataset = eval(cfg.trainset)(transforms.ToTensor(), "train")
            
        self.itr_per_epoch = math.ceil(len(train_dataset) / cfg.num_gpus / cfg.train_batch_size)
        self.batch_generator = DataLoader(dataset=train_dataset, batch_size=cfg.num_gpus*cfg.train_batch_size, shuffle=True, num_workers=cfg.num_thread, pin_memory=True)

    def _make_model(self):
        # prepare network
        self.logger.info("Creating graph and optimizer...")
        if cfg.SET:
            self.logger.info("Creating model with SET...")
        else:
            self.logger.info("Creating model without SET...")
        model = get_model('train')

        model = DataParallel(model).cuda()
        optimizer = self.get_optimizer(model)
        if cfg.continue_train:
            start_epoch, model, optimizer = self.load_model(model, optimizer)
        else:
            start_epoch = 0
        model.train()

        self.start_epoch = start_epoch
        self.model = model
        self.optimizer = optimizer

In [None]:
## HandOccNet Tester
class Tester(Base):
    def __init__(self):
        super(Tester, self).__init__(log_name = 'test_logs.txt')

    def _make_batch_generator(self):
        # data load and construct batch generator
        self.logger.info("Creating dataset...")
        # Augment train data
        # test_transforms = transforms.Compose([
        #     transforms.Resize((256, 192)),
        #     transforms.ToTensor()
        # ])
        self.test_dataset = eval(cfg.testset)(transforms.ToTensor(), "test")
        self.batch_generator = DataLoader(dataset=self.test_dataset, batch_size=cfg.num_gpus*cfg.test_batch_size, shuffle=False, num_workers=cfg.num_thread, pin_memory=True)
       
    def _make_model(self, test_epoch):
        model_path = os.path.join(cfg.model_dir, 'snapshot_%d.pth.tar' % test_epoch)
        assert os.path.exists(model_path), 'Cannot find model at ' + model_path
        self.logger.info('Load checkpoint from {}'.format(model_path))
        
        # prepare network
        self.logger.info("Creating graph...")
        model = get_model('test')
        model = DataParallel(model).cuda()
        ckpt = torch.load(model_path)
        model.load_state_dict(ckpt['network'], strict=False)
        model.eval()

        self.model = model

    def _evaluate(self, outs, cur_sample_idx):
        eval_result = self.test_dataset.evaluate(outs, cur_sample_idx)
        return eval_result

    def _print_eval_result(self, test_epoch):
        message = self.test_dataset.print_eval_result(test_epoch)
        for msg in message:
            self.logger.info(msg)

In [None]:
## HandOccNet train + test process

import torch
import argparse
from tqdm import tqdm
import numpy as np
import torch.backends.cudnn as cudnn
from main.config import cfg

cfg.set_args('0', False)
cudnn.benchmark = True

trainer = Trainer()
trainer._make_batch_generator()
trainer._make_model()

tester = Tester()
tester._make_batch_generator()

# train
for epoch in range(trainer.start_epoch, cfg.end_epoch):
    
    trainer.set_lr(epoch)
    trainer.tot_timer.tic()
    trainer.read_timer.tic()
    for itr, (inputs, targets) in enumerate(trainer.batch_generator):
        trainer.read_timer.toc()
        trainer.gpu_timer.tic()

        # forward
        trainer.optimizer.zero_grad()
        if cfg.simcc:
            loss, acc = trainer.model(inputs, targets, 'train')
        else:
            loss = trainer.model(inputs, targets, 'train')

        loss = {k:loss[k].mean() for k in loss}

        # backward
        sum(loss[k] for k in loss).backward()
        trainer.optimizer.step()
        trainer.gpu_timer.toc()
        screen = [
            'Epoch %d/%d itr %d/%d:' % (epoch, cfg.end_epoch, itr, trainer.itr_per_epoch),
            'lr: %g' % (trainer.get_lr()),
            'speed: %.2f(gpu%.2fs r_data%.2fs)s/itr' % (
                trainer.tot_timer.average_time, trainer.gpu_timer.average_time, trainer.read_timer.average_time),
            '%.2fh/epoch' % (trainer.tot_timer.average_time / 3600. * trainer.itr_per_epoch),
            ]
        screen += ['%s: %.4f' % ('loss_' + k, v.detach()) for k,v in loss.items()]
        if cfg.backbone == 'crossatt' or cfg.simcc:
            screen += ['%s: %.4f' % ('acc_' + k, v.detach()) for k,v in acc.items()]
        trainer.logger.info(' '.join(screen))

        trainer.tot_timer.toc()
        trainer.tot_timer.tic()
        trainer.read_timer.tic()
    
    if (epoch+1)%cfg.ckpt_freq== 0 or epoch+1 == cfg.end_epoch:
        trainer.save_model({
            'epoch': epoch,
            'network': trainer.model.state_dict(),
            'optimizer': trainer.optimizer.state_dict(),
        }, epoch+1)

        tester._make_model(epoch+1)

        eval_result = {}
        cur_sample_idx = 0
        for itr, (inputs, targets) in enumerate(tqdm(tester.batch_generator)):
            
            # forward
            with torch.no_grad():
                out = tester.model(inputs, targets, 'test')
            
            # save output
            out = {k: v for k,v in out.items()}
            for k,v in out.items(): batch_size = out[k].shape[0]
            out = [{k: v[bid] for k,v in out.items()} for bid in range(batch_size)]

            # evaluate
            tester._evaluate(out, cur_sample_idx)
            cur_sample_idx += len(out)

        tester._print_eval_result(epoch)

In [None]:
## HandOccNet train process

import argparse
from main.config import cfg
import torch
import torch.backends.cudnn as cudnn

cfg.set_args('0', False)
cudnn.benchmark = True

trainer = Trainer()
trainer._make_batch_generator()
trainer._make_model()

# train
for epoch in range(trainer.start_epoch, cfg.end_epoch):
    
    trainer.set_lr(epoch)
    trainer.tot_timer.tic()
    trainer.read_timer.tic()
    for itr, (inputs, targets) in enumerate(trainer.batch_generator):
        trainer.read_timer.toc()
        trainer.gpu_timer.tic()

        # forward
        trainer.optimizer.zero_grad()
        if cfg.simcc:
            loss, acc = trainer.model(inputs, targets, 'train')
        else:
            loss = trainer.model(inputs, targets, 'train')

        loss = {k:loss[k].mean() for k in loss}

        # backward
        sum(loss[k] for k in loss).backward()
        trainer.optimizer.step()
        trainer.gpu_timer.toc()
        screen = [
            'Epoch %d/%d itr %d/%d:' % (epoch, cfg.end_epoch, itr, trainer.itr_per_epoch),
            'lr: %g' % (trainer.get_lr()),
            'speed: %.2f(gpu%.2fs r_data%.2fs)s/itr' % (
                trainer.tot_timer.average_time, trainer.gpu_timer.average_time, trainer.read_timer.average_time),
            '%.2fh/epoch' % (trainer.tot_timer.average_time / 3600. * trainer.itr_per_epoch),
            ]
        screen += ['%s: %.4f' % ('loss_' + k, v.detach()) for k,v in loss.items()]
        if cfg.simcc:
            screen += ['%s: %.4f' % ('acc_' + k, v.detach()) for k,v in acc.items()]
        trainer.logger.info(' '.join(screen))

        trainer.tot_timer.toc()
        trainer.tot_timer.tic()
        trainer.read_timer.tic()
    
    if (epoch+1)%cfg.ckpt_freq== 0 or epoch+1 == cfg.end_epoch:
        trainer.save_model({
            'epoch': epoch,
            'network': trainer.model.state_dict(),
            'optimizer': trainer.optimizer.state_dict(),
        }, epoch+1)

In [None]:
## HandOccNet test process

import torch
import argparse
from tqdm import tqdm
import numpy as np
import torch.backends.cudnn as cudnn
from main.config import cfg

test_epoch = [epoch for epoch in range(10, 71, 10)]
cfg.set_args('0', False)
cudnn.benchmark = True

tester = Tester()
tester._make_batch_generator()

for epoch in test_epoch:
    tester._make_model(epoch)

    eval_result = {}
    cur_sample_idx = 0
    for itr, (inputs, targets) in enumerate(tqdm(tester.batch_generator)):
        
        # forward
        with torch.no_grad():
            out = tester.model(inputs, targets, 'test')
        
        # save output
        out = {k: v for k,v in out.items()}
        for k,v in out.items(): batch_size = out[k].shape[0]
        out = [{k: v[bid] for k,v in out.items()} for bid in range(batch_size)]

        # evaluate
        tester._evaluate(out, cur_sample_idx)
        cur_sample_idx += len(out)

    tester._print_eval_result(epoch)

In [None]:
## UneXt + HandOccNet UH_Trainer
from main.model_UX import get_UX_model
from main.model_HON import get_HON_model

class UH_Trainer(Base):
    def __init__(self):
        super(UH_Trainer, self).__init__(log_name = 'train_logs.txt')

    def get_UX_optimizer(self, model):
        # model_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.ux_lr)
        return optimizer

    def get_HON_optimizer(self, model):
        # model_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.hon_lr)
        return optimizer

    def save_UX_model(self, state, epoch):
        file_path = osp.join(cfg.model_dir,'snapshot_UX_{}.pth.tar'.format(str(epoch)))
        torch.save(state, file_path)
        self.logger.info("Write UX snapshot into {}".format(file_path))

    def save_HON_model(self, state, epoch):
        file_path = osp.join(cfg.model_dir,'snapshot_HON_{}.pth.tar'.format(str(epoch)))
        torch.save(state, file_path)
        self.logger.info("Write HON snapshot into {}".format(file_path))

    def load_UX_model(self, model, optimizer):
        model_file_list = glob.glob(osp.join(cfg.model_dir,'*.pth.tar'))
        cur_epoch = max([int(file_name[file_name.find('snapshot_UX_') + 9 : file_name.find('.pth.tar')]) for file_name in model_file_list])
        ckpt_path = osp.join(cfg.model_dir, 'snapshot_UX_' + str(cur_epoch) + '.pth.tar')
        ckpt = torch.load(ckpt_path) 
        start_epoch = ckpt['epoch'] + 1
        model.load_state_dict(ckpt['network'], strict=False)
        #optimizer.load_state_dict(ckpt['optimizer'])

        self.logger.info('Load checkpoint from {}'.format(ckpt_path))
        return start_epoch, model, optimizer

    def load_HON_model(self, model, optimizer):
        model_file_list = glob.glob(osp.join(cfg.model_dir,'*.pth.tar'))
        cur_epoch = max([int(file_name[file_name.find('snapshot_HON_') + 9 : file_name.find('.pth.tar')]) for file_name in model_file_list])
        ckpt_path = osp.join(cfg.model_dir, 'snapshot_HON_' + str(cur_epoch) + '.pth.tar')
        ckpt = torch.load(ckpt_path) 
        start_epoch = ckpt['epoch'] + 1
        model.load_state_dict(ckpt['network'], strict=False)
        #optimizer.load_state_dict(ckpt['optimizer'])

        self.logger.info('Load checkpoint from {}'.format(ckpt_path))
        return start_epoch, model, optimizer
    
    def set_ux_lr(self, epoch):
        for e in cfg.lr_dec_epoch:
            if epoch < e:
                break
        if epoch < cfg.lr_dec_epoch[-1]:
            idx = cfg.lr_dec_epoch.index(e)
            for g in self.ux_optimizer.param_groups:
                g['lr'] = cfg.ux_lr * (cfg.lr_dec_factor ** idx)
        else:
            for g in self.ux_optimizer.param_groups:
                g['lr'] = cfg.ux_lr * (cfg.lr_dec_factor ** len(cfg.lr_dec_epoch))

    def set_hon_lr(self, epoch):
        for e in cfg.lr_dec_epoch:
            if epoch < e:
                break
        if epoch < cfg.lr_dec_epoch[-1]:
            idx = cfg.lr_dec_epoch.index(e)
            for g in self.hon_optimizer.param_groups:
                g['lr'] = cfg.hon_lr * (cfg.lr_dec_factor ** idx)
        else:
            for g in self.hon_optimizer.param_groups:
                g['lr'] = cfg.hon_lr * (cfg.lr_dec_factor ** len(cfg.lr_dec_epoch))

    def get_ux_lr(self):
        for g in self.ux_optimizer.param_groups:
            cur_lr = g['lr']
        return cur_lr
    
    def get_hon_lr(self):
        for g in self.hon_optimizer.param_groups:
            cur_lr = g['lr']
        return cur_lr
    
    def _make_batch_generator(self):
        # data load and construct batch generator
        self.logger.info("Creating dataset...")
        # Augment train data
        # train_transforms = transforms.Compose([
        #     transforms.Resize((256, 192)),
        #     transforms.ToTensor()
        # ])
        train_dataset = eval(cfg.trainset)(transforms.ToTensor(), "train")
            
        self.itr_per_epoch = math.ceil(len(train_dataset) / cfg.num_gpus / cfg.train_batch_size)
        self.batch_generator = DataLoader(dataset=train_dataset, batch_size=cfg.num_gpus*cfg.train_batch_size, shuffle=True, num_workers=cfg.num_thread, pin_memory=True)

    def _make_UX_model(self):
        # prepare network
        self.logger.info("Creating graph and optimizer...")
        model = get_UX_model('train')

        model = DataParallel(model).cuda()
        ux_optimizer = self.get_UX_optimizer(model)
        if cfg.continue_train:
            start_epoch, model, ux_optimizer = self.load_model(model, ux_optimizer)
        else:
            start_epoch = 0
        model.train()

        self.start_epoch = start_epoch
        self.ux_model = model
        self.ux_optimizer = ux_optimizer

    def _make_HON_model(self):
        # prepare network
        self.logger.info("Creating graph and optimizer...")
        if cfg.SET:
            self.logger.info("Creating model with SET...")
        else:
            self.logger.info("Creating model without SET...")
        model = get_HON_model('train')

        model = DataParallel(model).cuda()
        hon_optimizer = self.get_HON_optimizer(model)
        if cfg.continue_train:
            start_epoch, model, hon_optimizer = self.load_model(model, hon_optimizer)
        else:
            start_epoch = 0
        model.train()

        self.start_epoch = start_epoch
        self.hon_model = model
        self.hon_optimizer = hon_optimizer

In [None]:
## UneXt + HandOccNet UH_Tester
from main.model_UX import get_UX_model
from main.model_HON import get_HON_model

class UH_Tester(Base):
    def __init__(self):
        super(UH_Tester, self).__init__(log_name = 'test_logs.txt')

    def _make_batch_generator(self):
        # data load and construct batch generator
        self.logger.info("Creating dataset...")
        # Augment train data
        # test_transforms = transforms.Compose([
        #     transforms.Resize((256, 192)),
        #     transforms.ToTensor()
        # ])
        self.test_dataset = eval(cfg.testset)(transforms.ToTensor(), "test")
        self.batch_generator = DataLoader(dataset=self.test_dataset, batch_size=cfg.num_gpus*cfg.test_batch_size, shuffle=False, num_workers=cfg.num_thread, pin_memory=True)
       
    def _make_UX_model(self, test_epoch):
        model_path = os.path.join(cfg.model_dir, 'snapshot_UX_%d.pth.tar' % test_epoch)
        assert os.path.exists(model_path), 'Cannot find model at ' + model_path
        self.logger.info('Load checkpoint from {}'.format(model_path))
        
        # prepare network
        self.logger.info("Creating graph...")
        model = get_UX_model('test')
        model = DataParallel(model).cuda()
        ckpt = torch.load(model_path)
        model.load_state_dict(ckpt['network'], strict=False)
        model.eval()

        self.ux_model = model

    def _make_HON_model(self, test_epoch):
        model_path = os.path.join(cfg.model_dir, 'snapshot_HON_%d.pth.tar' % test_epoch)
        assert os.path.exists(model_path), 'Cannot find model at ' + model_path
        self.logger.info('Load checkpoint from {}'.format(model_path))
        
        # prepare network
        self.logger.info("Creating graph...")
        model = get_HON_model('test')
        model = DataParallel(model).cuda()
        ckpt = torch.load(model_path)
        model.load_state_dict(ckpt['network'], strict=False)
        model.eval()

        self.hon_model = model

    def _evaluate(self, outs, cur_sample_idx):
        eval_result = self.test_dataset.evaluate(outs, cur_sample_idx)
        return eval_result

    def _print_eval_result(self, test_epoch):
        message = self.test_dataset.print_eval_result(test_epoch)
        for msg in message:
            self.logger.info(msg)

In [None]:
## Unext + HandOccNet train + test process

import torch
import argparse
from tqdm import tqdm
import numpy as np
import torch.backends.cudnn as cudnn
from main.config import cfg

cfg.set_args('0', False)
cudnn.benchmark = True

trainer = UH_Trainer()
trainer._make_batch_generator()
trainer._make_UX_model()
trainer._make_HON_model()

tester = UH_Tester()
tester._make_batch_generator()

# train
for epoch in range(trainer.start_epoch, cfg.end_epoch):
    
    trainer.set_ux_lr(epoch)
    trainer.set_hon_lr(epoch)
    trainer.tot_timer.tic()
    trainer.read_timer.tic()
    for itr, (inputs, targets) in enumerate(trainer.batch_generator):
        trainer.read_timer.toc()
        trainer.gpu_timer.tic()

        ux_loss, ux_acc, ux_outs = trainer.ux_model(inputs, targets, 'train', itr)

        ux_loss_dic = {k:ux_loss[k].mean() for k in ux_loss}
        ux_loss = sum(ux_loss[k] for k in ux_loss_dic)

        # ux forward
        trainer.ux_optimizer.zero_grad()
        # ux backward
        ux_loss.backward(retain_graph=True)
        trainer.ux_optimizer.step()

        if cfg.simcc:
            hon_loss, hon_acc = trainer.hon_model(inputs, targets, ux_outs, 'train')
        else:
            hon_loss = trainer.hon_model(inputs, targets, 'train')

        hon_loss_dic = {k:hon_loss[k].mean() for k in hon_loss}
        hon_loss = sum(hon_loss[k] for k in hon_loss_dic)

        # hon forward
        trainer.hon_optimizer.zero_grad()
        # hon backward
        hon_loss.backward()
        trainer.hon_optimizer.step()

        trainer.gpu_timer.toc()
        screen = [
            'Epoch %d/%d itr %d/%d:' % (epoch, cfg.end_epoch, itr, trainer.itr_per_epoch),
            'ux_lr: %g' % (trainer.get_ux_lr()),
            'hon_lr: %g' % (trainer.get_hon_lr()),
            'speed: %.2f(gpu%.2fs r_data%.2fs)s/itr' % (
                trainer.tot_timer.average_time, trainer.gpu_timer.average_time, trainer.read_timer.average_time),
            '%.2fh/epoch' % (trainer.tot_timer.average_time / 3600. * trainer.itr_per_epoch),
            ]
        screen += ['%s: %.4f' % ('loss_' + k, v.detach()) for k,v in ux_loss_dic.items()]
        screen += ['%s: %.4f' % ('loss_' + k, v.detach()) for k,v in hon_loss_dic.items()]
        
        screen += ['%s: %.4f' % ('acc_' + k, v.detach()) for k,v in ux_acc.items()]
        if cfg.simcc:
            screen += ['%s: %.4f' % ('acc_' + k, v.detach()) for k,v in hon_acc.items()]
        trainer.logger.info(' '.join(screen))

        trainer.tot_timer.toc()
        trainer.tot_timer.tic()
        trainer.read_timer.tic()
    
    if (epoch+1)%cfg.ckpt_freq== 0 or epoch+1 == cfg.end_epoch:
        trainer.save_UX_model({
            'epoch': epoch,
            'network': trainer.ux_model.state_dict(),
            'optimizer': trainer.ux_optimizer.state_dict(),
        }, epoch+1)

        trainer.save_HON_model({
            'epoch': epoch,
            'network': trainer.hon_model.state_dict(),
            'optimizer': trainer.hon_optimizer.state_dict(),
        }, epoch+1)

        tester._make_UX_model(epoch+1)
        tester._make_HON_model(epoch+1)

        eval_result = {}
        cur_sample_idx = 0
        for itr, (inputs, targets) in enumerate(tqdm(tester.batch_generator)):
            
            # forward
            with torch.no_grad():
                ux_out = tester.ux_model(inputs, targets, 'test', itr)
                hon_out = tester.hon_model(inputs, targets, ux_out, 'test')
            
            # save output
            ux_out = {k: v for k,v in ux_out.items()}
            for k,v in ux_out.items(): batch_size = ux_out[k].shape[0]
            hon_out = {k: v for k,v in hon_out.items()}
            for k,v in hon_out.items(): batch_size = hon_out[k].shape[0]
            out = []
            for bid in range(batch_size):
                combined_dict = {}
                for k,v in ux_out.items():
                    combined_dict[k] = v[bid]
                for k,v in hon_out.items():
                    combined_dict[k] = v[bid]
                out.append(combined_dict)
                
            # evaluate
            tester._evaluate(out, cur_sample_idx)
            cur_sample_idx += len(out)

        tester._print_eval_result(epoch)

In [None]:
## Unext + HandOccNet train process

import torch
import argparse
from tqdm import tqdm
import numpy as np
import torch.backends.cudnn as cudnn
from main.config import cfg

cfg.set_args('0', False)
cudnn.benchmark = True

trainer = UH_Trainer()
trainer._make_batch_generator()
trainer._make_UX_model()
trainer._make_HON_model()

# train
for epoch in range(trainer.start_epoch, cfg.end_epoch):
    
    trainer.set_ux_lr(epoch)
    trainer.set_hon_lr(epoch)
    trainer.tot_timer.tic()
    trainer.read_timer.tic()
    for itr, (inputs, targets) in enumerate(trainer.batch_generator):
        trainer.read_timer.toc()
        trainer.gpu_timer.tic()

        ux_loss, ux_acc, ux_outs = trainer.ux_model(inputs, targets, 'train', itr)

        ux_loss_dic = {k:ux_loss[k].mean() for k in ux_loss}
        ux_loss = sum(ux_loss[k] for k in ux_loss_dic)

        # ux forward
        trainer.ux_optimizer.zero_grad()
        # ux backward
        ux_loss.backward(retain_graph=True)
        trainer.ux_optimizer.step()

        if cfg.simcc:
            hon_loss, hon_acc = trainer.hon_model(inputs, targets, ux_outs, 'train')
        else:
            hon_loss = trainer.hon_model(inputs, targets, 'train')

        hon_loss_dic = {k:hon_loss[k].mean() for k in hon_loss}
        hon_loss = sum(hon_loss[k] for k in hon_loss_dic)

        # hon forward
        trainer.hon_optimizer.zero_grad()
        # hon backward
        hon_loss.backward()
        trainer.hon_optimizer.step()

        trainer.gpu_timer.toc()
        screen = [
            'Epoch %d/%d itr %d/%d:' % (epoch, cfg.end_epoch, itr, trainer.itr_per_epoch),
            'ux_lr: %g' % (trainer.get_ux_lr()),
            'hon_lr: %g' % (trainer.get_hon_lr()),
            'speed: %.2f(gpu%.2fs r_data%.2fs)s/itr' % (
                trainer.tot_timer.average_time, trainer.gpu_timer.average_time, trainer.read_timer.average_time),
            '%.2fh/epoch' % (trainer.tot_timer.average_time / 3600. * trainer.itr_per_epoch),
            ]
        screen += ['%s: %.4f' % ('loss_' + k, v.detach()) for k,v in ux_loss_dic.items()]
        screen += ['%s: %.4f' % ('loss_' + k, v.detach()) for k,v in hon_loss_dic.items()]
        
        screen += ['%s: %.4f' % ('acc_' + k, v.detach()) for k,v in ux_acc.items()]
        if cfg.simcc:
            screen += ['%s: %.4f' % ('acc_' + k, v.detach()) for k,v in hon_acc.items()]
        trainer.logger.info(' '.join(screen))

        trainer.tot_timer.toc()
        trainer.tot_timer.tic()
        trainer.read_timer.tic()
    
    if (epoch+1)%cfg.ckpt_freq== 0 or epoch+1 == cfg.end_epoch:
        trainer.save_UX_model({
            'epoch': epoch,
            'network': trainer.ux_model.state_dict(),
            'optimizer': trainer.ux_optimizer.state_dict(),
        }, epoch+1)

        trainer.save_HON_model({
            'epoch': epoch,
            'network': trainer.hon_model.state_dict(),
            'optimizer': trainer.hon_optimizer.state_dict(),
        }, epoch+1)

In [None]:
## Unext + HandOccNet test process

import torch
import argparse
from tqdm import tqdm
import numpy as np
import torch.backends.cudnn as cudnn
from main.config import cfg

cfg.set_args('0', False)
cudnn.benchmark = True

tester = UH_Tester()
tester._make_batch_generator()

# train
for epoch in range(9, cfg.end_epoch, 10):
    tester._make_UX_model(epoch+1)
    tester._make_HON_model(epoch+1)

    eval_result = {}
    cur_sample_idx = 0
    for itr, (inputs, targets) in enumerate(tqdm(tester.batch_generator)):
        
        # forward
        with torch.no_grad():
            ux_out = tester.ux_model(inputs, targets, 'test', itr)
            hon_out = tester.hon_model(inputs, targets, ux_out, 'test')
        
        # save output
        ux_out = {k: v for k,v in ux_out.items()}
        for k,v in ux_out.items(): batch_size = ux_out[k].shape[0]
        hon_out = {k: v for k,v in hon_out.items()}
        for k,v in hon_out.items(): batch_size = hon_out[k].shape[0]
        out = []
        for bid in range(batch_size):
            combined_dict = {}
            for k,v in ux_out.items():
                combined_dict[k] = v[bid]
            for k,v in hon_out.items():
                combined_dict[k] = v[bid]
            out.append(combined_dict)
            
        # evaluate
        tester._evaluate(out, cur_sample_idx)
        cur_sample_idx += len(out)

    tester._print_eval_result(epoch)