Skip to content

Commit

Permalink
hello, world
Browse files Browse the repository at this point in the history
  • Loading branch information
Cydia2018 committed Oct 14, 2022
1 parent 5ef69ef commit 9ecbcda
Show file tree
Hide file tree
Showing 24 changed files with 2,317 additions and 367 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
!setup.cfg
!cfg/yolov3*.cfg

pretrained_weights/
*.out
fortest.py
exp.sh

storage.googleapis.com
runs/*
data/*
Expand Down
405 changes: 46 additions & 359 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion data/coco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: ../datasets/coco # dataset root dir
path: data/coco # dataset root dir
train: train2017.txt # train images (relative to 'path') 118287 images
val: val2017.txt # val images (relative to 'path') 5000 images
test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
Expand Down
35 changes: 35 additions & 0 deletions data/coco_person.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# COCO 2017 dataset http://cocodataset.org by Microsoft
# Example usage: python train.py --data coco.yaml
# parent
# ├── yolov5
# └── datasets
# └── coco ← downloads here (20.1 GB)


# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: data/coco_person/ # dataset root dir
train: train.txt # train images (relative to 'path') 118287 images
val: valid.txt # val images (relative to 'path') 5000 images

# Classes
names:
0: person

# Download script/URL (optional)
# download: |
# from utils.general import download, Path


# # Download labels
# segments = False # segment or box labels
# dir = Path(yaml['path']) # dataset root dir
# url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
# urls = [url + ('coco2017labels-segments.zip' if segments else 'coco2017labels.zip')] # labels
# download(urls, dir=dir.parent)

# # Download data
# urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
# 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
# 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
# download(urls, dir=dir / 'images', threads=3)
35 changes: 35 additions & 0 deletions data/hand.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# COCO 2017 dataset http://cocodataset.org by Microsoft
# Example usage: python train.py --data coco.yaml
# parent
# ├── yolov5
# └── datasets
# └── coco ← downloads here (20.1 GB)


# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: data/hand/ # dataset root dir
train: train.txt # train images (relative to 'path') 118287 images
val: valid.txt # val images (relative to 'path') 5000 images

# Classes
names:
0: hand

# Download script/URL (optional)
# download: |
# from utils.general import download, Path


# # Download labels
# segments = False # segment or box labels
# dir = Path(yaml['path']) # dataset root dir
# url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
# urls = [url + ('coco2017labels-segments.zip' if segments else 'coco2017labels.zip')] # labels
# download(urls, dir=dir.parent)

# # Download data
# urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
# 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
# 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
# download(urls, dir=dir / 'images', threads=3)
34 changes: 34 additions & 0 deletions data/hyps/hyp.scratch-convnext.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Hyperparameters for low-augmentation COCO training from scratch
# python train.py --batch 64 --cfg yolov5n6.yaml --weights '' --data coco.yaml --img 640 --epochs 300 --linear
# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials

lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
momentum: 0.9 # SGD momentum/Adam beta1
weight_decay: 0.05 # optimizer weight decay 5e-4
warmup_epochs: 20 # warmup epochs (fractions ok)
warmup_momentum: 0.8 # warmup initial momentum
warmup_bias_lr: 0.01 # warmup initial bias lr
box: 0.05 # box loss gain
cls: 0.5 # cls loss gain
cls_pw: 1.0 # cls BCELoss positive_weight
obj: 1.0 # obj loss gain (scale with pixels)
obj_pw: 1.0 # obj BCELoss positive_weight
iou_t: 0.20 # IoU training threshold
anchor_t: 4.0 # anchor-multiple threshold
# anchors: 3 # anchors per output layer (0 to ignore)
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
degrees: 0.0 # image rotation (+/- deg)
translate: 0.1 # image translation (+/- fraction)
scale: 0.5 # image scale (+/- gain)
shear: 0.0 # image shear (+/- deg)
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
flipud: 0.0 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)
146 changes: 145 additions & 1 deletion models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

import json
import math
from mimetypes import init
import platform
from turtle import forward
import warnings
from collections import OrderedDict, namedtuple
from copy import copy
Expand All @@ -18,6 +20,8 @@
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
from PIL import Image
from torch.cuda import amp

Expand Down Expand Up @@ -55,6 +59,23 @@ def forward_fuse(self, x):
return self.act(self.conv(x))


class SlimConv(nn.Module):
# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
default_act = nn.SiLU() # default activation

def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

def forward(self, x):
return self.act(self.bn(self.conv(x)))

def forward_fuse(self, x):
return self.act(self.conv(x))


class DWConv(Conv):
# Depth-wise convolution
def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
Expand Down Expand Up @@ -107,7 +128,7 @@ class Bottleneck(nn.Module):
# Standard bottleneck
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
c_ = max(int(c2 * e), 8) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g)
self.add = shortcut and c1 == c2
Expand Down Expand Up @@ -163,6 +184,20 @@ def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))


class SlimC3(nn.Module):
# CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, inr=[1.0 for _ in range(20)], shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=inr[i]) for i in range(n)))

def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))


class C3x(C3):
# C3 module with cross-convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
Expand Down Expand Up @@ -229,6 +264,24 @@ def forward(self, x):
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))


class SlimSPPF(nn.Module):
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))


class Focus(nn.Module):
# Focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
Expand Down Expand Up @@ -270,6 +323,97 @@ def forward(self, x):
return self.conv(x) + self.shortcut(x)


# -----------------------------
class ConvNeXt_Stem(nn.Module):
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, groups=g, dilation=d)
self.ln = LayerNorm(c2, eps=1e-6, data_format="channels_first")

def forward(self, x):
return self.ln(self.conv(x))


class ConvNeXt_Downsample(nn.Module):
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, groups=g, dilation=d)
self.ln = LayerNorm(c1, eps=1e-6, data_format="channels_first")

def forward(self, x):
return self.conv(self.ln(x))


class ConvNeXt_Inside_Block(nn.Module):
def __init__(self, dim, layer_scale_init_value=1e-6, drop_path=0.): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

x = input + self.drop_path(x)
return x


class ConvNeXt_Block(nn.Module):
def __init__(self, c1, c2, n=1, layer_scale_init_value=1e-6):
super().__init__()
self.m = nn.Sequential(*(ConvNeXt_Inside_Block(c2, layer_scale_init_value) for _ in range(n)))
self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)

def forward(self, x):
return self.m(x)


class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )

def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x


class Contract(nn.Module):
# Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
def __init__(self, gain=2):
Expand Down
25 changes: 20 additions & 5 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _profile_one_layer(self, m, x, dt):
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
LOGGER.info('Fusing layers... ')
for m in self.model.modules():
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
if isinstance(m, (Conv, DWConv, SlimConv)) and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm
m.forward = m.forward_fuse # update forward
Expand Down Expand Up @@ -298,7 +298,7 @@ def _from_yaml(self, cfg):

def parse_model(d, ch): # model_dict, input_channels(3)
# Parse a YOLOv5 model.yaml dictionary
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
# LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
if act:
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
Expand All @@ -316,15 +316,30 @@ def parse_model(d, ch): # model_dict, input_channels(3)
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in {
Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x,
ConvNeXt_Stem, ConvNeXt_Block, ConvNeXt_Downsample}:
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)

args = [c1, c2, *args[1:]]
if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x}:
if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x, ConvNeXt_Block}:
args.insert(2, n) # number of repeats
n = 1
elif m in {SlimConv, SlimSPPF}:
c1, c2 = ch[f], args[0]
ratio = args[1]
if c2 != no: # if not output
c2 = max(make_divisible(c2 * gw * ratio, 8), 8)
args = [c1, c2, *args[2:]]
elif m is SlimC3:
c1, c2 = ch[f], args[0]
ratio = args[1]
if c2 != no: # if not output
c2 = max(make_divisible(c2 * gw * ratio, 8), 8)
args = [c1, c2, *args[2:]]
args.insert(2, n) # number of repeats
n = 1
elif m is nn.BatchNorm2d:
args = [ch[f]]
elif m is Concat:
Expand All @@ -347,7 +362,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
t = str(m)[8:-2].replace('__main__.', '') # module type
np = sum(x.numel() for x in m_.parameters()) # number params
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
# LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
if i == 0:
Expand Down
Loading

0 comments on commit 9ecbcda

Please sign in to comment.