Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add script for finding optimal anchor shapes #59

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

mihaimartalogu
Copy link

Hi @BichenWuUCB,
Many thanks for publishing the sources for your model.

I'm working on fitting it to my own dataset (more smartphone-camera shaped), and one of the things I needed to do was to adapt the shapes of the anchors.

Here I'm sharing the script I used for finding the optimal anchor sizes.

Note however that for the KITTI dataset I don't get at all the same results as the ones in the repository:
In config/kitti_res50_config.py:

         [[  94.,  49.], [ 225., 161.], [ 170.,  91.],
           [ 390., 181.], [  41.,  32.], [ 128.,  64.],
           [ 298., 164.], [ 232.,  99.], [  65.,  42.]])]

In config/kitti_squeezeDet_config.py:

          [[  36.,  37.], [ 366., 174.], [ 115.,  59.],
           [ 162.,  87.], [  38.,  90.], [ 258., 173.],
           [ 224., 108.], [  78., 170.], [  72.,  43.]])]

What I get instead is:

$ python scripts/kmeans_anchors.py --geometry 1248x384 --kmeans-max-iter 1000000
...
[[70.45, 41.96], [390.36, 165.54], [125.10, 66.58],
[98.62, 186.29], [29.57, 26.18], [43.72, 94.05],
[356.65, 339.55], [269.63, 170.46], [198.84, 101.97]]

screen shot 2017-06-16 at 17 52 44

Note for example that the "wide and short" sizes don't get an anchor, but maybe it's fair... Have I missed something? Could you have a look, and consider merging it if the implementation is correct? (I'm new to the field of machine learning, so not 100% confident)

Cheers,
Mihai

Copy link

@ilystsov ilystsov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added docstring.
Removed the comment about Python 2 since it's outdated and not relevant in modern Python.
Replaced nonlocals dictionary with the more modern nonlocal keyword available in Python 3.
mplified the method by removing unnecessary checks and updating tqdm progress bars.
ed os.path.join for creating paths.

Comment on lines +50 to +81
def get_dataset_metadata(dataset_root, input_w, input_h, max_jobs):
"""
Load all dataset metadata into memory. You might need to adapt this if your dataset is really huge.
"""
nonlocals = { # Python 2 doesn't support nonlocal, using a mutable dict() instead
'entries_done': 0,
'metadata': dict(),
'entries_done_pbar': None
}
with open(os.path.join(dataset_root, 'ImageSets', 'trainval.txt')) as f:
dataset_entries = f.read().splitlines()
with concurrent.futures.ProcessPoolExecutor(max_workers=max_jobs) as pool:

for entry in tqdm(dataset_entries, desc='Scheduling jobs'):
if nonlocals['entries_done_pbar'] is None:
# instantiating here so that it appears after the 'Scheduling jobs' one
nonlocals['entries_done_pbar'] = tqdm(total=len(dataset_entries), desc='Retrieving metadata')

def entry_done(future):
""" Record progress """
nonlocals['entries_done'] += 1
nonlocals['entries_done_pbar'].update(1)
fr = future.result()
if fr is not None:
local_entry, value = fr # do NOT use the entry variable from the scope!
nonlocals['metadata'][local_entry] = value

future = pool.submit(get_entry_metadata, dataset_root, entry, input_w, input_h)
future.add_done_callback(entry_done) # FIXME: doesn't work if chained directly to submit(). bug in futures? reproduce and submit report.
nonlocals['entries_done_pbar'].close()
assert len(nonlocals['metadata'].values()) >= 0.9 * len(dataset_entries) # catch if entry_done doesn't update the dict correctly
return nonlocals['metadata']

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_dataset_metadata(dataset_root, input_w, input_h, max_jobs):
"""
Load all dataset metadata into memory. You might need to adapt this if your dataset is really huge.
"""
nonlocals = { # Python 2 doesn't support nonlocal, using a mutable dict() instead
'entries_done': 0,
'metadata': dict(),
'entries_done_pbar': None
}
with open(os.path.join(dataset_root, 'ImageSets', 'trainval.txt')) as f:
dataset_entries = f.read().splitlines()
with concurrent.futures.ProcessPoolExecutor(max_workers=max_jobs) as pool:
for entry in tqdm(dataset_entries, desc='Scheduling jobs'):
if nonlocals['entries_done_pbar'] is None:
# instantiating here so that it appears after the 'Scheduling jobs' one
nonlocals['entries_done_pbar'] = tqdm(total=len(dataset_entries), desc='Retrieving metadata')
def entry_done(future):
""" Record progress """
nonlocals['entries_done'] += 1
nonlocals['entries_done_pbar'].update(1)
fr = future.result()
if fr is not None:
local_entry, value = fr # do NOT use the entry variable from the scope!
nonlocals['metadata'][local_entry] = value
future = pool.submit(get_entry_metadata, dataset_root, entry, input_w, input_h)
future.add_done_callback(entry_done) # FIXME: doesn't work if chained directly to submit(). bug in futures? reproduce and submit report.
nonlocals['entries_done_pbar'].close()
assert len(nonlocals['metadata'].values()) >= 0.9 * len(dataset_entries) # catch if entry_done doesn't update the dict correctly
return nonlocals['metadata']
def get_dataset_metadata(dataset_root, input_w, input_h, max_jobs):
"""
Load all dataset metadata into memory.
Args:
- dataset_root (str): The root directory of the dataset.
- input_w (int): Width of the input.
- input_h (int): Height of the input.
- max_jobs (int): Maximum number of concurrent processes to use.
Returns:
- dict: Metadata for the dataset.
"""
entries_done = 0
metadata = {}
with open(os.path.join(dataset_root, 'ImageSets', 'trainval.txt')) as f:
dataset_entries = f.read().splitlines()
with tqdm(total=len(dataset_entries), desc='Retrieving metadata') as entries_done_pbar:
with concurrent.futures.ProcessPoolExecutor(max_workers=max_jobs) as pool:
futures = [pool.submit(get_entry_metadata, dataset_root, entry, input_w, input_h) for entry in dataset_entries]
for future in concurrent.futures.as_completed(futures):
fr = future.result()
if fr is not None:
local_entry, value = fr
metadata[local_entry] = value
entries_done += 1
entries_done_pbar.update(1)
assert len(metadata.values()) >= 0.9 * len(dataset_entries), "Entry_done didn't update the dict correctly."
return metadata

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants