Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor effdet dataloaders wrapper and fix the image sizes passed to effdet #630

Merged
merged 11 commits into from
Feb 10, 2021
102 changes: 62 additions & 40 deletions icevision/models/ross/efficientdet/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def train_dl(dataset, batch_tfms=None, **dataloader_kwargs) -> DataLoader:
dataset=dataset,
build_batch=build_train_batch,
batch_tfms=batch_tfms,
**dataloader_kwargs
**dataloader_kwargs,
)


Expand All @@ -47,7 +47,7 @@ def valid_dl(dataset, batch_tfms=None, **dataloader_kwargs) -> DataLoader:
dataset=dataset,
build_batch=build_valid_batch,
batch_tfms=batch_tfms,
**dataloader_kwargs
**dataloader_kwargs,
)


Expand All @@ -67,7 +67,7 @@ def infer_dl(dataset, batch_tfms=None, **dataloader_kwargs) -> DataLoader:
dataset=dataset,
build_batch=build_infer_batch,
batch_tfms=batch_tfms,
**dataloader_kwargs
**dataloader_kwargs,
)


Expand All @@ -92,26 +92,22 @@ def build_train_batch(records, batch_tfms=None):
```
"""
records = common_build_batch(records, batch_tfms=batch_tfms)
batch_images, batch_bboxes, batch_classes = zip(
*(process_train_record(record) for record in records)
)

images = []
targets = {"bbox": [], "cls": []}
for record in records:
image = im2tensor(record["img"])
images.append(image)

if len(record["labels"]) == 0:
targets["cls"].append(tensor([0], dtype=torch.float))
targets["bbox"].append(tensor([[0, 0, 0, 0]], dtype=torch.float))
else:
labels = tensor(record["labels"], dtype=torch.float)
targets["cls"].append(labels)

bboxes = tensor([bbox.yxyx for bbox in record["bboxes"]], dtype=torch.float)
targets["bbox"].append(bboxes)
# convert to tensors
batch_images = torch.stack(batch_images)
batch_bboxes = [tensor(bboxes, dtype=torch.float32) for bboxes in batch_bboxes]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit more verbose, but I like the fact that we're passing dtype explicitly here, specially because as you pointed out effdet complains if the dtype is not what it expects

batch_classes = [tensor(classes, dtype=torch.float32) for classes in batch_classes]

images = torch.stack(images)
# convert to EffDet interface
targets = dict(
bbox=batch_bboxes,
cls=batch_classes,
potipot marked this conversation as resolved.
Show resolved Hide resolved
potipot marked this conversation as resolved.
Show resolved Hide resolved
)

return (images, targets), records
return (batch_images, targets), records


def build_valid_batch(records, batch_tfms=None):
Expand All @@ -122,7 +118,7 @@ def build_valid_batch(records, batch_tfms=None):
batch_tfms: Transforms to be applied at the batch level.

# Returns
A tuple with two items. The first will be a tuple like `(images, targets)`,
A tuple with two items. The first will be a tuple like `(batch_images, targets)`,
in the input format required by the model. The second will be an updated list
of the input records with `batch_tfms` applied.

Expand All @@ -134,19 +130,19 @@ def build_valid_batch(records, batch_tfms=None):
outs = model(*batch)
```
"""
(images, targets), records = build_train_batch(
records=records, batch_tfms=batch_tfms
)

img_sizes = [(r["height"], r["width"]) for r in records]
targets["img_size"] = tensor(img_sizes, dtype=torch.float)
(batch_images, targets), records = build_train_batch(records, batch_tfms)

targets["img_scale"] = tensor([1] * len(records), dtype=torch.float)
# convert to EffDet interface, when not training, dummy size and scale is required
targets = dict(
img_size=None,
img_scale=None,
**targets,
potipot marked this conversation as resolved.
Show resolved Hide resolved
)

return (images, targets), records
return (batch_images, targets), records


def build_infer_batch(dataset, batch_tfms=None):
def build_infer_batch(records, batch_tfms=None):
"""Builds a batch in the format required by the model when doing inference.

# Arguments
Expand All @@ -163,16 +159,42 @@ def build_infer_batch(dataset, batch_tfms=None):
outs = model(*batch)
```
"""
samples = common_build_batch(dataset, batch_tfms=batch_tfms)
records = common_build_batch(records, batch_tfms=batch_tfms)
batch_images, batch_sizes, batch_scales = zip(
*(process_infer_record(record) for record in records)
)

# convert to tensors
batch_images = torch.stack(batch_images)
batch_sizes = tensor(batch_sizes, dtype=torch.float32)
batch_scales = tensor(batch_scales, dtype=torch.float32)

# convert to EffDet interface
targets = dict(
img_size=batch_sizes,
img_scale=batch_scales,
potipot marked this conversation as resolved.
Show resolved Hide resolved
)

return (batch_images, targets), records


def process_train_record(record) -> tuple:
"""Extracts information from record and prepares a format required by the EffDet training"""
image = im2tensor(record["img"])
# background and dummy if no label in record
classes = record["labels"] if record["labels"] else [0]
bboxes = (
[bbox.yxyx for bbox in record["bboxes"]]
if len(record["labels"]) > 0
else [[0, 0, 0, 0]]
)
return image, bboxes, classes

tensor_imgs, img_sizes = [], []
for record in samples:
tensor_imgs.append(im2tensor(record["img"]))
img_sizes.append((record["height"], record["width"]))

tensor_imgs = torch.stack(tensor_imgs)
tensor_sizes = tensor(img_sizes, dtype=torch.float)
tensor_scales = tensor([1] * len(samples), dtype=torch.float)
img_info = {"img_size": tensor_sizes, "img_scale": tensor_scales}
def process_infer_record(record) -> tuple:
"""Extracts information from record and prepares a format required by the EffDet inference"""
image = im2tensor(record["img"])
image_size = image.shape[-2:]
image_scale = 1.0

return (tensor_imgs, img_info), samples
return image, image_size, image_scale
9 changes: 4 additions & 5 deletions tests/models/efficient_det/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,8 @@ def _test_batch_valid(images, targets):

assert set(targets.keys()) == {"cls", "bbox", "img_size", "img_scale"}

assert targets["img_scale"].dtype == torch.float
assert torch.all(targets["img_scale"] == tensor([1, 1]))

assert targets["img_size"].dtype == torch.float
assert torch.all(targets["img_size"] == tensor([[4, 4], [4, 4]]))
assert targets["img_scale"] is None
assert targets["img_size"] is None


def test_efficient_det_build_train_batch(records):
Expand Down Expand Up @@ -99,6 +96,8 @@ def test_efficient_det_build_infer_batch(img, batch_tfms):
img_info = {"img_size": img_sizes, "img_scale": img_scales}

batch_imgs, batch_info = batch

assert set(batch_info.keys()) == {"img_size", "img_scale"}
assert torch.equal(batch_imgs, tensor_img)
assert torch.equal(batch_info["img_size"], img_info["img_size"])
assert torch.equal(batch_info["img_scale"], img_info["img_scale"])