Skip to content

Commit

Permalink
[Doc]: refactor docs for basedataset (open-mmlab#318)
Browse files Browse the repository at this point in the history
  • Loading branch information
GT9505 committed Jun 21, 2022
1 parent 44538e5 commit 45f5859
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions docs/zh_cn/tutorials/basedataset.md
Expand Up @@ -112,7 +112,7 @@ data
```python
import os.path as osp

from mmengine.data import BaseDataset
from mmengine.dataset import BaseDataset


class ToyDataset(BaseDataset):
Expand All @@ -125,10 +125,10 @@ class ToyDataset(BaseDataset):
# }
def parse_data_info(self, raw_data_info):
data_info = raw_data_info
img_prefix = self.data_prefix.get('img', None)
img_prefix = self.data_prefix.get('img_path', None)
if img_prefix is not None:
data_info['img_path'] = osp.join(
img_prefix, data_info['img_path')
img_prefix, data_info['img_path'])
return data_info

```
Expand All @@ -146,7 +146,7 @@ pipeline = [

toy_dataset = ToyDataset(
data_root='data/',
data_prefix=dict(img='train/'),
data_prefix=dict(img_path='train/'),
ann_file='annotations/train.json',
pipeline=pipeline)
```
Expand Down Expand Up @@ -188,7 +188,7 @@ len(toy_dataset)
在上面的例子中,标注文件的每个原始数据只包含一个训练/测试样本(通常是图像领域)。如果每个原始数据包含若干个训练/测试样本(通常是视频领域),则只需保证 `parse_data_info()` 的返回值为 `list[dict]` 即可:

```python
from mmengine.data import BaseDataset
from mmengine.dataset import BaseDataset


class ToyVideoDataset(BaseDataset):
Expand Down Expand Up @@ -238,7 +238,7 @@ pipeline = [

toy_dataset = ToyDataset(
data_root='data/',
data_prefix=dict(img='train/'),
data_prefix=dict(img_path='train/'),
ann_file='annotations/train.json',
pipeline=pipeline,
# 在这里传入 lazy_init 变量
Expand Down Expand Up @@ -279,7 +279,7 @@ pipeline = [

toy_dataset = ToyDataset(
data_root='data/',
data_prefix=dict(img='train/'),
data_prefix=dict(img_path='train/'),
ann_file='annotations/train.json',
pipeline=pipeline,
# 在这里传入 serialize_data 变量
Expand All @@ -297,7 +297,7 @@ toy_dataset = ToyDataset(
MMEngine 提供了 `ConcatDataset` 包装来拼接多个数据集,使用方法如下:

```python
from mmengine.data import ConcatDataset
from mmengine.dataset import ConcatDataset

pipeline = [
dict(type='xxx', ...),
Expand All @@ -307,13 +307,13 @@ pipeline = [

toy_dataset_1 = ToyDataset(
data_root='data/',
data_prefix=dict(img='train/'),
data_prefix=dict(img_path='train/'),
ann_file='annotations/train.json',
pipeline=pipeline)

toy_dataset_2 = ToyDataset(
data_root='data/',
data_prefix=dict(img='val/'),
data_prefix=dict(img_path='val/'),
ann_file='annotations/val.json',
pipeline=pipeline)

Expand All @@ -328,7 +328,7 @@ toy_dataset_12 = ConcatDataset(datasets=[toy_dataset_1, toy_dataset_2])
MMEngine 提供了 `RepeatDataset` 包装来重复采样某个数据集若干次,使用方法如下:

```python
from mmengine.data import RepeatDataset
from mmengine.dataset import RepeatDataset

pipeline = [
dict(type='xxx', ...),
Expand All @@ -338,7 +338,7 @@ pipeline = [

toy_dataset = ToyDataset(
data_root='data/',
data_prefix=dict(img='train/'),
data_prefix=dict(img_path='train/'),
ann_file='annotations/train.json',
pipeline=pipeline)

Expand All @@ -357,16 +357,16 @@ MMEngine 提供了 `ClassBalancedDataset` 包装,来基于数据集中类别
`ClassBalancedDataset` 包装假设了被包装的数据集类支持 `get_cat_ids(idx)` 方法,`get_cat_ids(idx)` 方法返回一个列表,该列表包含了 `idx` 指定的 `data_info` 包含的样本类别,使用方法如下:

```python
from mmengine.data import BaseDataset, ClassBalancedDataset
from mmengine.dataset import BaseDataset, ClassBalancedDataset

class ToyDataset(BaseDataset):

def parse_data_info(self, raw_data_info):
data_info = raw_data_info
img_prefix = self.data_prefix.get('img', None)
img_prefix = self.data_prefix.get('img_path', None)
if img_prefix is not None:
data_info['img_path'] = osp.join(
img_prefix, data_info['img_path')
img_prefix, data_info['img_path'])
return data_info

# 必须支持的方法,需要返回样本的类别
Expand All @@ -382,7 +382,7 @@ pipeline = [

toy_dataset = ToyDataset(
data_root='data/',
data_prefix=dict(img='train/'),
data_prefix=dict(img_path='train/'),
ann_file='annotations/train.json',
pipeline=pipeline)

Expand All @@ -404,7 +404,7 @@ from mmengine.registry import DATASETS
@DATASETS.register_module()
class ExampleDatasetWrapper:

def __init__(self, dataset, lazy_init = False, ...):
def __init__(self, dataset, lazy_init=False, ...):
# 构建原数据集(self.dataset)
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
Expand Down

0 comments on commit 45f5859

Please sign in to comment.