-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
7 changed files
with
239 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Freeze module | ||
|
||
This page guide users to freeze module in YOLOX. | ||
Exp controls everything in YOLOX, so let's start from creating an Exp object. | ||
|
||
## 1. Create your own expermiment object | ||
|
||
We take an example of YOLOX-S model on COCO dataset to give a more clear guide. | ||
|
||
Import the config you want (or write your own Exp object inherit from `yolox.exp.BaseExp`). | ||
```python | ||
from yolox.exp.default.yolox_s import Exp as MyExp | ||
``` | ||
|
||
## 2. Override `get_model` method | ||
|
||
Here is a simple code to freeze backbone (FPN not included) of module. | ||
```python | ||
class Exp(MyExp): | ||
|
||
def get_model(self): | ||
from yolox.utils import freeze_module | ||
model = super().get_model() | ||
freeze_module(model.backbone.backbone) | ||
return model | ||
``` | ||
if you only want to freeze FPN, `freeze_module(model.backbone)` might help. | ||
|
||
## 3. Train | ||
Suppose that the path of your Exp is `/path/to/my_exp.py`, use the following command to train your model. | ||
```bash | ||
python3 -m yolox.tools.train -f /path/to/my_exp.py | ||
``` | ||
For more details of training, run the following command. | ||
```bash | ||
python3 -m yolox.tools.train --help | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding:utf-8 -*- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding:utf-8 -*- | ||
# Copyright (c) Megvii, Inc. and its affiliates. | ||
|
||
import unittest | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from yolox.utils import adjust_status, freeze_module | ||
from yolox.exp import get_exp | ||
|
||
|
||
class TestModelUtils(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.model: nn.Module = get_exp(exp_name="yolox-s").get_model() | ||
|
||
def test_model_state_adjust_status(self): | ||
data = torch.ones(1, 10, 10, 10) | ||
# use bn since bn changes state during train/val | ||
model = nn.BatchNorm2d(10) | ||
prev_state = model.state_dict() | ||
|
||
modes = [False, True] | ||
results = [True, False] | ||
|
||
# test under train/eval mode | ||
for mode, result in zip(modes, results): | ||
with adjust_status(model, training=mode): | ||
model(data) | ||
model_state = model.state_dict() | ||
self.assertTrue(len(model_state) == len(prev_state)) | ||
self.assertEqual( | ||
result, | ||
all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()]) | ||
) | ||
|
||
# test recurrsive context case | ||
prev_state = model.state_dict() | ||
with adjust_status(model, training=False): | ||
with adjust_status(model, training=False): | ||
model(data) | ||
model_state = model.state_dict() | ||
self.assertTrue(len(model_state) == len(prev_state)) | ||
self.assertTrue( | ||
all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()]) | ||
) | ||
|
||
def test_model_effect_adjust_status(self): | ||
# test context effect | ||
self.model.train() | ||
with adjust_status(self.model, training=False): | ||
for module in self.model.modules(): | ||
self.assertFalse(module.training) | ||
# all training after exit | ||
for module in self.model.modules(): | ||
self.assertTrue(module.training) | ||
|
||
# only backbone set to eval | ||
self.model.backbone.eval() | ||
with adjust_status(self.model, training=False): | ||
for module in self.model.modules(): | ||
self.assertFalse(module.training) | ||
|
||
for name, module in self.model.named_modules(): | ||
if "backbone" in name: | ||
self.assertFalse(module.training) | ||
else: | ||
self.assertTrue(module.training) | ||
|
||
def test_freeze_module(self): | ||
model = nn.Sequential( | ||
nn.Conv2d(3, 10, 1), | ||
nn.BatchNorm2d(10), | ||
nn.ReLU(), | ||
) | ||
data = torch.rand(1, 3, 10, 10) | ||
model.train() | ||
assert isinstance(model[1], nn.BatchNorm2d) | ||
before_states = model[1].state_dict() | ||
freeze_module(model[1]) | ||
model(data) | ||
after_states = model[1].state_dict() | ||
self.assertTrue( | ||
all([torch.allclose(v, after_states[k]) for k, v in before_states.items()]) | ||
) | ||
|
||
# yolox test | ||
self.model.train() | ||
for module in self.model.modules(): | ||
self.assertTrue(module.training) | ||
|
||
freeze_module(self.model, "backbone") | ||
for module in self.model.backbone.modules(): | ||
self.assertFalse(module.training) | ||
for p in self.model.backbone.parameters(): | ||
self.assertFalse(p.requires_grad) | ||
|
||
for module in self.model.head.modules(): | ||
self.assertTrue(module.training) | ||
for p in self.model.head.parameters(): | ||
self.assertTrue(p.requires_grad) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters