# Create Dataloader

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# move to where the codebase is located
%cd /content/drive/MyDrive/Final_Project/code/Fusion_Network_For_Infant_Pose_Estimation

In [2]:

import torch
from torch.utils.data import Dataset,TensorDataset, DataLoader
import torchvision.transforms as transforms

import math
import sys

import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import utils.utils as ut
from utils.utils_ds import *

class SMaLDataset(Dataset):
    joint_num = 14

    joints_name = (
        "R_Ankle", "R_Knee", "R_Hip", "L_Hip", "L_Knee", "L_Ankle", "R_Wrist", "R_Elbow", "R_Shoulder", "L_Shoulder",
        "L_Elbow", "L_Wrist", "Thorax", "Head"
        )

    flip_pairs_name = (
        ('R_Hip', 'L_Hip'), ('R_Knee', 'L_Knee'), ('R_Ankle', 'L_Ankle'),
        ('R_Shoulder', 'L_Shoulder'), ('R_Elbow', 'L_Elbow'), ('R_Wrist', 'L_Wrist')
      )
    skels_name = (
        ('Thorax', 'Head'),
        ('Thorax', 'R_Shoulder'), ('R_Shoulder', 'R_Elbow'), ('R_Elbow', 'R_Wrist'),
        ('Thorax', 'L_Shoulder'), ('L_Shoulder', 'L_Elbow'), ('L_Elbow', 'L_Wrist'),
        ('R_Hip', 'R_Knee'), ('R_Knee', 'R_Ankle'),
        ('L_Hip', 'L_Knee'), ('L_Knee', 'L_Ankle'),
      )

    skels_idx = ut.nameToIdx(skels_name, joints_name=joints_name)
    flip_pairs = ut.nameToIdx(flip_pairs_name, joints_name)

    def __init__(self, img_dir,fold,modality='both',covering=(['uncovered','cover1','cover2']),phase="train",
                 clip=False,mean=None,std=None,aug_param=None,swap_channels=False):
        self.covering = covering
        self.num_cover = len(self.covering)
        self.modality = modality
        self.sz_pch = (256,256)
        self.out_shp = (64,64)
        self.aug_param = aug_param
        self.phase = phase
        self.img_dir = img_dir + f"{fold}/{phase}"
        self.swap_channels = swap_channels

        self.img_labels = np.load(f'{self.img_dir}/joints.npy')
        self.img_labels[:,:,0] = self.img_labels[:,:,0]*self.sz_pch[0]
        self.img_labels[:,:,1] = self.img_labels[:,:,1]*self.sz_pch[1]
        self.imgs = {}
        for cover in self.covering:
          self.imgs[cover] = np.load(f'{self.img_dir}/{cover}.npy')
          if clip:
            self.imgs[cover][:,:,:,0] = np.clip(self.imgs[cover][:,:,:,0],400,800)

        if mean is None or std is None:
          means = []
          stds = []
          for cover in self.covering:
            means.append(self.imgs[cover].mean(axis=(0,1,2)))
            stds.append(self.imgs[cover].var(axis=(0,1,2)))
          self.mean = np.vstack(means).mean(axis=0)
          self.std = np.sqrt(np.vstack(stds).mean(axis=0))
        else:
          self.mean = mean
          self.std = std

    def __len__(self):
        return self.img_labels.shape[0]*self.num_cover

    def __getitem__(self, idx,debug=False):
        i = idx//self.num_cover
        cover = self.covering[idx%self.num_cover]

        label = self.img_labels[i, :, :]
        img = self.imgs[cover][i,:,:,:]
        img = cv2.resize(img, dsize=self.sz_pch, interpolation=cv2.INTER_CUBIC)

        img_height, img_width, img_channel = img.shape
        bb = [0,0,img_width,img_height]  # full image bb , make square bb
        bb = ut.adj_bb(bb, rt_xy=1)

        if self.phase=='train':
            scale, rot, do_flip, color_scale, do_occlusion = get_aug_config_by_values(self.aug_param)
        else:
            scale, rot, do_flip, color_scale, do_occlusion = 1.0, 0.0, False, [1.0, 1.0, 1.0], False

        img_patch, trans = generate_patch_image(img, bb, do_flip, scale, rot, do_occlusion, input_shape=self.sz_pch)

        if img_patch.ndim<3:
          img_channels = 1        # add one channel
          img_patch = img_patch[..., None]
        else:
          if "RGB"in self.modality:                 #changed by us maximum fuse modality is 3
              img_channels=3
          else:
              img_channels = img_patch.shape[2]
        for i in range(img_channels):
          img_patch[:, :, i] = img_patch[:, :, i] * color_scale[i]
        jt_patch = label.copy()

        if do_flip:
          jt_patch[:, 0] = img_width - label[:, 0] - 1
          for pair in SMaLDataset.flip_pairs:
            jt_patch[pair[0], :], jt_patch[pair[1], :] = jt_patch[pair[1], :].copy(), jt_patch[pair[0], :].copy()

        for i in range(len(jt_patch)):  #  jt trans
          jt_patch[i, 0:2] = trans_point2d(jt_patch[i, 0:2], trans)

        stride = self.sz_pch[0]/self.out_shp[1]  # jt shrink
        joints_hm = jt_patch/stride

        jt_vis = np.ones(self.joint_num)
        for i in range(len(jt_patch)):        # only check 2d here
          jt_vis[i] *= (
              (jt_patch[i, 0] >= 0) & \
              (jt_patch[i, 0] < self.sz_pch[0]) & \
              (jt_patch[i, 1] >= 0) & \
              (jt_patch[i, 1] < self.sz_pch[1])
          )  # nice filtering  all in range visibile

        hms, jt_wt = generate_target(joints_hm, jt_vis, sigma=1, sz_hm=self.out_shp)

        idx_t, idx_h = ut.nameToIdx(('Thorax', 'Head'), self.joints_name)
        l_std_hm = np.linalg.norm(joints_hm[idx_h] - joints_hm[idx_t])
        l_std_ori = np.linalg.norm(label[idx_h] - label[idx_t])

        trans_tch = transforms.Compose([transforms.ToTensor(),
          transforms.Normalize(mean=self.mean, std=self.std)]
        )
        pch_tch = trans_tch(img_patch)

        hms_tch = torch.from_numpy(hms)
        img_mean = self.mean
        img_std = self.std

        if self.modality=='depth':
          pch_tch = pch_tch[[0],:,:]
          img_mean = self.mean[0]
          img_std = self.std[0]
        elif self.modality=='psm':
          pch_tch = pch_tch[[1],:,:]
          img_mean = self.mean[1]
          img_std = self.std[1]
        else:
          if self.swap_channels:
            pch_tch[[0],:,:],pch_tch[[1],:,:] = pch_tch[[1],:,:],pch_tch[[0],:,:]
            self.mean[0],self.mean[1] = self.mean[1],self.mean[0]
            self.std[0],self.std[1] = self.std[1],self.std[0]

        result = {
            'original_img':img,
            'pch':pch_tch,
            'hms': hms_tch,
            'joints_vis': jt_wt,
            'joints_pch': jt_patch.astype(np.float32),       
            'l_std_hm':l_std_hm.astype(np.float32),
            'l_std_ori':l_std_ori.astype(np.float32),
            'joints_ori': label[:,:2].astype(np.float32),
            'bb': bb.astype(np.float32),
            'mean' :img_mean,
            'std':img_std
        }

        return result

