diff --git a/PaddleCV/tracking/README.md b/PaddleCV/tracking/README.md index 10e830eb10..1c9d8c5d50 100644 --- a/PaddleCV/tracking/README.md +++ b/PaddleCV/tracking/README.md @@ -48,15 +48,20 @@ pytracking 包含跟踪代码 主流的训练数据集有: - [VID](http://bvisionweb1.cs.unc.edu/ilsvrc2015/ILSVRC2015_VID.tar.gz) +- [DET](http://image-net.org/challenges/LSVRC/2015/) - [Microsoft COCO 2014](http://cocodataset.org/#download) +- [Microsoft COCO 2017](http://cocodataset.org/#download) +- [Youtube-VOS](https://youtube-vos.org/) - [LaSOT](https://drive.google.com/file/d/1O2DLxPP8M4Pn4-XCttCJUW3A29tDIeNa/view) - [GOT-10K](http://got-10k.aitestunion.com/downloads_dataset/full_data) 下载并解压后的数据集的组织方式为: ``` /Datasets/ - └─ ILSVRC2015_VID/ - └─ train2014/ + └─ ILSVRC2015/ + └─ ILSVRC2015_DET/ + └─ COCO/ + └─ YoutubeVOS/ └─ GOT-10K/ └─ LaSOTBenchmark/ @@ -71,16 +76,16 @@ Datasets是数据集保存的路径。 tracking的工作环境: - Linux - python3 -- PaddlePaddle1.7 +- PaddlePaddle1.8 > 注意:如果遇到cmath无法import的问题,建议切换Python版本,建议使用python3.6.8, python3.7.0 。另外, > tracking暂不支持在window上运行,如果开发者有需求在window上运行tracking,请在issue中提出需求。 ### 安装依赖 -1. 安装paddle,需要安装1.7版本的Paddle,如低于这个版本,请升级到Paddle 1.7. +1. 安装paddle,需要安装1.8版本的Paddle,如低于这个版本,请升级到Paddle 1.8. ```bash -pip install paddlepaddle-gpu==1.7.0 +pip install paddlepaddle-gpu==1.8.0 ``` 2. 安装第三方库,建议使用anaconda @@ -114,10 +119,12 @@ pip install python-prctl └─ atom_resnet18.pdparams └─ atom_resnet50.pdparams └─ backbone + └─ AlexNet.pdparams └─ ResNet18.pdparams └─ ResNet50.pdparams + └─ ResNet50_dilated.pdparams ``` -其中/pretrained_models/backbone/文件夹包含,ResNet18、ResNet50在Imagenet上的预训练模型。 +其中/pretrained_models/backbone/文件夹包含,AlexNet、ResNet18、ResNet50在Imagenet上的预训练模型。 ### 设置训练参数 @@ -154,7 +161,7 @@ python -c "from ltr.admin.environment import create_default_local_file; create_d ```bash self.workspace_dir = './checkpoints' self.lasot_dir = '/Datasets/LaSOTBenchmark/' - self.coco_dir = '/Datasets/train2014/' + self.coco_dir = '/Datasets/COCO/' self.got10k_dir = '/Datasets/GOT-10k/train' self.imagenet_dir = '/Datasets/ILSVRC2015/' ``` @@ -164,6 +171,16 @@ cd ltr/data_specs/ wget https://paddlemodels.cdn.bcebos.com/paddle_track/vot/got10k_lasot_split.tar tar xvf got10k_lasot_split.tar ``` +训练SiamRPN、SiamMask时,需要配置 workspace_dir,以及imagenet、coco、imagenetdet、youtubevos、lasot、got10k的数据集路径,如下: +```bash + self.workspace_dir = './checkpoints' + self.imagenet_dir = '/Datasets/ILSVRC2015/' + self.coco_dir = '/Datasets/COCO/' + self.imagenetdet_dir = '/Datasets/ILSVRC2015_DET/' + self.youtubevos_dir = '/Datasets/YoutubeVOS/' + self.lasot_dir = '/Datasets/LaSOTBenchmark/' + self.got10k_dir = '/Datasets/GOT-10k/train' +``` ### 启动训练 @@ -180,6 +197,15 @@ python run_training.py bbreg atom_res50_vid_lasot_coco # 训练 SiamFC python run_training.py siamfc siamfc_alexnet_vid + +# 训练 SiamRPN AlexNet +python run_training.py siamrpn siamrpn_alexnet + +# 训练 SiamMask-Base ResNet50 +python run_training.py siammask siammask_res50_base + +# 训练 SiamMask-Refine ResNet50,需要配置settings.base_model为最优的SiamMask-Base模型 +python run_training.py siammask siammask_res50_sharp ``` @@ -242,6 +268,19 @@ python eval_benchmark.py -d VOT2018 -tr bbreg.atom_res18_vid_lasot_coco -te atom python eval_benchmark.py -d VOT2018 -tr siamfc.siamfc_alexnet_vid -te siamfc.default -e 'range(1, 50, 1)' ``` +测试SiamRPN +``` +python eval_benchmark.py -d OTB100 -tr siamrpn.siamrpn_alexnet -te siamrpn.default_otb -e 'range(1, 40, 1)' +``` + +测试SiamMask +```bash +# 在VOT2018上测试SiamMask-Base +python eval_benchmark.py -d VOT2018 -tr siammask.siammask_res50_base -te siammask.base_default -e 'range(1, 20, 1)' +# 在VOT2018上测试SiamMask-Sharp +python eval_benchmark.py -d VOT2018 -tr siammask.siammask_res50_sharp -te siammask.sharp_default_vot -e 'range(1, 20, 1)' +``` + ## 跟踪结果可视化 @@ -265,7 +304,9 @@ jupyter notebook --ip 0.0.0.0 --port 8888 | 数据集 | 模型 | Backbone | 论文结果 | 训练结果 | 模型| | :-------: | :-------: | :---: | :---: | :---------: |:---------: | |VOT2018| ATOM | Res18 | EAO: 0.401 | 0.399 | [model](https://paddlemodels.cdn.bcebos.com/paddle_track/vot/ATOM.tar) | +|VOT2018| SiamMask | Res50 | EAO: 0.380 | 0.379 | [model](https://paddlemodels.cdn.bcebos.com/paddle_track/vot/SiamMask.tar) | |VOT2018| SiamFC | AlexNet | EAO: 0.188 | 0.211 | [model](https://paddlemodels.cdn.bcebos.com/paddle_track/vot/SiamFC.tar) | +|OTB100| SiamRPN | AlexNet | Succ: 0.637, Prcn: 0.851 | Succ: 0.644, Prcn: 0.848 | [model](https://paddlemodels.cdn.bcebos.com/paddle_track/vot/SiamRPN.tar) | ## 引用与参考 @@ -280,6 +321,26 @@ SiamFC **[[Paper]](https://arxiv.org/pdf/1811.07628.pdf) [[Code]](https://www.ro organization={Springer} } +SiamRPN **[[Paper]](http://openaccess.thecvf.com/content_cvpr_2018/papers/Li_High_Performance_Visual_CVPR_2018_paper.pdf) [[Code]](https://github.com/STVIR/pysot)** + + @inproceedings{li2018high, + title={High performance visual tracking with siamese region proposal network}, + author={Li, Bo and Yan, Junjie and Wu, Wei and Zhu, Zheng and Hu, Xiaolin}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + pages={8971--8980}, + year={2018} + } + +SiamMask **[[Paper]](https://arxiv.org/pdf/1812.05050.pdf) [[Code]](https://github.com/foolwood/SiamMask)** + + @inproceedings{wang2019fast, + title={Fast online object tracking and segmentation: A unifying approach}, + author={Wang, Qiang and Zhang, Li and Bertinetto, Luca and Hu, Weiming and Torr, Philip HS}, + booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, + pages={1328--1338}, + year={2019} + } + ATOM **[[Paper]](https://arxiv.org/pdf/1811.07628.pdf) [[Raw results]](https://drive.google.com/drive/folders/1MdJtsgr34iJesAgL7Y_VelP8RvQm_IG_) [[Models]](https://drive.google.com/open?id=1EsNSQr25qfXHYLqjZaVZElbGdUg-nyzd) [[Training Code]](https://github.com/visionml/pytracking/blob/master/ltr/README.md#ATOM) [[Tracker Code]](https://github.com/visionml/pytracking/blob/master/pytracking/README.md#ATOM)** @inproceedings{danelljan2019atom, diff --git a/PaddleCV/tracking/ltr/actors/__init__.py b/PaddleCV/tracking/ltr/actors/__init__.py index 9b89e0f607..4f308032c1 100644 --- a/PaddleCV/tracking/ltr/actors/__init__.py +++ b/PaddleCV/tracking/ltr/actors/__init__.py @@ -1,3 +1,4 @@ from .base_actor import BaseActor from .bbreg import AtomActor from .siamfc import SiamFCActor +from .siam import SiamActor diff --git a/PaddleCV/tracking/ltr/actors/siam.py b/PaddleCV/tracking/ltr/actors/siam.py new file mode 100644 index 0000000000..e284f2985d --- /dev/null +++ b/PaddleCV/tracking/ltr/actors/siam.py @@ -0,0 +1,50 @@ +from . import BaseActor +import paddle.fluid as fluid +import numpy as np + +class SiamActor(BaseActor): + """ Actor for training the SiamRPN/SiamMask""" + + def __call__(self, data): + # Run network to obtain predictiion + pred = self.net(data['train_images'], data['test_images']) + + # Compute loss + label_cls = fluid.layers.cast(x=data['label_cls'], dtype=np.int64) + cls_loss = self.objective['cls'](pred['cls'], label_cls) + loc_loss = self.objective['loc'](pred['loc'], data['label_loc'], data['label_loc_weight']) + + loss = {} + loss['cls'] = cls_loss + loss['loc'] = loc_loss + + # Return training stats + stats = {} + stats['Loss/cls'] = cls_loss.numpy() + stats['Loss/loc'] = loc_loss.numpy() + + # Compute mask loss if necessary + if 'mask' in pred: + mask_loss, iou_m, iou_5, iou_7 = self.objective['mask']( + pred['mask'], + data['label_mask'], + data['label_mask_weight']) + loss['mask'] = mask_loss + + stats['Loss/mask'] = mask_loss.numpy() + stats['Accuracy/mask_iou_mean'] = iou_m.numpy() + stats['Accuracy/mask_at_5'] = iou_5.numpy() + stats['Accuracy/mask_at_7'] = iou_7.numpy() + + # Use scale loss if exists + scale_loss = getattr(self.net, "scale_loss", None) + if callable(scale_loss): + total_loss = scale_loss(loss) + else: + total_loss = 0 + for k, v in loss.items(): + total_loss += v + + stats['Loss/total'] = total_loss.numpy() + + return total_loss, stats diff --git a/PaddleCV/tracking/ltr/admin/environment.py b/PaddleCV/tracking/ltr/admin/environment.py index 590d56d878..14a1e98505 100644 --- a/PaddleCV/tracking/ltr/admin/environment.py +++ b/PaddleCV/tracking/ltr/admin/environment.py @@ -16,7 +16,8 @@ def create_default_local_file(): 'trackingnet_dir': empty_str, 'coco_dir': empty_str, 'imagenet_dir': empty_str, - 'imagenetdet_dir': empty_str + 'imagenetdet_dir': empty_str, + 'youtubevos_dir': empty_str }) comment = { diff --git a/PaddleCV/tracking/ltr/admin/local.py b/PaddleCV/tracking/ltr/admin/local.py index f598f81482..9387a808bd 100644 --- a/PaddleCV/tracking/ltr/admin/local.py +++ b/PaddleCV/tracking/ltr/admin/local.py @@ -9,3 +9,4 @@ def __init__(self): self.coco_dir = '' self.imagenet_dir = '' self.imagenetdet_dir = '' + self.youtubevos_dir = '' diff --git a/PaddleCV/tracking/ltr/data/anchor.py b/PaddleCV/tracking/ltr/data/anchor.py new file mode 100644 index 0000000000..b943095e46 --- /dev/null +++ b/PaddleCV/tracking/ltr/data/anchor.py @@ -0,0 +1,261 @@ +import math +import numpy as np +from collections import namedtuple + +Corner = namedtuple('Corner', 'x1 y1 x2 y2') + +# alias +BBox = Corner +Center = namedtuple('Center', 'x y w h') + +def topleft2corner(topleft): + """ convert (x, y, w, h) to (x1, y1, x2, y2) + Args: + center: np.array (4 * N) + Return: + np.array (4 * N) + """ + x, y, w, h = topleft[0], topleft[1], topleft[2], topleft[3] + x1 = x + y1 = y + x2 = x + w + y2 = y + h + return x1, y1, x2, y2 + +def corner2center(corner): + """ convert (x1, y1, x2, y2) to (cx, cy, w, h) + Args: + conrner: Corner or np.array (4*N) + Return: + Center or np.array (4 * N) + """ + if isinstance(corner, Corner): + x1, y1, x2, y2 = corner + return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1)) + else: + x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3] + x = (x1 + x2) * 0.5 + y = (y1 + y2) * 0.5 + w = x2 - x1 + h = y2 - y1 + return x, y, w, h + + +def center2corner(center): + """ convert (cx, cy, w, h) to (x1, y1, x2, y2) + Args: + center: Center or np.array (4 * N) + Return: + center or np.array (4 * N) + """ + if isinstance(center, Center): + x, y, w, h = center + return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5) + else: + x, y, w, h = center[0], center[1], center[2], center[3] + x1 = x - w * 0.5 + y1 = y - h * 0.5 + x2 = x + w * 0.5 + y2 = y + h * 0.5 + return x1, y1, x2, y2 + + +def IoU(rect1, rect2): + """ caculate interection over union + Args: + rect1: (x1, y1, x2, y2) + rect2: (x1, y1, x2, y2) + Returns: + iou + """ + # overlap + x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3] + tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3] + + xx1 = np.maximum(tx1, x1) + yy1 = np.maximum(ty1, y1) + xx2 = np.minimum(tx2, x2) + yy2 = np.minimum(ty2, y2) + + ww = np.maximum(0, xx2 - xx1) + hh = np.maximum(0, yy2 - yy1) + + area = (x2 - x1) * (y2 - y1) + target_a = (tx2 - tx1) * (ty2 - ty1) + inter = ww * hh + iou = inter / (area + target_a - inter) + return iou + + +class Anchors: + """ + This class generate anchors. + """ + + def __init__(self, stride, ratios, scales, image_center=0, size=0): + self.stride = stride + self.ratios = ratios + self.scales = scales + self.image_center = 0 + self.size = 0 + + self.anchor_num = len(self.scales) * len(self.ratios) + + self.anchors = None + + self.generate_anchors() + + def generate_anchors(self): + """ + generate anchors based on predefined configuration + """ + self.anchors = np.zeros((self.anchor_num, 4), dtype=np.float32) + size = self.stride * self.stride + count = 0 + for r in self.ratios: + ws = int(math.sqrt(size * 1. / r)) + hs = int(ws * r) + + for s in self.scales: + w = ws * s + h = hs * s + self.anchors[count][:] = [-w * 0.5, -h * 0.5, w * 0.5, h * 0.5][:] + count += 1 + + def generate_all_anchors(self, im_c, size): + """ + im_c: image center + size: image size + """ + if self.image_center == im_c and self.size == size: + return False + self.image_center = im_c + self.size = size + + a0x = im_c - size // 2 * self.stride + ori = np.array([a0x] * 4, dtype=np.float32) + zero_anchors = self.anchors + ori + + x1 = zero_anchors[:, 0] + y1 = zero_anchors[:, 1] + x2 = zero_anchors[:, 2] + y2 = zero_anchors[:, 3] + + x1, y1, x2, y2 = map(lambda x: x.reshape(self.anchor_num, 1, 1), + [x1, y1, x2, y2]) + cx, cy, w, h = corner2center([x1, y1, x2, y2]) + + disp_x = np.arange(0, size).reshape(1, 1, -1) * self.stride + disp_y = np.arange(0, size).reshape(1, -1, 1) * self.stride + + cx = cx + disp_x + cy = cy + disp_y + + # broadcast + zero = np.zeros((self.anchor_num, size, size), dtype=np.float32) + cx, cy, w, h = map(lambda x: x + zero, [cx, cy, w, h]) + x1, y1, x2, y2 = center2corner([cx, cy, w, h]) + + self.all_anchors = (np.stack([x1, y1, x2, y2]).astype(np.float32), + np.stack([cx, cy, w, h]).astype(np.float32)) + return True + + +class AnchorTarget: + def __init__(self, + search_size, + output_size, + stride, + ratios, + scales, + num_pos, + num_neg, + num_total, + thr_high, + thr_low): + self.search_size = search_size + self.output_size = output_size + self.anchor_stride = stride + self.anchor_ratios = ratios + self.anchor_scales = scales + self.num_pos = num_pos + self.num_neg = num_neg + self.num_total = num_total + self.thr_high = thr_high + self.thr_low = thr_low + + self.anchors = Anchors(stride, + ratios, + scales) + + self.anchors.generate_all_anchors(im_c=search_size // 2, + size=output_size) + + def __call__(self, target, size, neg=False): + anchor_num = len(self.anchor_ratios) * len(self.anchor_scales) + + # -1 ignore 0 negative 1 positive + cls = -1 * np.ones((anchor_num, size, size), dtype=np.int64) + delta = np.zeros((4, anchor_num, size, size), dtype=np.float32) + delta_weight = np.zeros((anchor_num, size, size), dtype=np.float32) + + def select(position, keep_num=16): + num = position[0].shape[0] + if num <= keep_num: + return position, num + slt = np.arange(num) + np.random.shuffle(slt) + slt = slt[:keep_num] + return tuple(p[slt] for p in position), keep_num + + tcx, tcy, tw, th = corner2center(target) + + if neg: + # l = size // 2 - 3 + # r = size // 2 + 3 + 1 + # cls[:, l:r, l:r] = 0 + + cx = size // 2 + cy = size // 2 + cx += int(np.ceil((tcx - self.search_size // 2) / + self.anchor_stride + 0.5)) + cy += int(np.ceil((tcy - self.search_size // 2) / + self.anchor_stride + 0.5)) + l = max(0, cx - 3) + r = min(size, cx + 4) + u = max(0, cy - 3) + d = min(size, cy + 4) + cls[:, u:d, l:r] = 0 + + neg, neg_num = select(np.where(cls == 0), self.num_neg) + cls[:] = -1 + cls[neg] = 0 + + overlap = np.zeros((anchor_num, size, size), dtype=np.float32) + return cls, delta, delta_weight, overlap + + anchor_box = self.anchors.all_anchors[0] + anchor_center = self.anchors.all_anchors[1] + x1, y1, x2, y2 = anchor_box[0], anchor_box[1], \ + anchor_box[2], anchor_box[3] + cx, cy, w, h = anchor_center[0], anchor_center[1], \ + anchor_center[2], anchor_center[3] + + delta[0] = (tcx - cx) / w + delta[1] = (tcy - cy) / h + delta[2] = np.log(tw / w) + delta[3] = np.log(th / h) + + overlap = IoU([x1, y1, x2, y2], target) + + pos = np.where(overlap > self.thr_high) + neg = np.where(overlap < self.thr_low) + + pos, pos_num = select(pos, self.num_pos) + neg, neg_num = select(neg, self.num_total - self.num_pos) + + cls[pos] = 1 + delta_weight[pos] = 1. / (pos_num + 1e-6) + + cls[neg] = 0 + return cls, delta, delta_weight, overlap diff --git a/PaddleCV/tracking/ltr/data/image_loader.py b/PaddleCV/tracking/ltr/data/image_loader.py index 12bda4ee44..8352b471c1 100644 --- a/PaddleCV/tracking/ltr/data/image_loader.py +++ b/PaddleCV/tracking/ltr/data/image_loader.py @@ -12,7 +12,7 @@ def default_image_loader(path): im = jpeg4py_loader(path) if im is None: default_image_loader.use_jpeg4py = False - print('Using opencv_loader instead.') + print('Jpeg4py is not available. Using OpenCV instead.') else: default_image_loader.use_jpeg4py = True return im @@ -29,9 +29,9 @@ def jpeg4py_loader(path): try: return jpeg4py.JPEG(path).decode() except Exception as e: - print('ERROR: Could not read image "{}"'.format(path)) + print('ERROR: Jpeg4py could not read image "{}". Using OpenCV instead.'.format(path)) print(e) - return None + return opencv_loader(path) def opencv_loader(path): @@ -41,7 +41,7 @@ def opencv_loader(path): # convert to rgb and return return cv.cvtColor(im, cv.COLOR_BGR2RGB) except Exception as e: - print('ERROR: Could not read image "{}"'.format(path)) + print('ERROR: OpenCV could not read image "{}"'.format(path)) print(e) return None @@ -55,7 +55,7 @@ def lmdb_loader(path, lmdb_path=None): img_buffer = np.frombuffer(img_buffer, np.uint8) return cv.imdecode(img_buffer, cv.IMREAD_COLOR) except Exception as e: - print('ERROR: Could not read image "{}"'.format(path)) + print('ERROR: Lmdb could not read image "{}"'.format(path)) print(e) return None diff --git a/PaddleCV/tracking/ltr/data/loader.py b/PaddleCV/tracking/ltr/data/loader.py index 2f6bd48816..0d7c8f11c9 100644 --- a/PaddleCV/tracking/ltr/data/loader.py +++ b/PaddleCV/tracking/ltr/data/loader.py @@ -1,10 +1,27 @@ import os +import signal import sys import dataflow as df import numpy as np +# handle terminate reader process, do not print stack frame +def _reader_quit(signum, frame): + print("Reader process exit.") + sys.exit() + + +def _term_group(sig_num, frame): + print('pid {} terminated, terminate group ' + '{}...'.format(os.getpid(), os.getpgrp())) + os.killpg(os.getpgid(os.getpid()), signal.SIGKILL) + + +signal.signal(signal.SIGTERM, _reader_quit) +signal.signal(signal.SIGINT, _term_group) + + class LTRLoader(df.DataFlow): """ Data loader. Combines a dataset and a sampler, and provides diff --git a/PaddleCV/tracking/ltr/data/processing.py b/PaddleCV/tracking/ltr/data/processing.py index ab207da002..5c6bca92be 100644 --- a/PaddleCV/tracking/ltr/data/processing.py +++ b/PaddleCV/tracking/ltr/data/processing.py @@ -2,6 +2,7 @@ from ltr.data import transforms import ltr.data.processing_utils as prutils +from ltr.data.anchor import AnchorTarget from pytracking.libs import TensorDict @@ -113,6 +114,148 @@ def __call__(self, data: TensorDict, rng=None): return data +class SiamProcessing(BaseProcessing): + def __init__(self, + search_area_factor, + output_sz, + center_jitter_factor, + scale_jitter_factor, + label_params, + mode='pair', + scale_type='context', + border_type='meanpad', + *args, + **kwargs): + self._init_transform(*args, **kwargs) + self.search_area_factor = search_area_factor + self.output_sz = output_sz + self.center_jitter_factor = center_jitter_factor + self.scale_jitter_factor = scale_jitter_factor + self.mode = mode + self.scale_type = scale_type + self.border_type = border_type + self.label_params = label_params + self.anchor_target = AnchorTarget( + label_params['search_size'], + label_params['output_size'], + label_params['anchor_stride'], + label_params['anchor_ratios'], + label_params['anchor_scales'], + label_params['num_pos'], + label_params['num_neg'], + label_params['num_total'], + label_params['thr_high'], + label_params['thr_low']) + + def _init_transform(self, + transform=transforms.ToArray(), + train_transform=None, + test_transform=None, + train_mask_transform=None, + test_mask_transform=None, + joint_transform=None): + self.transform = {'train': transform if train_transform is None else train_transform, + 'test': transform if test_transform is None else test_transform, + 'joint': joint_transform} + super().__init__( + transform=transform, + train_transform=train_transform, + test_transform=test_transform, + joint_transform=joint_transform) + self.transform['train_mask'] = self.transform['train'] if train_mask_transform is None \ + else train_mask_transform + self.transform['test_mask'] = self.transform['test'] if test_mask_transform is None \ + else test_mask_transform + + def _get_jittered_box(self, box, mode, rng): + jittered_size = box[2:4] * (1 + (2 * rng.rand(2) - 1) * self.scale_jitter_factor[mode]) + max_offset = (np.sqrt(jittered_size.prod()) * self.center_jitter_factor[mode]) + jittered_center = box[0:2] + 0.5 * box[2:4] + max_offset * (rng.rand(2) - 0.5) + + return np.concatenate((jittered_center - 0.5 * jittered_size, jittered_size), axis=0) + + def _get_label(self, target_bb, neg): + return self.anchor_target(target_bb, self.label_params['output_size'], neg) + + def __call__(self, data: TensorDict, rng=None): + neg = data['neg'] + + # Apply joint transforms + if self.transform['joint'] is not None: + num_train_images = len(data['train_images']) + all_images = data['train_images'] + data['test_images'] + all_images_trans = self.transform['joint'](*all_images) + + data['train_images'] = all_images_trans[:num_train_images] + data['test_images'] = all_images_trans[num_train_images:] + + for s in ['train', 'test']: + assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \ + "In pair mode, num train/test frames must be 1" + + # Add a uniform noise to the center pos + jittered_anno = [self._get_jittered_box(a, s, rng) for a in data[s + '_anno']] + + # Crop image region centered at jittered_anno box + try: + crops, boxes = prutils.jittered_center_crop( + data[s + '_images'], + jittered_anno, + data[s + '_anno'], + self.search_area_factor[s], + self.output_sz[s], + scale_type=self.scale_type, + border_type=self.border_type) + mask_crops, _ = prutils.jittered_center_crop( + data[s + '_masks'], + jittered_anno, + data[s + '_anno'], + self.search_area_factor[s], + self.output_sz[s], + scale_type=self.scale_type, + border_type='zeropad') + except Exception as e: + print('{}, anno: {}'.format(data['dataset'], data[s + '_anno'])) + raise e + + # Apply transforms + data[s + '_images'] = [self.transform[s](x) for x in crops] + data[s + '_anno'] = boxes + data[s + '_masks'] = [self.transform[s + '_mask'](x) for x in mask_crops] + + # Prepare output + if self.mode == 'sequence': + data = data.apply(prutils.stack_tensors) + else: + data = data.apply(lambda x: x[0] if isinstance(x, list) else x) + + # Get labels + if self.label_params is not None: + assert data['test_anno'].shape[0] == 1 + gt_box = data['test_anno'][0] + gt_box[2:] += gt_box[:2] + cls, delta, delta_weight, overlap = self._get_label(gt_box, neg) + + mask = data['test_masks'][0] + if np.sum(mask) > 0: + mask_weight = cls.max(axis=0, keepdims=True) + else: + mask_weight = np.zeros([1, cls.shape[1], cls.shape[2]], dtype=np.float32) + mask = (mask > 0.5) * 2. - 1. + + data['label_cls'] = cls + data['label_loc'] = delta + data['label_loc_weight'] = delta_weight + data['label_mask'] = mask + data['label_mask_weight'] = mask_weight + data.pop('train_anno') + data.pop('test_anno') + data.pop('train_masks') + data.pop('test_masks') + + return data + + class ATOMProcessing(BaseProcessing): """ The processing class used for training ATOM. The images are processed in the following way. First, the target bounding box is jittered by adding some noise. Next, a square region (called search region ) diff --git a/PaddleCV/tracking/ltr/data/sampler.py b/PaddleCV/tracking/ltr/data/sampler.py index 064c604dfe..cdee571add 100644 --- a/PaddleCV/tracking/ltr/data/sampler.py +++ b/PaddleCV/tracking/ltr/data/sampler.py @@ -1,3 +1,4 @@ +import random import numpy as np import dataflow as df from pytracking.libs import TensorDict @@ -178,3 +179,267 @@ def __iter__(self): # Send for processing yield self.processing(data, rng=self.rng) + + +class MaskSampler(df.RNGDataFlow): + """ Class responsible for sampling frames from training sequences to form batches. Each training sample is a + tuple consisting of i) a train frame, used to obtain the modulation vector, and ii) a set of test frames on which + the IoU prediction loss is calculated. + + The sampling is done in the following ways. First a dataset is selected at random. Next, a sequence is selected + from that dataset. A 'train frame' is then sampled randomly from the sequence. Next, depending on the + frame_sample_mode, the required number of test frames are sampled randomly, either from the range + [train_frame_id - max_gap, train_frame_id + max_gap] in the 'default' mode, or from [train_frame_id, train_frame_id + max_gap] + in the 'causal' mode. Only the frames in which the target is visible are sampled, and if enough visible frames are + not found, the 'max_gap' is incremented. + + The sampled frames are then passed through the input 'processing' function for the necessary processing- + """ + + def __init__(self, + datasets, + p_datasets, + samples_per_epoch, + max_gap, + num_test_frames=1, + processing=no_processing, + frame_sample_mode='default', + neg=0): + """ + args: + datasets - List of datasets to be used for training + p_datasets - List containing the probabilities by which each dataset will be sampled + samples_per_epoch - Number of training samples per epoch + max_gap - Maximum gap, in frame numbers, between the train (reference) frame and the test frames. + num_test_frames - Number of test frames used for calculating the rpn/mask prediction loss. + processing - An instance of Processing class which performs the necessary processing of the data. + frame_sample_mode - Either 'default' or 'causal'. If 'causal', then the test frames are sampled in a causal + manner. + neg - Probability of sampling a negative sample pair. + """ + self.datasets = datasets + + # If p not provided, sample uniformly from all videos + if p_datasets is None: + p_datasets = [1 for d in self.datasets] + + # Normalize + p_total = sum(p_datasets) + self.p_datasets = [x / p_total for x in p_datasets] + + self.samples_per_epoch = samples_per_epoch + self.max_gap = max_gap + self.num_test_frames = num_test_frames + self.num_train_frames = 1 # Only a single train frame allowed + self.processing = processing + self.frame_sample_mode = frame_sample_mode + self.neg = neg + + def __len__(self): + return self.samples_per_epoch + + def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None): + """ Samples num_ids frames between min_id and max_id for which target is visible + + args: + visible - 1d Tensor indicating whether target is visible for each frame + num_ids - number of frames to be samples + min_id - Minimum allowed frame number + max_id - Maximum allowed frame number + + returns: + list - List of sampled frame numbers. None if not sufficient visible frames could be found. + """ + if min_id is None or min_id < 0: + min_id = 0 + if max_id is None or max_id > len(visible): + max_id = len(visible) + + valid_ids = [i for i in range(min_id, max_id) if visible[i]] + + # No visible ids + if len(valid_ids) == 0: + return None + + inds = self.rng.choice(range(len(valid_ids)), size=num_ids, replace=True) + ids = [valid_ids[ii] for ii in inds] + # return random.choices(valid_ids, k=num_ids) + return ids + + def has_mask(self, dataset): + return dataset.get_name() in ['coco', 'youtubevos'] + + def _get_positive_pair(self, dataset): + is_video_dataset = dataset.is_video_sequence() + + min_visible_frames = 2 * (self.num_test_frames + self.num_train_frames) + enough_visible_frames = False + + # Sample a sequence with enough visible frames and get anno for the same + while not enough_visible_frames: + seq_id = self.rng.randint(0, dataset.get_num_sequences() - 1) + anno, visible = dataset.get_sequence_info(seq_id) + num_visible = np.sum(visible.astype('int64')) + enough_visible_frames = not is_video_dataset or ( + num_visible > min_visible_frames and len(visible) >= 20) + + if is_video_dataset: + train_frame_ids = None + test_frame_ids = None + gap_increase = 0 + if self.frame_sample_mode == 'default': + # Sample frame numbers + while test_frame_ids is None: + train_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames) + test_frame_ids = self._sample_visible_ids( + visible, + min_id=train_frame_ids[0] - self.max_gap - gap_increase, + max_id=train_frame_ids[0] + self.max_gap + gap_increase, + num_ids=self.num_test_frames) + gap_increase += 5 # Increase gap until a frame is found + elif self.frame_sample_mode == 'causal': + # Sample frame numbers in a causal manner, i.e. test_frame_ids > train_frame_ids + while test_frame_ids is None: + base_frame_id = self._sample_visible_ids( + visible, + num_ids=1, + min_id=self.num_train_frames - 1, + max_id=len(visible) - self.num_test_frames) + prev_frame_ids = self._sample_visible_ids( + visible, num_ids=self.num_train_frames - 1, + min_id=base_frame_id[0] - self.max_gap - gap_increase, + max_id=base_frame_id[0]) + if prev_frame_ids is None: + gap_increase += 5 + continue + train_frame_ids = base_frame_id + prev_frame_ids + test_frame_ids = self._sample_visible_ids( + visible, min_id=train_frame_ids[0] + 1, + max_id=train_frame_ids[0] + self.max_gap + gap_increase, + num_ids=self.num_test_frames) + gap_increase += 5 # Increase gap until a frame is found + else: + raise ValueError('Unknown frame_sample_mode.') + else: + train_frame_ids = [1] * self.num_train_frames + test_frame_ids = [1] * self.num_test_frames + + return seq_id, train_frame_ids, test_frame_ids, anno + + def _get_random_pair(self, train_dataset, test_dataset): + is_video_dataset = train_dataset.is_video_sequence() + + min_visible_frames = self.num_train_frames + enough_visible_frames = False + + # Sample a sequence with enough visible frames and get anno for the same + while not enough_visible_frames: + train_seq_id = self.rng.randint(0, train_dataset.get_num_sequences() - 1) + train_anno, visible = train_dataset.get_sequence_info(train_seq_id) + num_visible = np.sum(visible.astype('int64')) + enough_visible_frames = not is_video_dataset or ( + num_visible > min_visible_frames and len(visible) >= 20) + + if is_video_dataset: + # Sample frame numbers + train_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames) + else: + train_frame_ids = [1] * self.num_train_frames + + is_video_dataset = test_dataset.is_video_sequence() + + min_visible_frames = self.num_test_frames + enough_visible_frames = False + + # Sample a sequence with enough visible frames and get anno for the same + while not enough_visible_frames: + test_seq_id = self.rng.randint(0, test_dataset.get_num_sequences() - 1) + test_anno, visible = test_dataset.get_sequence_info(test_seq_id) + num_visible = np.sum(visible.astype('int64')) + enough_visible_frames = not is_video_dataset or ( + num_visible > min_visible_frames and len(visible) >= 20) + + if is_video_dataset: + # Sample frame numbers + test_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_test_frames) + else: + test_frame_ids = [1] * self.num_test_frames + + return train_seq_id, test_seq_id, train_frame_ids, test_frame_ids, train_anno, test_anno + + def __iter__(self): + """ + args: + index (int): Index (Ignored since we sample randomly) + + returns: + TensorDict - dict containing all the data blocks + """ + + neg = self.neg and self.neg > random.random() + + # Select a dataset + if neg: + dataset_idx = self.rng.choice( + range(len(self.datasets)), + p=self.p_datasets, + replace=False) + train_dataset = self.datasets[dataset_idx] + + dataset_idx = self.rng.choice( + range(len(self.datasets)), + p=self.p_datasets, + replace=False) + test_dataset = self.datasets[dataset_idx] + train_seq_id, test_seq_id, train_frame_ids, test_frame_ids, train_anno, test_anno = \ + self._get_random_pair(train_dataset, test_dataset) + + # Get frames + train_frames, train_anno, _ = train_dataset.get_frames( + train_seq_id, + train_frame_ids, + train_anno) + train_masks = [np.zeros([frame.shape[0], frame.shape[1], 1], dtype=np.float32) + for frame in train_frames] + test_frames, test_anno, _ = test_dataset.get_frames( + test_seq_id, + test_frame_ids, + test_anno) + test_masks = [np.zeros([frame.shape[0], frame.shape[1], 1], dtype=np.float32) + for frame in test_frames] + else: + dataset_idx = self.rng.choice( + range(len(self.datasets)), + p=self.p_datasets, + replace=False) + dataset = self.datasets[dataset_idx] + seq_id, train_frame_ids, test_frame_ids, anno = self._get_positive_pair(dataset) + + # Get frames + if self.has_mask(dataset): + train_frames, train_anno, train_masks, _ = dataset.get_frames_mask( + seq_id, train_frame_ids, anno) + test_frames, test_anno, test_masks, _ = dataset.get_frames_mask( + seq_id, test_frame_ids, anno) + else: + train_frames, train_anno, _ = dataset.get_frames( + seq_id, train_frame_ids, anno) + train_masks = [np.zeros([frame.shape[0], frame.shape[1], 1], dtype=np.float32) + for frame in train_frames] + test_frames, test_anno, _ = dataset.get_frames(seq_id, test_frame_ids, anno) + test_masks = [np.zeros([frame.shape[0], frame.shape[1], 1], dtype=np.float32) + for frame in test_frames] + + # Prepare data + data = TensorDict({ + 'train_images': train_frames, + 'train_anno': train_anno, + 'train_masks': train_masks, + 'test_images': test_frames, + 'test_anno': test_anno, + 'test_masks': test_masks, + 'neg': neg + }) + + # Send for processing + yield self.processing(data, rng=self.rng) diff --git a/PaddleCV/tracking/ltr/data/transforms.py b/PaddleCV/tracking/ltr/data/transforms.py index 83c6e36111..a60aaa1fa6 100644 --- a/PaddleCV/tracking/ltr/data/transforms.py +++ b/PaddleCV/tracking/ltr/data/transforms.py @@ -80,6 +80,19 @@ def __call__(self, tensor): return (tensor - self.mean) / self.std +class Transpose(Transform): + """ Transpose image.""" + + def __call__(self, img): + if len(img.shape) == 3: + img = img.transpose((2, 0, 1)) + elif len(img.shape) == 2: + img = np.expand_dims(img, axis=0) + else: + raise NotImplementedError + return img.astype('float32') + + class ToArray(Transform): """ Transpose image and jitter brightness""" @@ -146,3 +159,53 @@ def transform(self, img, do_flip): return layers.reverse(img, 2) return np.fliplr(img).copy() return img + + +class Blur(Transform): + """ Blur the image by applying a random kernel.""" + + def __init__(self, probability=0.5): + self.probability = probability + + def roll(self): + return random.random() < self.probability + + def transform(self, img, do_blur): + def rand_kernel(): + sizes = np.arange(5, 46, 2) + size = np.random.choice(sizes) + kernel = np.zeros((size, size)) + c = int(size/2) + wx = np.random.random() + kernel[:, c] += 1. / size * wx + kernel[c, :] += 1. / size * (1-wx) + return kernel + + if do_blur: + kernel = rand_kernel() + img = cv.filter2D(img, -1, kernel) + return img + + +class Color(Transform): + """ Blur the image by applying a random kernel.""" + + def __init__(self, probability=1): + self.probability = probability + self.rgbVar = np.array( + [ + [-0.55919361, 0.98062831, - 0.41940627], + [1.72091413, 0.19879334, - 1.82968581], + [4.64467907, 4.73710203, 4.88324118] + ], + dtype=np.float32) + + def roll(self): + return random.random() < self.probability + + def transform(self, img, do_color_aug): + if do_color_aug: + offset = np.dot(self.rgbVar, np.random.randn(3, 1)) + offset = offset.reshape(3) + img = img - offset + return img diff --git a/PaddleCV/tracking/ltr/dataset/__init__.py b/PaddleCV/tracking/ltr/dataset/__init__.py index 330cd163c6..1504dc793a 100644 --- a/PaddleCV/tracking/ltr/dataset/__init__.py +++ b/PaddleCV/tracking/ltr/dataset/__init__.py @@ -2,7 +2,8 @@ from .got10k import Got10k from .tracking_net import TrackingNet from .imagenetvid import ImagenetVID +from .imagenetdet import ImagenetDET from .coco_seq import MSCOCOSeq from .vot import VOT -from .youtube_vos import VOS +from .youtube_vos import YoutubeVOS from .youtube_bb import YoutubeBB diff --git a/PaddleCV/tracking/ltr/dataset/coco_seq.py b/PaddleCV/tracking/ltr/dataset/coco_seq.py index d55442944d..690f5cb304 100644 --- a/PaddleCV/tracking/ltr/dataset/coco_seq.py +++ b/PaddleCV/tracking/ltr/dataset/coco_seq.py @@ -85,12 +85,19 @@ def _get_anno(self, seq_id): anno = self.coco_set.anns[self.sequence_list[seq_id]]['bbox'] return np.reshape(np.array(anno), (1, 4)) - def _get_frames(self, seq_id): + def _get_frames(self, seq_id, mask=False): path = self.coco_set.loadImgs( [self.coco_set.anns[self.sequence_list[seq_id]]['image_id']])[0][ 'file_name'] img = self.image_loader(os.path.join(self.img_pth, path)) - return img + + if mask: + ann = self.coco_set.anns[self.sequence_list[seq_id]] + im_mask = (self.coco_set.annToMask(ann).astype(np.float32) > 0.5).astype(np.float32) + im_mask = np.expand_dims(im_mask, axis=2) + return img, im_mask + else: + return img def get_meta_info(self, seq_id): try: @@ -128,3 +135,21 @@ def get_frames(self, seq_id=None, frame_ids=None, anno=None): object_meta = self.get_meta_info(seq_id) return frame_list, anno_frames, object_meta + + def get_frames_mask(self, seq_id=None, frame_ids=None, anno=None): + # COCO is an image dataset. Thus we replicate the image denoted by seq_id len(frame_ids) times, and return a + # list containing these replicated images. + frame, mask = self._get_frames(seq_id, mask=True) + + frame_list = [frame.copy() for _ in frame_ids] + + mask_list = [mask.copy() for _ in frame_ids] + + if anno is None: + anno = self._get_anno(seq_id) + + anno_frames = [anno.copy()[0, :] for _ in frame_ids] + + object_meta = self.get_meta_info(seq_id) + + return frame_list, anno_frames, mask_list, object_meta diff --git a/PaddleCV/tracking/ltr/dataset/imagenetdet.py b/PaddleCV/tracking/ltr/dataset/imagenetdet.py new file mode 100644 index 0000000000..297c3817fa --- /dev/null +++ b/PaddleCV/tracking/ltr/dataset/imagenetdet.py @@ -0,0 +1,143 @@ +import os +import numpy as np +from .base_dataset import BaseDataset +from ltr.data.image_loader import default_image_loader +import xml.etree.ElementTree as ET +import glob +import json +from collections import OrderedDict +import nltk +from nltk.corpus import wordnet +from ltr.admin.environment import env_settings + + +class ImagenetDET(BaseDataset): + """ Imagenet DET dataset. + + Publication: + ImageNet Large Scale Visual Recognition Challenge + Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, + Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei + IJCV, 2015 + https://arxiv.org/pdf/1409.0575.pdf + + Download the dataset from http://image-net.org/ + """ + + def __init__(self, root=None, filter=None, image_loader=default_image_loader): + """ + args: + root - path to the imagenet det dataset. + image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) + is used by default. + """ + root = env_settings().imagenetdet_dir if root is None else root + super().__init__(root, image_loader) + self.filter = filter + + self.set_list = ['ILSVRC2013_train', 'ILSVRC2014_train_0000', + 'ILSVRC2014_train_0001', 'ILSVRC2014_train_0002', + 'ILSVRC2014_train_0003', 'ILSVRC2014_train_0004', + 'ILSVRC2014_train_0005', 'ILSVRC2014_train_0006'] + + cache_file = os.path.join(root, 'cache.json') + if os.path.isfile(cache_file): + # If available, load the pre-processed cache file containing meta-info for each sequence + with open(cache_file, 'r') as f: + sequence_list_dict = json.load(f) + + self.sequence_list = sequence_list_dict + else: + # Else process the imagenet annotations and generate the cache file + self.sequence_list = self._process_anno(root) + + with open(cache_file, 'w') as f: + json.dump(self.sequence_list, f) + + def is_video_sequence(self): + return False + + def get_name(self): + return 'imagenetdet' + + def get_num_sequences(self): + return len(self.sequence_list) + + def get_sequence_info(self, seq_id): + anno = self._get_anno(seq_id) + target_visible = (anno[:, 2] > 0) & (anno[:, 3] > 0) + if self.filter: + target_large = (anno[:, 2] * anno[:, 3] > 30 * 30) + ratio = anno[:, 2] / anno[:, 3] + target_reasonable_ratio = (10 > ratio) & (ratio > 0.1) + target_visible = target_visible & target_reasonable_ratio & target_large + return anno, target_visible + + def _get_anno(self, seq_id): + anno = self.sequence_list[seq_id]['anno'] + return np.reshape(np.array(anno), (1, 4)) + + def _get_frames(self, seq_id): + set_name = self.set_list[self.sequence_list[seq_id]['set_id']] + folder = self.sequence_list[seq_id]['folder'] + if folder == set_name: + folder = '' + filename = self.sequence_list[seq_id]['filename'] + + frame_path = os.path.join(self.root, 'Data', 'DET', 'train', set_name, folder, + '{:s}.JPEG'.format(filename)) + return self.image_loader(frame_path) + + def get_frames(self, seq_id, frame_ids, anno=None): + # ImageNet DET is an image dataset. Thus we replicate the image denoted by seq_id len(frame_ids) times, and return a + # list containing these replicated images. + frame = self._get_frames(seq_id) + + frame_list = [frame.copy() for _ in frame_ids] + + if anno is None: + anno = self._get_anno(seq_id) + + anno_frames = [anno.copy()[0, :] for _ in frame_ids] + + object_meta = OrderedDict({'object_class': self.sequence_list[seq_id]['class_name'], + 'motion_class': None, + 'major_class': None, + 'root_class': None, + 'motion_adverb': None}) + + return frame_list, anno_frames, object_meta + + def _process_anno(self, root): + # Builds individual tracklets + base_det_anno_path = os.path.join(root, 'Annotations', 'DET', 'train') + + all_sequences = [] + for set_id, set in enumerate(self.set_list): + if set_id == 0: + xmls = sorted(glob.glob(os.path.join(base_det_anno_path, set, '*', '*.xml'))) + else: + xmls = sorted(glob.glob(os.path.join(base_det_anno_path, set, '*.xml'))) + for xml in xmls: + xmltree = ET.parse(xml) + folder = xmltree.find('folder').text + filename = xmltree.find('filename').text + image_size = [int(xmltree.find('size/width').text), int(xmltree.find('size/height').text)] + objects = xmltree.findall('object') + # Find all objects + for id, object_iter in enumerate(objects): + bndbox = object_iter.find('bndbox') + x1 = int(bndbox.find('xmin').text) + y1 = int(bndbox.find('ymin').text) + x2 = int(bndbox.find('xmax').text) + y2 = int(bndbox.find('ymax').text) + object_anno = [x1, y1, x2 - x1, y2 - y1] + class_name = None + if x2 <= x1 or y2 <= y1: + continue + + new_sequence = {'set_id': set_id, 'folder': folder, 'filename': filename, + 'class_name': class_name, 'anno': object_anno, 'image_size': image_size} + all_sequences.append(new_sequence) + + return all_sequences diff --git a/PaddleCV/tracking/ltr/dataset/youtube_vos.py b/PaddleCV/tracking/ltr/dataset/youtube_vos.py index f884272f4b..45474d1cce 100644 --- a/PaddleCV/tracking/ltr/dataset/youtube_vos.py +++ b/PaddleCV/tracking/ltr/dataset/youtube_vos.py @@ -2,151 +2,244 @@ from .base_dataset import BaseDataset from ltr.data.image_loader import default_image_loader import numpy as np -import cv2 as cv import json +import cv2 from collections import OrderedDict from ltr.admin.environment import env_settings -def get_axis_aligned_bbox(region): - region = np.array(region) - if len(region.shape) == 3: - # region (1,4,2) - region = np.array([ - region[0][0][0], region[0][0][1], region[0][1][0], region[0][1][1], - region[0][2][0], region[0][2][1], region[0][3][0], region[0][3][1] - ]) - - cx = np.mean(region[0::2]) - cy = np.mean(region[1::2]) - x1 = min(region[0::2]) - - x2 = max(region[0::2]) - y1 = min(region[1::2]) - y2 = max(region[1::2]) - - A1 = np.linalg.norm(region[0:2] - region[2:4]) * np.linalg.norm(region[ - 2:4] - region[4:6]) - A2 = (x2 - x1) * (y2 - y1) - s = np.sqrt(A1 / A2) - if s is np.nan: - x11, y11, w, h = 0, 0, 0, 0 - else: - w = s * (x2 - x1) + 1 - h = s * (y2 - y1) + 1 +def get_target_to_image_ratio(seq): + anno = np.array(seq['anno']) + img_sz = np.array(seq['image_size']) + return np.sqrt(anno[0, 2:4].prod() / (img_sz.prod())) + + +class Instance(object): + instID = 0 + pixelCount = 0 + + def __init__(self, imgNp, instID): + if (instID ==0 ): + return + self.instID = int(instID) + self.pixelCount = int(self.getInstancePixels(imgNp, instID)) + + def getInstancePixels(self, imgNp, instLabel): + return (imgNp == instLabel).sum() - x11 = cx - w // 2 - y11 = cy - h // 2 - return x11, y11, w, h + def toDict(self): + buildDict = {} + buildDict["instID"] = self.instID + buildDict["pixelCount"] = self.pixelCount + return buildDict + def __str__(self): + return "("+str(self.instID)+")" -class VOS(BaseDataset): - def __init__(self, root=None, image_loader=default_image_loader): - # root = env_settings().vot_dir if root is None else root - assert root is not None + +def xyxy_to_xywh(xyxy): + """Convert [x1 y1 x2 y2] box format to [x1 y1 w h] format.""" + if isinstance(xyxy, (list, tuple)): + # Single box given as a list of coordinates + assert len(xyxy) == 4 + x1, y1 = xyxy[0], xyxy[1] + w = xyxy[2] - x1 + 1 + h = xyxy[3] - y1 + 1 + return (x1, y1, w, h) + elif isinstance(xyxy, np.ndarray): + # Multiple boxes given as a 2D ndarray + return np.hstack((xyxy[:, 0:2], xyxy[:, 2:4] - xyxy[:, 0:2] + 1)) + else: + raise TypeError('Argument xyxy must be a list, tuple, or numpy array.') + + +def polys_to_boxes(polys): + """Convert a list of polygons into an array of tight bounding boxes.""" + boxes_from_polys = np.zeros((len(polys), 4), dtype=np.float32) + for i in range(len(polys)): + poly = polys[i] + x0 = min(min(p[::2]) for p in poly) + x1 = max(max(p[::2]) for p in poly) + y0 = min(min(p[1::2]) for p in poly) + y1 = max(max(p[1::2]) for p in poly) + boxes_from_polys[i, :] = [x0, y0, x1, y1] + return boxes_from_polys + + +class YoutubeVOS(BaseDataset): + """ Youtube-VOS dataset. + + Publication: + + https://arxiv.org/pdf/ + + Download the dataset from https://youtube-vos.org/dataset/download + """ + + def __init__(self, root=None, filter=None, image_loader=default_image_loader, min_length=1, max_target_area=1): + """ + args: + root - path to the youtube-vos dataset. + image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) + is used by default. + min_length - Minimum allowed sequence length. + max_target_area - max allowed ratio between target area and image area. Can be used to filter out targets + which cover complete image. + """ + root = env_settings().youtubevos_dir if root is None else root super().__init__(root, image_loader) - with open(os.path.join(self.root, 'meta.json')) as f: - self.meta = json.load(f)['videos'] - - self.sequence_list = self._get_sequence_list() - self.ann = self._get_annotations() - - def _get_sequence_list(self): - seq_list = [] - videos = self.meta.keys() - for v in videos: - objs = self.meta[v]['objects'].keys() - for o in objs: - if "rotate_box" in self.meta[v]['objects'][o]: - seq_list.append((v, o)) - assert len(seq_list) > 0 - return seq_list - - def _get_annotations(self): - ann = {} - for seq in self.sequence_list: - ann[seq] = {'bbox': [], 'rbb': []} - polygons = self.meta[seq[0]]['objects'][seq[1]]['rotate_box'] - for vs in polygons: - if len(vs) == 4: - polys = [ - vs[0], vs[1] + vs[3] - 1, vs[0], vs[1], - vs[0] + vs[2] - 1, vs[1], vs[0] + vs[2] - 1, - vs[1] + vs[3] - 1 - ] - else: - polys = vs - if not np.all(polys == 0): - box = get_axis_aligned_bbox(polys) - rbb = cv.minAreaRect( - np.int0(np.array(polys).reshape((-1, 2)))) - else: - box = np.array([0, 0, 0, 0]) - rbb = ((0, 0), (0, 0), 0) - if box[2] * box[3] > 500 * 500: - print(box) - # assume small rotation angle, switch height, width - if rbb[2] < -45: - angle = rbb[2] + 90 - height = rbb[1][0] - width = rbb[1][1] - else: - angle = rbb[2] - height = rbb[1][1] - width = rbb[1][0] - rbb = [rbb[0][0], rbb[0][1], width, height, angle] - ann[seq]['bbox'].append(box) - ann[seq]['rbb'].append(rbb) - return ann - - def is_video_sequence(self): - return True + cache_file = os.path.join(root, 'cache.json') + if os.path.isfile(cache_file): + # If available, load the pre-processed cache file containing meta-info for each sequence + with open(cache_file, 'r') as f: + sequence_list_dict = json.load(f) + + self.sequence_list = sequence_list_dict + else: + # Else process the youtube-vos annotations and generate the cache file + print('processing the youtube-vos annotations...') + self.sequence_list = self._process_anno(root) + + with open(cache_file, 'w') as f: + json.dump(self.sequence_list, f) + print('cache file generated!') + + # Filter the sequences based on min_length and max_target_area in the first frame + self.sequence_list = [x for x in self.sequence_list if len(x['anno']) >= min_length and + get_target_to_image_ratio(x) < max_target_area] + self.filter = filter def get_name(self): - return 'vot' + return 'youtubevos' def get_num_sequences(self): return len(self.sequence_list) def get_sequence_info(self, seq_id): - anno = self._get_anno(seq_id) + anno = np.array(self.sequence_list[seq_id]['anno']) target_visible = (anno[:, 2] > 0) & (anno[:, 3] > 0) - target_large = (anno[:, 2] * anno[:, 3] > 30 * 30) - target_resonable = (anno[:, 2] * anno[:, 3] < 500 * 500) - return anno, target_visible & target_large & target_resonable - - def _get_anno(self, seq_id): - anno = self.ann[self.sequence_list[seq_id]]['bbox'] - return np.reshape(np.array(anno), (-1, 4)) - - def get_meta_info(self, seq_id): - object_meta = OrderedDict({ - 'object_class': None, - 'motion_class': None, - 'major_class': None, - 'root_class': None, - 'motion_adverb': None - }) - return object_meta - - def _get_frame_path(self, seq_id, frame_id): - v, o = self.sequence_list[seq_id] - frame_name = self.meta[v]['objects'][o]['frames'][frame_id] - return os.path.join(self.root, 'JPEGImages', v, - '{}.jpg'.format(frame_name)) # frames start from 1 - - def _get_frame(self, seq_id, frame_id): - return self.image_loader(self._get_frame_path(seq_id, frame_id)) - - def get_frames(self, seq_id=None, frame_ids=None, anno=None): - frame_list = [self._get_frame(seq_id, f_id) for f_id in frame_ids] + if self.filter is not None: + target_large = (anno[:, 2] * anno[:, 3] > 30 * 30) + ratio = anno[:, 2] / anno[:, 3] + target_reasonable_ratio = (10 > ratio) & (ratio > 0.1) + target_visible = target_visible & target_reasonable_ratio & target_large + return anno, target_visible + + def _get_frame(self, sequence, frame_id): + vid_name = sequence['video'] + frame_number = sequence['frames'][frame_id] + + frame_path = os.path.join(self.root, 'train', 'JPEGImages', vid_name, + '{:05d}.jpg'.format(frame_number)) + return self.image_loader(frame_path) + + def _get_mask(self, sequence, frame_id): + vid_name = sequence['video'] + frame_number = sequence['frames'][frame_id] + id = sequence['id'] + + mask_path = os.path.join(self.root, 'train', 'Annotations', vid_name, + '{:05d}.png'.format(frame_number)) + mask = cv2.imread(mask_path, 0) + mask = (mask == id).astype(np.float32) + mask = np.expand_dims(mask, axis=2) + return mask + + def get_frames(self, seq_id, frame_ids, anno=None): + sequence = self.sequence_list[seq_id] + + frame_list = [self._get_frame(sequence, f) for f in frame_ids] if anno is None: - anno = self._get_anno(seq_id) + anno = sequence['anno'] + # Return as list of tensors anno_frames = [anno[f_id, :] for f_id in frame_ids] - object_meta = self.get_meta_info(seq_id) + # added the class info to the meta info + object_meta = OrderedDict({'object_class': sequence['class_name'], + 'motion_class': None, + 'major_class': None, + 'root_class': None, + 'motion_adverb': None}) return frame_list, anno_frames, object_meta + + def get_frames_mask(self, seq_id, frame_ids, anno=None): + sequence = self.sequence_list[seq_id] + + frame_list = [self._get_frame(sequence, f) for f in frame_ids] + mask_list = [self._get_mask(sequence, f) for f in frame_ids] + + if anno is None: + anno = sequence['anno'] + + # Return as list of tensors + anno_frames = [anno[f_id, :] for f_id in frame_ids] + + # added the class info to the meta info + object_meta = OrderedDict({'object_class': sequence['class_name'], + 'motion_class': None, + 'major_class': None, + 'root_class': None, + 'motion_adverb': None}) + + return frame_list, anno_frames, mask_list, object_meta + + def _process_anno(self, root): + # Builds individual tracklets + base_anno_path = os.path.join(root, 'train', 'Annotations') + + num_obj = 0 + num_ann = 0 + all_sequences = [] + meta = json.load(open(os.path.join(base_anno_path, '../meta.json'))) + for vid_id, video in enumerate(meta['videos']): + v = meta['videos'][video] + frames = [] + objects = dict() + for obj in v['objects']: + o = v['objects'][obj] + frames.extend(o['frames']) + frames = sorted(set(frames)) + + for frame in frames: + file_name = os.path.join(video, frame) + img = cv2.imread(os.path.join(base_anno_path, file_name+'.png'), 0) + h, w = img.shape[:2] + image_size = [w, h] + + for instanceId in np.unique(img): + if instanceId == 0: + continue + instanceObj = Instance(img, instanceId) + instanceObj_dict = instanceObj.toDict() + mask = (img == instanceId).astype(np.uint8) + if cv2.__version__[0] == '3': + _, contour, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + else: + contour, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + polygons = [c.reshape(-1).tolist() for c in contour] + instanceObj_dict['contours'] = [p for p in polygons if len(p) > 4] + if len(instanceObj_dict['contours']) and instanceObj_dict['pixelCount'] > 1000: + len_p = [len(p) for p in instanceObj_dict['contours']] + if min(len_p) <= 4: + print('Warning: invalid contours.') + continue # skip non-instance categories + + bbox = xyxy_to_xywh( + polys_to_boxes([instanceObj_dict['contours']])).tolist()[0] + if instanceId not in objects: + objects[instanceId] = \ + {'anno': [], 'frames': [], 'image_size': image_size} + objects[instanceId]['anno'].append(bbox) + objects[instanceId]['frames'].append(int(frame)) + + for obj in objects: + new_sequence = {'video': video, 'id': int(obj), 'class_name': None, + 'frames': objects[obj]['frames'], 'anno': objects[obj]['anno'], + 'image_size': image_size} + all_sequences.append(new_sequence) + print('Youtube-VOS: ', len(all_sequences)) + return all_sequences diff --git a/PaddleCV/tracking/ltr/models/backbone/alexnet.py b/PaddleCV/tracking/ltr/models/backbone/alexnet.py new file mode 100644 index 0000000000..e70a70a572 --- /dev/null +++ b/PaddleCV/tracking/ltr/models/backbone/alexnet.py @@ -0,0 +1,178 @@ +import os + +from paddle import fluid +from paddle.fluid.dygraph import nn +from ltr.admin.environment import env_settings + +CURRENT_DIR = os.path.dirname(__file__) + + +class alexnet(fluid.dygraph.Layer): + def __init__(self, name, is_test, output_layers): + super(alexnet, self).__init__() + + self.is_test = is_test + self.layer_init() + self.output_layers = output_layers + + def layer_init(self): + # for conv1 + self.conv1 = nn.Conv2D( + num_channels=3, + num_filters=96, + filter_size=11, + stride=2, + padding=0, + groups=1, + param_attr=self.weight_init(), + bias_attr=self.bias_init()) + self.bn1 = nn.BatchNorm( + num_channels=96, + is_test=self.is_test, + param_attr=self.norm_weight_init(), + bias_attr=self.bias_init(), + use_global_stats=self.is_test) + self.pool1 = nn.Pool2D( + pool_size=3, pool_type="max", pool_stride=2, pool_padding=0) + # for conv2 + self.conv2 = nn.Conv2D( + num_channels=96, + num_filters=256, + filter_size=5, + stride=1, + padding=0, + groups=1, + param_attr=self.weight_init(), + bias_attr=self.bias_init()) + self.bn2 = nn.BatchNorm( + num_channels=256, + is_test=self.is_test, + param_attr=self.norm_weight_init(), + bias_attr=self.bias_init(), + use_global_stats=self.is_test) + self.pool2 = nn.Pool2D( + pool_size=3, pool_type="max", pool_stride=2, pool_padding=0) + # for conv3 + self.conv3 = nn.Conv2D( + num_channels=256, + num_filters=384, + filter_size=3, + stride=1, + padding=0, + groups=1, + param_attr=self.weight_init(), + bias_attr=self.bias_init()) + self.bn3 = nn.BatchNorm( + num_channels=384, + is_test=self.is_test, + param_attr=self.norm_weight_init(), + bias_attr=self.bias_init(), + use_global_stats=self.is_test) + # for conv4 + self.conv4 = nn.Conv2D( + num_channels=384, + num_filters=384, + filter_size=3, + stride=1, + padding=0, + groups=1, + param_attr=self.weight_init(), + bias_attr=self.bias_init()) + self.bn4 = nn.BatchNorm( + num_channels=384, + is_test=self.is_test, + param_attr=self.norm_weight_init(), + bias_attr=self.bias_init(), + use_global_stats=self.is_test) + # for conv5 + self.conv5 = nn.Conv2D( + num_channels=384, + num_filters=256, + filter_size=3, + stride=1, + padding=0, + groups=1, + param_attr=self.weight_init(), + bias_attr=self.bias_init()) + self.bn5 = nn.BatchNorm( + num_channels=256, + is_test=self.is_test, + param_attr=self.norm_weight_init(), + bias_attr=self.bias_init(), + use_global_stats=self.is_test) + + def _add_output_and_check(self, name, x, outputs): + if name in self.output_layers: + outputs.append(x) + return len(self.output_layers) == len(outputs) + + @fluid.dygraph.no_grad + def forward(self, inputs): + outputs = [] + + out1 = self.conv1(inputs) + out1 = self.bn1(out1) + out1 = fluid.layers.relu(out1) + if self._add_output_and_check('conv1', out1, outputs): + outputs[-1].stop_gradient = True if self.is_test else False + return outputs[0] if len(outputs) == 1 else outputs + + out1 = self.pool1(out1) + + out2 = self.conv2(out1) + out2 = self.bn2(out2) + out2 = fluid.layers.relu(out2) + if self._add_output_and_check('conv2', out2, outputs): + outputs[-1].stop_gradient = True if self.is_test else False + return outputs[0] if len(outputs) == 1 else outputs + + out2 = self.pool2(out2) + + out3 = self.conv3(out2) + out3 = self.bn3(out3) + out3 = fluid.layers.relu(out3) + if self._add_output_and_check('conv3', out3, outputs): + outputs[-1].stop_gradient = True if self.is_test else False + return outputs[0] if len(outputs) == 1 else outputs + + out4 = self.conv4(out3) + out4 = self.bn4(out4) + out4 = fluid.layers.relu(out4) + if self._add_output_and_check('conv4', out4, outputs): + outputs[-1].stop_gradient = True if self.is_test else False + return outputs[0] if len(outputs) == 1 else outputs + + out5 = self.conv5(out4) + out5 = self.bn5(out5) + if self._add_output_and_check('conv5', out5, outputs): + outputs[-1].stop_gradient = True if self.is_test else False + return outputs[0] if len(outputs) == 1 else outputs + + outputs[-1].stop_gradient = True if self.is_test else False + return outputs[0] if len(outputs) == 1 else outputs + + def norm_weight_init(self): + init = fluid.initializer.ConstantInitializer(1.0) + param = fluid.ParamAttr(initializer=init) + return param + + def weight_init(self): + init = fluid.initializer.MSRAInitializer(uniform=False) + param = fluid.ParamAttr(initializer=init) + return param + + def bias_init(self): + init = fluid.initializer.ConstantInitializer(value=0.) + param = fluid.ParamAttr(initializer=init) + return param + + +def AlexNet(name, is_test, output_layers, pretrained=False): + net = alexnet(name, is_test=is_test, output_layers=output_layers) + if pretrained: + params_path = os.path.join(env_settings().backbone_dir, 'AlexNet') + print("=> loading backbone model from '{}'".format(params_path)) + params, _ = fluid.load_dygraph(params_path) + net.load_dict(params) + print("Done") + return net diff --git a/PaddleCV/tracking/ltr/models/backbone/resnet_dilated.py b/PaddleCV/tracking/ltr/models/backbone/resnet_dilated.py new file mode 100644 index 0000000000..52240f5370 --- /dev/null +++ b/PaddleCV/tracking/ltr/models/backbone/resnet_dilated.py @@ -0,0 +1,320 @@ +import os + +import paddle.fluid as fluid +import paddle.fluid.dygraph.nn as nn +from ltr.admin.environment import env_settings + +CURRENT_DIR = os.path.dirname(__file__) + + +def weight_init(): + init = fluid.initializer.MSRAInitializer(uniform=False) + param = fluid.ParamAttr(initializer=init) + return param + + +def norm_weight_init(constant=1.0): + init = fluid.initializer.ConstantInitializer(constant) + param = fluid.ParamAttr(initializer=init) + return param + + +def norm_bias_init(): + init = fluid.initializer.ConstantInitializer(value=0.) + param = fluid.ParamAttr(initializer=init) + return param + + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + in_channels, + out_channels, + filter_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bn_init_constant=1.0): + super(ConvBNLayer, self).__init__() + + self.conv = nn.Conv2D( + num_channels=in_channels, + filter_size=filter_size, + num_filters=out_channels, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + param_attr=weight_init(), + bias_attr=False) + self.bn = nn.BatchNorm( + out_channels, + param_attr=norm_weight_init(bn_init_constant), + bias_attr=norm_bias_init(), + act=None, + momentum=0.1, + use_global_stats=True) + + def forward(self, inputs): + res = self.conv(inputs) + self.conv_res = res + res = self.bn(res) + return res + + +class BasicBlock(fluid.dygraph.Layer): + expansion = 1 + + def __init__(self, + in_channels, + out_channels, + stride=1, + is_downsample=None): + + super(BasicBlock, self).__init__() + + self.expansion = 1 + + self.conv_bn1 = ConvBNLayer( + num_channels=in_channels, + out_channels=out_channels, + filter_size=3, + stride=stride, + groups=1) + self.conv_bn2 = ConvBNLayer( + out_channels=out_channels, + filter_size=3, + stride=1, + groups=1) + + self.is_downsample = is_downsample + if self.is_downsample: + self.downsample = ConvBNLayer( + num_channels=in_channels, + out_channels=out_channels, + filter_size=1, + stride=stride) + + self.stride = stride + + def forward(self, inputs): + identity = inputs + res = self.conv_bn1(inputs) + res = fluid.layers.relu(res) + + res = self.conv_bn2(res) + + if self.is_downsample: + identity = self.downsample(identity) + + res += identity + res = fluid.layers.relu(res) + return res + + +class Bottleneck(fluid.dygraph.Layer): + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + stride=1, + downsample=None, + base_width=64, + dilation=1, + groups=1): + super(Bottleneck, self).__init__() + + width = int(out_channels*(base_width / 64.))*groups + + self.conv_bn1 = ConvBNLayer( + in_channels=in_channels, + filter_size=1, + out_channels=width, + groups=1) + + padding = 2 - stride + if downsample is not None and dilation > 1: + dilation = dilation // 2 + padding = dilation + + assert stride == 1 or dilation == 1, \ + "stride and dilation must have one equals to zero at least" + + if dilation > 1: + padding = dilation + + self.conv_bn2 = ConvBNLayer( + in_channels=width, + filter_size=3, + out_channels=width, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups) + self.conv_bn3 = ConvBNLayer( + in_channels=width, + filter_size=1, + out_channels=out_channels*self.expansion, + bn_init_constant=0.) + + self.downsample = downsample + self.stride = stride + + def forward(self, inputs): + identify = inputs + + out = self.conv_bn1(inputs) + out = fluid.layers.relu(out) + + out = self.conv_bn2(out) + out = fluid.layers.relu(out) + + out = self.conv_bn3(out) + + if self.downsample is not None: + identify = self.downsample(inputs) + + out += identify + out = fluid.layers.relu(out) + return out + + +class ResNet(fluid.dygraph.Layer): + def __init__(self, name, Block, layers, output_layers, is_test=False): + """ + + :param name: str, namescope + :param layers: int, the layer of defined network + :param output_layers: list of int, the layers for output + """ + super(ResNet, self).__init__(name_scope=name) + + support_layers = [50] + assert layers in support_layers, \ + "support layer can only be one of [50, ]" + self.layers = layers + self.feat_layers = ['block{}'.format(i) for i in output_layers] + output_depth = max(output_layers) + 1 + self.is_test = is_test + + if layers == 18: + depths = [2, 2, 2, 2] + elif layers == 50 or layers == 34: + depths = [3, 4, 6, 3] + elif layers == 101: + depths = [3, 4, 23, 3] + elif layers == 152: + depths = [3, 8, 36, 3] + + strides = [1, 2, 1, 1] + num_filters = [64, 128, 256, 512] + dilations = [1, 1, 2, 4] + + self.in_channels = 64 + self.dilation = 1 + + self.conv_bn_init = ConvBNLayer( + in_channels=3, + out_channels=self.in_channels, + filter_size=7, + stride=2) + + self.maxpool = nn.Pool2D( + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type="max") + + block_collect = [] + downsample = None + for i in range(min(len(depths), output_depth)): + # collect layers in each block + _block = [] + + stride = strides[i] + out_channel = num_filters[i] + dilation = dilations[i] + + if stride != 1 or self.in_channels != self.in_channels*Block.expansion: + if stride == 1 and dilation == 1: + downsample = ConvBNLayer( + in_channels=self.in_channels, + out_channels=out_channel*Block.expansion, + filter_size=1, + stride=stride) + else: + if dilation > 1: + dd = dilation // 2 + padding = dd + else: + dd = 1 + padding = 0 + downsample = ConvBNLayer( + in_channels=self.in_channels, + out_channels=out_channel*Block.expansion, + filter_size=3, + stride=stride, + padding=padding, + dilation=dd) + + bottleneck_block = self.add_sublayer( + "block{}_0".format(i), + Block( + in_channels=self.in_channels, + out_channels=out_channel, + stride=stride, + dilation=dilation, + downsample=downsample)) + + _block.append(bottleneck_block) + + self.in_channels = num_filters[i]*Block.expansion + + for j in range(1, depths[i]): + bottleneck_block = self.add_sublayer( + "block{}_{}".format(i, j), + Block(self.in_channels, out_channel, dilation=dilation)) + _block.append(bottleneck_block) + + # collect blocks + block_collect.append(_block) + + self.block_collect = block_collect + + @fluid.dygraph.no_grad + def forward(self, inputs): + out = [] + res = self.conv_bn_init(inputs) + res = fluid.layers.relu(res) + out.append(res) + res = self.maxpool(res) + for i in range(len(self.block_collect)): + + for layer in self.block_collect[i]: + res = layer(res) + + name = 'block{}'.format(i) + if name in self.feat_layers: + out.append(res) + if (len(out) - 1) == len(self.feat_layers): + out[-1].stop_gradient = True if self.is_test else False + if len(out) == 1: + return out[0] + else: + return out + + out[-1].stop_gradient = True if self.is_test else False + return out + + +def resnet50(name, pretrained=False, **kwargs): + net = ResNet(name, Block=Bottleneck, layers=50, **kwargs) + if pretrained: + params_path = os.path.join(env_settings().backbone_dir, 'ResNet50_dilated') + print("=> loading backbone model from '{}'".format(params_path)) + params, _ = fluid.load_dygraph(params_path) + net.load_dict(params) + print("Done") + + return net diff --git a/PaddleCV/tracking/ltr/models/loss.py b/PaddleCV/tracking/ltr/models/loss.py new file mode 100644 index 0000000000..e28b003887 --- /dev/null +++ b/PaddleCV/tracking/ltr/models/loss.py @@ -0,0 +1,121 @@ +import paddle.fluid as fluid +import numpy as np + + +def get_cls_loss(pred, label, select): + if select.shape[0] == 0: + return fluid.layers.reduce_sum(pred) * 0 + pred = fluid.layers.gather(pred, select) + label = fluid.layers.gather(label, select) + label = fluid.layers.reshape(label, [-1, 1]) + loss = fluid.layers.softmax_with_cross_entropy( + logits = pred, + label = label) + return fluid.layers.mean(loss) + + +def select_softmax_with_cross_entropy_loss(pred, label): + b, c, h, w = pred.shape + pred = fluid.layers.reshape(pred, [b, 2, -1, h, w]) + pred = fluid.layers.transpose(pred, [0, 2, 3, 4, 1]) + pred = fluid.layers.reshape(pred, [-1, 2]) + label = fluid.layers.reshape(label, [-1]) + pos = fluid.layers.where(label == 1) + neg = fluid.layers.where(label == 0) + loss_pos = get_cls_loss(pred, label, pos) + loss_neg = get_cls_loss(pred, label, neg) + return loss_pos * 0.5 + loss_neg * 0.5 + + +def weight_l1_loss(pred_loc, label_loc, loss_weight): + b, c, h, w = pred_loc.shape + pred_loc = fluid.layers.reshape(pred_loc, [b, 4, -1, h, w]) + loss = fluid.layers.abs(pred_loc - label_loc) + loss = fluid.layers.reduce_sum(loss, dim=1) + loss = loss * loss_weight + return fluid.layers.reduce_sum(loss) / b + + +def soft_margin_loss(pred, label): + #loss = fluid.layers.elementwise_mul(pred, label) + loss = fluid.layers.exp(-1 * pred * label) + loss = fluid.layers.log(1 + loss) + return fluid.layers.reduce_mean(loss) + + +def iou_measure(pred, label): + pred = fluid.layers.cast(pred >= 0, 'float32') + pred = fluid.layers.cast(pred == 1, 'float32') + label = fluid.layers.cast(label == 1, 'float32') + mask_sum = pred + label + intxn = fluid.layers.reduce_sum( + fluid.layers.cast(mask_sum == 2, 'float32'), dim=1) + union = fluid.layers.reduce_sum( + fluid.layers.cast(mask_sum > 0, 'float32'), dim=1) + iou = intxn / union + iou_m = fluid.layers.reduce_mean(iou) + iou_5 = fluid.layers.cast(iou > 0.5, 'float32') + iou_5 = fluid.layers.reduce_sum(iou_5) / iou.shape[0] + iou_7 = fluid.layers.cast(iou > 0.7, 'float32') + iou_7 = fluid.layers.reduce_sum(iou_7) / iou.shape[0] + return iou_m, iou_5, iou_7 + + +def select_mask_logistic_loss(pred_mask, label_mask, loss_weight, out_size=63, gt_size=127): + loss_weight = fluid.layers.reshape(loss_weight, [-1]) + pos = loss_weight == 1 + if np.sum(pos.numpy()) == 0: + return fluid.layers.reduce_sum(pred_mask) * 0, fluid.layers.reduce_sum(pred_mask) * 0, fluid.layers.reduce_sum(pred_mask) * 0, fluid.layers.reduce_sum(pred_mask) * 0 + pos = fluid.layers.where(pos) + if len(pred_mask.shape) == 4: + pred_mask = fluid.layers.transpose(pred_mask, [0, 2, 3, 1]) + pred_mask = fluid.layers.reshape(pred_mask, [-1, 1, out_size, out_size]) + pred_mask = fluid.layers.gather(pred_mask, pos) + pred_mask = fluid.layers.resize_bilinear(pred_mask, out_shape=[gt_size, gt_size]); + pred_mask = fluid.layers.reshape(pred_mask, [-1, gt_size * gt_size]) + label_mask_uf = fluid.layers.unfold(label_mask, [gt_size, gt_size], 8, 32) + else: + pred_mask = fluid.layers.gather(pred_mask, pos) + label_mask_uf = fluid.layers.unfold(label_mask, [gt_size, gt_size], 8, 0) + + label_mask_uf = fluid.layers.transpose(label_mask_uf, [0, 2, 1]) + label_mask_uf = fluid.layers.reshape(label_mask_uf, [-1, gt_size * gt_size]) + + label_mask_uf = fluid.layers.gather(label_mask_uf, pos) + loss = soft_margin_loss(pred_mask, label_mask_uf) + if np.isnan(loss.numpy()): + return fluid.layers.reduce_sum(pred_mask) * 0, fluid.layers.reduce_sum(pred_mask) * 0, fluid.layers.reduce_sum(pred_mask) * 0, fluid.layers.reduce_sum(pred_mask) * 0 + iou_m, iou_5, iou_7 = iou_measure(pred_mask, label_mask_uf) + return loss, iou_m, iou_5, iou_7 + + +if __name__ == "__main__": + import numpy as np + pred_mask = np.random.randn(4, 63*63, 25, 25) + weight_mask = np.random.randn(4, 1, 25, 25) > 0.9 + label_mask = np.random.randint(-1, 1, (4, 1, 255, 255)) + + pred_loc = np.random.randn(3, 32, 17, 17) + weight_loc = np.random.randn(3, 8, 17, 17) + label_loc = np.random.randn(3, 4, 8, 17, 17) + + pred_cls = np.random.randn(3, 16, 17, 17) + label_cls = np.random.randint(0, 2, (3, 8, 17, 17)) + + with fluid.dygraph.guard(): + pred_mask = fluid.dygraph.to_variable(pred_mask) + weight_mask = fluid.dygraph.to_variable(weight_mask.astype('float32')) + label_mask = fluid.dygraph.to_variable(label_mask.astype('float32')) + loss = select_mask_logistic_loss(pred_mask, label_mask, weight_mask) + print("loss_mask = ", loss) + + pred_loc = fluid.dygraph.to_variable(pred_loc) + weight_loc = fluid.dygraph.to_variable(weight_loc) + label_loc = fluid.dygraph.to_variable(label_loc) + loss = weight_l1_loss(pred_loc, label_loc, weight_loc) + print("loss_loc = ", loss) + + pred_cls = fluid.dygraph.to_variable(pred_cls) + label_cls = fluid.dygraph.to_variable(label_cls) + loss = select_softmax_with_cross_entropy_loss(pred_cls, label_cls) + print("loss_cls = ", loss) diff --git a/PaddleCV/tracking/ltr/models/siam/head.py b/PaddleCV/tracking/ltr/models/siam/head.py new file mode 100644 index 0000000000..ed34a09ac0 --- /dev/null +++ b/PaddleCV/tracking/ltr/models/siam/head.py @@ -0,0 +1,303 @@ +import paddle.fluid as fluid +import paddle.fluid.dygraph.nn as nn +import os.path as osp +import sys + +from ltr.models.siam.xcorr import xcorr, xcorr_depthwise + +CURRENT_DIR = osp.dirname(__file__) +sys.path.append(osp.join(CURRENT_DIR, '..', '..', '..')) + + +def weight_init(): + init = fluid.initializer.MSRAInitializer(uniform=False) + param = fluid.ParamAttr(initializer=init) + return param + + +def bias_init(): + init = fluid.initializer.ConstantInitializer(value=0.) + param = fluid.ParamAttr(initializer=init) + return param + + +def norm_weight_init(): + init = fluid.initializer.Uniform(low=0., high=1.) + param = fluid.ParamAttr(initializer=init) + return param + + +def norm_bias_init(): + init = fluid.initializer.ConstantInitializer(value=0.) + param = fluid.ParamAttr(initializer=init) + return param + + +class RPN(fluid.dygraph.Layer): + def __init__(self): + super(RPN, self).__init__() + + def forward(self, z_f, x_f): + raise NotImplementedError + + +class DepthwiseXCorr(fluid.dygraph.Layer): + def __init__(self, + in_channels, + hidden, + out_channels, + filter_size=3, + is_test=False): + super(DepthwiseXCorr, self).__init__() + self.kernel_conv1 = nn.Conv2D( + num_channels=in_channels, + num_filters=hidden, + filter_size=filter_size, + stride=1, + padding=0, + groups=1, + param_attr=weight_init(), + bias_attr=False) + self.kernel_bn1 = nn.BatchNorm( + num_channels=hidden, + act='relu', + param_attr=norm_weight_init(), + bias_attr=norm_bias_init(), + momentum=0.9, + use_global_stats=is_test) + + self.search_conv1 = nn.Conv2D( + num_channels=in_channels, + num_filters=hidden, + filter_size=filter_size, + stride=1, + padding=0, + groups=1, + param_attr=weight_init(), + bias_attr=False) + self.search_bn1 = nn.BatchNorm( + num_channels=hidden, + act='relu', + param_attr=norm_weight_init(), + bias_attr=norm_bias_init(), + momentum=0.9, + use_global_stats=is_test) + + self.head_conv1 = nn.Conv2D( + num_channels=hidden, + num_filters=hidden, + filter_size=1, + stride=1, + padding=0, + groups=1, + param_attr=weight_init(), + bias_attr=False) + self.head_bn1 = nn.BatchNorm( + num_channels=hidden, + act='relu', + param_attr=norm_weight_init(), + bias_attr=norm_bias_init(), + momentum=0.9, + use_global_stats=is_test) + self.head_conv2 = nn.Conv2D( + num_channels=hidden, + num_filters=out_channels, + filter_size=1, + stride=1, + padding=0, + groups=1, + param_attr=weight_init()) + + + def forward(self, kernel, search): + kernel = self.kernel_conv1(kernel) + kernel = self.kernel_bn1(kernel) + + search = self.search_conv1(search) + search = self.search_bn1(search) + + feature = xcorr_depthwise(search, kernel) + out = self.head_conv1(feature) + out = self.head_bn1(out) + out = self.head_conv2(out) + return out + + +class DepthwiseRPN(RPN): + def __init__(self, anchor_num=5, in_channels=256, out_channels=256, is_test=False): + super(DepthwiseRPN, self).__init__() + self.cls = DepthwiseXCorr(in_channels, out_channels, 2 * anchor_num, is_test=is_test) + self.loc = DepthwiseXCorr(in_channels, out_channels, 4 * anchor_num, is_test=is_test) + + def forward(self, z_f, x_f): + cls = self.cls(z_f, x_f) + loc = self.loc(z_f, x_f) + return cls, loc + + +class MaskCorr(DepthwiseXCorr): + def __init__(self, + in_channels, + hidden, + out_channels, + filter_size=3, + hidden_filter_size=5, + is_test=False): + super(MaskCorr, self).__init__( + in_channels, + hidden, + out_channels, + filter_size, + is_test) + + def forward(self, kernel, search): + kernel = self.kernel_conv1(kernel) + kernel = self.kernel_bn1(kernel) + + search = self.search_conv1(search) + search = self.search_bn1(search) + + feature = xcorr_depthwise(search, kernel) + out = self.head_conv1(feature) + out = self.head_bn1(out) + out = self.head_conv2(out) + return out, feature + + +class RefineModule(fluid.dygraph.Layer): + def __init__(self, + in_channels, + hidden1, + hidden2, + out_channels, + out_shape, + filter_size=3, + padding=1): + super(RefineModule, self).__init__() + self.v_conv0 = nn.Conv2D( + num_channels=in_channels, + num_filters=hidden1, + filter_size=filter_size, + stride=1, + padding=padding, + groups=1, + param_attr=weight_init()) + self.v_conv1 = nn.Conv2D( + num_channels=hidden1, + num_filters=hidden2, + filter_size=filter_size, + stride=1, + padding=padding, + groups=1, + param_attr=weight_init()) + self.h_conv0 = nn.Conv2D( + num_channels=hidden2, + num_filters=hidden2, + filter_size=filter_size, + stride=1, + padding=padding, + groups=1, + param_attr=weight_init()) + self.h_conv1 = nn.Conv2D( + num_channels=hidden2, + num_filters=hidden2, + filter_size=filter_size, + stride=1, + padding=padding, + groups=1, + param_attr=weight_init()) + + self.out_shape = out_shape + self.post = nn.Conv2D( + num_channels=hidden2, + num_filters=out_channels, + filter_size=filter_size, + stride=1, + padding=padding, + groups=1, + param_attr=weight_init()) + + def forward(self, xh, xv): + yh = self.h_conv0(xh) + yh = fluid.layers.relu(yh) + yh = self.h_conv1(yh) + yh = fluid.layers.relu(yh) + + yv = self.v_conv0(xv) + yv = fluid.layers.relu(yv) + yv = self.v_conv1(yv) + yv = fluid.layers.relu(yv) + + out = yh + yv + out = fluid.layers.resize_nearest(out, out_shape=self.out_shape, align_corners=False) + out = self.post(out) + return out + +class Refine(fluid.dygraph.Layer): + def __init__(self): + super(Refine, self).__init__() + self.U4 = RefineModule( + in_channels=64, + hidden1=16, + hidden2=4, + out_channels=1, + filter_size=3, + padding=1, + out_shape=[127, 127]) + + self.U3 = RefineModule( + in_channels=256, + hidden1=64, + hidden2=16, + out_channels=4, + filter_size=3, + padding=1, + out_shape=[61, 61]) + + self.U2 = RefineModule( + in_channels=512, + hidden1=128, + hidden2=32, + out_channels=16, + filter_size=3, + padding=1, + out_shape=[31, 31]) + + self.deconv = nn.Conv2DTranspose( + num_channels=256, + num_filters=32, + filter_size=15, + padding=0, + stride=15) + + + def forward(self, xf, corr_feature, pos=None, test=False): + if test: + p0 = fluid.layers.pad2d(xf[0], [16, 16, 16, 16]) + p0 = p0[:, :, 4*pos[0]:4*pos[0]+61, 4*pos[1]:4*pos[1]+61] + p1 = fluid.layers.pad2d(xf[1], [8, 8, 8, 8]) + p1 = p1[:, :, 2*pos[0]:2*pos[0]+31, 2*pos[1]:2*pos[1]+31] + p2 = fluid.layers.pad2d(xf[2], [4, 4, 4, 4]) + p2 = p2[:, :, pos[0]:pos[0]+15, pos[1]:pos[1]+15] + p3 = corr_feature[:, :, pos[0], pos[1]] + p3 = fluid.layers.reshape(p3, [-1, 256, 1, 1]) + else: + p0 = fluid.layers.unfold(xf[0], [61, 61], 4, 0) + p0 = fluid.layers.transpose(p0, [0, 2, 1]) + p0 = fluid.layers.reshape(p0, [-1, 64, 61, 61]) + p1 = fluid.layers.unfold(xf[1], [31, 31], 2, 0) + p1 = fluid.layers.transpose(p1, [0, 2, 1]) + p1 = fluid.layers.reshape(p1, [-1, 256, 31, 31]) + p2 = fluid.layers.unfold(xf[2], [15, 15], 1, 0) + p2 = fluid.layers.transpose(p2, [0, 2, 1]) + p2 = fluid.layers.reshape(p2, [-1, 512, 15, 15]) + p3 = fluid.layers.transpose(corr_feature, [0, 2, 3, 1]) + p3 = fluid.layers.reshape(p3, [-1, 256, 1, 1]) + + out = self.deconv(p3) + out = self.U2(out, p2) + out = self.U3(out, p1) + out = self.U4(out, p0) + out = fluid.layers.reshape(out, [-1, 127*127]) + + return out diff --git a/PaddleCV/tracking/ltr/models/siam/neck.py b/PaddleCV/tracking/ltr/models/siam/neck.py new file mode 100644 index 0000000000..e59183123e --- /dev/null +++ b/PaddleCV/tracking/ltr/models/siam/neck.py @@ -0,0 +1,83 @@ +import paddle.fluid as fluid +import paddle.fluid.dygraph.nn as nn +import os.path as osp +import sys + +CURRENT_DIR = osp.dirname(__file__) +sys.path.append(osp.join(CURRENT_DIR, '..', '..', '..')) + + +def weight_init(): + init = fluid.initializer.MSRAInitializer(uniform=False) + param = fluid.ParamAttr(initializer=init) + return param + + +def bias_init(): + init = fluid.initializer.ConstantInitializer(value=0.) + param = fluid.ParamAttr(initializer=init) + return param + + +def norm_weight_init(): + init = fluid.initializer.Uniform(low=0., high=1.) + param = fluid.ParamAttr(initializer=init) + return param + + +def norm_bias_init(): + init = fluid.initializer.ConstantInitializer(value=0.) + param = fluid.ParamAttr(initializer=init) + return param + + +class AdjustLayer(fluid.dygraph.Layer): + def __init__(self, num_channels, num_filters, is_test=False): + super(AdjustLayer, self).__init__() + self.conv = nn.Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=1, + param_attr=weight_init(), + bias_attr=False) + self.bn = nn.BatchNorm( + num_channels=num_filters, + param_attr=norm_weight_init(), + bias_attr=norm_bias_init(), + momentum=0.9, + act=None, + use_global_stats=is_test) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if x.shape[3] < 20: + l = 4 + r = -4 + x = x[:, :, l:r, l:r] + return x + + +class AdjustAllLayer(fluid.dygraph.Layer): + def __init__(self, in_channels, out_channels, is_test=False): + super(AdjustAllLayer, self).__init__('') + self.num = len(out_channels) + self.sub_layer_list = [] + if self.num == 1: + self.downsample = AdjustLayer(in_channels[0], out_channels[0], is_test) + else: + for i in range(self.num): + Build_Adjust_Layer = self.add_sublayer( + 'downsample'+str(i+2), + AdjustLayer(in_channels[i], out_channels[i], is_test)) + self.sub_layer_list.append(Build_Adjust_Layer) + + def forward(self, features): + if self.num == 1: + return self.downsample(features) + else: + out = [] + for i in range(self.num): + build_adjust_layer_i = sub_layer_list[i] + out.append(build_adjust_layer_i(features[i])) + return out diff --git a/PaddleCV/tracking/ltr/models/siam/siam.py b/PaddleCV/tracking/ltr/models/siam/siam.py new file mode 100644 index 0000000000..629c361c98 --- /dev/null +++ b/PaddleCV/tracking/ltr/models/siam/siam.py @@ -0,0 +1,213 @@ +import paddle +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph +import os.path as osp +import sys + +CURRENT_DIR = osp.dirname(__file__) +sys.path.append(osp.join(CURRENT_DIR, '..', '..', '..')) + +from ltr.models.backbone.resnet_dilated import resnet50 +from ltr.models.backbone.alexnet import AlexNet +from ltr.models.siam.head import DepthwiseRPN, MaskCorr, Refine +from ltr.models.siam.neck import AdjustAllLayer + + +class Siamnet(dygraph.layers.Layer): + def __init__(self, + feature_extractor, + rpn_head, + neck=None, + mask_head=None, + refine_head=None, + scale_loss=None): + + super(Siamnet, self).__init__() + + self.feature_extractor = feature_extractor + self.rpn_head = rpn_head + self.neck = neck + self.mask_head = mask_head + self.refine_head = refine_head + self.scale_loss = scale_loss + + def forward(self, template, search): + # get feature + if len(template.shape) == 5: + template = fluid.layers.reshape(template, [-1, *list(template.shape)[-3:]]) + search = fluid.layers.reshape(search, [-1, *list(search.shape)[-3:]]) + + zf = self.feature_extractor(template) + xf = self.feature_extractor(search) + if not self.mask_head is None: + zf = zf[-1] + xf_refine = xf[:-1] + xf = xf[-1] + if isinstance(zf, list): + zf = zf[-1] + if isinstance(xf, list): + xf = xf[-1] + if not self.neck is None: + zf = self.neck(zf) + xf = self.neck(xf) + cls, loc = self.rpn_head(zf, xf) + + if not self.mask_head is None: + if not self.refine_head is None: + _, mask_corr_feature = self.mask_head(zf, xf) + mask = self.refine_head(xf_refine, mask_corr_feature) + else: + mask, mask_corr_feature = self.mask_head(zf, xf) + return {'cls': cls, + 'loc': loc, + 'mask': mask} + else: + return {'cls': cls, + 'loc': loc} + + def extract_backbone_features(self, im): + return self.feature_extractor(im) + + def template(self, template): + zf = self.feature_extractor(template) + if not self.mask_head is None: + zf = zf[-1] + if isinstance(zf, list): + zf = zf[-1] + if not self.neck is None: + zf = self.neck(zf) + self.zf = zf + + def track(self, search): + xf = self.feature_extractor(search) + if not self.mask_head is None: + self.xf = xf[:-1] + xf = xf[-1] + if isinstance(xf, list): + xf = xf[-1] + if not self.neck is None: + xf = self.neck(xf) + cls, loc = self.rpn_head(self.zf, xf) + + if not self.mask_head is None: + mask, self.mask_corr_feature = self.mask_head(self.zf, xf) + return {'cls': cls, + 'loc': loc, + 'mask': mask} + else: + return {'cls': cls, + 'loc': loc} + + def mask_refine(self, pos): + return self.refine_head(self.xf, self.mask_corr_feature, pos, test=True) + + +def SiamRPN_AlexNet(backbone_pretrained=True, + backbone_is_test=True, + is_test=False, + scale_loss=None): + backbone = AlexNet( + 'AlexNet', + is_test=backbone_is_test, + output_layers=['conv5'], + pretrained=backbone_pretrained) + + rpn_head = DepthwiseRPN(anchor_num=5, in_channels=256, out_channels=256, is_test=is_test) + + model = Siamnet( + feature_extractor=backbone, + rpn_head=rpn_head, + scale_loss=scale_loss) + return model + + +def SiamRPN_ResNet50(backbone_pretrained=True, + backbone_is_test=True, + is_test=False, + scale_loss=None): + backbone = resnet50( + 'ResNet50', + pretrained=backbone_pretrained, + output_layers=[2], + is_test=backbone_is_test) + + neck = AdjustAllLayer(in_channels=[1024], out_channels=[256], is_test=is_test) + + rpn_head = DepthwiseRPN(anchor_num=5, in_channels=256, out_channels=256, is_test=is_test) + + model = Siamnet( + feature_extractor=backbone, + neck=neck, + rpn_head=rpn_head, + scale_loss=scale_loss) + return model + + +def SiamMask_ResNet50_base(backbone_pretrained=True, + backbone_is_test=True, + is_test=False, + scale_loss=None): + backbone = resnet50( + 'ResNet50', + pretrained=backbone_pretrained, + output_layers=[0,1,2], + is_test=backbone_is_test) + + neck = AdjustAllLayer(in_channels=[1024], out_channels=[256], is_test=is_test) + + rpn_head = DepthwiseRPN(anchor_num=5, in_channels=256, out_channels=256, is_test=is_test) + + mask_head = MaskCorr(in_channels=256, hidden=256, out_channels=3969, is_test=is_test) + + model = Siamnet( + feature_extractor=backbone, + neck=neck, + rpn_head=rpn_head, + mask_head=mask_head, + scale_loss=scale_loss) + return model + + +def SiamMask_ResNet50_sharp(backbone_pretrained=False, + backbone_is_test=True, + is_test=False, + scale_loss=None): + backbone = resnet50( + 'ResNet50', + pretrained=backbone_pretrained, + output_layers=[0,1,2], + is_test=backbone_is_test) + + neck = AdjustAllLayer(in_channels=[1024], out_channels=[256], is_test=True) + + rpn_head = DepthwiseRPN(anchor_num=5, in_channels=256, out_channels=256, is_test=True) + + mask_head = MaskCorr(in_channels=256, hidden=256, out_channels=3969, is_test=is_test) + + refine_head = Refine() + + model = Siamnet( + feature_extractor=backbone, + neck=neck, + rpn_head=rpn_head, + mask_head=mask_head, + refine_head=refine_head, + scale_loss=scale_loss) + return model + + +if __name__ == '__main__': + import numpy as np + + search = np.random.uniform(-1, 1, [1, 3, 255, 255]).astype(np.float32) + template = np.random.uniform(-1, 1, [1, 3, 127, 127]).astype(np.float32) + with fluid.dygraph.guard(): + search = fluid.dygraph.to_variable(search) + template = fluid.dygraph.to_variable(template) + + model = SiamMask(False) + + res = model(template, search) + params = model.state_dict() + for v in params: + print(v) diff --git a/PaddleCV/tracking/ltr/models/siam/xcorr.py b/PaddleCV/tracking/ltr/models/siam/xcorr.py new file mode 100644 index 0000000000..a6295e8f34 --- /dev/null +++ b/PaddleCV/tracking/ltr/models/siam/xcorr.py @@ -0,0 +1,29 @@ +import paddle.fluid as fluid +import paddle.fluid.dygraph.nn as nn + +from pytracking.libs.Fconv2d import FConv2D + + +def xcorr(x, kernel): + """group conv2d to calculate cross correlation + """ + batch = kernel.shape[0] + px = fluid.layers.reshape(x, [1, -1, x.shape[2], x.shape[3]]) + pk = fluid.layers.reshape(kernel, [-1, x.shape[1], kernel.shape[2], kernel.shape[3]]) + scores_map = FConv2D(px, pk, stride=1, padding=0, dilation=1, groups=batch) + scores_map = fluid.layers.reshape( + scores_map, [batch, -1, scores_map.shape[2], scores_map.shape[3]]) + return scores_map + + +def xcorr_depthwise(x, kernel): + """depthwise cross correlation + """ + batch = kernel.shape[0] + channel = kernel.shape[1] + px = fluid.layers.reshape(x, [1, -1, x.shape[2], x.shape[3]]) + pk = fluid.layers.reshape(kernel, [-1, 1, kernel.shape[2], kernel.shape[3]]) + scores_map = FConv2D(px, pk, stride=1, padding=0, dilation=1, groups=batch*channel) + scores_map = fluid.layers.reshape( + scores_map,[batch, -1, scores_map.shape[2], scores_map.shape[3]]) + return scores_map diff --git a/PaddleCV/tracking/ltr/models/siamese/target_estimator_net.py b/PaddleCV/tracking/ltr/models/siamese/target_estimator_net.py index 25c8637b18..6ff0542394 100644 --- a/PaddleCV/tracking/ltr/models/siamese/target_estimator_net.py +++ b/PaddleCV/tracking/ltr/models/siamese/target_estimator_net.py @@ -2,7 +2,6 @@ from paddle.fluid import dygraph from paddle.fluid.dygraph import nn -from pytracking.libs.Fconv2d import Conv2D from pytracking.libs.Fconv2d import FConv2D diff --git a/PaddleCV/tracking/ltr/train_settings/siammask/siammask_res50_base.py b/PaddleCV/tracking/ltr/train_settings/siammask/siammask_res50_base.py new file mode 100644 index 0000000000..faf031bd58 --- /dev/null +++ b/PaddleCV/tracking/ltr/train_settings/siammask/siammask_res50_base.py @@ -0,0 +1,179 @@ +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph + +import ltr.actors as actors +import ltr.data.transforms as dltransforms +from ltr.data import processing, sampler, loader +from ltr.dataset import ImagenetVID, ImagenetDET, MSCOCOSeq, YoutubeVOS, Lasot, Got10k +from ltr.models.siam.siam import SiamMask_ResNet50_base +from ltr.models.loss import select_softmax_with_cross_entropy_loss, weight_l1_loss, select_mask_logistic_loss +from ltr.trainers import LTRTrainer +from ltr.trainers.learning_rate_scheduler import LinearLrWarmup +import numpy as np +import cv2 as cv +from PIL import Image, ImageEnhance + + +def run(settings): + # Most common settings are assigned in the settings struct + settings.description = 'SiamMask_base with ResNet-50 backbone.' + settings.print_interval = 100 # How often to print loss and other info + settings.batch_size = 64 # Batch size + settings.samples_per_epoch = 600000 # Number of training pairs per epoch + settings.num_workers = 4 # Number of workers for image loading + settings.search_area_factor = {'train': 1.0, 'test': 2.0} + settings.output_sz = {'train': 127, 'test': 255} + settings.scale_type = 'context' + settings.border_type = 'meanpad' + + # Settings for the image sample and label generation + settings.center_jitter_factor = {'train': 0.125, 'test': 2.0} + settings.scale_jitter_factor = {'train': 0.05, 'test': 0.18} + settings.label_params = { + 'search_size': 255, + 'output_size': 25, + 'anchor_stride': 8, + 'anchor_ratios': [0.33, 0.5, 1, 2, 3], + 'anchor_scales': [8], + 'num_pos': 16, + 'num_neg': 16, + 'num_total': 64, + 'thr_high': 0.6, + 'thr_low': 0.3 + } + settings.loss_weights = {'cls': 1., 'loc': 1.2, 'mask':36} + settings.neg = 0.2 + + # Train datasets + vos_train = YoutubeVOS() + vid_train = ImagenetVID() + coco_train = MSCOCOSeq() + det_train = ImagenetDET() + lasot_train = Lasot(split='train') + got10k_train = Got10k(split='train') + + # Validation datasets + vid_val = ImagenetVID() + + # The joint augmentation transform, that is applied to the pairs jointly + transform_joint = dltransforms.ToGrayscale(probability=0.25) + + # The augmentation transform applied to the training set (individually to each image in the pair) + transform_exemplar = dltransforms.Transpose() + transform_instance = dltransforms.Compose( + [ + dltransforms.Color(probability=1.0), + dltransforms.Blur(probability=0.18), + dltransforms.Transpose() + ]) + transform_instance_mask = dltransforms.Transpose() + + # Data processing to do on the training pairs + data_processing_train = processing.SiamProcessing( + search_area_factor=settings.search_area_factor, + output_sz=settings.output_sz, + center_jitter_factor=settings.center_jitter_factor, + scale_jitter_factor=settings.scale_jitter_factor, + scale_type=settings.scale_type, + border_type=settings.border_type, + mode='sequence', + label_params=settings.label_params, + train_transform=transform_exemplar, + test_transform=transform_instance, + test_mask_transform=transform_instance_mask, + joint_transform=transform_joint) + + # Data processing to do on the validation pairs + data_processing_val = processing.SiamProcessing( + search_area_factor=settings.search_area_factor, + output_sz=settings.output_sz, + center_jitter_factor=settings.center_jitter_factor, + scale_jitter_factor=settings.scale_jitter_factor, + scale_type=settings.scale_type, + border_type=settings.border_type, + mode='sequence', + label_params=settings.label_params, + transform=transform_exemplar, + joint_transform=transform_joint) + + nums_per_epoch = settings.samples_per_epoch // settings.batch_size + # The sampler for training + dataset_train = sampler.MaskSampler( + [vid_train, coco_train, det_train, vos_train, lasot_train, got10k_train], + [2, 1, 1, 2, 1, 1], + samples_per_epoch=nums_per_epoch * settings.batch_size, + max_gap=100, + processing=data_processing_train, + neg=settings.neg) + + # The loader for training + train_loader = loader.LTRLoader( + 'train', + dataset_train, + training=True, + batch_size=settings.batch_size, + num_workers=settings.num_workers, + stack_dim=0) + + # The sampler for validation + dataset_val = sampler.MaskSampler( + [vid_val], + [1, ], + samples_per_epoch=100 * settings.batch_size, + max_gap=100, + processing=data_processing_val) + + # The loader for validation + val_loader = loader.LTRLoader( + 'val', + dataset_val, + training=False, + batch_size=settings.batch_size, + num_workers=settings.num_workers, + stack_dim=0) + + # creat network, set objective, creat optimizer, learning rate scheduler, trainer + with dygraph.guard(): + # Create network + + def scale_loss(loss): + total_loss = 0 + for k in settings.loss_weights: + total_loss += loss[k] * settings.loss_weights[k] + return total_loss + + net = SiamMask_ResNet50_base(scale_loss=scale_loss) + + # Define objective + objective = { + 'cls': select_softmax_with_cross_entropy_loss, + 'loc': weight_l1_loss, + 'mask': select_mask_logistic_loss + } + + # Create actor, which wraps network and objective + actor = actors.SiamActor(net=net, objective=objective) + + # Set to training mode + actor.train() + + # Define optimizer and learning rate + decayed_lr = fluid.layers.exponential_decay( + learning_rate=0.005, + decay_steps=nums_per_epoch, + decay_rate=0.9642, + staircase=True) + lr_scheduler = LinearLrWarmup( + learning_rate=decayed_lr, + warmup_steps=5*nums_per_epoch, + start_lr=0.001, + end_lr=0.005) + + optimizer = fluid.optimizer.Adam( + parameter_list=net.rpn_head.parameters() + + net.neck.parameters() + + net.mask_head.parameters(), + learning_rate=lr_scheduler) + + trainer = LTRTrainer(actor, [train_loader, val_loader], optimizer, settings, lr_scheduler) + trainer.train(20, load_latest=False, fail_safe=False) diff --git a/PaddleCV/tracking/ltr/train_settings/siammask/siammask_res50_sharp.py b/PaddleCV/tracking/ltr/train_settings/siammask/siammask_res50_sharp.py new file mode 100644 index 0000000000..dde3c0b354 --- /dev/null +++ b/PaddleCV/tracking/ltr/train_settings/siammask/siammask_res50_sharp.py @@ -0,0 +1,189 @@ +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph + +import ltr.actors as actors +import ltr.data.transforms as dltransforms +from ltr.data import processing, sampler, loader +from ltr.dataset import ImagenetVID, ImagenetDET, MSCOCOSeq, YoutubeVOS +from ltr.models.siam.siam import SiamMask_ResNet50_sharp +from ltr.models.loss import select_softmax_with_cross_entropy_loss, weight_l1_loss, select_mask_logistic_loss +from ltr.trainers import LTRTrainer +from ltr.trainers.learning_rate_scheduler import LinearLrWarmup +import numpy as np +import cv2 as cv +from PIL import Image, ImageEnhance + + +def run(settings): + # Most common settings are assigned in the settings struct + settings.base_model = '' + settings.description = 'SiamMask_sharp with ResNet-50 backbone.' + settings.print_interval = 100 # How often to print loss and other info + settings.batch_size = 64 # Batch size + settings.samples_per_epoch = 600000 # Number of training pairs per epoch + settings.num_workers = 8 # Number of workers for image loading + settings.search_area_factor = {'train': 1.0, 'test': 143./127.} + settings.output_sz = {'train': 127, 'test': 143} + settings.scale_type = 'context' + settings.border_type = 'meanpad' + + # Settings for the image sample and label generation + settings.center_jitter_factor = {'train': 0.2, 'test': 0.4} + settings.scale_jitter_factor = {'train': 0.05, 'test': 0.18} + settings.label_params = { + 'search_size': 143, + 'output_size': 3, + 'anchor_stride': 8, + 'anchor_ratios': [0.33, 0.5, 1, 2, 3], + 'anchor_scales': [8], + 'num_pos': 16, + 'num_neg': 16, + 'num_total': 64, + 'thr_high': 0.6, + 'thr_low': 0.3 + } + settings.loss_weights = {'cls': 0., 'loc': 0., 'mask':1} + settings.neg = 0 + + # Train datasets + vos_train = YoutubeVOS() + coco_train = MSCOCOSeq() + + # Validation datasets + vos_val = vos_train + + # The joint augmentation transform, that is applied to the pairs jointly + transform_joint = dltransforms.ToGrayscale(probability=0.25) + + # The augmentation transform applied to the training set (individually to each image in the pair) + transform_exemplar = dltransforms.Transpose() + transform_instance = dltransforms.Compose( + [ + dltransforms.Color(probability=1.0), + dltransforms.Blur(probability=0.18), + dltransforms.Transpose() + ]) + transform_instance_mask = dltransforms.Transpose() + + # Data processing to do on the training pairs + data_processing_train = processing.SiamProcessing( + search_area_factor=settings.search_area_factor, + output_sz=settings.output_sz, + center_jitter_factor=settings.center_jitter_factor, + scale_jitter_factor=settings.scale_jitter_factor, + scale_type=settings.scale_type, + border_type=settings.border_type, + mode='sequence', + label_params=settings.label_params, + train_transform=transform_exemplar, + test_transform=transform_instance, + test_mask_transform=transform_instance_mask, + joint_transform=transform_joint) + + # Data processing to do on the validation pairs + data_processing_val = processing.SiamProcessing( + search_area_factor=settings.search_area_factor, + output_sz=settings.output_sz, + center_jitter_factor=settings.center_jitter_factor, + scale_jitter_factor=settings.scale_jitter_factor, + scale_type=settings.scale_type, + border_type=settings.border_type, + mode='sequence', + label_params=settings.label_params, + transform=transform_exemplar, + joint_transform=transform_joint) + + nums_per_epoch = settings.samples_per_epoch // settings.batch_size + # The sampler for training + dataset_train = sampler.MaskSampler( + [coco_train, vos_train], + [1 ,1], + samples_per_epoch=nums_per_epoch * settings.batch_size, + max_gap=100, + processing=data_processing_train, + neg=settings.neg) + + # The loader for training + train_loader = loader.LTRLoader( + 'train', + dataset_train, + training=True, + batch_size=settings.batch_size, + num_workers=settings.num_workers, + stack_dim=0) + + # The sampler for validation + dataset_val = sampler.MaskSampler( + [vos_val], + [1, ], + samples_per_epoch=100 * settings.batch_size, + max_gap=100, + processing=data_processing_val) + + # The loader for validation + val_loader = loader.LTRLoader( + 'val', + dataset_val, + training=False, + batch_size=settings.batch_size, + num_workers=settings.num_workers, + stack_dim=0) + + # creat network, set objective, creat optimizer, learning rate scheduler, trainer + with dygraph.guard(): + # Create network + + def scale_loss(loss): + total_loss = 0 + for k in settings.loss_weights: + total_loss += loss[k] * settings.loss_weights[k] + return total_loss + + net = SiamMask_ResNet50_sharp(scale_loss=scale_loss) + + # Load parameters from the best_base_model + if settings.base_model == '': + raise Exception( + 'The base_model path is not setup. Check settings.base_model in "ltr/train_settings/siammask/siammask_res50_sharp.py".' + ) + para_dict, _ = fluid.load_dygraph(settings.base_model) + model_dict = net.state_dict() + + for key in model_dict.keys(): + if key in para_dict.keys(): + model_dict[key] = para_dict[key] + + net.set_dict(model_dict) + + # Define objective + objective = { + 'cls': select_softmax_with_cross_entropy_loss, + 'loc': weight_l1_loss, + 'mask': select_mask_logistic_loss + } + + # Create actor, which wraps network and objective + actor = actors.SiamActor(net=net, objective=objective) + + # Set to training mode + actor.train() + + # Define optimizer and learning rate + decayed_lr = fluid.layers.exponential_decay( + learning_rate=0.0005, + decay_steps=nums_per_epoch, + decay_rate=0.9, + staircase=True) + lr_scheduler = LinearLrWarmup( + learning_rate=decayed_lr, + warmup_steps=5*nums_per_epoch, + start_lr=0.0001, + end_lr=0.0005) + + optimizer = fluid.optimizer.Adam( + parameter_list=net.mask_head.parameters() + + net.refine_head.parameters(), + learning_rate=lr_scheduler) + + trainer = LTRTrainer(actor, [train_loader, val_loader], optimizer, settings, lr_scheduler) + trainer.train(20, load_latest=False, fail_safe=False) diff --git a/PaddleCV/tracking/ltr/train_settings/siamrpn/siamrpn_alexnet.py b/PaddleCV/tracking/ltr/train_settings/siamrpn/siamrpn_alexnet.py new file mode 100644 index 0000000000..13b877973b --- /dev/null +++ b/PaddleCV/tracking/ltr/train_settings/siamrpn/siamrpn_alexnet.py @@ -0,0 +1,172 @@ +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph + +import ltr.actors as actors +import ltr.data.transforms as dltransforms +from ltr.data import processing, sampler, loader +from ltr.dataset import ImagenetVID, ImagenetDET, MSCOCOSeq, YoutubeVOS, Lasot, Got10k +from ltr.models.siam.siam import SiamRPN_AlexNet +from ltr.models.loss import select_softmax_with_cross_entropy_loss, weight_l1_loss +from ltr.trainers import LTRTrainer +from ltr.trainers.learning_rate_scheduler import LinearLrWarmup +import numpy as np +import cv2 as cv +from PIL import Image, ImageEnhance + + +def run(settings): + # Most common settings are assigned in the settings struct + settings.description = 'SiamRPN with AlexNet backbone.' + settings.print_interval = 100 # How often to print loss and other info + settings.batch_size = 512 # Batch size + settings.samples_per_epoch = 600000 # Number of training pairs per epoch + settings.num_workers = 8 # Number of workers for image loading + settings.search_area_factor = {'train': 1.0, 'test': 2.0} + settings.output_sz = {'train': 127, 'test': 255} + settings.scale_type = 'context' + settings.border_type = 'meanpad' + + # Settings for the image sample and label generation + settings.center_jitter_factor = {'train': 0.125, 'test': 2.0} + settings.scale_jitter_factor = {'train': 0.05, 'test': 0.18} + settings.label_params = { + 'search_size': 255, + 'output_size': 17, + 'anchor_stride': 8, + 'anchor_ratios': [0.33, 0.5, 1, 2, 3], + 'anchor_scales': [8], + 'num_pos': 16, + 'num_neg': 16, + 'num_total': 64, + 'thr_high': 0.6, + 'thr_low': 0.3 + } + settings.loss_weights = {'cls': 1., 'loc': 1.2} + settings.neg = 0.2 + + # Train datasets + vos_train = YoutubeVOS() + vid_train = ImagenetVID() + coco_train = MSCOCOSeq() + det_train = ImagenetDET() + #lasot_train = Lasot(split='train') + #got10k_train = Got10k(split='train') + + # Validation datasets + vid_val = ImagenetVID() + + # The joint augmentation transform, that is applied to the pairs jointly + transform_joint = dltransforms.ToGrayscale(probability=0.25) + + # The augmentation transform applied to the training set (individually to each image in the pair) + transform_exemplar = dltransforms.Transpose() + transform_instance = dltransforms.Compose( + [ + dltransforms.Color(probability=1.0), + dltransforms.Blur(probability=0.18), + dltransforms.Transpose() + ]) + transform_instance_mask = dltransforms.Transpose() + + # Data processing to do on the training pairs + data_processing_train = processing.SiamProcessing( + search_area_factor=settings.search_area_factor, + output_sz=settings.output_sz, + center_jitter_factor=settings.center_jitter_factor, + scale_jitter_factor=settings.scale_jitter_factor, + scale_type=settings.scale_type, + border_type=settings.border_type, + mode='sequence', + label_params=settings.label_params, + train_transform=transform_exemplar, + test_transform=transform_instance, + test_mask_transform=transform_instance_mask, + joint_transform=transform_joint) + + # Data processing to do on the validation pairs + data_processing_val = processing.SiamProcessing( + search_area_factor=settings.search_area_factor, + output_sz=settings.output_sz, + center_jitter_factor=settings.center_jitter_factor, + scale_jitter_factor=settings.scale_jitter_factor, + scale_type=settings.scale_type, + border_type=settings.border_type, + mode='sequence', + label_params=settings.label_params, + transform=transform_exemplar, + joint_transform=transform_joint) + + nums_per_epoch = settings.samples_per_epoch // settings.batch_size + # The sampler for training + dataset_train = sampler.MaskSampler( + [vid_train, coco_train, det_train, vos_train], + [2, 1, 1, 2], + samples_per_epoch=nums_per_epoch * settings.batch_size, + max_gap=100, + processing=data_processing_train, + neg=settings.neg) + + # The loader for training + train_loader = loader.LTRLoader( + 'train', + dataset_train, + training=True, + batch_size=settings.batch_size, + num_workers=settings.num_workers, + stack_dim=0) + + # The sampler for validation + dataset_val = sampler.MaskSampler( + [vid_val], + [1, ], + samples_per_epoch=100 * settings.batch_size, + max_gap=100, + processing=data_processing_val) + + # The loader for validation + val_loader = loader.LTRLoader( + 'val', + dataset_val, + training=False, + batch_size=settings.batch_size, + num_workers=settings.num_workers, + stack_dim=0) + + # creat network, set objective, creat optimizer, learning rate scheduler, trainer + with dygraph.guard(): + # Create network + + def scale_loss(loss): + total_loss = 0 + for k in settings.loss_weights: + total_loss += loss[k] * settings.loss_weights[k] + return total_loss + + net = SiamRPN_AlexNet(scale_loss=scale_loss) + + # Define objective + objective = { + 'cls': select_softmax_with_cross_entropy_loss, + 'loc': weight_l1_loss, + } + + # Create actor, which wraps network and objective + actor = actors.SiamActor(net=net, objective=objective) + + # Define optimizer and learning rate + decayed_lr = fluid.layers.exponential_decay( + learning_rate=0.01, + decay_steps=nums_per_epoch, + decay_rate=0.9407, + staircase=True) + lr_scheduler = LinearLrWarmup( + learning_rate=decayed_lr, + warmup_steps=5*nums_per_epoch, + start_lr=0.005, + end_lr=0.01) + optimizer = fluid.optimizer.Adam( + parameter_list=net.rpn_head.parameters(), + learning_rate=lr_scheduler) + + trainer = LTRTrainer(actor, [train_loader, val_loader], optimizer, settings, lr_scheduler) + trainer.train(50, load_latest=False, fail_safe=False) diff --git a/PaddleCV/tracking/ltr/train_settings/siamrpn/siamrpn_res50.py b/PaddleCV/tracking/ltr/train_settings/siamrpn/siamrpn_res50.py new file mode 100644 index 0000000000..f27ecfb87e --- /dev/null +++ b/PaddleCV/tracking/ltr/train_settings/siamrpn/siamrpn_res50.py @@ -0,0 +1,159 @@ +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph + +import ltr.actors as actors +import ltr.data.transforms as dltransforms +from ltr.data import processing, sampler, loader +from ltr.dataset import ImagenetVID, ImagenetDET, MSCOCOSeq, YoutubeVOS +from ltr.models.siam.siam import SiamRPN_ResNet50 +from ltr.models.loss import select_softmax_with_cross_entropy_loss, weight_l1_loss +from ltr.trainers import LTRTrainer +import numpy as np +import cv2 as cv +from PIL import Image, ImageEnhance + + +def run(settings): + # Most common settings are assigned in the settings struct + settings.description = 'SiamRPN with ResNet-50 backbone.' + settings.print_interval = 100 # How often to print loss and other info + settings.batch_size = 32 # Batch size + settings.num_workers = 4 # Number of workers for image loading + settings.search_area_factor = {'train': 1.0, 'test': 255./127.} + settings.output_sz = {'train': 127, 'test': 255} + settings.scale_type = 'context' + settings.border_type = 'meanpad' + + # Settings for the image sample and label generation + settings.center_jitter_factor = {'train': 0.1, 'test': 1.5} + settings.scale_jitter_factor = {'train': 0.05, 'test': 0.18} + settings.label_params = { + 'search_size': 255, + 'output_size': 25, + 'anchor_stride': 8, + 'anchor_ratios': [0.33, 0.5, 1, 2, 3], + 'anchor_scales': [8], + 'num_pos': 16, + 'num_neg': 16, + 'num_total': 64, + 'thr_high': 0.6, + 'thr_low': 0.3 + } + settings.loss_weights = {'cls': 1., 'loc': 1.2} + settings.neg = 0.2 + + # Train datasets + vos_train = YoutubeVOS() + vid_train = ImagenetVID() + coco_train = MSCOCOSeq() + det_train = ImagenetDET() + + # Validation datasets + #vid_val = ImagenetVID() + vid_val = coco_train + + # The joint augmentation transform, that is applied to the pairs jointly + transform_joint = dltransforms.ToGrayscale(probability=0.25) + + # The augmentation transform applied to the training set (individually to each image in the pair) + transform_exemplar = dltransforms.Transpose() + transform_instance = dltransforms.Transpose() + + # Data processing to do on the training pairs + data_processing_train = processing.SiamProcessing( + search_area_factor=settings.search_area_factor, + output_sz=settings.output_sz, + center_jitter_factor=settings.center_jitter_factor, + scale_jitter_factor=settings.scale_jitter_factor, + scale_type=settings.scale_type, + border_type=settings.border_type, + mode='sequence', + label_params=settings.label_params, + train_transform=transform_exemplar, + test_transform=transform_instance, + joint_transform=transform_joint) + + # Data processing to do on the validation pairs + data_processing_val = processing.SiamProcessing( + search_area_factor=settings.search_area_factor, + output_sz=settings.output_sz, + center_jitter_factor=settings.center_jitter_factor, + scale_jitter_factor=settings.scale_jitter_factor, + scale_type=settings.scale_type, + border_type=settings.border_type, + mode='sequence', + label_params=settings.label_params, + transform=transform_exemplar, + joint_transform=transform_joint) + + # The sampler for training + dataset_train = sampler.MaskSampler( + [vid_train, coco_train, det_train, vos_train], + [2, 1 ,1, 2], + samples_per_epoch=5000 * settings.batch_size, + max_gap=100, + processing=data_processing_train, + neg=settings.neg) + + # The loader for training + train_loader = loader.LTRLoader( + 'train', + dataset_train, + training=True, + batch_size=settings.batch_size, + num_workers=settings.num_workers, + stack_dim=0) + + # The sampler for validation + dataset_val = sampler.MaskSampler( + [vid_val], + [1, ], + samples_per_epoch=100 * settings.batch_size, + max_gap=100, + processing=data_processing_val) + + # The loader for validation + val_loader = loader.LTRLoader( + 'val', + dataset_val, + training=False, + batch_size=settings.batch_size, + num_workers=settings.num_workers, + stack_dim=0) + + # creat network, set objective, creat optimizer, learning rate scheduler, trainer + with dygraph.guard(): + # Create network + + def scale_loss(loss): + total_loss = 0 + for k in settings.loss_weights: + total_loss += loss[k] * settings.loss_weights[k] + return total_loss + + net = SiamRPN_ResNet50(scale_loss=scale_loss) + + # Define objective + objective = { + 'cls': select_softmax_with_cross_entropy_loss, + 'loc': weight_l1_loss, + } + + # Create actor, which wraps network and objective + actor = actors.SiamActor(net=net, objective=objective) + + # Set to training mode + actor.train() + + # Define optimizer and learning rate + lr_scheduler = fluid.layers.exponential_decay( + learning_rate=0.005, + decay_steps=5000, + decay_rate=0.9659, + staircase=True) + optimizer = fluid.optimizer.Adam( + parameter_list=net.rpn_head.parameters() + net.neck.parameters(), + learning_rate=lr_scheduler) + + trainer = LTRTrainer(actor, [train_loader, val_loader], optimizer, settings, lr_scheduler) + trainer.train(50, load_latest=False, fail_safe=False) diff --git a/PaddleCV/tracking/ltr/trainers/base_trainer.py b/PaddleCV/tracking/ltr/trainers/base_trainer.py index 99a206d96a..5a4be1223b 100644 --- a/PaddleCV/tracking/ltr/trainers/base_trainer.py +++ b/PaddleCV/tracking/ltr/trainers/base_trainer.py @@ -123,7 +123,7 @@ def load_checkpoint(self, checkpoint=None): self.settings.project_path, net_type))) if checkpoint_list: - checkpoint_path = checkpoint_list[-1].split('.')[0] + checkpoint_path = os.path.splitext(checkpoint_list[-1])[0] else: print('No matching checkpoint file found') return @@ -144,13 +144,13 @@ def load_checkpoint(self, checkpoint=None): self.optimizer.set_dict(opt_params) # paddle load state - state_path = '{}/{}/custom_state.pickle'.format( - self._checkpoint_dir, self.settings.project_path) current_state = pickle.load( - open(os.path.join(state_path, 'custom_state.pickle'), 'rb')) + open(os.path.join(checkpoint_path, '_custom_state.pickle'), 'rb')) print("\nload checkpoint done !! Current states are as follows:") - for key, value in enumerate(current_state): + for key, value in current_state.items(): print(key, value) + self.epoch = current_state['epoch'] + self.stats = current_state['stats'] return True diff --git a/PaddleCV/tracking/ltr/trainers/learning_rate_scheduler.py b/PaddleCV/tracking/ltr/trainers/learning_rate_scheduler.py new file mode 100644 index 0000000000..eff90564aa --- /dev/null +++ b/PaddleCV/tracking/ltr/trainers/learning_rate_scheduler.py @@ -0,0 +1,106 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import math +from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay + + +class LinearLrWarmup(LearningRateDecay): + """ + This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling. + For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks `_ + + When global_step < warmup_steps, learning rate is updated as: + + .. code-block:: text + + linear_step = end_lr - start_lr + lr = start_lr + linear_step * (global_step / warmup_steps) + + where start_lr is the initial learning rate, and end_lr is the final learning rate; + + When global_step >= warmup_steps, learning rate is updated as: + + .. code-block:: text + + lr = learning_rate + + where lr is the learning_rate after warm-up. + + Args: + learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32. + warmup_steps (int): Steps for warm up. + start_lr (float): Initial learning rate of warm up. + end_lr (float): Final learning rate of warm up. + begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0. + step(int, optional): The step size used to calculate the new global_step in the description above. + The default value is 1. + dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as + 'float32', 'float64'. The default value is 'float32'. + + Returns: + Variable: Warm-up learning rate with the same data type as learning_rate. + + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + + learning_rate = 0.1 + warmup_steps = 50 + start_lr = 1. / 3. + end_lr = 0.1 + + with fluid.dygraph.guard(): + lr_decay = fluid.dygraph.LinearLrWarmup( learning_rate, warmup_steps, start_lr, end_lr) + + + """ + + def __init__(self, + learning_rate, + warmup_steps, + start_lr, + end_lr, + begin=1, + step=1, + dtype='float32'): + super(LinearLrWarmup, self).__init__(begin, step, dtype) + type_check = isinstance(learning_rate, float) or isinstance( + learning_rate, int) or isinstance(learning_rate, LearningRateDecay) + if not type_check: + raise TypeError( + "the type of learning_rate should be [int, float or LearningRateDecay], the current type is {}". + format(learning_rate)) + self.learning_rate = learning_rate + self.warmup_steps = warmup_steps + assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format( + end_lr, start_lr) + self.lr_ratio_before_warmup = ( + float(end_lr) - float(start_lr)) / float(warmup_steps) + self.start_lr = start_lr + + def step(self): + base_lr = self.learning_rate + if isinstance(self.learning_rate, LearningRateDecay): + base_lr = base_lr() + + if self.step_num < self.warmup_steps: + return self.start_lr + self.lr_ratio_before_warmup * self.step_num + else: + return base_lr diff --git a/PaddleCV/tracking/ltr/trainers/ltr_trainer.py b/PaddleCV/tracking/ltr/trainers/ltr_trainer.py index d928f7e9b1..52b62b1e2e 100644 --- a/PaddleCV/tracking/ltr/trainers/ltr_trainer.py +++ b/PaddleCV/tracking/ltr/trainers/ltr_trainer.py @@ -162,6 +162,8 @@ def _print_stats(self, i, loader, batch_size): print_str += '%s: %.5f , ' % (name, val.avg) print_str += '%s: %.5f , ' % ("time", batch_size / batch_fps * self.settings.print_interval) + if loader.training: + print_str += '%s: %f , ' % ("lr", self.optimizer.current_step_lr()) print(print_str[:-5]) def _stats_new_epoch(self): diff --git a/PaddleCV/tracking/pytracking/eval_benchmark.py b/PaddleCV/tracking/pytracking/eval_benchmark.py index cf9536eb29..28e84756a2 100644 --- a/PaddleCV/tracking/pytracking/eval_benchmark.py +++ b/PaddleCV/tracking/pytracking/eval_benchmark.py @@ -3,6 +3,7 @@ from __future__ import print_function from __future__ import unicode_literals +import paddle.fluid import argparse import importlib import os @@ -172,8 +173,12 @@ def run_one_sequence(video, params, tracker=None): if isinstance(res, int): outputs.append('{}'.format(res)) else: - outputs.append('{},{},{},{}'.format(res[0], res[1], res[ - 2], res[3])) + if len(res) is 8: + outputs.append('{},{},{},{},{},{},{},{}'.format( + res[0], res[1], res[2], res[3], res[4], res[5], res[6], res[7])) + else: + outputs.append('{},{},{},{}'.format( + res[0], res[1], res[2], res[3])) f.write('\n'.join(outputs)) else: os.makedirs(save_dir, exist_ok=True) diff --git a/PaddleCV/tracking/pytracking/features/augmentation.py b/PaddleCV/tracking/pytracking/features/augmentation.py index 41272bb183..189cd55aa9 100644 --- a/PaddleCV/tracking/pytracking/features/augmentation.py +++ b/PaddleCV/tracking/pytracking/features/augmentation.py @@ -1,205 +1,205 @@ -import numpy as np -import math - -from paddle.fluid import layers - -import cv2 as cv - -from pytracking.features.preprocessing import numpy_to_paddle, paddle_to_numpy -from pytracking.libs.Fconv2d import FConv2D -from pytracking.libs.paddle_utils import PTensor, _padding, n2p - - -class Transform: - """Base data augmentation transform class.""" - - def __init__(self, output_sz=None, shift=None): - self.output_sz = output_sz - self.shift = (0, 0) if shift is None else shift - - def __call__(self, image): - raise NotImplementedError - - def crop_to_output(self, image, shift=None): - if isinstance(image, PTensor): - imsz = image.shape[2:] - else: - imsz = image.shape[:2] - - if self.output_sz is None: - pad_h = 0 - pad_w = 0 - else: - pad_h = (self.output_sz[0] - imsz[0]) / 2 - pad_w = (self.output_sz[1] - imsz[1]) / 2 - if shift is None: - shift = self.shift - pad_left = math.floor(pad_w) + shift[1] - pad_right = math.ceil(pad_w) - shift[1] - pad_top = math.floor(pad_h) + shift[0] - pad_bottom = math.ceil(pad_h) - shift[0] - - if isinstance(image, PTensor): - return _padding( - image, (pad_left, pad_right, pad_top, pad_bottom), - mode='replicate') - else: - return _padding( - image, (0, 0, pad_left, pad_right, pad_top, pad_bottom), - mode='replicate') - - -class Identity(Transform): - """Identity transformation.""" - - def __call__(self, image): - return self.crop_to_output(image) - - -class FlipHorizontal(Transform): - """Flip along horizontal axis.""" - - def __call__(self, image): - if isinstance(image, PTensor): - return self.crop_to_output(layers.reverse(image, 3)) - else: - return self.crop_to_output(np.fliplr(image)) - - -class FlipVertical(Transform): - """Flip along vertical axis.""" - - def __call__(self, image: PTensor): - if isinstance(image, PTensor): - return self.crop_to_output(layers.reverse(image, 2)) - else: - return self.crop_to_output(np.flipud(image)) - - -class Translation(Transform): - """Translate.""" - - def __init__(self, translation, output_sz=None, shift=None): - super().__init__(output_sz, shift) - self.shift = (self.shift[0] + translation[0], - self.shift[1] + translation[1]) - - def __call__(self, image): - return self.crop_to_output(image) - - -class Scale(Transform): - """Scale.""" - - def __init__(self, scale_factor, output_sz=None, shift=None): - super().__init__(output_sz, shift) - self.scale_factor = scale_factor - - def __call__(self, image): - # Calculate new size. Ensure that it is even so that crop/pad becomes easier - h_orig, w_orig = image.shape[2:] - - if h_orig != w_orig: - raise NotImplementedError - - h_new = round(h_orig / self.scale_factor) - h_new += (h_new - h_orig) % 2 - w_new = round(w_orig / self.scale_factor) - w_new += (w_new - w_orig) % 2 - - if isinstance(image, PTensor): - image_resized = layers.resize_bilinear( - image, [h_new, w_new], align_corners=False) - else: - image_resized = cv.resize( - image, (w_new, h_new), interpolation=cv.INTER_LINEAR) - return self.crop_to_output(image_resized) - - -class Affine(Transform): - """Affine transformation.""" - - def __init__(self, transform_matrix, output_sz=None, shift=None): - super().__init__(output_sz, shift) - self.transform_matrix = transform_matrix - - def __call__(self, image, crop=True): - if isinstance(image, PTensor): - return self.crop_to_output( - numpy_to_paddle(self( - paddle_to_numpy(image), crop=False))) - else: - warp = cv.warpAffine( - image, - self.transform_matrix, - image.shape[1::-1], - borderMode=cv.BORDER_REPLICATE) - if crop: - return self.crop_to_output(warp) - else: - return warp - - -class Rotate(Transform): - """Rotate with given angle.""" - - def __init__(self, angle, output_sz=None, shift=None): - super().__init__(output_sz, shift) - self.angle = math.pi * angle / 180 - - def __call__(self, image, crop=True): - if isinstance(image, PTensor): - return self.crop_to_output( - numpy_to_paddle(self( - paddle_to_numpy(image), crop=False))) - else: - c = (np.expand_dims(np.array(image.shape[:2]), 1) - 1) / 2 - R = np.array([[math.cos(self.angle), math.sin(self.angle)], - [-math.sin(self.angle), math.cos(self.angle)]]) - H = np.concatenate([R, c - R @c], 1) - warp = cv.warpAffine( - image, H, image.shape[1::-1], borderMode=cv.BORDER_REPLICATE) - if crop: - return self.crop_to_output(warp) - else: - return warp - - -class Blur(Transform): - """Blur with given sigma (can be axis dependent).""" - - def __init__(self, sigma, output_sz=None, shift=None): - super().__init__(output_sz, shift) - if isinstance(sigma, (float, int)): - sigma = (sigma, sigma) - self.sigma = sigma - self.filter_size = [math.ceil(2 * s) for s in self.sigma] - - x_coord = [ - np.arange( - -sz, sz + 1, 1, dtype='float32') for sz in self.filter_size - ] - self.filter_np = [ - np.exp(0 - (x * x) / (2 * s**2)) - for x, s in zip(x_coord, self.sigma) - ] - self.filter_np[0] = np.reshape( - self.filter_np[0], [1, 1, -1, 1]) / np.sum(self.filter_np[0]) - self.filter_np[1] = np.reshape( - self.filter_np[1], [1, 1, 1, -1]) / np.sum(self.filter_np[1]) - - def __call__(self, image): - if isinstance(image, PTensor): - sz = image.shape[2:] - filter = [n2p(f) for f in self.filter_np] - im1 = FConv2D( - layers.reshape(image, [-1, 1, sz[0], sz[1]]), - filter[0], - padding=(self.filter_size[0], 0)) - return self.crop_to_output( - layers.reshape( - FConv2D( - im1, filter[1], padding=(0, self.filter_size[1])), - [1, -1, sz[0], sz[1]])) - else: - return paddle_to_numpy(self(numpy_to_paddle(image))) +import numpy as np +import math + +from paddle.fluid import layers + +import cv2 as cv + +from pytracking.features.preprocessing import numpy_to_paddle, paddle_to_numpy +from pytracking.libs.Fconv2d import FConv2D +from pytracking.libs.paddle_utils import PTensor, _padding, n2p + + +class Transform: + """Base data augmentation transform class.""" + + def __init__(self, output_sz=None, shift=None): + self.output_sz = output_sz + self.shift = (0, 0) if shift is None else shift + + def __call__(self, image): + raise NotImplementedError + + def crop_to_output(self, image, shift=None): + if isinstance(image, PTensor): + imsz = image.shape[2:] + else: + imsz = image.shape[:2] + + if self.output_sz is None: + pad_h = 0 + pad_w = 0 + else: + pad_h = (self.output_sz[0] - imsz[0]) / 2 + pad_w = (self.output_sz[1] - imsz[1]) / 2 + if shift is None: + shift = self.shift + pad_left = math.floor(pad_w) + shift[1] + pad_right = math.ceil(pad_w) - shift[1] + pad_top = math.floor(pad_h) + shift[0] + pad_bottom = math.ceil(pad_h) - shift[0] + + if isinstance(image, PTensor): + return _padding( + image, (pad_left, pad_right, pad_top, pad_bottom), + mode='replicate') + else: + return _padding( + image, (0, 0, pad_left, pad_right, pad_top, pad_bottom), + mode='replicate') + + +class Identity(Transform): + """Identity transformation.""" + + def __call__(self, image): + return self.crop_to_output(image) + + +class FlipHorizontal(Transform): + """Flip along horizontal axis.""" + + def __call__(self, image): + if isinstance(image, PTensor): + return self.crop_to_output(layers.reverse(image, 3)) + else: + return self.crop_to_output(np.fliplr(image)) + + +class FlipVertical(Transform): + """Flip along vertical axis.""" + + def __call__(self, image: PTensor): + if isinstance(image, PTensor): + return self.crop_to_output(layers.reverse(image, 2)) + else: + return self.crop_to_output(np.flipud(image)) + + +class Translation(Transform): + """Translate.""" + + def __init__(self, translation, output_sz=None, shift=None): + super().__init__(output_sz, shift) + self.shift = (self.shift[0] + translation[0], + self.shift[1] + translation[1]) + + def __call__(self, image): + return self.crop_to_output(image) + + +class Scale(Transform): + """Scale.""" + + def __init__(self, scale_factor, output_sz=None, shift=None): + super().__init__(output_sz, shift) + self.scale_factor = scale_factor + + def __call__(self, image): + # Calculate new size. Ensure that it is even so that crop/pad becomes easier + h_orig, w_orig = image.shape[2:] + + if h_orig != w_orig: + raise NotImplementedError + + h_new = round(h_orig / self.scale_factor) + h_new += (h_new - h_orig) % 2 + w_new = round(w_orig / self.scale_factor) + w_new += (w_new - w_orig) % 2 + + if isinstance(image, PTensor): + image_resized = layers.resize_bilinear( + image, [h_new, w_new], align_corners=False) + else: + image_resized = cv.resize( + image, (w_new, h_new), interpolation=cv.INTER_LINEAR) + return self.crop_to_output(image_resized) + + +class Affine(Transform): + """Affine transformation.""" + + def __init__(self, transform_matrix, output_sz=None, shift=None): + super().__init__(output_sz, shift) + self.transform_matrix = transform_matrix + + def __call__(self, image, crop=True): + if isinstance(image, PTensor): + return self.crop_to_output( + numpy_to_paddle(self( + paddle_to_numpy(image), crop=False))) + else: + warp = cv.warpAffine( + image, + self.transform_matrix, + image.shape[1::-1], + borderMode=cv.BORDER_REPLICATE) + if crop: + return self.crop_to_output(warp) + else: + return warp + + +class Rotate(Transform): + """Rotate with given angle.""" + + def __init__(self, angle, output_sz=None, shift=None): + super().__init__(output_sz, shift) + self.angle = math.pi * angle / 180 + + def __call__(self, image, crop=True): + if isinstance(image, PTensor): + return self.crop_to_output( + numpy_to_paddle(self( + paddle_to_numpy(image), crop=False))) + else: + c = (np.expand_dims(np.array(image.shape[:2]), 1) - 1) / 2 + R = np.array([[math.cos(self.angle), math.sin(self.angle)], + [-math.sin(self.angle), math.cos(self.angle)]]) + H = np.concatenate([R, c - R @c], 1) + warp = cv.warpAffine( + image, H, image.shape[1::-1], borderMode=cv.BORDER_REPLICATE) + if crop: + return self.crop_to_output(warp) + else: + return warp + + +class Blur(Transform): + """Blur with given sigma (can be axis dependent).""" + + def __init__(self, sigma, output_sz=None, shift=None): + super().__init__(output_sz, shift) + if isinstance(sigma, (float, int)): + sigma = (sigma, sigma) + self.sigma = sigma + self.filter_size = [math.ceil(2 * s) for s in self.sigma] + + x_coord = [ + np.arange( + -sz, sz + 1, 1, dtype='float32') for sz in self.filter_size + ] + self.filter_np = [ + np.exp(0 - (x * x) / (2 * s**2)) + for x, s in zip(x_coord, self.sigma) + ] + self.filter_np[0] = np.reshape( + self.filter_np[0], [1, 1, -1, 1]) / np.sum(self.filter_np[0]) + self.filter_np[1] = np.reshape( + self.filter_np[1], [1, 1, 1, -1]) / np.sum(self.filter_np[1]) + + def __call__(self, image): + if isinstance(image, PTensor): + sz = image.shape[2:] + filter = [n2p(f) for f in self.filter_np] + im1 = FConv2D( + layers.reshape(image, [-1, 1, sz[0], sz[1]]), + filter[0], + padding=(self.filter_size[0], 0)) + return self.crop_to_output( + layers.reshape( + FConv2D( + im1, filter[1], padding=(0, self.filter_size[1])), + [1, -1, sz[0], sz[1]])) + else: + return paddle_to_numpy(self(numpy_to_paddle(image))) diff --git a/PaddleCV/tracking/pytracking/features/deep.py b/PaddleCV/tracking/pytracking/features/deep.py index 376acf6d07..86a72a8040 100644 --- a/PaddleCV/tracking/pytracking/features/deep.py +++ b/PaddleCV/tracking/pytracking/features/deep.py @@ -5,6 +5,7 @@ from ltr.models.bbreg.atom import atom_resnet50, atom_resnet18 from ltr.models.siamese.siam import siamfc_alexnet +from ltr.models.siam.siam import SiamRPN_AlexNet, SiamMask_ResNet50_sharp, SiamMask_ResNet50_base from pytracking.admin.environment import env_settings from pytracking.features.featurebase import MultiFeatureBase from pytracking.libs import TensorList @@ -347,3 +348,147 @@ def extract(self, im: np.ndarray, debug_save_name=None): output_features[layer].numpy() for layer in self.output_layers ]) return output + + +class SRPNAlexNet(MultiFeatureBase): + """Alexnet feature. + args: + output_layers: List of layers to output. + net_path: Relative or absolute net path (default should be fine). + use_gpu: Use GPU or CPU. + """ + + def __init__(self, + net_path='estimator', + use_gpu=True, + *args, **kwargs): + super().__init__(*args, **kwargs) + + self.use_gpu = use_gpu + self.net_path = net_path + + def initialize(self): + with fluid.dygraph.guard(): + if os.path.isabs(self.net_path): + net_path_full = self.net_path + else: + net_path_full = os.path.join(env_settings().network_path, self.net_path) + + self.net = SiamRPN_AlexNet(backbone_pretrained=False, is_test=True) + + state_dict, _ = fluid.load_dygraph(net_path_full) + self.net.load_dict(state_dict) + self.net.eval() + + def free_memory(self): + if hasattr(self, 'net'): + del self.net + + def extract(self, im: np.ndarray, debug_save_name=None): + with fluid.dygraph.guard(): + if debug_save_name is not None: + np.savez(debug_save_name, im) + + im = n2p(im) + + output_features = self.net.extract_backbone_features(im) + + # Store the raw backbone features which are input to estimator + output = TensorList([layer.numpy() for layer in output_features]) + return output + + +class SMaskResNet50_base(MultiFeatureBase): + """Resnet50-dilated feature. + args: + output_layers: List of layers to output. + net_path: Relative or absolute net path (default should be fine). + use_gpu: Use GPU or CPU. + """ + + def __init__(self, + net_path='estimator', + use_gpu=True, + *args, **kwargs): + super().__init__(*args, **kwargs) + + self.use_gpu = use_gpu + self.net_path = net_path + + def initialize(self): + with fluid.dygraph.guard(): + if os.path.isabs(self.net_path): + net_path_full = self.net_path + else: + net_path_full = os.path.join(env_settings().network_path, self.net_path) + + self.net = SiamMask_ResNet50_base(backbone_pretrained=False, is_test=True) + + state_dict, _ = fluid.load_dygraph(net_path_full) + self.net.load_dict(state_dict) + self.net.eval() + + def free_memory(self): + if hasattr(self, 'net'): + del self.net + + def extract(self, im: np.ndarray, debug_save_name=None): + with fluid.dygraph.guard(): + if debug_save_name is not None: + np.savez(debug_save_name, im) + + im = n2p(im) + + output_features = self.net.extract_backbone_features(im) + + # Store the raw backbone features which are input to estimator + output = TensorList([layer.numpy() for layer in output_features]) + return output + + +class SMaskResNet50_sharp(MultiFeatureBase): + """Resnet50-dilated feature. + args: + output_layers: List of layers to output. + net_path: Relative or absolute net path (default should be fine). + use_gpu: Use GPU or CPU. + """ + + def __init__(self, + net_path='estimator', + use_gpu=True, + *args, **kwargs): + super().__init__(*args, **kwargs) + + self.use_gpu = use_gpu + self.net_path = net_path + + def initialize(self): + with fluid.dygraph.guard(): + if os.path.isabs(self.net_path): + net_path_full = self.net_path + else: + net_path_full = os.path.join(env_settings().network_path, self.net_path) + + self.net = SiamMask_ResNet50_sharp(backbone_pretrained=False, is_test=True) + + state_dict, _ = fluid.load_dygraph(net_path_full) + self.net.load_dict(state_dict) + self.net.eval() + + def free_memory(self): + if hasattr(self, 'net'): + del self.net + + def extract(self, im: np.ndarray, debug_save_name=None): + with fluid.dygraph.guard(): + if debug_save_name is not None: + np.savez(debug_save_name, im) + + im = n2p(im) + + output_features = self.net.extract_backbone_features(im) + + # Store the raw backbone features which are input to estimator + output = TensorList([layer.numpy() for layer in output_features]) + return output diff --git a/PaddleCV/tracking/pytracking/libs/Fconv2d.py b/PaddleCV/tracking/pytracking/libs/Fconv2d.py index 36cafcb65a..8880d97141 100644 --- a/PaddleCV/tracking/pytracking/libs/Fconv2d.py +++ b/PaddleCV/tracking/pytracking/libs/Fconv2d.py @@ -1,189 +1,189 @@ -from __future__ import print_function - -import numpy as np -from paddle.fluid.framework import Variable, in_dygraph_mode -from paddle.fluid import core, dygraph_utils -from paddle.fluid.layers import nn, utils -from paddle.fluid.data_feeder import check_variable_and_dtype -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.layer_helper import LayerHelper - - -def _is_list_or_tuple(input): - return isinstance(input, (list, tuple)) - - -def _zero_padding_in_batch_and_channel(padding, channel_last): - if channel_last: - return list(padding[0]) == [0, 0] and list(padding[-1]) == [0, 0] - else: - return list(padding[0]) == [0, 0] and list(padding[1]) == [0, 0] - - -def _exclude_padding_in_batch_and_channel(padding, channel_last): - padding_ = padding[1:-1] if channel_last else padding[2:] - padding_ = [elem for pad_a_dim in padding_ for elem in pad_a_dim] - return padding_ - - -def _update_padding_nd(padding, channel_last, num_dims): - if isinstance(padding, str): - padding = padding.upper() - if padding not in ["SAME", "VALID"]: - raise ValueError( - "Unknown padding: '{}'. It can only be 'SAME' or 'VALID'.". - format(padding)) - if padding == "VALID": - padding_algorithm = "VALID" - padding = [0] * num_dims - else: - padding_algorithm = "SAME" - padding = [0] * num_dims - elif _is_list_or_tuple(padding): - # for padding like - # [(pad_before, pad_after), (pad_before, pad_after), ...] - # padding for batch_dim and channel_dim included - if len(padding) == 2 + num_dims and _is_list_or_tuple(padding[0]): - if not _zero_padding_in_batch_and_channel(padding, channel_last): - raise ValueError( - "Non-zero padding({}) in the batch or channel dimensions " - "is not supported.".format(padding)) - padding_algorithm = "EXPLICIT" - padding = _exclude_padding_in_batch_and_channel(padding, - channel_last) - if utils._is_symmetric_padding(padding, num_dims): - padding = padding[0::2] - # for padding like [pad_before, pad_after, pad_before, pad_after, ...] - elif len(padding) == 2 * num_dims and isinstance(padding[0], int): - padding_algorithm = "EXPLICIT" - padding = utils.convert_to_list(padding, 2 * num_dims, 'padding') - if utils._is_symmetric_padding(padding, num_dims): - padding = padding[0::2] - # for padding like [pad_d1, pad_d2, ...] - elif len(padding) == num_dims and isinstance(padding[0], int): - padding_algorithm = "EXPLICIT" - padding = utils.convert_to_list(padding, num_dims, 'padding') - else: - raise ValueError("In valid padding: {}".format(padding)) - # for integer padding - else: - padding_algorithm = "EXPLICIT" - padding = utils.convert_to_list(padding, num_dims, 'padding') - return padding, padding_algorithm - - -def FConv2D(input, - weight, - bias=None, - padding=0, - stride=1, - dilation=1, - groups=1, - use_cudnn=True, - act=None, - data_format="NCHW", - name=None): - # entry checks - if not isinstance(use_cudnn, bool): - raise ValueError("Attr(use_cudnn) should be True or False. " - "Received Attr(use_cudnn): {}.".format(use_cudnn)) - if data_format not in ["NCHW", "NHWC"]: - raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. " - "Received Attr(data_format): {}.".format(data_format)) - - channel_last = (data_format == "NHWC") - channel_dim = -1 if channel_last else 1 - num_channels = input.shape[channel_dim] - num_filters = weight.shape[0] - if num_channels < 0: - raise ValueError("The channel dimmention of the input({}) " - "should be defined. Received: {}.".format( - input.shape, num_channels)) - if num_channels % groups != 0: - raise ValueError( - "the channel of input must be divisible by groups," - "received: the channel of input is {}, the shape of input is {}" - ", the groups is {}".format(num_channels, input.shape, groups)) - if num_filters % groups != 0: - raise ValueError( - "the number of filters must be divisible by groups," - "received: the number of filters is {}, the shape of weight is {}" - ", the groups is {}".format(num_filters, weight.shape, groups)) - - # update attrs - padding, padding_algorithm = _update_padding_nd(padding, channel_last, 2) - stride = utils.convert_to_list(stride, 2, 'stride') - dilation = utils.convert_to_list(dilation, 2, 'dilation') - - l_type = "conv2d" - if (num_channels == groups and num_filters % num_channels == 0 and - not use_cudnn): - l_type = 'depthwise_conv2d' - - inputs = {'Input': [input], 'Filter': [weight]} - attrs = { - 'strides': stride, - 'paddings': padding, - 'dilations': dilation, - 'groups': groups, - 'use_cudnn': use_cudnn, - 'use_mkldnn': False, - 'fuse_relu_before_depthwise_conv': False, - "padding_algorithm": padding_algorithm, - "data_format": data_format - } - - if in_dygraph_mode(): - attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, - 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False, - 'fuse_relu_before_depthwise_conv', False, "padding_algorithm", - padding_algorithm, "data_format", data_format) - pre_bias = getattr(core.ops, l_type)(input, weight, *attrs) - if bias is not None: - pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim) - else: - pre_act = pre_bias - out = dygraph_utils._append_activation_in_dygraph( - pre_act, act, use_cudnn=use_cudnn) - else: - inputs = {'Input': [input], 'Filter': [weight]} - attrs = { - 'strides': stride, - 'paddings': padding, - 'dilations': dilation, - 'groups': groups, - 'use_cudnn': use_cudnn, - 'use_mkldnn': False, - 'fuse_relu_before_depthwise_conv': False, - "padding_algorithm": padding_algorithm, - "data_format": data_format - } - check_variable_and_dtype(input, 'input', - ['float16', 'float32', 'float64'], 'conv2d') - helper = LayerHelper(l_type, **locals()) - dtype = helper.input_dtype() - pre_bias = helper.create_variable_for_type_inference(dtype) - outputs = {"Output": [pre_bias]} - helper.append_op( - type=l_type, inputs=inputs, outputs=outputs, attrs=attrs) - if bias is not None: - pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim) - else: - pre_act = pre_bias - out = helper.append_activation(pre_act) - return out - - -def test_conv2d_with_filter(): - - import paddle.fluid.dygraph as dygraph - import numpy as np - - exemplar = np.random.random((8, 4, 6, 6)).astype(np.float32) - instance = np.random.random((8, 4, 22, 22)).astype(np.float32) - - with dygraph.guard(): - exem = dygraph.to_variable(exemplar) - inst = dygraph.to_variable(instance) - res = FConv2D(inst, exem, groups=1) +from __future__ import print_function + +import numpy as np +from paddle.fluid.framework import Variable, in_dygraph_mode +from paddle.fluid import core, dygraph_utils +from paddle.fluid.layers import nn, utils +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.layer_helper import LayerHelper + + +def _is_list_or_tuple(input): + return isinstance(input, (list, tuple)) + + +def _zero_padding_in_batch_and_channel(padding, channel_last): + if channel_last: + return list(padding[0]) == [0, 0] and list(padding[-1]) == [0, 0] + else: + return list(padding[0]) == [0, 0] and list(padding[1]) == [0, 0] + + +def _exclude_padding_in_batch_and_channel(padding, channel_last): + padding_ = padding[1:-1] if channel_last else padding[2:] + padding_ = [elem for pad_a_dim in padding_ for elem in pad_a_dim] + return padding_ + + +def _update_padding_nd(padding, channel_last, num_dims): + if isinstance(padding, str): + padding = padding.upper() + if padding not in ["SAME", "VALID"]: + raise ValueError( + "Unknown padding: '{}'. It can only be 'SAME' or 'VALID'.". + format(padding)) + if padding == "VALID": + padding_algorithm = "VALID" + padding = [0] * num_dims + else: + padding_algorithm = "SAME" + padding = [0] * num_dims + elif _is_list_or_tuple(padding): + # for padding like + # [(pad_before, pad_after), (pad_before, pad_after), ...] + # padding for batch_dim and channel_dim included + if len(padding) == 2 + num_dims and _is_list_or_tuple(padding[0]): + if not _zero_padding_in_batch_and_channel(padding, channel_last): + raise ValueError( + "Non-zero padding({}) in the batch or channel dimensions " + "is not supported.".format(padding)) + padding_algorithm = "EXPLICIT" + padding = _exclude_padding_in_batch_and_channel(padding, + channel_last) + if utils._is_symmetric_padding(padding, num_dims): + padding = padding[0::2] + # for padding like [pad_before, pad_after, pad_before, pad_after, ...] + elif len(padding) == 2 * num_dims and isinstance(padding[0], int): + padding_algorithm = "EXPLICIT" + padding = utils.convert_to_list(padding, 2 * num_dims, 'padding') + if utils._is_symmetric_padding(padding, num_dims): + padding = padding[0::2] + # for padding like [pad_d1, pad_d2, ...] + elif len(padding) == num_dims and isinstance(padding[0], int): + padding_algorithm = "EXPLICIT" + padding = utils.convert_to_list(padding, num_dims, 'padding') + else: + raise ValueError("In valid padding: {}".format(padding)) + # for integer padding + else: + padding_algorithm = "EXPLICIT" + padding = utils.convert_to_list(padding, num_dims, 'padding') + return padding, padding_algorithm + + +def FConv2D(input, + weight, + bias=None, + padding=0, + stride=1, + dilation=1, + groups=1, + use_cudnn=True, + act=None, + data_format="NCHW", + name=None): + # entry checks + if not isinstance(use_cudnn, bool): + raise ValueError("Attr(use_cudnn) should be True or False. " + "Received Attr(use_cudnn): {}.".format(use_cudnn)) + if data_format not in ["NCHW", "NHWC"]: + raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. " + "Received Attr(data_format): {}.".format(data_format)) + + channel_last = (data_format == "NHWC") + channel_dim = -1 if channel_last else 1 + num_channels = input.shape[channel_dim] + num_filters = weight.shape[0] + if num_channels < 0: + raise ValueError("The channel dimmention of the input({}) " + "should be defined. Received: {}.".format( + input.shape, num_channels)) + if num_channels % groups != 0: + raise ValueError( + "the channel of input must be divisible by groups," + "received: the channel of input is {}, the shape of input is {}" + ", the groups is {}".format(num_channels, input.shape, groups)) + if num_filters % groups != 0: + raise ValueError( + "the number of filters must be divisible by groups," + "received: the number of filters is {}, the shape of weight is {}" + ", the groups is {}".format(num_filters, weight.shape, groups)) + + # update attrs + padding, padding_algorithm = _update_padding_nd(padding, channel_last, 2) + stride = utils.convert_to_list(stride, 2, 'stride') + dilation = utils.convert_to_list(dilation, 2, 'dilation') + + l_type = "conv2d" + if (num_channels == groups and num_filters % num_channels == 0 and + not use_cudnn): + l_type = 'depthwise_conv2d' + + inputs = {'Input': [input], 'Filter': [weight]} + attrs = { + 'strides': stride, + 'paddings': padding, + 'dilations': dilation, + 'groups': groups, + 'use_cudnn': use_cudnn, + 'use_mkldnn': False, + 'fuse_relu_before_depthwise_conv': False, + "padding_algorithm": padding_algorithm, + "data_format": data_format + } + + if in_dygraph_mode(): + attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, + 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False, + 'fuse_relu_before_depthwise_conv', False, "padding_algorithm", + padding_algorithm, "data_format", data_format) + pre_bias = getattr(core.ops, l_type)(input, weight, *attrs) + if bias is not None: + pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim) + else: + pre_act = pre_bias + out = dygraph_utils._append_activation_in_dygraph( + pre_act, act, use_cudnn=use_cudnn) + else: + inputs = {'Input': [input], 'Filter': [weight]} + attrs = { + 'strides': stride, + 'paddings': padding, + 'dilations': dilation, + 'groups': groups, + 'use_cudnn': use_cudnn, + 'use_mkldnn': False, + 'fuse_relu_before_depthwise_conv': False, + "padding_algorithm": padding_algorithm, + "data_format": data_format + } + check_variable_and_dtype(input, 'input', + ['float16', 'float32', 'float64'], 'conv2d') + helper = LayerHelper(l_type, **locals()) + dtype = helper.input_dtype() + pre_bias = helper.create_variable_for_type_inference(dtype) + outputs = {"Output": [pre_bias]} + helper.append_op( + type=l_type, inputs=inputs, outputs=outputs, attrs=attrs) + if bias is not None: + pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim) + else: + pre_act = pre_bias + out = helper.append_activation(pre_act) + return out + + +def test_conv2d_with_filter(): + + import paddle.fluid.dygraph as dygraph + import numpy as np + + exemplar = np.random.random((8, 4, 6, 6)).astype(np.float32) + instance = np.random.random((8, 4, 22, 22)).astype(np.float32) + + with dygraph.guard(): + exem = dygraph.to_variable(exemplar) + inst = dygraph.to_variable(instance) + res = FConv2D(inst, exem, groups=1) print(res.shape) \ No newline at end of file diff --git a/PaddleCV/tracking/pytracking/libs/operation.py b/PaddleCV/tracking/pytracking/libs/operation.py index f9a095a2cf..94bf2f4a0e 100644 --- a/PaddleCV/tracking/pytracking/libs/operation.py +++ b/PaddleCV/tracking/pytracking/libs/operation.py @@ -1,59 +1,59 @@ -from paddle import fluid -from paddle.fluid import layers -from pytracking.libs.Fconv2d import FConv2D -from pytracking.libs.tensorlist import tensor_operation, TensorList -from paddle.fluid.framework import Variable as PTensor - - -@tensor_operation -def conv2d(input: PTensor, - weight: PTensor, - bias: PTensor=None, - stride=1, - padding=0, - dilation=1, - groups=1, - mode=None): - """Standard conv2d. Returns the input if weight=None.""" - - if weight is None: - return input - - ind = None - if mode is not None: - if padding != 0: - raise ValueError('Cannot input both padding and mode.') - if mode == 'same': - padding = (weight.shape[2] // 2, weight.shape[3] // 2) - if weight.shape[2] % 2 == 0 or weight.shape[3] % 2 == 0: - ind = (slice(-1) - if weight.shape[2] % 2 == 0 else slice(None), slice(-1) - if weight.shape[3] % 2 == 0 else slice(None)) - elif mode == 'valid': - padding = (0, 0) - elif mode == 'full': - padding = (weight.shape[2] - 1, weight.shape[3] - 1) - else: - raise ValueError('Unknown mode for padding.') - - assert bias is None - out = FConv2D( - input, - weight, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups) - if ind is None: - return out - return out[:, :, ind[0], ind[1]] - - -@tensor_operation -def conv1x1(input: PTensor, weight: PTensor): - """Do a convolution with a 1x1 kernel weights. Implemented with matmul, which can be faster than using conv.""" - - if weight is None: - return input - - return FConv2D(input, weight) +from paddle import fluid +from paddle.fluid import layers +from pytracking.libs.Fconv2d import FConv2D +from pytracking.libs.tensorlist import tensor_operation, TensorList +from paddle.fluid.framework import Variable as PTensor + + +@tensor_operation +def conv2d(input: PTensor, + weight: PTensor, + bias: PTensor=None, + stride=1, + padding=0, + dilation=1, + groups=1, + mode=None): + """Standard conv2d. Returns the input if weight=None.""" + + if weight is None: + return input + + ind = None + if mode is not None: + if padding != 0: + raise ValueError('Cannot input both padding and mode.') + if mode == 'same': + padding = (weight.shape[2] // 2, weight.shape[3] // 2) + if weight.shape[2] % 2 == 0 or weight.shape[3] % 2 == 0: + ind = (slice(-1) + if weight.shape[2] % 2 == 0 else slice(None), slice(-1) + if weight.shape[3] % 2 == 0 else slice(None)) + elif mode == 'valid': + padding = (0, 0) + elif mode == 'full': + padding = (weight.shape[2] - 1, weight.shape[3] - 1) + else: + raise ValueError('Unknown mode for padding.') + + assert bias is None + out = FConv2D( + input, + weight, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups) + if ind is None: + return out + return out[:, :, ind[0], ind[1]] + + +@tensor_operation +def conv1x1(input: PTensor, weight: PTensor): + """Do a convolution with a 1x1 kernel weights. Implemented with matmul, which can be faster than using conv.""" + + if weight is None: + return input + + return FConv2D(input, weight) diff --git a/PaddleCV/tracking/pytracking/parameter/siammask/base_default.py b/PaddleCV/tracking/pytracking/parameter/siammask/base_default.py new file mode 100644 index 0000000000..32ebd4ec6c --- /dev/null +++ b/PaddleCV/tracking/pytracking/parameter/siammask/base_default.py @@ -0,0 +1,43 @@ +import numpy as np + +from pytracking.features import deep +from pytracking.features.extractor import MultiResolutionExtractor +from pytracking.utils import TrackerParams, FeatureParams + + +def parameters(): + params = TrackerParams() + + # These are usually set from outside + params.debug = 0 # Debug level + params.visualization = False # Do visualization + + # Use GPU or not (IoUNet requires this to be True) + params.use_gpu = True + + # Feature specific parameters + deep_params = TrackerParams() + + # Patch sampling parameters + params.exemplar_size = 127 + params.instance_size = 255 + params.base_size = 8 + params.context_amount = 0.5 + + # Anchor parameters + params.anchor_stride = 8 + params.anchor_ratios = [0.33, 0.5, 1, 2, 3] + params.anchor_scales = [8] + + # Tracking parameters + params.penalty_k = 0.1 + params.window_influence = 0.41 + params.lr = 0.32 + params.mask_threshold = 0.15 + + # Setup the feature extractor + deep_fparams = FeatureParams(feature_params=[deep_params]) + deep_feat = deep.SMaskResNet50_base(fparams=deep_fparams) + params.features = MultiResolutionExtractor([deep_feat]) + + return params diff --git a/PaddleCV/tracking/pytracking/parameter/siammask/sharp_default_otb.py b/PaddleCV/tracking/pytracking/parameter/siammask/sharp_default_otb.py new file mode 100644 index 0000000000..d880ae9f8a --- /dev/null +++ b/PaddleCV/tracking/pytracking/parameter/siammask/sharp_default_otb.py @@ -0,0 +1,47 @@ +import numpy as np + +from pytracking.features import deep +from pytracking.features.extractor import MultiResolutionExtractor +from pytracking.utils import TrackerParams, FeatureParams + + +def parameters(): + params = TrackerParams() + + # These are usually set from outside + params.debug = 0 # Debug level + params.visualization = False # Do visualization + + # Use GPU or not (IoUNet requires this to be True) + params.use_gpu = True + + # Feature specific parameters + deep_params = TrackerParams() + + # Patch sampling parameters + params.exemplar_size = 127 + params.instance_size = 255 + params.base_size = 8 + params.context_amount = 0.5 + params.mask_output_size = 127 + + # Anchor parameters + params.anchor_stride = 8 + params.anchor_ratios = [0.33, 0.5, 1, 2, 3] + params.anchor_scales = [8] + + # Tracking parameters + params.penalty_k = 0.04 + params.window_influence = 0.42 + params.lr = 0.25 + params.mask_threshold = 0.30 + + # output rect result + params.polygon = False + + # Setup the feature extractor + deep_fparams = FeatureParams(feature_params=[deep_params]) + deep_feat = deep.SMaskResNet50_sharp(fparams=deep_fparams) + params.features = MultiResolutionExtractor([deep_feat]) + + return params diff --git a/PaddleCV/tracking/pytracking/parameter/siammask/sharp_default_vot.py b/PaddleCV/tracking/pytracking/parameter/siammask/sharp_default_vot.py new file mode 100644 index 0000000000..509e948b93 --- /dev/null +++ b/PaddleCV/tracking/pytracking/parameter/siammask/sharp_default_vot.py @@ -0,0 +1,47 @@ +import numpy as np + +from pytracking.features import deep +from pytracking.features.extractor import MultiResolutionExtractor +from pytracking.utils import TrackerParams, FeatureParams + + +def parameters(): + params = TrackerParams() + + # These are usually set from outside + params.debug = 0 # Debug level + params.visualization = False # Do visualization + + # Use GPU or not (IoUNet requires this to be True) + params.use_gpu = True + + # Feature specific parameters + deep_params = TrackerParams() + + # Patch sampling parameters + params.exemplar_size = 127 + params.instance_size = 255 + params.base_size = 8 + params.context_amount = 0.5 + params.mask_output_size = 127 + + # Anchor parameters + params.anchor_stride = 8 + params.anchor_ratios = [0.33, 0.5, 1, 2, 3] + params.anchor_scales = [8] + + # Tracking parameters + params.penalty_k = 0.20 + params.window_influence = 0.41 + params.lr = 0.30 + params.mask_threshold = 0.30 + + # output polygon result + params.polygon = True + + # Setup the feature extractor + deep_fparams = FeatureParams(feature_params=[deep_params]) + deep_feat = deep.SMaskResNet50_sharp(fparams=deep_fparams) + params.features = MultiResolutionExtractor([deep_feat]) + + return params diff --git a/PaddleCV/tracking/pytracking/parameter/siamrpn/default_otb.py b/PaddleCV/tracking/pytracking/parameter/siamrpn/default_otb.py new file mode 100644 index 0000000000..e088d8bb57 --- /dev/null +++ b/PaddleCV/tracking/pytracking/parameter/siamrpn/default_otb.py @@ -0,0 +1,42 @@ +import numpy as np + +from pytracking.features import deep +from pytracking.features.extractor import MultiResolutionExtractor +from pytracking.utils import TrackerParams, FeatureParams + + +def parameters(): + params = TrackerParams() + + # These are usually set from outside + params.debug = 0 # Debug level + params.visualization = False # Do visualization + + # Use GPU or not (IoUNet requires this to be True) + params.use_gpu = True + + # Feature specific parameters + deep_params = TrackerParams() + + # Patch sampling parameters + params.exemplar_size = 127 + params.instance_size = 287 + params.base_size = 0 + params.context_amount = 0.5 + + # Anchor parameters + params.anchor_stride = 8 + params.anchor_ratios = [0.33, 0.5, 1, 2, 3] + params.anchor_scales = [8] + + # Tracking parameters + params.penalty_k = 0.18 + params.window_influence = 0.41 + params.lr = 0.05 + + # Setup the feature extractor + deep_fparams = FeatureParams(feature_params=[deep_params]) + deep_feat = deep.SRPNAlexNet(fparams=deep_fparams) + params.features = MultiResolutionExtractor([deep_feat]) + + return params diff --git a/PaddleCV/tracking/pytracking/tracker/siammask/__init__.py b/PaddleCV/tracking/pytracking/tracker/siammask/__init__.py new file mode 100755 index 0000000000..1a5145b57a --- /dev/null +++ b/PaddleCV/tracking/pytracking/tracker/siammask/__init__.py @@ -0,0 +1,4 @@ +from .siammask import SiamMask + +def get_tracker_class(): + return SiamMask diff --git a/PaddleCV/tracking/pytracking/tracker/siammask/siammask.py b/PaddleCV/tracking/pytracking/tracker/siammask/siammask.py new file mode 100755 index 0000000000..7b0c73e8ab --- /dev/null +++ b/PaddleCV/tracking/pytracking/tracker/siammask/siammask.py @@ -0,0 +1,293 @@ +import time +import math +import cv2 +import numpy as np + +import paddle.fluid as fluid +from paddle.fluid import dygraph + +from pytracking.tracker.base.basetracker import BaseTracker +from ltr.data.anchor import Anchors + + +class SiamMask(BaseTracker): + def initialize_features(self): + if not getattr(self, 'features_initialized', False): + self.params.features.initialize() + self.features_initialized = True + + def initialize(self, image, state, *args, **kwargs): + # Initialize some stuff + self.frame_num = 1 + + # Initialize features + self.initialize_features() + + self.time = 0 + tic = time.time() + + # Get position and size + # self.pos: target center (y, x) + self.pos = np.array( + [ + state[1] + state[3] // 2, + state[0] + state[2] // 2 + ], + dtype=np.float32) + self.target_sz = np.array([state[3], state[2]], dtype=np.float32) + + # Set search area + context = self.params.context_amount * np.sum(self.target_sz) + self.z_sz = np.sqrt(np.prod(self.target_sz + context)) + self.x_sz = round(self.z_sz * (self.params.instance_size / self.params.exemplar_size)) + + self.score_size = (self.params.instance_size - self.params.exemplar_size) // \ + self.params.anchor_stride + 1 + self.params.base_size + self.anchor_num = len(self.params.anchor_ratios) * len(self.params.anchor_scales) + hanning = np.hanning(self.score_size) + window = np.outer(hanning, hanning) + self.window = np.tile(window.flatten(), self.anchor_num) + self.anchors = self.generate_anchor(self.score_size) + + # Convert image + self.avg_color = np.mean(image, axis=(0, 1)) + with dygraph.guard(): + exemplar_image = self._crop_and_resize( + image, + self.pos, + self.z_sz, + out_size=self.params.exemplar_size, + pad_color=self.avg_color) + + # get template + self.params.features.features[0].net.template(exemplar_image) + + self.time += time.time() - tic + + def track(self, image): + self.frame_num += 1 + + # Convert image + image = np.asarray(image) + + with dygraph.guard(): + # search images + instance_image = self._crop_and_resize( + image, + self.pos, + self.x_sz, + out_size=self.params.instance_size, + pad_color=self.avg_color) + instance_box = [ + self.pos[1] - self.x_sz / 2, + self.pos[0] - self.x_sz / 2, + self.x_sz, + self.x_sz] + # predict + output = self.params.features.features[0].net.track(instance_image) + score = self._convert_score(output['cls']) + pred_bbox = self._convert_bbox(output['loc'], self.anchors) + + def change(r): + return np.maximum(r, 1. / r) + + def sz(w, h): + pad = (w + h) * 0.5 + return np.sqrt((w + pad) * (h + pad)) + + # scale penalty + scale_z = self.params.exemplar_size / self.z_sz + s_c = change(sz(pred_bbox[2, :], pred_bbox[3, :]) / + (sz(self.target_sz[1]*scale_z, self.target_sz[0]*scale_z))) + + # aspect ratio penalty + r_c = change((self.target_sz[1]/self.target_sz[0]) / + (pred_bbox[2, :]/pred_bbox[3, :])) + penalty = np.exp(-(r_c * s_c - 1) * self.params.penalty_k) + pscore = penalty * score + + # window penalty + pscore = pscore * (1 - self.params.window_influence) + \ + self.window * self.params.window_influence + best_idx = np.argmax(pscore) + + bbox = pred_bbox[:, best_idx] / scale_z + lr = penalty[best_idx] * score[best_idx] * self.params.lr + + cx = bbox[0] + self.pos[1] + cy = bbox[1] + self.pos[0] + + # smooth bbox + width = self.target_sz[1] * (1 - lr) + bbox[2] * lr + height = self.target_sz[0] * (1 - lr) + bbox[3] * lr + + # clip boundary + cx, cy, width, height = self._bbox_clip(cx, cy, width, height, image.shape[:2]) + + # update state + self.pos = np.array([cy, cx]) + self.target_sz = np.array([height, width]) + context = self.params.context_amount * np.sum(self.target_sz) + self.z_sz = np.sqrt(np.prod(self.target_sz + context)) + self.x_sz = round(self.z_sz * (self.params.instance_size / self.params.exemplar_size)) + + if self.params.features.features[0].net.refine_head is None or not self.params.polygon: + # Return new state + yx = self.pos - self.target_sz / 2 + new_state = np.array([yx[1], yx[0], self.target_sz[1], self.target_sz[0]], 'float32') + return new_state.tolist() + + # processing mask + pos = np.unravel_index(best_idx, (5, self.score_size, self.score_size)) + delta_x, delta_y = int(pos[2]), int(pos[1]) + with dygraph.guard(): + mask = self.params.features.features[0].net.mask_refine((delta_y, delta_x)) + mask = fluid.layers.sigmoid(mask) + mask = fluid.layers.reshape(mask, [-1]) + out_size = self.params.mask_output_size + mask = fluid.layers.reshape(mask,[out_size, out_size]).numpy() + + s = instance_box[2] / self.params.instance_size + base_size = self.params.base_size + stride = self.params.anchor_stride + sub_box = [instance_box[0] + (delta_x - base_size/2) * stride * s, + instance_box[1] + (delta_y - base_size/2) * stride * s, + s * self.params.exemplar_size, + s * self.params.exemplar_size] + s = out_size / sub_box[2] + + im_h, im_w = image.shape[:2] + back_box = [-sub_box[0] * s, -sub_box[1] * s, im_w*s, im_h*s] + mask_in_img = self._crop_back(mask, back_box, (im_w, im_h)) + polygon = self._mask_post_processing(mask_in_img) + # Return new state + new_state = polygon.flatten() + + return new_state.tolist() + + def generate_anchor(self, score_size): + anchors = Anchors( + self.params.anchor_stride, + self.params.anchor_ratios, + self.params.anchor_scales) + anchor = anchors.anchors + x1, y1, x2, y2 = anchor[:, 0], anchor[:, 1], anchor[:, 2], anchor[:, 3] + anchor = np.stack([(x1+x2)*0.5, (y1+y2)*0.5, x2-x1, y2-y1], 1) + total_stride = anchors.stride + anchor_num = anchor.shape[0] + anchor = np.tile(anchor, score_size * score_size).reshape((-1, 4)) + ori = - (score_size // 2) * total_stride + xx, yy = np.meshgrid( + [ori + total_stride * dx for dx in range(score_size)], + [ori + total_stride * dy for dy in range(score_size)]) + xx, yy = np.tile(xx.flatten(), (anchor_num, 1)).flatten(), \ + np.tile(yy.flatten(), (anchor_num, 1)).flatten() + anchor[:, 0], anchor[:, 1] = xx.astype(np.float32), yy.astype(np.float32) + return anchor + + def _crop_and_resize(self, image, center, size, out_size, pad_color): + # convert box to corners (0-indexed) + size = round(size) + corners = np.concatenate( + ( + np.floor(center - (size + 1) / 2 + 0.5), + np.floor(center - (size + 1) / 2 + 0.5) + size + )) + corners = np.round(corners).astype(int) + + # pad image if necessary + pads = np.concatenate((-corners[:2], corners[2:] - image.shape[:2])) + npad = max(0, int(pads.max())) + if npad > 0: + image = cv2.copyMakeBorder( + image, + npad, + npad, + npad, + npad, + cv2.BORDER_CONSTANT, + value=pad_color) + + # crop image patch + corners = (corners + npad).astype(int) + patch = image[corners[0]:corners[2], corners[1]:corners[3]] + + # resize to out_size + patch = cv2.resize(patch, (out_size, out_size)) + + patch = patch.transpose(2, 0, 1) + patch = patch[np.newaxis, :, :, :] + patch = patch.astype(np.float32) + patch = fluid.dygraph.to_variable(patch) + return patch + + def _convert_bbox(self, delta, anchor): + delta = fluid.layers.transpose(delta, [1, 2, 3, 0]) + delta = fluid.layers.reshape(delta, [4, -1]).numpy() + + delta[0, :] = delta[0, :] * anchor[:, 2] + anchor[:, 0] + delta[1, :] = delta[1, :] * anchor[:, 3] + anchor[:, 1] + delta[2, :] = np.exp(delta[2, :]) * anchor[:, 2] + delta[3, :] = np.exp(delta[3, :]) * anchor[:, 3] + return delta + + def _convert_score(self, score): + score = fluid.layers.transpose(score, [1, 2, 3, 0]) + score = fluid.layers.reshape(score, [2, -1]) + score = fluid.layers.transpose(score, [1, 0]) + score = fluid.layers.softmax(score, axis=1)[:, 1].numpy() + return score + + def _bbox_clip(self, cx, cy, width, height, boundary): + cx = max(0, min(cx, boundary[1])) + cy = max(0, min(cy, boundary[0])) + width = max(10, min(width, boundary[1])) + height = max(10, min(height, boundary[0])) + return cx, cy, width, height + + def _crop_back(self, image, bbox, out_sz, padding=0): + a = (out_sz[0] - 1) / bbox[2] + b = (out_sz[1] - 1) / bbox[3] + c = -a * bbox[0] + d = -b * bbox[1] + mapping = np.array([[a, 0, c], [0, b, d]]).astype(np.float) + crop = cv2.warpAffine( + image, + mapping, + (out_sz[0], out_sz[1]), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=padding) + return crop + + def _mask_post_processing(self, mask): + target_mask = (mask > self.params.mask_threshold) + target_mask = target_mask.astype(np.uint8) + if cv2.__version__[-5] == '4': + contours, _ = cv2.findContours( + target_mask, + cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + else: + _, contours, _ = cv2.findContours( + target_mask, + cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + cnt_area = [cv2.contourArea(cnt) for cnt in contours] + if len(contours) != 0 and np.max(cnt_area) > 100: + contour = contours[np.argmax(cnt_area)] + polygon = contour.reshape(-1, 2) + prbox = cv2.boxPoints(cv2.minAreaRect(polygon)) + rbox_in_img = prbox + else: # empty mask + yx = self.pos - self.target_sz / 2 + location = np.array([yx[1], yx[0], self.target_sz[1], self.target_sz[0]], 'float32') + rbox_in_img = np.array( + [ + [location[0], location[1]], + [location[0] + location[2], location[1]], + [location[0] + location[2], location[1] + location[3]], + [location[0], location[1] + location[3]] + ]) + return rbox_in_img + diff --git a/PaddleCV/tracking/pytracking/tracker/siamrpn/__init__.py b/PaddleCV/tracking/pytracking/tracker/siamrpn/__init__.py new file mode 100755 index 0000000000..0c98e68614 --- /dev/null +++ b/PaddleCV/tracking/pytracking/tracker/siamrpn/__init__.py @@ -0,0 +1,4 @@ +from .siamrpn import SiamRPN + +def get_tracker_class(): + return SiamRPN diff --git a/PaddleCV/tracking/pytracking/tracker/siamrpn/siamrpn.py b/PaddleCV/tracking/pytracking/tracker/siamrpn/siamrpn.py new file mode 100755 index 0000000000..5b9385750b --- /dev/null +++ b/PaddleCV/tracking/pytracking/tracker/siamrpn/siamrpn.py @@ -0,0 +1,214 @@ +import time +import math +import cv2 +import numpy as np + +import paddle.fluid as fluid +from paddle.fluid import dygraph + +from pytracking.tracker.base.basetracker import BaseTracker +from ltr.data.anchor import Anchors + + +class SiamRPN(BaseTracker): + def initialize_features(self): + if not getattr(self, 'features_initialized', False): + self.params.features.initialize() + self.features_initialized = True + + def initialize(self, image, state, *args, **kwargs): + # Initialize some stuff + self.frame_num = 1 + + # Initialize features + self.initialize_features() + + self.time = 0 + tic = time.time() + + # Get position and size + # self.pos: target center (y, x) + self.pos = np.array( + [ + state[1] + state[3] // 2, + state[0] + state[2] // 2 + ], + dtype=np.float32) + self.target_sz = np.array([state[3], state[2]], dtype=np.float32) + + # Set search area + context = self.params.context_amount * np.sum(self.target_sz) + self.z_sz = np.sqrt(np.prod(self.target_sz + context)) + self.x_sz = round(self.z_sz * (self.params.instance_size / self.params.exemplar_size)) + + self.score_size = (self.params.instance_size - self.params.exemplar_size) // \ + self.params.anchor_stride + 1 + self.params.base_size + self.anchor_num = len(self.params.anchor_ratios) * len(self.params.anchor_scales) + hanning = np.hanning(self.score_size) + window = np.outer(hanning, hanning) + self.window = np.tile(window.flatten(), self.anchor_num) + self.anchors = self.generate_anchor(self.score_size) + + # Convert image + self.avg_color = np.mean(image, axis=(0, 1)) + with dygraph.guard(): + exemplar_image = self._crop_and_resize( + image, + self.pos, + self.z_sz, + out_size=self.params.exemplar_size, + pad_color=self.avg_color) + + # get template + self.params.features.features[0].net.template(exemplar_image) + + self.time += time.time() - tic + + def track(self, image): + self.frame_num += 1 + + # Convert image + image = np.asarray(image) + + with dygraph.guard(): + # search images + instance_image = self._crop_and_resize( + image, + self.pos, + self.x_sz, + out_size=self.params.instance_size, + pad_color=self.avg_color) + + # predict + output = self.params.features.features[0].net.track(instance_image) + score = self._convert_score(output['cls']) + pred_bbox = self._convert_bbox(output['loc'], self.anchors) + + def change(r): + return np.maximum(r, 1. / r) + + def sz(w, h): + pad = (w + h) * 0.5 + return np.sqrt((w + pad) * (h + pad)) + + # scale penalty + scale_z = self.params.exemplar_size / self.z_sz + s_c = change(sz(pred_bbox[2, :], pred_bbox[3, :]) / + (sz(self.target_sz[1]*scale_z, self.target_sz[0]*scale_z))) + + # aspect ratio penalty + r_c = change((self.target_sz[1]/self.target_sz[0]) / + (pred_bbox[2, :]/pred_bbox[3, :])) + penalty = np.exp(-(r_c * s_c - 1) * self.params.penalty_k) + pscore = penalty * score + + # window penalty + pscore = pscore * (1 - self.params.window_influence) + \ + self.window * self.params.window_influence + best_idx = np.argmax(pscore) + + bbox = pred_bbox[:, best_idx] / scale_z + lr = penalty[best_idx] * score[best_idx] * self.params.lr + + cx = bbox[0] + self.pos[1] + cy = bbox[1] + self.pos[0] + + # smooth bbox + width = self.target_sz[1] * (1 - lr) + bbox[2] * lr + height = self.target_sz[0] * (1 - lr) + bbox[3] * lr + + # clip boundary + cx, cy, width, height = self._bbox_clip(cx, cy, width, height, image.shape[:2]) + + # update state + self.pos = np.array([cy, cx]) + self.target_sz = np.array([height, width]) + context = self.params.context_amount * np.sum(self.target_sz) + self.z_sz = np.sqrt(np.prod(self.target_sz + context)) + self.x_sz = round(self.z_sz * (self.params.instance_size / self.params.exemplar_size)) + + # Return new state + yx = self.pos - self.target_sz / 2 + new_state = np.array([yx[1], yx[0], self.target_sz[1], self.target_sz[0]], 'float32') + + return new_state.tolist() + + def generate_anchor(self, score_size): + anchors = Anchors( + self.params.anchor_stride, + self.params.anchor_ratios, + self.params.anchor_scales) + anchor = anchors.anchors + x1, y1, x2, y2 = anchor[:, 0], anchor[:, 1], anchor[:, 2], anchor[:, 3] + anchor = np.stack([(x1+x2)*0.5, (y1+y2)*0.5, x2-x1, y2-y1], 1) + total_stride = anchors.stride + anchor_num = anchor.shape[0] + anchor = np.tile(anchor, score_size * score_size).reshape((-1, 4)) + ori = - (score_size // 2) * total_stride + xx, yy = np.meshgrid( + [ori + total_stride * dx for dx in range(score_size)], + [ori + total_stride * dy for dy in range(score_size)]) + xx, yy = np.tile(xx.flatten(), (anchor_num, 1)).flatten(), \ + np.tile(yy.flatten(), (anchor_num, 1)).flatten() + anchor[:, 0], anchor[:, 1] = xx.astype(np.float32), yy.astype(np.float32) + return anchor + + def _crop_and_resize(self, image, center, size, out_size, pad_color): + # convert box to corners (0-indexed) + size = round(size) + corners = np.concatenate( + ( + np.floor(center - (size + 1) / 2 + 0.5), + np.floor(center - (size + 1) / 2 + 0.5) + size + )) + corners = np.round(corners).astype(int) + + # pad image if necessary + pads = np.concatenate((-corners[:2], corners[2:] - image.shape[:2])) + npad = max(0, int(pads.max())) + if npad > 0: + image = cv2.copyMakeBorder( + image, + npad, + npad, + npad, + npad, + cv2.BORDER_CONSTANT, + value=pad_color) + + # crop image patch + corners = (corners + npad).astype(int) + patch = image[corners[0]:corners[2], corners[1]:corners[3]] + + # resize to out_size + patch = cv2.resize(patch, (out_size, out_size)) + + patch = patch.transpose(2, 0, 1) + patch = patch[np.newaxis, :, :, :] + patch = patch.astype(np.float32) + patch = fluid.dygraph.to_variable(patch) + return patch + + def _convert_bbox(self, delta, anchor): + delta = fluid.layers.transpose(delta, [1, 2, 3, 0]) + delta = fluid.layers.reshape(delta, [4, -1]).numpy() + + delta[0, :] = delta[0, :] * anchor[:, 2] + anchor[:, 0] + delta[1, :] = delta[1, :] * anchor[:, 3] + anchor[:, 1] + delta[2, :] = np.exp(delta[2, :]) * anchor[:, 2] + delta[3, :] = np.exp(delta[3, :]) * anchor[:, 3] + return delta + + def _convert_score(self, score): + score = fluid.layers.transpose(score, [1, 2, 3, 0]) + score = fluid.layers.reshape(score, [2, -1]) + score = fluid.layers.transpose(score, [1, 0]) + score = fluid.layers.softmax(score, axis=1)[:, 1].numpy() + return score + + def _bbox_clip(self, cx, cy, width, height, boundary): + cx = max(0, min(cx, boundary[1])) + cy = max(0, min(cy, boundary[0])) + width = max(10, min(width, boundary[1])) + height = max(10, min(height, boundary[0])) + return cx, cy, width, height