# Litepose 
[Litepose][1] proposes an efficient way to perform pose estimation, providing a low computationl cost, scale invariant and reliable architecture. It follows a bottom-up approach, namely it uses just one network to do both keypoints estimation and grouping. In this implementation I have also relied on two other papers (also cited by Litepose), one is [HigherHRNet][2] that proposes the main architecture and the other is [Associative Embedding][3] that introduces a way to assign identity-free keypoints to the person they belong to.

Litepose modifies the HigherHRNet architecture going from a multi-branch to a single-branch one by gradual shrinking, with the purpose of speeding up the inference, making it run on low computational power devices as well.

The architecture uses a MobileNet backbone with **Large Kernel Convolutions** that have shown great results empirically. The backbone output is passed to multiple deconvolutional blocks implementing the main feature of the Litepose Paper that is **Fusion Deconv Head**. This allows to obtain scale aware results by merging backbone intermediate features and refined features, in this way the network can exploit high resolution features, that help to catch close joints, without involving a multi-branch architecture. 

Let $t$ be the number of convolutional blocks and $n$ be the number of the current deconvolutional block, the features fusion is implemented by summing the features of deconvolutional block in position $n$ with the features of backbone block in position $t-n-1$, refined by an additional convolutional layer. Eventually the merged features are passed to a final block for each deconvolutional layer that produces the output. The results of the network are provided in several scales, one for each deconvolutional layer. Hence the output is a $(n,j,s_i,s_i)$ tensor where $n$ is the number of scales (i.e. deconv blocks), $j$ is the number of joints that we want to detect, $s_i$ for $i\in{1,2,...,n}$ is the size of the current scale. The image ... clarifies the network structure.

[1]:https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_Lite_Pose_Efficient_Architecture_Design_for_2D_Human_Pose_Estimation_CVPR_2022_paper.pdf
[2]:https://arxiv.org/pdf/1908.10357.pdf
[3]:https://papers.nips.cc/paper/2017/file/8edd72158ccd2a879f79cb2538568fdc-Paper.pdf

In [1]:
from lp_coco_utils.lp_getDataset import getDatasetProcessed
from lp_training.lp_trainer import train
from lp_training.lp_loss import computeLoss
from lp_model.lp_litepose import LitePose
import lp_config.lp_common_config as cc
import torch
from lp_inference.lp_inference import inference, assocEmbedding
from lp_utils.lp_image_processing import drawHeatmap, drawKeypoints, normalizeImage, drawSkeleton
from lp_testing.lp_evaluation import computeOKS

This file has to be seen only as an entry that calls wrapper functions, the implentation of those functions can be found in the subdirectories of the repository.   
Every hyperparameter can be edited in `src/lp_config`.  
`lp_common_config.py` contains the general configurations about the dataset loading, training and inference. On the other hand `lp_model_config.py` contrains the parameters that encode the model structure. The current model configs are taken from the Neural Architecture Search performed by the paper authors. I used the small size network due to the computational power available, however better results can be achieved simply by scaling the network size (Good parameters combinations are provided by the paper authors).

Code taken by the [official paper repository](https://github.com/mit-han-lab/litepose):
- classes `CocoDataset` and `CocoKeypoints` are partially taken, I added fiftyone support that makes the dataset setup easier and I removed unnecessary code.
- I took the code inside `lp_generators.py` and `lp_transforms.py` as well, since they were a `CocoKeypoints` dependencies

# Training

The dataset is downloaded by using fiftyone APIs and keypoint heatmaps are created for each sample. 

In [2]:
train(cc.config["batch_size"])

loading annotations into memory...
Done (t=0.09s)
creating index...
index created!
loading annotations into memory...
Done (t=0.04s)
creating index...
index created!


100%|██████████| 250/250 [00:36<00:00,  6.78it/s]
100%|██████████| 500/500 [00:24<00:00, 20.83it/s]


epoch #1 

TRAINING LOSS:
total Loss = 0.08047071336954832
heatmap Loss = 0.07520170802809298
tag Loss = 0.005269005094654858

VALIDATION LOSS:
total Loss = 0.016819940937682985
heatmap Loss = 0.012100410318933427
tag Loss = 0.004719530618283898





100%|██████████| 250/250 [00:35<00:00,  6.95it/s]
100%|██████████| 500/500 [00:24<00:00, 20.77it/s]


epoch #2 

TRAINING LOSS:
total Loss = 0.01762638370320201
heatmap Loss = 0.012879698220640421
tag Loss = 0.004746685437858104

VALIDATION LOSS:
total Loss = 0.013628414099104702
heatmap Loss = 0.008918515951838344
tag Loss = 0.004709898107685149





100%|██████████| 250/250 [00:36<00:00,  6.88it/s]
100%|██████████| 500/500 [00:23<00:00, 20.92it/s]


epoch #3 

TRAINING LOSS:
total Loss = 0.014641476087272167
heatmap Loss = 0.009914935329928994
tag Loss = 0.004726540789939463

VALIDATION LOSS:
total Loss = 0.013281734415329993
heatmap Loss = 0.008567944450303911
tag Loss = 0.004713789953384549





100%|██████████| 250/250 [00:36<00:00,  6.78it/s]
100%|██████████| 500/500 [00:24<00:00, 20.40it/s]


epoch #4 

TRAINING LOSS:
total Loss = 0.014037863615900278
heatmap Loss = 0.009304489215835928
tag Loss = 0.004733374439179898

VALIDATION LOSS:
total Loss = 0.013867149045690894
heatmap Loss = 0.009160315503831952
tag Loss = 0.004706833506934345





100%|██████████| 250/250 [00:36<00:00,  6.86it/s]
100%|██████████| 500/500 [00:24<00:00, 20.45it/s]


epoch #5 

TRAINING LOSS:
total Loss = 0.013796656161546707
heatmap Loss = 0.009107654850929976
tag Loss = 0.004689001296646893

VALIDATION LOSS:
total Loss = 0.015042394181713463
heatmap Loss = 0.009902948658913373
tag Loss = 0.0051394454846158625





100%|██████████| 250/250 [00:36<00:00,  6.78it/s]
100%|██████████| 500/500 [00:23<00:00, 21.18it/s]


epoch #6 

TRAINING LOSS:
total Loss = 0.013556378792971373
heatmap Loss = 0.008887629497796297
tag Loss = 0.004668749294243753

VALIDATION LOSS:
total Loss = 0.013232948506250978
heatmap Loss = 0.008584347394295037
tag Loss = 0.004648601124528796





100%|██████████| 250/250 [00:37<00:00,  6.66it/s]
100%|██████████| 500/500 [00:24<00:00, 20.67it/s]


epoch #7 

TRAINING LOSS:
total Loss = 0.013300016470253468
heatmap Loss = 0.008642715716734528
tag Loss = 0.0046573007432743905

VALIDATION LOSS:
total Loss = 0.055786728955805304
heatmap Loss = 0.051002869371324776
tag Loss = 0.004783859477378428





100%|██████████| 250/250 [00:36<00:00,  6.84it/s]
100%|██████████| 500/500 [00:24<00:00, 20.12it/s]


epoch #8 

TRAINING LOSS:
total Loss = 0.013222670879215002
heatmap Loss = 0.008575820753350854
tag Loss = 0.004646850080229342

VALIDATION LOSS:
total Loss = 0.012928037642501295
heatmap Loss = 0.008297235233709217
tag Loss = 0.0046308024278841915





100%|██████████| 250/250 [00:36<00:00,  6.81it/s]
100%|██████████| 500/500 [00:24<00:00, 20.31it/s]


epoch #9 

TRAINING LOSS:
total Loss = 0.01302877901494503
heatmap Loss = 0.008380192602053284
tag Loss = 0.004648586425930261

VALIDATION LOSS:
total Loss = 0.01295717848651111
heatmap Loss = 0.008306454771198332
tag Loss = 0.00465072370460257





100%|██████████| 250/250 [00:36<00:00,  6.84it/s]
100%|██████████| 500/500 [00:25<00:00, 19.86it/s]


epoch #10 

TRAINING LOSS:
total Loss = 0.013029399439692496
heatmap Loss = 0.00838365208543837
tag Loss = 0.004645747371017933

VALIDATION LOSS:
total Loss = 0.012876895803026854
heatmap Loss = 0.0082467020326294
tag Loss = 0.004630193808116019





100%|██████████| 250/250 [00:36<00:00,  6.85it/s]
100%|██████████| 500/500 [00:25<00:00, 19.98it/s]


epoch #11 

TRAINING LOSS:
total Loss = 0.01295751177892089
heatmap Loss = 0.008302106786519289
tag Loss = 0.004655405006371439

VALIDATION LOSS:
total Loss = 0.012824951919727028
heatmap Loss = 0.008177054868079722
tag Loss = 0.004647897041868418





100%|██████████| 250/250 [00:36<00:00,  6.81it/s]
100%|██████████| 500/500 [00:24<00:00, 20.68it/s]


epoch #12 

TRAINING LOSS:
total Loss = 0.012882953956723214
heatmap Loss = 0.00824509616754949
tag Loss = 0.004637857780791819

VALIDATION LOSS:
total Loss = 0.012836325895041228
heatmap Loss = 0.00820368056325242
tag Loss = 0.004632645340403542





100%|██████████| 250/250 [00:37<00:00,  6.65it/s]
100%|██████████| 500/500 [00:24<00:00, 20.09it/s]


epoch #13 

TRAINING LOSS:
total Loss = 0.01296816450357437
heatmap Loss = 0.008315359145402908
tag Loss = 0.004652805358171463

VALIDATION LOSS:
total Loss = 0.01282859725691378
heatmap Loss = 0.008176085074432194
tag Loss = 0.00465251217642799





100%|██████████| 250/250 [00:39<00:00,  6.38it/s]
100%|██████████| 500/500 [00:26<00:00, 19.19it/s]


epoch #14 

TRAINING LOSS:
total Loss = 0.012867116160690785
heatmap Loss = 0.008221789864823222
tag Loss = 0.004645326280966401

VALIDATION LOSS:
total Loss = 0.012760887102223933
heatmap Loss = 0.008120832859538496
tag Loss = 0.0046400542431510984





100%|██████████| 250/250 [00:37<00:00,  6.73it/s]
100%|██████████| 500/500 [00:25<00:00, 19.84it/s]


epoch #15 

TRAINING LOSS:
total Loss = 0.012857821855694056
heatmap Loss = 0.008207674564793706
tag Loss = 0.004650147279724479

VALIDATION LOSS:
total Loss = 0.012875987044535577
heatmap Loss = 0.008235850444994867
tag Loss = 0.004640136564616114





100%|██████████| 250/250 [00:38<00:00,  6.55it/s]
100%|██████████| 500/500 [00:24<00:00, 20.03it/s]


epoch #16 

TRAINING LOSS:
total Loss = 0.012847450278699398
heatmap Loss = 0.008199847046285867
tag Loss = 0.004647603264078498

VALIDATION LOSS:
total Loss = 0.012733603549189865
heatmap Loss = 0.008076060827821493
tag Loss = 0.004657542700413615





100%|██████████| 250/250 [00:36<00:00,  6.79it/s]
100%|██████████| 500/500 [00:24<00:00, 20.23it/s]


epoch #17 

TRAINING LOSS:
total Loss = 0.012751270856708288
heatmap Loss = 0.00810851182602346
tag Loss = 0.004642759037204087

VALIDATION LOSS:
total Loss = 0.012750382878817619
heatmap Loss = 0.008107863499782979
tag Loss = 0.004642519370652735





100%|██████████| 250/250 [00:36<00:00,  6.85it/s]
100%|██████████| 500/500 [00:24<00:00, 20.43it/s]


epoch #18 

TRAINING LOSS:
total Loss = 0.012810359228402375
heatmap Loss = 0.008172775665298105
tag Loss = 0.004637583567760885

VALIDATION LOSS:
total Loss = 0.013524260600097477
heatmap Loss = 0.008881290347315371
tag Loss = 0.004642970251152292





100%|██████████| 250/250 [00:36<00:00,  6.84it/s]
100%|██████████| 500/500 [00:27<00:00, 17.91it/s]


epoch #19 

TRAINING LOSS:
total Loss = 0.013021481834352016
heatmap Loss = 0.008378675501793623
tag Loss = 0.004642806359566748

VALIDATION LOSS:
total Loss = 0.013007370192557573
heatmap Loss = 0.008359879932366312
tag Loss = 0.004647490238305181





100%|██████████| 250/250 [00:38<00:00,  6.44it/s]
100%|██████████| 500/500 [00:26<00:00, 18.92it/s]


epoch #20 

TRAINING LOSS:
total Loss = 0.01305600407719612
heatmap Loss = 0.00841228255815804
tag Loss = 0.004643721512518823

VALIDATION LOSS:
total Loss = 0.012982219656929373
heatmap Loss = 0.008344521477818488
tag Loss = 0.004637698173988611





100%|██████████| 250/250 [00:36<00:00,  6.83it/s]
100%|██████████| 500/500 [00:25<00:00, 19.82it/s]


epoch #21 

TRAINING LOSS:
total Loss = 0.013030008919537067
heatmap Loss = 0.008384572545066476
tag Loss = 0.004645436380989849

VALIDATION LOSS:
total Loss = 0.013055928810499608
heatmap Loss = 0.008404898783192038
tag Loss = 0.0046510300131049





100%|██████████| 250/250 [00:37<00:00,  6.68it/s]
100%|██████████| 500/500 [00:25<00:00, 19.62it/s]


epoch #22 

TRAINING LOSS:
total Loss = 0.012890338525176048
heatmap Loss = 0.008252503329887986
tag Loss = 0.004637835164554417

VALIDATION LOSS:
total Loss = 0.01293287575431168
heatmap Loss = 0.008272846695035696
tag Loss = 0.004660029043443501





100%|██████████| 250/250 [00:36<00:00,  6.76it/s]
100%|██████████| 500/500 [00:25<00:00, 19.30it/s]


epoch #23 

TRAINING LOSS:
total Loss = 0.012886956293135882
heatmap Loss = 0.00822942927479744
tag Loss = 0.0046575270304456354

VALIDATION LOSS:
total Loss = 0.012873668960295617
heatmap Loss = 0.008236169890500605
tag Loss = 0.004637499056756497





100%|██████████| 250/250 [00:37<00:00,  6.59it/s]
100%|██████████| 500/500 [00:24<00:00, 20.44it/s]


epoch #24 

TRAINING LOSS:
total Loss = 0.012826781220734119
heatmap Loss = 0.008184250904247166
tag Loss = 0.004642530271783471

VALIDATION LOSS:
total Loss = 0.012793014236725866
heatmap Loss = 0.008143021154683083
tag Loss = 0.004649993083439767





100%|██████████| 250/250 [00:36<00:00,  6.83it/s]
100%|██████████| 500/500 [00:24<00:00, 20.25it/s]

epoch #25 

TRAINING LOSS:
total Loss = 0.012838764283806085
heatmap Loss = 0.008203745471313596
tag Loss = 0.004635018786415457

VALIDATION LOSS:
total Loss = 0.012763793285936118
heatmap Loss = 0.00813878972362727
tag Loss = 0.004625003540422767



end training, exec time: 1549.6994819641113





# Inference
Unfotunately OpenCV method `imshow()` has a well known bug with python notebooks.

In [2]:
import cv2
import random

import torch.nn.functional as F


model = LitePose().to(cc.config["device"])
model.load_state_dict(torch.load("lp_trained_models/mytag"))

ds = getDatasetProcessed("validation")

data_loader = torch.utils.data.DataLoader(
    ds,
    batch_size=8
)

row = next(iter(data_loader))
images = row[0].to(cc.config["device"])
#img_size = 256
#images = F.interpolate(images, size = (img_size, img_size))
gthm = row[1]
output, keypoints = inference(model, images)

loading annotations into memory...
Done (t=0.20s)
creating index...
index created!


In [3]:
embedding = assocEmbedding(keypoints)

In [4]:
jointsHeatmap = output[1][2][:cc.config["num_joints"]]

img, finalHm, superimposed = drawHeatmap(images[2], jointsHeatmap)
img, gtfinalHm, gtsuperimposed = drawHeatmap(images[2], gthm[1][2])
cv2.imshow("Image", img)
cv2.imshow("Final heatmap", finalHm)
cv2.imshow("Superimposed", superimposed)

cv2.imshow("Ground Truth heatmap", gtfinalHm)
cv2.imshow("Ground Truth Superimposed", gtsuperimposed)
cv2.waitKey()
cv2.destroyAllWindows()

In [8]:
import lp_utils.lp_image_processing as ip
import numpy as np

img = images[1]

tj = output[1][1]

for t in range(14,28):
    tagJoints = tj[t]
    scaled = ip.scaleImage(tagJoints.unsqueeze(0), img.shape[1]).cpu().numpy()
    scaled = scaled[0]
    scaled = ip.normalizeImage(scaled)
    finalHm = cv2.applyColorMap(np.uint8(scaled), cv2.COLORMAP_JET)
    cv2.imshow(str(t), finalHm)
    cv2.waitKey()
cv2.destroyAllWindows()


#img = img.cpu().numpy().transpose(1, 2, 0)
#img = normalizeImage(img)
#img = np.uint8(img)
#heatmaps = ip.scaleImage(tagJoints, img.shape[1]).cpu().numpy()

#finalHm = ip.mergeMultipleHeatmaps(heatmaps)
#finalHm = normalizeImage(finalHm)
#finalHm = cv2.applyColorMap(np.uint8(finalHm), cv2.COLORMAP_JET)

#cv2.imshow("Final heatmap", finalHm)


torch.Size([14, 128, 128])
torch.Size([28, 128, 128])


In [3]:
tagJoints = output[1][2][cc.config["num_joints"]:]

print(tagJoints[0].max())

img, finalHm, superimposed = drawHeatmap(images[2], tagJoints)
cv2.imshow("Image", img)
cv2.imshow("Final heatmap", finalHm)
cv2.imshow("Superimposed", superimposed)

cv2.waitKey()
cv2.destroyAllWindows()

tensor(0., device='cuda:0')


In [5]:
img = drawKeypoints(images[1], keypoints[1])
cv2.imshow("Image Keypoints", img)
cv2.waitKey()
cv2.destroyAllWindows()

In [4]:
img = drawSkeleton(images[1], embedding[1])
cv2.imshow("Image Keypoints", img)
cv2.waitKey()
cv2.destroyAllWindows()