PyTorch implementation of the paper: "EMA-Net: Efficient Multitask Affinity Learning for Dense Scene Predictions".
- Download weights from here (HRNet-W18-C-Small-v2).
- Save them to
models\pretrained_models\hrnet_w18_small_model_v2.pth
.
- Set the
db_root
(dataroot) inconfigs/mypath.py
to where you stored the dataset. - Set the
--storage_root
and--config
intrain.sh
- Run
train.sh
You can find the preprint of our paper on arXiv.
Please cite our paper if you use the code or the results of our work.
@article{sinodinos2024ema,
title={EMA-Net: Efficient Multitask Affinity Learning for Dense Scene Predictions},
author={Sinodinos, Dimitrios and Armanfard, Narges},
journal={arXiv preprint arXiv:2401.11124},
year={2024}
}
Multitask learning (MTL) has gained prominence for its ability to jointly predict multiple tasks, achieving better per-task performance while using fewer per-task model parameters than single-task learning. More recently, decoder-focused architectures have considerably improved multitask performance by refining task predictions using the features of other related tasks. However, most of these refinement methods fail to simultaneously capture local and global task-specific representations, as well as cross-task patterns in a parameter-efficient manner. In this paper, we introduce the Efficient Multitask Affinity Learning Network (EMA-Net), which is a lightweight framework that enhances the task refinement capabilities of multitask networks. EMA-Net adeptly captures local, global, and cross-task interactions using our novel Cross-Task Affinity Learning (CTAL) module. The key innovation of CTAL lies in its ability to manipulate task affinity matrices in a manner that is optimally suited to apply parameter-efficient grouped convolutions without worrying about information loss. Our results show that we achieve state-of-the-art MTL performance for CNN-based decoder-focused models while using substantially fewer model parameters.
This repo borrows several elements from Multi-Task-Learning-PyTorch and MTAN.