Official implementation of "Weakly-supervised positional contrastive learning: application to cirrhosis classification", accepted paper at MICCAI 2023 [paper].
Authors: Emma Sarfati
[1] Guerbet Research, Villepinte, France
[2] LTCI, Télécom Paris, Institut Polytechnique de Paris, France
[3] Sorbonne Université, CNRS, LIP6, Paris, France
This paper introduces a new contrastive learning method based on a generic kernel-loss function that allows to leverage discrete and continuous meta-labels for medical imaging.
Let
In contrastive learning (CL), one wants to find a parametric function
In practice, one does not know the definition of negative and positive. This is the main difference between each CL method. In SimCLR [1], positives are two random augmentations of the anchor
where the indices
References:
[1] SimCLR
[2] SupCon
[3] y-Aware
This repo contains the official codes for WSP Contrastive Learning. The codes are implemented using PyTorch-Lightning.
All the images must be stored in the path_to_data
path and must contain two folders inside:
/train
: training images in Nifty format./validation
: validation images in Nifty format.
To run properly the codes, you will have to provide a Pandas DataFrame with the following index and columns:
- Index: name of the subjects.
- Column
class
: radiological class or histological class depending on the type of task (pretraining or classification). - Column
label
: histological class (if available).
We provide the dataframe for the public TGCA-LIHC dataset that we used in our paper for evaluation (dataframe_lihc.csv
file. The patients here have a histological confirmation through the Ishak score.
The portal venous phase CT-scans of subject
column of dataframe_lihc.csv
file. Also, to avoid duplicated scans, we only kept the older scan for each patient.
The file main.py
can be launched in two different modes: pretraining or finetuning. Many other arguments follow, that you will have to indicate by following this convention:
python main.py --mode <put mode here> --rep_dim <put number here> --num_classes <put number here>
And so on. All the arguments are available in the file config.py
and are provided below.
mode: str = 'finetuning',
rep_dim: int = 512,
hidden_dim: int = 256,
output_dim: int = 128,
num_classes: int = 4,
encoder: str = 'tiny',
n_layer: int = 18,
lr: float = 1e-5,
weight_decay: float = 1e-5,
label_name: str = 'label',
n_fold: int = 4,
cross_val: bool = False,
pretrained_path: str = None,
sigma: float = 0.85,
temperature: float = 0.1,
kernel: str = 'rbf',
max_epochs: int = 40,
batch_size: int = 64,
pretrained: bool = False,
path_to_data: str = "path_to_data",
lght_dir: str = "path_to_models"
For either pretraining or finetuning mode, the data that are fed to the models are, in this order: data, label, subject_id, z
.
-
data
: 2D image of shape (1,512,512). -
label
: discrete label corresponding to the variable$y$ in Eq. (1). -
subject_id
: name of the subject, for convenience. -
z
: continuous label corresponding to the variable$d$ in Eq. (1). Please note that the normalized positional coordinate$d\in [0,1]$ (namedz
in the code) is computed automatically given each volume at the beginning of thedataset.py
file. Hence you will only have to provide the discrete label in the dataframe. If you wish to use other labels related to the patients you can do it by providing other columns in the original dataframe. You will need to change the implementation of the loss accordingly by adding discrete/continuous kernels.
For pretraining, launch the following line of code.
main.py --mode pretraining
For finetuning, launch the following line of code.
main.py --mode finetuning
For adding an argument, you can follow the protocol described above.
We use TensorBoard for following metrics. To access it, launch in your terminal:
tensorboard --logdir=<path of where your codes are>
For evaluation, you can run a Jupyter Notebook and import the pretrained weights. To reproduce the evaluation of the paper with