Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

EfficientDet pretrained arch transfer learning #615

Closed
potipot opened this issue Jan 22, 2021 · 22 comments 路 Fixed by #630
Closed

EfficientDet pretrained arch transfer learning #615

potipot opened this issue Jan 22, 2021 · 22 comments 路 Fixed by #630
Labels
documentation Improvements or additions to documentation example request good first issue Good for newcomers help wanted Extra attention is needed

Comments

@potipot
Copy link
Contributor

potipot commented Jan 22, 2021

馃摀 New <Tutorial/Example>

Request for an example

What is the task?
Object detection using transfer learning for the whole architecture. Are there some defined methods to load fastai model and change its head to a different number of classes, similar to this?

I was able to run the Faster-RCNN example using this example trained on COCO dataset and evaluate its mAP.

The EfficientDet workflow seems not to be yet ready. Has there been some update on that?

I was able to create EfficientDet with pretrained encoder and train it myself on COCO. I'm now trying to do transfer learning for a different number of classes. Loading model through fastai, expectedly, throws an error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-22-cf184f01fec3> in <module>
----> 1 learn.load('coco_local')

~/anaconda3/envs/icevision/lib/python3.8/site-packages/fastai/learner.py in load(self, file, with_opt, device, **kwargs)
    293         if self.opt is None: self.create_opt()
    294         file = join_path_file(file, self.path/self.model_dir, ext='.pth')
--> 295         load_model(file, self.model, self.opt, device=device, **kwargs)
    296         return self
    297 

~/anaconda3/envs/icevision/lib/python3.8/site-packages/fastai/learner.py in load_model(file, model, opt, with_opt, device, strict)
     47     hasopt = set(state)=={'model', 'opt'}
     48     model_state = state['model'] if hasopt else state
---> 49     get_model(model).load_state_dict(model_state, strict=strict)
     50     if hasopt and with_opt:
     51         try: opt.load_state_dict(state['opt'])

~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1049 
   1050         if len(error_msgs) > 0:
-> 1051             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1052                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1053         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for DetBenchTrain:
	size mismatch for model.class_net.predict.conv_pw.weight: copying a param with shape torch.Size([819, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([99, 64, 1, 1]).
	size mismatch for model.class_net.predict.conv_pw.bias: copying a param with shape torch.Size([819]) from checkpoint, the shape in current model is torch.Size([99]).

Is this example for a specific model?
EfficientDet

Is this example for a specific dataset?
COCO transfer learning


Don't remove
Main issue for examples: #39

@potipot potipot added documentation Improvements or additions to documentation example request good first issue Good for newcomers help wanted Extra attention is needed labels Jan 22, 2021
@lgvaz
Copy link
Collaborator

lgvaz commented Jan 22, 2021

The EfficientDet workflow seems not to be yet ready.

Outdated comment 馃槄 It's going to be removed in upcoming commits

Actually, when you call efficientdet.model it already loads pretrained weights from COCO. This is what we grab by default, you can take a look there for all the sizes with pretrained weights.

@lgvaz
Copy link
Collaborator

lgvaz commented Jan 22, 2021

If you want to load your own weights and change the head for fine-tuning, take a look here for insights on how to do it.

The key concept there is model.reset_head

@potipot
Copy link
Contributor Author

potipot commented Jan 22, 2021

Thanks for the update!

For now I wrote this workaround function to load pretrained weights with matching parameter names and shapes:

@patch
def load_matching(self:Learner, model_name:str):
    this_model = self.model.state_dict()
    trained_model = torch.load(f'models/{model_name}.pth')['model']
    
    for (this_module, this_param), (loaded_module, loaded_param) in zip(this_model.items(), trained_model.items()):
        assert(this_module==loaded_module), f'Models differ: {this_module}, {loaded_module}'
        if this_param.shape==loaded_param.shape: 
            this_model[this_module]=loaded_param
        else:
            print(f'Weights not loaded: {this_module}: {this_param.shape=}, {loaded_param.shape=}')
    
    return self.model.load_state_dict(this_model)

@potipot
Copy link
Contributor Author

potipot commented Jan 25, 2021

The EfficientDet workflow seems not to be yet ready.

Outdated comment sweat_smile It's going to be removed in upcoming commits

Actually, when you call efficientdet.model it already loads pretrained weights from COCO. This is what we grab by default, you can take a look there for all the sizes with pretrained weights.

any ETA on those commits? I was trying to reproduce the mAP results from the source you provided, but the problem is that apart from loading weights, the head gets modified and parameter values are not verbatim

class_net.predict.conv_pw.weight model_loaded.size=torch.Size([810, 64, 1, 1]) model_current.size=torch.Size([819, 64, 1, 1])
class_net.predict.conv_pw.bias model_loaded.size=torch.Size([810]) model_current.size=torch.Size([819])

I guess this is because we add background to the ClassMap.

UPDATE 25.01.2021

Following your advice I forced the number of classes to be 90 and I can observe an improvement in COCOMetric (was nearly all zeros before):
model: tf_efficientdet_lite0

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.122
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.227
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.113
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.015
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.203
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.279
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.177
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.258
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.275
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.114
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.338
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.409

This however is still not close to the reported value of 33.6 (AP) @[ IoU=0.50:0.95 | area= all

I've made sure that I'm using the same size and config parameters as in the original (padding color, normalization).

One thing that I'm still trying to figure out is the bbox configuration. In the tfms.A.Adapter uses hardcoded pascal voc bbox orientation of xyxy while comment in this thread suggest the models were pretrained using yxyx.

I will try to change that parameter and re-run the validations tomorrow. (I'm doing both fastai and pl validations, same results)

Or perhaps you have some other suggestion on how to obtain the same mAP in validation?

@lgvaz
Copy link
Collaborator

lgvaz commented Jan 26, 2021

this thread suggest the models were pretrained using yxyx.

Model specific formatting happens inside the dataloader here (this is how we can be agnostic to any model implementation)

I've made sure that I'm using the same size and config parameters as in the original (padding color, normalization).

This is all being done on COCO right? can you share the code? I can take a look and we can figure out what is lacking to achieve the reported results

Actually, let us make this clear: Are you training from scratch on coco or loading the model with the pre-trained weights and checking the mAP?

@lgvaz
Copy link
Collaborator

lgvaz commented Jan 26, 2021

any ETA on those commits?

I meant a commit only to remove that comment 馃槄 , did you have something else in mind?

@potipot
Copy link
Contributor Author

potipot commented Jan 26, 2021

Thanks for pointing to the place in dataloader.
I'm trying to reproduce the results of a trained model to make sure that I'm using the transfer learning correctly.
All done on COCO Dataset downloaded from the official link, 90 classes, coco.parser

One thing that surprised me: I changed the order in the build_train_batch function (also used by build_valid_batch) from yxyx to xyxy and re-run the evaluation metrics and results are the same! The only difference is the loss function increase in the changed example.

#default - yxyx
learn.validate()

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.122
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.227
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.113
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.015
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.203
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.279
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.177
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.258
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.275
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.114
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.338
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.409

(#2) [0.6418642997741699,0.12209737428903714] # ValLoss, COCOMetric
# modified - xyxy
learn.validate()

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.122
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.227
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.113
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.015
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.203
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.279
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.177
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.258
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.275
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.114
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.338
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.409

(#2) [2.570807933807373,0.12209737428903714] # ValLoss, COCOMetric

Maybe its a problem with the COCOMetric again?
Another hypothesis is difference in the ordering of classes. For example passing class_map to

batch, samples = first(train_dl)
show_samples(
    samples[:6], class_map=class_map, ncols=3, denormalize_fn=denormalize_imagenet
)

throws and error IndexError: list index out of range.
NOTE: Coco uses 1..90 numbering while in class map the default is 0..89?

My search continues..

Code:

from icevision.all import *
import icedata

from fastai.vision.all import *
from fastai.callback.wandb import *

import wandb

from imports import *

path: Param("Training dataset path", str) = Path.home()/'Datasets/image/coco/'
bs: Param("Batch size", int) = 8
log: Param("Log to wandb", bool) = False
num_workers: Param("Number of workers to use", int) = 4
resume: Param("Link to pretrained model", str) = None
name: Param('experiment name', str) = 'coco'

class_map = icedata.coco.class_map(background=None)
path = Path(path)
coco_train = icedata.coco.parser(
    img_dir=path / 'train2017',
    annotations_file=path/'annotations/instances_train2017.json',
    mask=False)

coco_valid = icedata.coco.parser(
    img_dir=path / 'val2017',
    annotations_file=path/'annotations/instances_val2017.json',
    mask=False)

train_records, *_ = coco_train.parse(data_splitter=SingleSplitSplitter(), cache_filepath=path/'train_cache')
valid_records, *_ = coco_valid.parse(data_splitter=SingleSplitSplitter(), cache_filepath=path/'valid_cache')
show_record(train_records[1], display_label=True)

size = 512
aug_tfms = tfms.A.aug_tfms(
    size=size,
    shift_scale_rotate=tfms.A.ShiftScaleRotate(rotate_limit=(-15, 15)),
    pad=partial(tfms.A.PadIfNeeded, border_mode=0)
)
aug_tfms.append(tfms.A.Normalize())

train_tfms = tfms.A.Adapter(aug_tfms)
valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(size), tfms.A.Normalize()])

train_ds = Dataset(train_records, train_tfms)
valid_ds = Dataset(valid_records, valid_tfms)

metrics = [COCOMetric(print_summary=True)]

train_dl = efficientdet.train_dl(train_ds, batch_size=bs, num_workers=num_workers, shuffle=True)
valid_dl = efficientdet.valid_dl(valid_ds, batch_size=bs, num_workers=num_workers, shuffle=False)

model = efficientdet.model(model_name="tf_efficientdet_lite0", num_classes=len(class_map), img_size=size)

learn = efficientdet.fastai.learner(dls=[train_dl, valid_dl], model=model, metrics=metrics)
#default - yxyx
learn.validate()

@potipot
Copy link
Contributor Author

potipot commented Jan 27, 2021

I'm thinking this is something specific to the Effdet + COCO dataset, cause I was able to train the Effdet model using icevision workflow and obtain good mAP results.

I calculated mAP from preds and records accumulated in the metric (before conversion to COCO API) using another library and got nearly the same results, here is how I did it:

def raw_to_odm():
    groundtruth_bbs = []
    detected_bbs = []
    for pred, record in zip(preds, records):
        image_name = record['filepath'].name
        for bbox, label in zip(record['bboxes'], record['labels']):
            class_id=label
            coordinates = bbox.xywh
            bb = BoundingBox(image_name=image_name, class_id=class_id, coordinates=coordinates)
            groundtruth_bbs.append(bb)

        for score, bbox, label in zip(pred['scores'], pred['bboxes'], pred['labels']):
            class_id=label
            coordinates = bbox.xywh
            bb = BoundingBox(image_name=image_name, class_id=class_id, coordinates=coordinates, 
                             bb_type=BBType.DETECTED, confidence=score)
            detected_bbs.append(bb)
    return groundtruth_bbs, detected_bbs

get_coco_summary(groundtruth_bbs, detected_bbs)
{'AP': 0.11017058248779392,
 'AP50': 0.2117803965667843,
 'AP75': 0.09722278750266831,
 'APsmall': 0.005122517036826111,
 'APmedium': 0.1955153436184948,
 'APlarge': 0.26638203179285247,
 'AR1': 0.15844023950095418,
 'AR10': 0.22851812631946697,
 'AR100': 0.2428790814116496,
 'ARsmall': 0.07777300144455561,
 'ARmedium': 0.33974585757257,
 'ARlarge': 0.37168434508668985}

I'm still investigating the Effdet repo to see what kind of tweaking they do to achieve these mAP results. Any help is welcome!

@potipot
Copy link
Contributor Author

potipot commented Jan 27, 2021

Bingo! I figured it out.

When passing target images to the EfficientDet model, icevision preserves the original image sizes of targets and those values are processed by the effdet itself (no padding remembered). I forced those values to be 512 (as in the padding I use) and the metric show correct results now:

Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.311
Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.489
Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.328
Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.095
Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.356
Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.503
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.270
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.418
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.442
Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.177
Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.523
Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.659

This also explains why I was able to train on my own Dataset with aspect_ratio 1 and no resize or padding used in the tfms.

    def forward(self, *args, **kwargs):
        args[1]['img_size'] = torch.full_like(args[1]['img_size'], 512.0)
        return self.model(*args, **kwargs)

I will try to figure out a PR on how to fix this the right way.

@lgvaz
Copy link
Collaborator

lgvaz commented Jan 27, 2021

Amazing work @potipot, I was going to investigate this right now but you already solved it! hahahah

So, if I understood correctly, the error is on this line? Here we are setting the image size without padding, but instead we need to set the image size with padding? If so we just need to take the image size from the images tensor images.shape[2:]

@lgvaz
Copy link
Collaborator

lgvaz commented Jan 27, 2021

Another question, how was your experience using this other library for metrics? Is it a good replacement for pycocotools?

@potipot
Copy link
Contributor Author

potipot commented Jan 27, 2021

So, if I understood correctly, the error is on this line? Here we are setting the image size without padding, but instead we need to set the image size with padding? If so we just need to take the image size from the images tensor images.shape[2:]

I think this would be the place to insert it and rely directly on the input tensor; however, what they do at EffDet is they use some box scaling and resizing, to scale it up to the match with the original target image shape and only then, calculate the metric.
Note the difference in parameters passed to _batch_detection in train and validation

I'm not sure how this improves the results they get on COCO. I guess it can have impact on the box size and whether it is classified as small or large but otherwise?

@potipot
Copy link
Contributor Author

potipot commented Jan 27, 2021

Another question, how was your experience using this other library for metrics? Is it a good replacement for pycocotools?

It was way easier to implement but I think slightly slower. The API is self explanatory with single call to

get_coco_summary(groundtruth_bbs: list[BoundingBox], detected_bbs: list[BoundingBox])

I could take a look later on how to speed it up. I expect there is room for improvement.

@lgvaz
Copy link
Collaborator

lgvaz commented Jan 27, 2021

I could take a look later on how to speed it up. I expect there is room for improvement.

Do you think there is value in us trying to implement these metrics ourselves?

@lgvaz
Copy link
Collaborator

lgvaz commented Jan 27, 2021

Note the difference in parameters passed to _batch_detection in train and validation

I still have to look deeper into this, but here is a conversation I had with ross that can be helpful.

img_scale is used to move coorodinates between what I think of as the 'model canvas', the img_size * img_size input image size of the model. Umages are scaled down maintaining aspect to fit in that square, located at the origin, upper left corner, the rest is padded if the original image aspect is not a square. The img_scale stores ratio needed to move the output coordinates of the model back to the original image coordinate space for coco evaluator. you can just set img_scale to 1 to not use it. and the image size values used to crop bbox to (img_size, img_size) if you want to handle of the image sizing, scaling, evaluation yourself.

@lgvaz
Copy link
Collaborator

lgvaz commented Jan 27, 2021

Btw @potipot, have you joined our forum? If not, consider joining, we have lots of interesting discussions happening there =)

https://discord.gg/JDBeZYK

@lgvaz
Copy link
Collaborator

lgvaz commented Jan 28, 2021

Note the difference in parameters passed to _batch_detection in train and validation

So, we want to handle re-scaling ourselves, we have the option to set img_size and img_scale as None, from what I'm seeing this will then not rescale the bboxes [1] [2]

Before we continue, can you try that and see if you still get the correct results?

I'm not sure how this improves the results they get on COCO. I guess it can have impact on the box size and whether it is classified as small or large but otherwise?

Yeaaah, I'm still not quite sure as well. What we have to be careful about is that when we call CocoMetric.accumulate records and preds should have the same img_sizes and bboxes scaled accordingly. When effdet is scaling the predictions internally it might be messing this up

@potipot
Copy link
Contributor Author

potipot commented Jan 28, 2021

Do you think there is value in us trying to implement these metrics ourselves?

I think yes, cause current conversion to COCO api is quite robust, I had trouble understanding what was going on. Maybe we could try more OO approach? The code would be much cleaner if we used the library I linked before.

@lgvaz
Copy link
Collaborator

lgvaz commented Jan 28, 2021

Maybe we could try more OO approach?

For sure! The code for the conversion was implemented in a hush, a better way to implement this functionality is to use the visitor pattern in the RecordMixins.

I opened this issue to keep track of that


Btw, the same strategy can be used for the transforms, currently tfms.A.Adapter is also quite a mess

@lgvaz
Copy link
Collaborator

lgvaz commented Jan 28, 2021

The code would be much cleaner if we used the library I linked before.

The only disadvantage is that it's a bit slower right? Probably because it's fully implement in python while pycocotools uses C.

@potipot
Copy link
Contributor Author

potipot commented Jan 29, 2021

COCOMetric gives correct results for validation with this change:

def build_valid_batch(records, batch_tfms=None):
    (images, targets), records = build_train_batch(
        records=records, batch_tfms=batch_tfms
    )

-   img_sizes = [(r["height"], r["width"]) for r in records]
-   targets["img_size"] = tensor(img_sizes, dtype=torch.float)
+   # passing the size of transformed image to efficientdet, necessary for its own scaling and resizing, see
+   # https://github.com/rwightman/efficientdet-pytorch/blob/645d84a6f0cd837703f98f48179a06c354902515/effdet/bench.py#L100
+   targets["img_size"] = tensor([image.shape[-2:] for image in images], dtype=torch.float)
    targets["img_scale"] = tensor([1] * len(records), dtype=torch.float)

    return (images, targets), records

@potipot
Copy link
Contributor Author

potipot commented Feb 2, 2021

Effdet inference will be resolved by #630

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation example request good first issue Good for newcomers help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants