Skip to content

Commit

Permalink
Merge pull request #156 from rafiberlin/fix_image_retrieval_pipeline
Browse files Browse the repository at this point in the history
Makes the image retrieval part work without too much effort. (should address issues #64 and #109)
  • Loading branch information
KaihuaTang committed Mar 17, 2022
2 parents 18c16b6 + 6c4d1bb commit 634a6e2
Show file tree
Hide file tree
Showing 19 changed files with 1,053 additions and 289 deletions.
6 changes: 6 additions & 0 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,14 @@ python setup.py build_ext install
cd $INSTALL_DIR
git clone https://github.com/NVIDIA/apex.git
cd apex

# WARNING if you use older Versions of Pytorch (anything below 1.7), you will need a hard reset,
# as the newer version of apex does require newer pytorch versions. Ignore the hard reset otherwise.
git reset --hard 3fe10b5597ba14a748ebb271a6ab97c09c5701ac

python setup.py install --cuda_ext --cpp_ext


# install PyTorch Detection
cd $INSTALL_DIR
git clone https://github.com/KaihuaTang/Scene-Graph-Benchmark.pytorch.git
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ MOTIFS-PredCls-none | 59.64 | 66.11 | 67.96 | 11.46 | 14.60 | 15.84 | 5.79 | 11.
MOTIFS-PredCls-TDE | 33.38 | 45.88 | 51.25 | 17.85 | 24.75 | 28.70 | 8.28 | 14.31 | 18.04

## SGDet on Custom Images
Note that evaluation on custum images is only applicable for SGDet model, because PredCls and SGCls model requires additional ground-truth bounding boxes information. To detect scene graphs into a json file on your own images, you need to turn on the switch TEST.CUSTUM_EVAL and give a folder path that contains the custom images to TEST.CUSTUM_PATH. Only JPG files are allowed. The output will be saved as custom_prediction.json in the given DETECTED_SGG_DIR.
Note that evaluation on custum images is only applicable for SGDet model, because PredCls and SGCls model requires additional ground-truth bounding boxes information. To detect scene graphs into a json file on your own images, you need to turn on the switch TEST.CUSTUM_EVAL and give a folder path (or a json file containing a list of image paths) that contains the custom images to TEST.CUSTUM_PATH. Only JPG files are allowed. The output will be saved as custom_prediction.json in the given DETECTED_SGG_DIR.

Test Example 1 : (SGDet, **Causal TDE**, MOTIFS Model, SUM Fusion) [(checkpoint)](https://onedrive.live.com/embed?cid=22376FFAD72C4B64&resid=22376FFAD72C4B64%21781947&authkey=AF_EM-rkbMyT3gs)
```bash
Expand Down
1 change: 1 addition & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
_C.DATASETS.VAL = ()
# List of the dataset names for testing, as present in paths_catalog.py
_C.DATASETS.TEST = ()
_C.DATASETS.TO_TEST = None

# -----------------------------------------------------------------------------
# DataLoader
Expand Down
4 changes: 3 additions & 1 deletion maskrcnn_benchmark/config/paths_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@


class DatasetCatalog(object):
DATA_DIR = "datasets"
#DATA_DIR = "/home/users/alatif/data/ImageCorpora/"
DATA_DIR = "/media/rafi/Samsung_T5/_DATASETS/"
DATASETS = {
"coco_2017_train": {
"img_dir": "coco/train2017",
Expand Down Expand Up @@ -116,6 +117,7 @@ class DatasetCatalog(object):
"roidb_file": "vg/VG-SGG-with-attri.h5",
"dict_file": "vg/VG-SGG-dicts-with-attri.json",
"image_file": "vg/image_data.json",
"capgraphs_file": "vg/vg_capgraphs_anno.json",
},
}

Expand Down
22 changes: 19 additions & 3 deletions maskrcnn_benchmark/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,17 @@ def get_dataset_statistics(cfg):
logger.info('Loading data statistics from: ' + str(save_file))
logger.info('-'*100)
return torch.load(save_file, map_location=torch.device("cpu"))
else:
logger.info('Unable to load data statistics from: ' + str(save_file))

statistics = []
for dataset_name in dataset_names:
data = DatasetCatalog.get(dataset_name, cfg)
factory = getattr(D, data["factory"])
args = data["args"]
# Remove it because not part of the original repo (factory cant deal with additional parameters...).
if "capgraphs_file" in args.keys():
del args["capgraphs_file"]
dataset = factory(**args)
statistics.append(dataset.get_statistics())
logger.info('finish')
Expand Down Expand Up @@ -89,6 +94,11 @@ def build_dataset(cfg, dataset_list, transforms, dataset_catalog, is_train=True)
if data["factory"] == "PascalVOCDataset":
args["use_difficult"] = not is_train
args["transforms"] = transforms

#Remove it because not part of the original repo (factory cant deal with additional parameters...).
if "capgraphs_file" in args.keys():
del args["capgraphs_file"]

# make dataset from factory
dataset = factory(**args)
datasets.append(dataset)
Expand Down Expand Up @@ -153,8 +163,14 @@ def make_batch_data_sampler(
return batch_sampler


def make_data_loader(cfg, mode='train', is_distributed=False, start_iter=0):
def make_data_loader(cfg, mode='train', is_distributed=False, start_iter=0, dataset_to_test=None):
assert mode in {'train', 'val', 'test'}
assert dataset_to_test in {'train', 'val', 'test', None}
# this variable enable to run a test on any data split, even on the training dataset
# without actually flagging it for training....
if dataset_to_test is None:
dataset_to_test = mode

num_gpus = get_world_size()
is_train = mode == 'train'
if is_train:
Expand Down Expand Up @@ -199,9 +215,9 @@ def make_data_loader(cfg, mode='train', is_distributed=False, start_iter=0):
"maskrcnn_benchmark.config.paths_catalog", cfg.PATHS_CATALOG, True
)
DatasetCatalog = paths_catalog.DatasetCatalog
if mode == 'train':
if dataset_to_test == 'train':
dataset_list = cfg.DATASETS.TRAIN
elif mode == 'val':
elif dataset_to_test == 'val':
dataset_list = cfg.DATASETS.VAL
else:
dataset_list = cfg.DATASETS.TEST
Expand Down
15 changes: 11 additions & 4 deletions maskrcnn_benchmark/data/datasets/visual_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,17 @@ def get_statistics(self):
def get_custom_imgs(self, path):
self.custom_files = []
self.img_info = []
for file_name in os.listdir(path):
self.custom_files.append(os.path.join(path, file_name))
img = Image.open(os.path.join(path, file_name)).convert("RGB")
self.img_info.append({'width':int(img.width), 'height':int(img.height)})
if os.path.isdir(path):
for file_name in tqdm(os.listdir(path)):
self.custom_files.append(os.path.join(path, file_name))
img = Image.open(os.path.join(path, file_name)).convert("RGB")
self.img_info.append({'width':int(img.width), 'height':int(img.height)})
# Expecting a list of paths in a json file
if os.path.isfile(path):
file_list = json.load(open(path))
for file in tqdm(file_list):
img = Image.open(file).convert("RGB")
self.img_info.append({'width': int(img.width), 'height': int(img.height)})

def get_img_info(self, index):
# WARNING: original image_file.json has several pictures with false image size
Expand Down
60 changes: 56 additions & 4 deletions maskrcnn_benchmark/image_retrieval/S2G-RETRIEVAL.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,70 @@
# Sentence-to-Graph Retrieval (S2G)

Forgive me, this part of code is ugly and less organized.
Warning - this part of code is less organized.

## Preprocessing

Run the ```maskrcnn_benchmark/image_retrieval/preprocessing.py``` to process the annotations and checkpoints, where ```detected_path``` should be set to the corresponding checkpoints you want to use, ```vg_data, vg_dict, vg_info``` should have already downloaded if you followed DATASET.md, ```cap_graph``` is the ground-truth captions and generated sentence graphs (you can download it from [here](https://onedrive.live.com/embed?cid=22376FFAD72C4B64&resid=22376FFAD72C4B64%21779999&authkey=AGW0Wxjb1JSDFnc)). We use [SceneGraphParser](https://github.com/vacancy/SceneGraphParser) to generate these sentence graphs.
Pre-requisite: ```vg_data, vg_dict, vg_info``` should have already downloaded if you followed DATASET.md.

You also need to set the ```cap_graph``` PATH and ```vg_dict``` PATH in ```maskrcnn_benchmark/image_retrieval/dataloader.py``` manually.
You will also need a pre-trained SGDet model, for example from [here](https://onedrive.live.com/embed?cid=22376FFAD72C4B64&resid=22376FFAD72C4B64%21781947&authkey=AF_EM-rkbMyT3gs). This is the SGDet model that is beeing described in the main `README.md`

Download the ground-truth captions and generated sentence graphs from [here](https://onedrive.live.com/embed?cid=22376FFAD72C4B64&resid=22376FFAD72C4B64%21779999&authkey=AGW0Wxjb1JSDFnc).

Please note that this file needs to be configured properly in maskrcnn_benchmark/config/paths_catalog.py, See `DATASETS`, `VG_stanford_filtered_with_attribute` under the key `capgraphs_file`.

We used [SceneGraphParser](https://github.com/vacancy/SceneGraphParser) to generate these sentence graphs.
The script ```maskrcnn_benchmark/image_retrieval/sentence_to_graph_processing.py``` partially shows, how the text scene graphs were generated (under the key `vg_coco_id_to_capgraphs` in the dowloaded generated sentence graphs file).


Create the test results of the SGDet model for the training and test datasets with:

```bash
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port 10027 --nproc_per_node=1 tools/relation_test_net.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX False MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.ROI_RELATION_HEAD.PREDICTOR CausalAnalysisPredictor MODEL.ROI_RELATION_HEAD.CAUSAL.EFFECT_TYPE TDE MODEL.ROI_RELATION_HEAD.CAUSAL.FUSION_TYPE sum MODEL.ROI_RELATION_HEAD.CAUSAL.CONTEXT_LAYER motifs TEST.IMS_PER_BATCH 1 DTYPE "float16" GLOVE_DIR /home/kaihua/glove MODEL.PRETRAINED_DETECTOR_CKPT /home/kaihua/checkpoints/causal-motifs-sgdet OUTPUT_DIR /home/kaihua/checkpoints/causal-motifs-sgdet DATASETS.TO_TEST train
```

```bash
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port 10027 --nproc_per_node=1 tools/relation_test_net.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX False MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.ROI_RELATION_HEAD.PREDICTOR CausalAnalysisPredictor MODEL.ROI_RELATION_HEAD.CAUSAL.EFFECT_TYPE TDE MODEL.ROI_RELATION_HEAD.CAUSAL.FUSION_TYPE sum MODEL.ROI_RELATION_HEAD.CAUSAL.CONTEXT_LAYER motifs TEST.IMS_PER_BATCH 1 DTYPE "float16" GLOVE_DIR /home/kaihua/glove MODEL.PRETRAINED_DETECTOR_CKPT /home/kaihua/checkpoints/causal-motifs-sgdet OUTPUT_DIR /home/kaihua/checkpoints/causal-motifs-sgdet DATASETS.TO_TEST test
```

It will create under `/home/kaihua/checkpoints/causal-motifs-sgdet/inference/` the directories `VG_stanford_filtered_with_attribute_train` and `VG_stanford_filtered_with_attribute_test` with saved results.

Now, run the ```maskrcnn_benchmark/image_retrieval/preprocessing.py --test-results-path your-result-path --output-file-name outfile.json``` for both training and testing results previously produced.

You should be obtaining two files:

`/home/kaihua/checkpoints/causal-motifs-sgdet/inference/VG_stanford_filtered_with_attribute_train/sg_of_causal_sgdet_ctx_only.json`

and

`/home/kaihua/checkpoints/causal-motifs-sgdet/inference/VG_stanford_filtered_with_attribute_test/sg_of_causal_sgdet_ctx_only.json`

## Training and Evaluation

You need to manually set ```sg_train_path```, ```sg_val_path``` and ```sg_test_path``` in ```tools/image_retrieval_main.py``` to `/home/kaihua/checkpoints/causal-motifs-sgdet/inference/VG_stanford_filtered_with_attribute_train/sg_of_causal_sgdet_ctx_only.json`
, `/home/kaihua/checkpoints/causal-motifs-sgdet/inference/VG_stanford_filtered_with_attribute_val/sg_of_causal_sgdet_ctx_only.json`
and

`/home/kaihua/checkpoints/causal-motifs-sgdet/inference/VG_stanford_filtered_with_attribute_test/sg_of_causal_sgdet_ctx_only.json` respectively.


If you use your own pretrained model: keep in mind that you need to evaluate your model on ** training, validation and testing set ** to get the generated crude scene graphs. Our evaluation code will automatically saves the crude SGGs into ```checkpoints/MODEL_NAME/inference/VG_stanford_filtered_with_attribute_test/``` or ```checkpoints/MODEL_NAME/inference/VG_stanford_filtered_with_attribute_train/```
or ```checkpoints/MODEL_NAME/inference/VG_stanford_filtered_with_attribute_val/```



Run the ```tools/image_retrieval_main.py``` for both training and evaluation.

To load the generated scene graphs of the given SGG checkpoints, you need to manually set ```sg_train_path``` and ```sg_test_path``` in ```tools/image_retrieval_main.py```, which means you need to evaluate your model on **both training and testing set** to get the generated crude scene graphs. Our evaluation code will automatically saves the crude SGGs into ```checkpoints/MODEL_NAME/inference/VG_stanford_filtered_wth_attribute_test/``` or ```checkpoints/MODEL_NAME/inference/VG_stanford_filtered_wth_attribute_train/```, which will be further processed to generate the input of ```sg_train_path``` and ```sg_test_path``` by our preprocessing code ```maskrcnn_benchmark/image_retrieval/preprocessing.py```.
For example, you can train it with:

```tools/image_retrieval_main.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" SOLVER.IMS_PER_BATCH 32 SOLVER.PRE_VAL True SOLVER.SCHEDULE.TYPE WarmupMultiStepLR SOLVER.MAX_ITER 18 SOLVER.CHECKPOINT_PERIOD 3 OUTPUT_DIR /media/rafi/Samsung_T5/_DATASETS/vg/model/ SOLVER.VAL_PERIOD 3```

You call also run an evaluation on any set (parameter `DATASETS.TO_TEST`) with:

```tools/image_retrieval_test.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" SOLVER.IMS_PER_BATCH 32 MODEL.PRETRAINED_DETECTOR_CKPT /media/rafi/Samsung_T5/_DATASETS/vg/model/[your_model_name].pytorch OUTPUT_DIR /media/rafi/Samsung_T5/_DATASETS/vg/model/results DATASETS.TO_TEST test```

Please note that the calculation logic differs from the one used in ```tools/image_retrieval_main.py```.
Details of the calculation can be found under ```Test Cases Metrics.pdf```, under the Type Fei Fei.


## Results

Expand Down
Binary file not shown.
12 changes: 10 additions & 2 deletions maskrcnn_benchmark/image_retrieval/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tqdm import tqdm

from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.config.paths_catalog import DatasetCatalog
from maskrcnn_benchmark.data import make_data_loader
from maskrcnn_benchmark.solver import make_lr_scheduler
from maskrcnn_benchmark.solver import make_optimizer
Expand All @@ -30,13 +31,20 @@
from maskrcnn_benchmark.utils.logger import setup_logger, debug_print
from maskrcnn_benchmark.utils.miscellaneous import mkdir, save_config
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
import os

class SGEncoding(data.Dataset):
""" SGEncoding dataset """
def __init__(self, train_ids, test_ids, sg_data, test_on=False, val_on=False, num_test=5000, num_val=5000):
super(SGEncoding, self).__init__()
cap_graph = json.load(open('/data1/vg_capgraphs_anno.json'))
vg_dict = json.load(open('/home/kaihua/projects/maskrcnn-benchmark/datasets/vg/VG-SGG-dicts-with-attri.json'))

data_dir = DatasetCatalog.DATA_DIR
attrs = DatasetCatalog.DATASETS["VG_stanford_filtered_with_attribute"]
cap_graph_file = os.path.join(data_dir, attrs["capgraphs_file"])
vg_dict_file = os.path.join(data_dir, attrs["dict_file"])

cap_graph = json.load(open(cap_graph_file))
vg_dict = json.load(open(vg_dict_file))
self.img_txt_sg = sg_data
self.key_list = list(self.img_txt_sg.keys())
self.key_list.sort()
Expand Down
Loading

0 comments on commit 634a6e2

Please sign in to comment.