Skip to content

Commit

Permalink
feat(utils): freeze module (#1156)
Browse files Browse the repository at this point in the history
* feat(utils): freeze module
  • Loading branch information
FateScript committed Mar 6, 2022
1 parent ac379df commit 5bbfc11
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 11 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ python -m yolox.tools.eval -n yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --f

* [Training on custom data](docs/train_custom_data.md)
* [Manipulating training image size](docs/manipulate_training_image_size.md)
* [Freezing model](docs/freeze_module.md)

</details>

Expand Down
37 changes: 37 additions & 0 deletions docs/freeze_module.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Freeze module

This page guide users to freeze module in YOLOX.
Exp controls everything in YOLOX, so let's start from creating an Exp object.

## 1. Create your own expermiment object

We take an example of YOLOX-S model on COCO dataset to give a more clear guide.

Import the config you want (or write your own Exp object inherit from `yolox.exp.BaseExp`).
```python
from yolox.exp.default.yolox_s import Exp as MyExp
```

## 2. Override `get_model` method

Here is a simple code to freeze backbone (FPN not included) of module.
```python
class Exp(MyExp):

def get_model(self):
from yolox.utils import freeze_module
model = super().get_model()
freeze_module(model.backbone.backbone)
return model
```
if you only want to freeze FPN, `freeze_module(model.backbone)` might help.

## 3. Train
Suppose that the path of your Exp is `/path/to/my_exp.py`, use the following command to train your model.
```bash
python3 -m yolox.tools.train -f /path/to/my_exp.py
```
For more details of training, run the following command.
```bash
python3 -m yolox.tools.train --help
```
2 changes: 2 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
107 changes: 107 additions & 0 deletions tests/utils/test_model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import unittest

import torch
from torch import nn

from yolox.utils import adjust_status, freeze_module
from yolox.exp import get_exp


class TestModelUtils(unittest.TestCase):

def setUp(self):
self.model: nn.Module = get_exp(exp_name="yolox-s").get_model()

def test_model_state_adjust_status(self):
data = torch.ones(1, 10, 10, 10)
# use bn since bn changes state during train/val
model = nn.BatchNorm2d(10)
prev_state = model.state_dict()

modes = [False, True]
results = [True, False]

# test under train/eval mode
for mode, result in zip(modes, results):
with adjust_status(model, training=mode):
model(data)
model_state = model.state_dict()
self.assertTrue(len(model_state) == len(prev_state))
self.assertEqual(
result,
all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()])
)

# test recurrsive context case
prev_state = model.state_dict()
with adjust_status(model, training=False):
with adjust_status(model, training=False):
model(data)
model_state = model.state_dict()
self.assertTrue(len(model_state) == len(prev_state))
self.assertTrue(
all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()])
)

def test_model_effect_adjust_status(self):
# test context effect
self.model.train()
with adjust_status(self.model, training=False):
for module in self.model.modules():
self.assertFalse(module.training)
# all training after exit
for module in self.model.modules():
self.assertTrue(module.training)

# only backbone set to eval
self.model.backbone.eval()
with adjust_status(self.model, training=False):
for module in self.model.modules():
self.assertFalse(module.training)

for name, module in self.model.named_modules():
if "backbone" in name:
self.assertFalse(module.training)
else:
self.assertTrue(module.training)

def test_freeze_module(self):
model = nn.Sequential(
nn.Conv2d(3, 10, 1),
nn.BatchNorm2d(10),
nn.ReLU(),
)
data = torch.rand(1, 3, 10, 10)
model.train()
assert isinstance(model[1], nn.BatchNorm2d)
before_states = model[1].state_dict()
freeze_module(model[1])
model(data)
after_states = model[1].state_dict()
self.assertTrue(
all([torch.allclose(v, after_states[k]) for k, v in before_states.items()])
)

# yolox test
self.model.train()
for module in self.model.modules():
self.assertTrue(module.training)

freeze_module(self.model, "backbone")
for module in self.model.backbone.modules():
self.assertFalse(module.training)
for p in self.model.backbone.parameters():
self.assertFalse(p.requires_grad)

for module in self.model.head.modules():
self.assertTrue(module.training)
for p in self.model.head.parameters():
self.assertTrue(p.requires_grad)


if __name__ == "__main__":
unittest.main()
11 changes: 6 additions & 5 deletions yolox/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MeterBuffer,
ModelEMA,
WandbLogger,
adjust_status,
all_reduce_norm,
get_local_rank,
get_model_info,
Expand Down Expand Up @@ -169,7 +170,6 @@ def before_train(self):
self.ema_model.updates = self.max_iter * self.start_epoch

self.model = model
self.model.train()

self.evaluator = self.exp.get_evaluator(
batch_size=self.args.batch_size, is_distributed=self.is_distributed
Expand Down Expand Up @@ -320,13 +320,14 @@ def evaluate_and_save_model(self):
if is_parallel(evalmodel):
evalmodel = evalmodel.module

ap50_95, ap50, summary = self.exp.eval(
evalmodel, self.evaluator, self.is_distributed
)
with adjust_status(evalmodel, training=False):
ap50_95, ap50, summary = self.exp.eval(
evalmodel, self.evaluator, self.is_distributed
)

update_best_ckpt = ap50_95 > self.best_ap
self.best_ap = max(self.best_ap, ap50_95)

self.model.train()
if self.rank == 0:
if self.args.logger == "tensorboard":
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
Expand Down
1 change: 1 addition & 0 deletions yolox/exp/yolox_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def init_yolo(M):

self.model.apply(init_yolo)
self.model.head.initialize_biases(1e-2)
self.model.train()
return self.model

def get_data_loader(
Expand Down
91 changes: 85 additions & 6 deletions yolox/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.

import contextlib
from copy import deepcopy
from typing import Sequence

import torch
import torch.nn as nn
Expand All @@ -13,11 +15,12 @@
"fuse_model",
"get_model_info",
"replace_module",
"freeze_module",
"adjust_status",
]


def get_model_info(model, tsize):

def get_model_info(model: nn.Module, tsize: Sequence[int]) -> str:
stride = 64
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
Expand All @@ -28,8 +31,18 @@ def get_model_info(model, tsize):
return info


def fuse_conv_and_bn(conv, bn):
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
def fuse_conv_and_bn(conv: nn.Conv2d, bn: nn.BatchNorm2d) -> nn.Conv2d:
"""
Fuse convolution and batchnorm layers.
check more info on https://tehnokv.com/posts/fusing-batchnorm-and-conv/
Args:
conv (nn.Conv2d): convolution to fuse.
bn (nn.BatchNorm2d): batchnorm to fuse.
Returns:
nn.Conv2d: fused convolution behaves the same as the input conv and bn.
"""
fusedconv = (
nn.Conv2d(
conv.in_channels,
Expand Down Expand Up @@ -63,7 +76,15 @@ def fuse_conv_and_bn(conv, bn):
return fusedconv


def fuse_model(model):
def fuse_model(model: nn.Module) -> nn.Module:
"""fuse conv and bn in model
Args:
model (nn.Module): model to fuse
Returns:
nn.Module: fused model
"""
from yolox.models.network_blocks import BaseConv

for m in model.modules():
Expand All @@ -74,7 +95,7 @@ def fuse_model(model):
return model


def replace_module(module, replaced_module_type, new_module_type, replace_func=None):
def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
"""
Replace given type in module to a new type. mostly used in deploy.
Expand Down Expand Up @@ -104,3 +125,61 @@ def default_replace_func(replaced_module_type, new_module_type):
model.add_module(name, new_child)

return model


def freeze_module(module: nn.Module, name=None) -> nn.Module:
"""freeze module inplace
Args:
module (nn.Module): module to freeze.
name (str, optional): name to freeze. If not given, freeze the whole module.
Note that fuzzy match is not supported. Defaults to None.
Examples:
freeze the backbone of model
>>> freeze_moudle(model.backbone)
or freeze the backbone of model by name
>>> freeze_moudle(model, name="backbone")
"""
for param_name, parameter in module.named_parameters():
if name is None or name in param_name:
parameter.requires_grad = False

# ensure module like BN and dropout are freezed
for module_name, sub_module in module.named_modules():
# actually there are no needs to call eval for every single sub_module
if name is None or name in module_name:
sub_module.eval()

return module


@contextlib.contextmanager
def adjust_status(module: nn.Module, training: bool = False) -> nn.Module:
"""Adjust module to training/eval mode temporarily.
Args:
module (nn.Module): module to adjust status.
training (bool): training mode to set. True for train mode, False fro eval mode.
Examples:
>>> with adjust_status(model, training=False):
... model(data)
"""
status = {}

def backup_status(module):
for m in module.modules():
# save prev status to dict
status[m] = m.training
m.training = training

def recover_status(module):
for m in module.modules():
# recover prev status from dict
m.training = status.pop(m)

backup_status(module)
yield module
recover_status(module)

0 comments on commit 5bbfc11

Please sign in to comment.