Skip to content

Commit

Permalink
Merge pull request #1189 from FateScript/hubload
Browse files Browse the repository at this point in the history
feat(model): support hub load
  • Loading branch information
GOATmessi7 committed Mar 21, 2022
2 parents 10a04c7 + e685457 commit c4298f8
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 4 deletions.
19 changes: 19 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

"""
Usage example:
import torch
model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_s")
"""
dependencies = ["torch"]

from yolox.models import ( # isort:skip # noqa: F401, E402
yolox_tiny,
yolox_nano,
yolox_s,
yolox_m,
yolox_l,
yolox_x,
yolov3,
)
1 change: 1 addition & 0 deletions yolox/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.

from .build import *
from .darknet import CSPDarknet, Darknet
from .losses import IOUloss
from .yolo_fpn import YOLOFPN
Expand Down
91 changes: 91 additions & 0 deletions yolox/models/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

import torch
from torch import nn
from torch.hub import load_state_dict_from_url

__all__ = [
"create_yolox_model",
"yolox_nano",
"yolox_tiny",
"yolox_s",
"yolox_m",
"yolox_l",
"yolox_x",
"yolov3",
]

_CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
_CKPT_FULL_PATH = {
"yolox-nano": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_nano.pth",
"yolox-tiny": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_tiny.pth",
"yolox-s": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_s.pth",
"yolox-m": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_m.pth",
"yolox-l": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_l.pth",
"yolox-x": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_x.pth",
"yolov3": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_darknet.pth",
}


def create_yolox_model(
name: str, pretrained: bool = True, num_classes: int = 80, device=None
) -> nn.Module:
"""creates and loads a YOLOX model
Args:
name (str): name of model. for example, "yolox-s", "yolox-tiny".
pretrained (bool): load pretrained weights into the model. Default to True.
num_classes (int): number of model classes. Defalut to 80.
device (str): default device to for model. Defalut to None.
Returns:
YOLOX model (nn.Module)
"""
from yolox.exp import get_exp, Exp

if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

assert name in _CKPT_FULL_PATH, f"user should use one of value in {_CKPT_FULL_PATH.keys()}"
exp: Exp = get_exp(exp_name=name)
exp.num_classes = num_classes
yolox_model = exp.get_model()
if pretrained and num_classes == 80:
weights_url = _CKPT_FULL_PATH[name]
ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
if "model" in ckpt:
ckpt = ckpt["model"]
yolox_model.load_state_dict(ckpt)

yolox_model.to(device)
return yolox_model


def yolox_nano(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-nano", pretrained, num_classes, device)


def yolox_tiny(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)


def yolox_s(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-s", pretrained, num_classes, device)


def yolox_m(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-m", pretrained, num_classes, device)


def yolox_l(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-l", pretrained, num_classes, device)


def yolox_x(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-x", pretrained, num_classes, device)


def yolov3(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
12 changes: 8 additions & 4 deletions yolox/models/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from yolox.utils import bboxes_iou
from yolox.utils import bboxes_iou, meshgrid

from .losses import IOUloss
from .network_blocks import BaseConv, DWConv
Expand Down Expand Up @@ -220,7 +220,7 @@ def get_output_and_grid(self, output, k, stride, dtype):
n_ch = 5 + self.num_classes
hsize, wsize = output.shape[-2:]
if grid.shape[2:4] != output.shape[2:4]:
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij")
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
self.grids[k] = grid

Expand All @@ -237,7 +237,7 @@ def decode_outputs(self, outputs, dtype):
grids = []
strides = []
for (hsize, wsize), stride in zip(self.hw, self.strides):
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij")
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
Expand Down Expand Up @@ -321,7 +321,11 @@ def get_losses(
labels,
imgs,
)
except RuntimeError:
except RuntimeError as e:
# TODO: the string might change, consider a better way
if "CUDA out of memory. " not in str(e):
raise # RuntimeError might not caused by CUDA OOM

logger.error(
"OOM RuntimeError is raised due to the huge memory cost during label assignment. \
CPU mode is applied in this batch. If you want to avoid this issue, \
Expand Down
1 change: 1 addition & 0 deletions yolox/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .allreduce_norm import *
from .boxes import *
from .checkpoint import load_ckpt, save_checkpoint
from .compat import meshgrid
from .demo_utils import *
from .dist import *
from .ema import *
Expand Down
15 changes: 15 additions & 0 deletions yolox/utils/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

import torch

_TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]

__all__ = ["meshgrid"]


def meshgrid(*tensors):
if _TORCH_VER >= [1, 10]:
return torch.meshgrid(*tensors, indexing="ij")
else:
return torch.meshgrid(*tensors)

0 comments on commit c4298f8

Please sign in to comment.