In [3]:
def get_aug_config_by_values(aug_param):
	scale_factor = aug_param['scale_factor']
	rot_factor = aug_param['rot_factor']
	color_factor = aug_param['color_factor']
	do_occlusion = aug_param["do_occlusion"]

	scale = np.clip(np.random.randn(), -1.0, 1.0) * scale_factor + 1.0
	rot = np.clip(np.random.randn(), -2.0,
	              2.0) * rot_factor if random.random() <= 0.6 else 0        # -60 to 60
	do_flip = random.random() <= 0.5
	c_up = 1.0 + color_factor
	c_low = 1.0 - color_factor
	color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]

	return scale, rot, do_flip, color_scale, do_occlusion

#Train Fusion Network

In [None]:
!pip install configargparse dominate yacs

In [5]:
def plotImage(ax, image):
    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).cpu().numpy()
    image = np.array(image)
    # print(image.shape)
    ax.imshow(image)

def plotImage(ax, image,c):
    if torch.is_tensor(image):
        image = image.permute(1, 2, 0)[:,:,c].cpu().numpy()
    image = np.array(image)
    ax.imshow(image)

def plot2DJoints(ax, joints2D, connectedJoints, jointColours, visJoints=None):
    # Plot skeleton
    for i in np.arange(len(connectedJoints)):
        joint1 = connectedJoints[i, 0]
        joint2 = connectedJoints[i, 1]
        if visJoints is None or (visJoints[joint1] == 1 and visJoints[joint2] == 1):
            x, y = [
                np.array(
                    [
                        joints2D[connectedJoints[i, 0], j],
                        joints2D[connectedJoints[i, 1], j],
                    ]
                )
                for j in range(2)
            ]
            ax.plot(x, y, lw=2, c=jointColours[i])

    # Plot joint coordiantes
    for i in range(len(joints2D)):
        scatterColour = "black" if visJoints is None or visJoints[i] == 1 else "orange"
        ax.scatter(joints2D[i, 0], joints2D[i, 1], c=scatterColour)
        # ax.text(
        #     joints2D[i, 0], joints2D[i, 1]-5, i)

def plot2DJointsPredAndTruth(ax, joints2DPrd, joints2DTruth, connectedJoints, visJoints=None):
    joints2Dlist = [joints2DTruth, joints2DPrd]
    jointColourslist = [SMaL_configs["gtjointColours"],SMaL_configs["predjointColours"]]
    scatterColourList = ["yellow","blue"]
    markerList = ["o","d"]
    # Plot skeleton
    for m in range(2):
      joints2D = joints2Dlist[m]
      scatterColour = scatterColourList[m]
      marker = markerList[m]
      # Plot joint coordiantes
      for i in range(len(joints2D)):
          scatterColour = scatterColour if visJoints is None or visJoints[i] == 1 else "orange"
          ax.scatter(joints2D[i, 0], joints2D[i, 1], c=scatterColour, marker=marker, alpha=0.5)
          # ax.text(
          #     joints2D[i, 0], joints2D[i, 1]-5, i)
    for m in range(2):
      joints2D = joints2Dlist[m]
      jointColours = jointColourslist[m]
      for i in np.arange(len(connectedJoints)):
          joint1 = connectedJoints[i, 0]
          joint2 = connectedJoints[i, 1]
          if visJoints is None or (visJoints[joint1] == 1 and visJoints[joint2] == 1):
              x, y = [
                  np.array(
                      [
                          joints2D[connectedJoints[i, 0], j],
                          joints2D[connectedJoints[i, 1], j],
                      ]
                  )
                  for j in range(2)
              ]
              ax.plot(x, y, lw=2, c=jointColours[i],  alpha=0.5)


def save_pred(
    input_img, pred, truth, i, sv_dir, vis_joints
):
    fname = sv_dir+"/"+str(i) + '.jpg'
    numJoints = SMaL_configs["numJoints"]
    connectedJoints = SMaL_configs["connectedJoints"]
    jointColours = SMaL_configs["jointColours"]
    numRows = 1
    ax = plt.subplot(numRows, 1, 1)
    ax.set_title("Output 2D")
    plotImage(ax, input_img, 0)
    plot2DJointsPredAndTruth(
      ax,
      pred,
      truth,
      connectedJoints,
      vis_joints
    )
    plt.savefig(fname)
    plt.close()

In [6]:
'''
2d pose estimation handling
'''

import utils.vis as vis
import utils.utils as ut
import numpy as np
import cv2
import torch
import json
from os import path
import os
from utils.logger import Colorlogger
from utils.utils_tch import get_model_summary
from core.loss import JointsMSELoss
from torch.utils.data import DataLoader
from torch.optim import Adam
import time
from utils.utils_ds import accuracy, flip_back
from utils.visualizer import Visualizer


def train(loader, model, criterion, optimizer, epoch, n_iter=-1, logger=None, opts=None, visualizer=None):
  '''
  iter through epoch , return rst{'acc', loss'} each as list can be used outside for updating.
  :param loader:
  :param model:
  :param criterion:
  :param optimizer:
  :param epoch:  for print infor
  :param n_iter: the iteration wanted, -1 for all iters
  :param opts: keep some additional controls
  :param visualizer: for visualizer
  :return:
  '''

  batch_time = ut.AverageMeter()
  data_time = ut.AverageMeter()
  losses = ut.AverageMeter()
  acc = ut.AverageMeter()

  # switch to train mode
  model.train()
  end = time.time()
  li_loss = []
  li_acc = []
  for i, inp_dct in enumerate(loader):
    # get items
    if i>=n_iter and n_iter>0:    # break if iter is set and i is greater than that
      break
    input = inp_dct['pch']
    target = inp_dct['hms']     # 14 x 64 x 1??
    target_weight = inp_dct['joints_vis']

    # measure data loading time     weight, visible or not
    data_time.update(time.time() - end)

    # compute output
    outputs = model(input)      # no need to cuda it?
    outputs=outputs["output"]
    target = target.cuda(non_blocking=True)                                  ######WITH GPU
    target_weight = target_weight.cuda(non_blocking=True)                    ######WITH GPU

    if isinstance(outputs, list):       # list multiple stage version
      loss = criterion(outputs[0], target, target_weight)
      for output in outputs[1:]:
        loss += criterion(output, target, target_weight)
    else:
      output = outputs
      loss = criterion(output, target, target_weight)


    # compute gradient and do update step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # measure accuracy and record loss
    losses.update(loss.item(), input.size(0))
    _, avg_acc, cnt, pred = accuracy(output.detach().cpu().numpy(),
                                     target.detach().cpu().numpy())  # hm directly, with normalize with 1/10 dim,  pck0.5,  cnt: n_smp,  pred
    acc.update(avg_acc, cnt)  # keep average acc

    if visualizer and 0 == i % opts.update_html_freq:     # update current result, get vis dict
      n_jt =  SMaL_configs["numJoints"]
      mod0 = opts.mod_src[0]
      mean = inp_dct['mean']
      std = inp_dct['std']
      img_patch_vis = ut.ts2cv2(input[0], mean, std)  # to CV BGR, mean std control channel detach inside
      # pseudo change
      cm = getattr(cv2,SMaL_configs["dct_clrMap"][mod0])
      img_patch_vis = cv2.applyColorMap(img_patch_vis, cm)[...,::-1]  # RGB

      # get pred
      pred2d_patch = np.ones((n_jt, 3))  # 3rd for  vis
      pred2d_patch[:, :2] = pred[0] / opts.out_shp[0] * opts.sz_pch[1]
      img_skel = vis.vis_keypoints(img_patch_vis, pred2d_patch,  SMaL_configs["connectedJoints"])

      hm_gt = target[0].cpu().detach().numpy().sum(axis=0)    # HXW
      hm_gt = ut.normImg(hm_gt)

      hm_pred = output[0].detach().cpu().numpy().sum(axis=0)
      hm_pred = ut.normImg(hm_pred)
      img_cb = vis.hconcat_resize([img_skel, hm_gt, hm_pred])
      vis_dict = {'img_cb': img_cb}
      visualizer.display_current_results(vis_dict, epoch, False)

    # measure elapsed time
    batch_time.update(time.time() - end)
    end = time.time()

    if i % opts.print_freq == 0:
      msg = 'Epoch: [{0}][{1}/{2}]\t' \
            'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
            'Speed {speed:.1f} samples/s\t' \
            'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
            'Loss {loss.val:.5f} ({loss.avg:.5f})\t' \
            'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
        epoch, i, len(loader), batch_time=batch_time,
        speed=input.size(0) / batch_time.val,
        data_time=data_time, loss=losses, acc=acc)
      logger.info(msg)
      li_loss.append(losses.val)   # the current loss
      li_acc.append(acc.val)

  return {'losses':li_loss, 'accs':li_acc}

In [7]:
def validate(loader, model, criterion, fold, n_iter=-1, logger=None, opts=None, if_svVis=False, visualizer=None):
  '''
  loop through loder, all res, get preds and gts and normled dist.
  With flip test for higher acc.
  for preds, bbs, jts_ori, jts_weigth out, recover preds_ori, dists_nmd, pckh( dist and joints_vis filter, , print, if_sv then save all these
  :param loader:
  :param ds_rd: the reader, givens the length and flip pairs
  :param model:
  :param criterion:
  :param optimizer:
  :param epoch:
  :param n_iter:
  :param logger:
  :param opts:
  :return:
  '''
  batch_time = ut.AverageMeter()
  losses = ut.AverageMeter()
  acc = ut.AverageMeter()

  # switch to evaluate mode
  model.eval()

  # num_samples = ds_rd.n_smpl
  n_jt =  SMaL_configs["numJoints"]


  # to accum rst
  preds_hm = []
  bbs = []
  li_joints_ori = []
  li_joints_vis = []
  li_l_std_ori = []
  with torch.no_grad():
    end = time.time()
    for i, inp_dct in enumerate(loader):
      # compute output
      input = inp_dct['pch']
      target = inp_dct['hms']
      target_weight = inp_dct['joints_vis']
      gt = inp_dct["joints_pch"]
      bb = inp_dct['bb']
      joints_ori = inp_dct['joints_ori']
      l_std_ori = inp_dct['l_std_ori']
      mean = inp_dct["mean"]
      std = inp_dct["std"]
      if i>= n_iter and n_iter>0:     # limiting iters
        break
      outputs = model(input)
      outputs =outputs["output"]
      if isinstance(outputs, list):
        output = outputs[-1]
      else:
        output = outputs
      output_ori = output.clone()     # original output of original image

      target = target.cuda(non_blocking=True)                                         ####withGPU
      target_weight = target_weight.cuda(non_blocking=True)                           ####withGPU
      loss = criterion(output, target, target_weight)

      num_images = input.size(0)
      # measure accuracy and record loss
      losses.update(loss.item(), num_images)
      _, avg_acc, cnt, pred_hm = accuracy(output.cpu().numpy(),
                                       target.cpu().numpy())
      acc.update(avg_acc, cnt)

      # preds can be furhter refined with subpixel trick, but it is already good enough.
      # measure elapsed time
      batch_time.update(time.time() - end)
      end = time.time()

      # keep rst
      preds_hm.append(pred_hm)        # already numpy, 2D
      bbs.append(bb.numpy())
      li_joints_ori.append(joints_ori.numpy())
      li_joints_vis.append(target_weight.cpu().numpy())
      li_l_std_ori.append(l_std_ori.numpy())

      if if_svVis:
        for j in range(num_images):
          sv_dir = opts.vis_test_dir  # exp/vis/Human36M
          idx_test = f'{fold}_{i}_{j}'  # image index
          pred2d_patch = np.ones((n_jt, 3))  # 3rd for  vis
          pred2d_patch[:,:2] = pred_hm[j] / opts.out_shp[0] * opts.sz_pch[1]      # only first
          save_pred(input[j], pred2d_patch, gt[j], idx_test, sv_dir, target_weight[j])

      if i % opts.print_freq == 0:
        msg = 'Test: [{0}/{1}]\t' \
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
              'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
          i, len(loader), batch_time=batch_time,
          loss=losses, acc=acc)
        logger.info(msg)

  preds_hm = np.concatenate(preds_hm,axis=0)      # N x n_jt  x 2
  bbs = np.concatenate(bbs, axis=0)
  joints_ori = np.concatenate(li_joints_ori, axis=0)
  joints_vis = np.concatenate(li_joints_vis, axis=0)
  l_std_ori_all = np.concatenate(li_l_std_ori, axis=0)

  preds_ori = ut.warp_coord_to_original(preds_hm, bbs, sz_out=opts.out_shp)
  err_nmd = ut.distNorm(preds_ori,  joints_ori, l_std_ori_all)
  ticks = np.linspace(0,1,11)   # 11 ticks
  pck_all_1 = ut.pck(err_nmd, joints_vis, ticks=ticks)
  ticks = np.linspace(0,0.5,11)   # 11 ticks
  pck_all_05 = ut.pck(err_nmd, joints_vis, ticks=ticks)

  # save to plain format for easy processing
  rst = {
    'preds_ori':preds_ori.tolist(),
    'joints_ori':joints_ori.tolist(),
    'l_std_ori_all': l_std_ori_all.tolist(),
    'err_nmd': err_nmd.tolist(),
    'pck1': pck_all_1.tolist(),
    'pck05': pck_all_05.tolist()
  }

  return rst

In [8]:
import logging
import os

OK = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
END = '\033[0m'

PINK = '\033[95m'
BLUE = '\033[94m'
GREEN = OK
RED = FAIL
WHITE = END
YELLOW = WARNING

class ColorloggerLocal():
    def __init__(self, log_dir, log_name='train_logs.txt'):
        # one _logger, add one file logger one stream logger
        self._logger = logging.getLogger(log_name)
        self._logger.setLevel(logging.INFO)
        log_file = os.path.join(log_dir, log_name)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        file_log = logging.FileHandler(log_file, mode='a')
        file_log.setLevel(logging.INFO)
        console_log = logging.StreamHandler()
        console_log.setLevel(logging.INFO)
        formatter = logging.Formatter(
            "{}%(asctime)s{} %(message)s".format(GREEN, END),
            "%m-%d %H:%M:%S")
        file_log.setFormatter(formatter)
        console_log.setFormatter(formatter)
        if not self._logger.hasHandlers():
          self._logger.addHandler(file_log)
          self._logger.addHandler(console_log)

    def debug(self, msg):
        self._logger.debug(str(msg))

    def info(self, msg):
        self._logger.info(str(msg))

    def warning(self, msg):
        self._logger.warning(WARNING + 'WRN: ' + str(msg) + END)

    def critical(self, msg):
        self._logger.critical(RED + 'CRI: ' + str(msg) + END)

    def error(self, msg):
        self._logger.error(RED + 'ERR: ' + str(msg) + END)

In [9]:
SMaL_configs = {
    "numJoints": 14,
    "connectedJoints": np.array(
        [
            [0, 1],
            [1, 2],
            [3, 4],
            [4, 5],
            [6, 7],
            [7, 8],
            [9, 10],
            [10, 11],
            [8, 12],
            [9, 12],
            [12, 13]
        ]
    ),
    "headIndex": 13,
    "jointNames": (
        "R_Ankle",
        "R_Knee",
        "R_Hip",
        "L_Hip",
        "L_Knee",
        "L_Ankle",
        "R_Wrist",
        "R_Elbow",
        "R_Shoulder",
        "L_Shoulder",
		    "L_Elbow",
        "L_Wrist",
        "Thorax",
        "Head"),
    "jointColours": [
        "blue",
        "blue",
        "blue",
        "red",
        "red",
        "red",
        "green",
        "green",
        "green",
        "blue",
        "blue",
        "blue",
        "red",
        "red",
    ],
    "gtjointColours": [
        "yellow"
    ]*14,
    "predjointColours": ["blue"]*14,
    "dct_clrMap": {      # the name of cv2 color map
		"depth":'COLORMAP_BONE',
		'RGB':'COLORMAP_BONE'
	  },
    "flip_pairs": (
		('R_Hip', 'L_Hip'), ('R_Knee', 'L_Knee'), ('R_Ankle', 'L_Ankle'),
		('R_Shoulder', 'L_Shoulder'), ('R_Elbow', 'L_Elbow'), ('R_Wrist', 'L_Wrist')
	  	),
      "skels_name" : (
		# ('Pelvis', 'Thorax'),
		('Thorax', 'Head'),
		('Thorax', 'R_Shoulder'), ('R_Shoulder', 'R_Elbow'), ('R_Elbow', 'R_Wrist'),
		('Thorax', 'L_Shoulder'), ('L_Shoulder', 'L_Elbow'), ('L_Elbow', 'L_Wrist'),
		# ('Pelvis', 'R_Hip'),
		('R_Hip', 'R_Knee'), ('R_Knee', 'R_Ankle'),
		# ('Pelvis', 'L_Hip'),
		('L_Hip', 'L_Knee'), ('L_Knee', 'L_Ankle'),
	)
}

setup options

In [None]:
import opt

opts = opt.parseArgs()

opts.ds_fd = '/content/drive/MyDrive/Final_Project/dataset/SMaL-224/' # give your dataset folder here
opts.sz_pch=[256, 256]
opts.fc_depth = 50
opts.cov_li = ['cover2','cover1', 'uncover']        # give the cover class you want here
opts.prep = 'jt_hm'
opts.n_thread = 5
opts.if_pinMem = False
opts.test_par = 'test'
opts.mod_src = ['depth', "psm"]
opts.out_shp = (64, 64, -1)
opts.if_bb = True
opts.gpu_ids = [0]
opts.suffix_exp_train = 'suffix_for_current_training_execution'
opts.if_test = False
opts.exp_dir = '/content/drive/MyDrive/Final_Project/code/Fusion_Network_For_Infant_Pose_Estimation/output/' + opts.suffix_exp_train  # model output folder
opts.log_dir = opts.exp_dir + "/log"
opts.model_dir = opts.exp_dir + "/model_dump"
opts.vis_dir = opts.exp_dir + "/vis/test"
opts.rst_dir = opts.exp_dir + "/result"
opts.vis_test_dir = opts.exp_dir + "/vis/withgt"
opts.web_dir = opts.exp_dir + "/web"
opts.input_nc = 2
opts.bestpath_file = 'path_to_best_model_when_fin_tuning'  # fusion model trained with SLP dataset (weights to initialise the model for fine tunning)
opts.print_freq = 5
opts.end_epoch = 20
opts.nmTest = "test"
opts.model = "HRposeFuseNetNewUnweighted_v2"
opts.modelConf = "config/HRposeFuseNetNewUnweighted_v2"
opts.fuse_stage = 2         #2 or 3
opts.fuse_type = "iAFF"         #add,concat,iAFF
opts.swap_channels = False
opts.fine_tune = False           #True if runing fine tuning
opts.aug_param = {'rot_factor': 2, 'scale_factor': 0, 'do_occlusion': False,'color_factor':0}

In [None]:
exec('from model.{} import get_pose_net'.format(opts.model))


def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

def main():

  # get logger
  if_test = opts.if_test
  if if_test:
    log_suffix = 'test'
  else:
    log_suffix = 'train'
  logger = ColorloggerLocal(opts.log_dir, '{}_logs.txt'.format(log_suffix))    # avoid overwritting, will append
  logger.propagate = False
  opt.set_env(opts)
  opt.print_options(opts, if_sv=True)
  n_jt = SMaL_configs["numJoints"]    #

  # get model
  model =  get_pose_net(in_ch=opts.input_nc, out_ch=n_jt, fuse_stage=opts.fuse_stage, fuse_type=opts.fuse_type, mod_src=opts.mod_src)

  # define loss function (criterion) and optimizer
  criterion = JointsMSELoss(      # try to not use weights                                                                    ######WITH GPU
    use_target_weight=True
  ).cuda()
  # criterion = JointsMSELoss(      # try to not use weights
  #   use_target_weight=True
  # )


  # for visualzier
  if opts.display_id > 0:
    visualizer = Visualizer(opts)  # only plot losses here, a loss log comes with it,
  else:
    visualizer = None
  # get optmizer
  best_perf = 0.0
  last_epoch = -1
  optimizer = Adam(model.parameters(), lr=opts.lr)
  checkpoint_file = os.path.join(
    opts.model_dir, 'checkpoint.pth')
  last_fold = 0
  if 0 == opts.start_epoch or not path.exists(checkpoint_file):  #    from scratch
    begin_epoch =  0     # either set or not exist all the same from 0
    if opts.fine_tune:
        checkpoint_file = os.path.join(opts.bestpath_file,'model_best.pth')
        checkpoint = torch.load(checkpoint_file)                                                                              ###########WITH GPU
        # checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['state_dict'])  # here should be cuda setting
    losses = []     # for tracking model performance.
    accs= []
  else:  # get chk points
    logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
    checkpoint = torch.load(checkpoint_file)
    begin_epoch = checkpoint['epoch']
    best_perf = checkpoint['perf']
    last_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])  # here should be cuda setting
    losses = checkpoint['losses']
    accs = checkpoint['accs']
    last_fold = checkpoint['last_fold']

    optimizer.load_state_dict(checkpoint['optimizer'])
    optimizer_to(optimizer, 'cuda')
    logger.info("=> loaded checkpoint '{}' (epoch {})".format(
      checkpoint_file, checkpoint['epoch']))

  milestones = opts.lr_dec_epoch
  lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones, opts.lr_dec_factor,
    last_epoch=last_epoch
  )  # scheduler will be set to place given last from checkpoints
  if opts.epoch_step > 0:
    end_epoch = min(opts.end_epoch, opts.start_epoch + opts.epoch_step)
  else:
    end_epoch = opts.end_epoch

  dump_input = torch.rand(
    (16, opts.input_nc, opts.sz_pch[1], opts.sz_pch[0])
  )


  model = torch.nn.DataParallel(model, device_ids=opts.gpu_ids).cuda()                                     ###########WITH GPU
  # model = torch.nn.DataParallel(model, device_ids=opts.gpu_ids)


  logger.info(get_model_summary(model, dump_input.cuda()))                                     ###########WITH GPU
  # logger.info(get_model_summary(model, dump_input))


  n_iter = opts.trainIter  # only for test purpose     quick test
  if (opts.mod_src == ['depth', "psm"]):
      mod = "both"
  else :
      mod = opts.mod_src[0]

  if not if_test:
    for fold in range(last_fold, 5):
      logger.info('Started fold {0}'.format(fold))
      trainset = SMaLDataset(opts.ds_fd,fold=fold,covering=['uncovered','cover1','cover2'],phase="train",
                          clip=True,aug_param = opts.aug_param,modality=mod,swap_channels=opts.swap_channels)
      train_loader = DataLoader(trainset, batch_size=16, shuffle=True)
      valset = SMaLDataset(opts.ds_fd,fold=fold,clip=True,modality=mod,phase="val",swap_channels=opts.swap_channels)
      val_loader = DataLoader(valset, batch_size=300, shuffle=False)
      for epoch in range(begin_epoch,end_epoch):
        if opts.display_id > 0:
          visualizer.reset()      # clean up the vis
        # train for one epoch
        # rst_trn = train(train_loader, SLP_rd_train, model, criterion, optimizer, epoch, n_iter=n_iter, logger=logger, opts=opts, visualizer=visualizer)
        rst_trn = train(train_loader, model, criterion, optimizer, epoch, n_iter=n_iter, logger=logger, opts=opts, visualizer=visualizer)
        losses += rst_trn['losses']
        accs += rst_trn['accs']

        # evaluate on validation set    to update
        rst_test = validate(
          val_loader, model, criterion, fold=fold,
          n_iter=n_iter, logger=logger, opts=opts)   # save preds, gt, preds_in ori, idst_normed to recovery, error here for last epoch?

        #HANDLE---
        pck_all_1 = rst_test['pck1']
        perf_indicator_1 = pck_all_1[-1][-1] # the last entry
        pckh1 = np.array(pck_all_1)[:, -1]   # the last indicies     15 x 11 last
        pck_all_05 = rst_test['pck05']
        perf_indicator_05 = pck_all_05[-1][-1] # the last entry
        pckh05 = np.array(pck_all_05)[:, -1]   # the last indicies     15 x 11 last
        # print(np.array(pck_all))
        titles_c = list(SMaL_configs['jointNames']) + ['total']
        ut.prt_rst([pckh1], titles_c, ['pckh1'], fn_prt=logger.info)
        ut.prt_rst([pckh05], titles_c, ['pckh0.5'], fn_prt=logger.info)
        #---------

        lr_scheduler.step()     # new version updating here
        if perf_indicator_1 >= best_perf:
          best_perf = perf_indicator_1
          best_model = True
        else:
          best_model = False

        logger.info('=> saving checkpoint to {}'.format(opts.model_dir))
        ckp = {
          'epoch': epoch + 1,     # epoch to next, after finish 0 this is 1
          'model': opts.model,
          'state_dict': model.module.state_dict(),
          'best_state_dict': model.module.state_dict(),
          'perf': perf_indicator_1,
          'optimizer': optimizer.state_dict(),
          'losses': losses,       # for later updating
          'accs': accs,
          'last_fold': fold
        }
        torch.save(ckp, os.path.join(opts.model_dir, 'checkpoint.pth'))
        if best_model:
          torch.save(ckp, os.path.join(opts.model_dir, 'model_best.pth'))
        # save directly, if statebest save another
      begin_epoch = 0
    final_model_state_file = os.path.join(
      opts.model_dir, 'final_state.pth'     # only after last iters
    )
    logger.info('=> saving final model state to {}'.format(
      final_model_state_file)
    )
    torch.save(model.module.state_dict(), final_model_state_file)

  # single test with loaded model, save the result
  logger.info('----run final test----')
  for fold in range(5):
    testset = SMaLDataset('/content/drive/MyDrive/Final_Project/dataset/SMaL-224/',fold=fold,clip=True,modality=mod,phase="test",swap_channels=opts.swap_channels)
    test_loader = DataLoader(testset, batch_size=300, shuffle=False)
    rst_test = validate(
      test_loader, model, criterion,
      n_iter=n_iter, logger=logger, opts=opts, if_svVis=True, fold=fold)  # save preds, gt, preds_in ori, idst_normed to recovery

    #HANDLE-------------
    pck_all_1 = rst_test['pck1']
    # perf_indicator_1 = pck_all_1[-1][-1] # the last entry
    pckh1 = np.array(pck_all_1)[:, -1]   # the last indicies     15 x 11 last
    pck_all_05 = rst_test['pck05']
    # perf_indicator_05 = pck_all_05[-1][-1] # the last entry
    pckh05 = np.array(pck_all_05)[:, -1]   # the last indicies     15 x 11 last
    # print(np.array(pck_all))
    titles_c = list(SMaL_configs['jointNames']) + ['total']
    ut.prt_rst([pckh1], titles_c, ['pckh1'], fn_prt=logger.info)
    ut.prt_rst([pckh05], titles_c, ['pckh0.5'], fn_prt=logger.info)
    pth_rst = path.join(opts.rst_dir, opts.nmTest + '.json')
    with open(pth_rst, 'w') as f:
      json.dump(rst_test, f)
    #---------------

if __name__ == '__main__':
  main()

# Predict with fusion model

Load image from the dataset

In [None]:
import matplotlib.pyplot as plt

# move to dataset location
%cd /content/drive/MyDrive/Final_Project/dataset            

cover_type = "uncovered"
image_id = 23 ##<300

depth_images = np.load('/content/drive/MyDrive/Final_Project/dataset/SMaL-All/depth.npz',mmap_mode='r')
psm_images = np.load('/content/drive/MyDrive/Final_Project/dataset/SMaL-All/psm.npz',mmap_mode='r')
color_images = np.load('/content/drive/MyDrive/Final_Project/dataset/SMaL-All/color.npz',mmap_mode='r')
labels_img = np.load('/content/drive/MyDrive/Final_Project/dataset/SMaL-All/jnts.npy',mmap_mode='r')

psm_image = psm_images[cover_type][image_id]
depth_image = depth_images[cover_type][image_id]
color_image = color_images[cover_type][image_id]

fig = plt.figure(figsize=(10, 10))
ax1 = fig.add_subplot(2,2,1)
ax1.imshow(psm_image)
ax2 = fig.add_subplot(2,2,2)
ax2.imshow(depth_image)
ax3 = fig.add_subplot(2,2,3)
ax3.imshow(color_image)

In [None]:
###move to where the codebase is located
%cd /content/drive/MyDrive/Final_Project/code/Fusion_Network

Calculate mean and std for image normalization

In [87]:
means = []
stds = []
for fold in range(5):
  testset = SMaLDataset('/content/drive/MyDrive/Final_Project/dataset/SMaL-224/',fold=fold,clip=True,modality="both",phase="test",swap_channels=False)
  test_loader = DataLoader(testset, batch_size=1, shuffle=True)
  with torch.no_grad():
    for i, inp_dct in enumerate(test_loader):
      # compute output
      input = inp_dct['pch']
      mean = inp_dct["mean"]
      std = inp_dct["std"]
      means.append(mean.cpu().detach().numpy())
      stds.append(std.cpu().detach().numpy())
avg_mean = np.vstack(means).mean(axis=0)
avg_std = np.vstack(stds).mean(axis=0)
print(avg_mean, avg_std)


[5.8949493e+02 1.1351888e-02] [59.654747    0.07786007]


Image preprocessing

In [88]:
depth_image_exp = np.expand_dims(depth_image,axis = 0)
psm_image_exp = np.expand_dims(psm_image,axis = 0)
input_image_stacked = np.stack([depth_image,psm_image], axis = 2)
input_image_exp = np.expand_dims(input_image_stacked, axis=0)
input_image_exp[:,:,:,0] = np.clip(input_image_exp[:,:,:,0],400,800)
mean = avg_mean
std = avg_std
img = input_image_exp[0,:,:,:]
img_resized = cv2.resize(img, dsize=(256,256), interpolation=cv2.INTER_CUBIC)
img_height, img_width, img_channel = img_resized.shape
bb = [0,0,img_width,img_height]  # full image bb , make square bb
bb = ut.adj_bb(bb, rt_xy=1)
scale, rot, do_flip, color_scale, do_occlusion = 1.0, 0.0, False, [1.0, 1.0, 1.0], False
img_patch, trans = generate_patch_image(img_resized, bb, do_flip, scale, rot, do_occlusion, input_shape=(256,256))
img_channels = img_patch.shape[2]
for i in range(img_channels):
  img_patch[:, :, i] = img_patch[:, :, i] * color_scale[i]
trans_tch = transforms.Compose([transforms.ToTensor(),
          transforms.Normalize(mean=mean, std=std)]
        )
pch_tch = trans_tch(img_patch)[None,:,:,:]
pch_tch.shape

torch.Size([1, 2, 256, 256])

Initialise the network

In [None]:

from model.HRposeFuseNetNewUnweighted_v2 import get_pose_net

model = "HRposeFuseNetNewUnweighted_v2"
mod_src = ["depth", "psm"]
fuse_model_path = '/content/drive/MyDrive/Final_Project/code/Fusion_Network/output/SMaL_depth_PM_2_iaff_dropout/model_dump'

model =  get_pose_net(in_ch=2, out_ch=14, fuse_stage=2, fuse_type="iAFF", mod_src=mod_src)

checkpoint_file = os.path.join(fuse_model_path,'model_best.pth')

checkpoint = torch.load(checkpoint_file)        ###WITH GPU
# checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))

model.load_state_dict(checkpoint['state_dict'])
model.cuda()      ###WITH GPU

model.eval()


Pose prediction and visualization

In [90]:
pred = model(pch_tch.cuda())

In [91]:
from utils.utils_ds import get_max_preds
end = time.time()
out_shp = (64,64)
sz_pch = (256,256)
preds_hm = []
bbs = []
li_joints_ori = []
li_joints_vis = []
li_l_std_ori = []
outputs =pred["output"]

if isinstance(outputs, list):
  output = outputs[-1]
else:
  output = outputs

pred_hm, _ = get_max_preds(output.cpu().detach().numpy())
pred2d_patch = np.ones((SMaL_configs["numJoints"], 3))  # 3rd for  vis
pred2d_patch[:,:2] = pred_hm[0] / out_shp[0] * sz_pch[1]      # only first

In [None]:
_,(ax1,ax2) = plt.subplots(1,2, figsize=(15,15))
color_image_resized = cv2.resize(color_image, dsize=(256,256), interpolation=cv2.INTER_CUBIC)
plotImage(ax1, color_image_resized,0)
plot2DJoints(ax1, pred2d_patch,  SMaL_configs["connectedJoints"], SMaL_configs["gtjointColours"])

plotImage(ax2, img_patch[:,:,0],0)
plot2DJoints(ax2, pred2d_patch,  SMaL_configs["connectedJoints"], SMaL_configs["gtjointColours"])

Predicted heatmaps

In [None]:

t = pred['output'][0][0]
plt.imshow(t.cpu().detach().numpy())

Visualise using dataloader

In [None]:
out_shp = (64,64)
sz_pch = (256,256)
testset = SMaLDataset('/content/drive/MyDrive/Final_Project/dataset/SMaL-224/',fold=4,clip=True,modality="both",phase="test",swap_channels=False)
test_loader = DataLoader(testset, batch_size=1, shuffle=True)
with torch.no_grad():
  for i, inp_dct in enumerate(test_loader):
    # compute output
    input = inp_dct['pch']
    print(input.shape)
    mean = inp_dct["mean"]
    std = inp_dct["std"]
    outputs = model(input.cuda())
    outputs =outputs["output"]
    if isinstance(outputs, list):
      output = outputs[-1]
    else:
      output = outputs
    output_ori = output.clone()     # original output of original image
    # _, avg_acc, cnt, pred_hm = accuracy(output.cpu().numpy(),
    #                                   target.cpu().numpy())
    pred_hm = get_max_preds(output.cpu().numpy())
    pred2d_patch = np.ones((14, 3))  # 3rd for  vis
    ax = plt.subplot(1,2,1)
    plotImage(ax, input[0],0)
    pred2d_patch[:,:2] = pred_hm[0] / out_shp[0] * sz_pch[1]
    plot2DJoints(ax, pred2d_patch,  SMaL_configs["connectedJoints"], SMaL_configs["gtjointColours"])
    break

print(mean, std)