# Pytorch必要的知识--加载数据的方法code模版
### 学习目标
* 学会使用加载数据的方法，给到模型做训练

### 知识点概要
* Dataset类（或者 IterableDataset 类）负责存储样本以及它们对应的标签
    - Map-sytel 数据集读取的方法（主流方法）
    - Iterable-style 数据集（补充方法：超大数据时可以使用，避免一次性加载产生OOM）
* DataLoader 类负责迭代地访问 Dataset （或者 IterableDataset ）中的样本。

### Dataset
下面我们以加载一个文本分类数据集为例，看看如何具体地创建一个自定义的 Map-style 数据集：

In [None]:
import torch

from datasets import load_dataset


#定义数据集,方便后续模型读取批量数据。
class Dataset(torch.utils.data.Dataset):
    def __init__(self, path, data_type):
        self.data = self.load_data(path, data_type)
    
    def load_data(self, path, data_type):
        tmp_dataset = load_dataset(path, split=data_type)
        Data = {}
        for idx, line in enumerate(tmp_dataset):
            sample = line
            Data[idx] = sample
        return Data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


train_data = Dataset(path='seamew/ChnSentiCorp', data_type='train')
valid_data = Dataset(path='seamew/ChnSentiCorp', data_type='validation')
test_data = Dataset(path='seamew/ChnSentiCorp', data_type='test')

In [2]:
print(valid_data[0])

{'text': '這間酒店環境和服務態度亦算不錯,但房間空間太小~~不宣容納太大件行李~~且房間格調還可以~~ 中餐廳的廣東點心不太好吃~~要改善之~~~~但算價錢平宜~~可接受~~ 西餐廳格調都很好~~但吃的味道一般且令人等得太耐了~~要改善之~~', 'label': 1}


### DataLoaders
* batch_size：每一批次的样本数量
* shuffle：是否打乱数据集
* collate_fn：批处理函数，用于对采样出的 batch 中的样本进行处理（例如我们前面提过的 Padding 操作）

文本及label分割、对每一个样本进行中文=》tokenid、pandding 、截断操作、数值转化为Pytorch张量 。

In [3]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

checkpoint = "bert-base-chinese"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def collote_fn(batch_samples):
    batch_text= []
    batch_label = []
    for sample in batch_samples:
        batch_text.append(sample['text'])
        batch_label.append(int(sample['label']))
    X = tokenizer(
        batch_text, 
        padding=True, 
        truncation=True, 
        return_tensors="pt"
    )
    y = torch.tensor(batch_label)
    return X, y

train_dataloader = DataLoader(train_data, batch_size=4, shuffle=True, collate_fn=collote_fn)
valid_dataloader = DataLoader(valid_data, batch_size=4, shuffle=False, collate_fn=collote_fn)
test_dataloader = DataLoader(test_data, batch_size=4, shuffle=False, collate_fn=collote_fn)

batch_X, batch_y = next(iter(train_dataloader))
print('batch_X shape:', {k: v.shape for k, v in batch_X.items()})
print('batch_y shape:', batch_y.shape)
print(batch_X)
print(batch_y)

batch_X shape: {'input_ids': torch.Size([4, 50]), 'token_type_ids': torch.Size([4, 50]), 'attention_mask': torch.Size([4, 50])}
batch_y shape: torch.Size([4])
{'input_ids': tensor([[ 101, 3403, 1114, 7313, 6772, 1920, 8024, 2970, 4991, 3302, 1218,  679,
         7231,  511, 2600, 1378, 3302, 1218,  679, 1922, 1962, 8024, 6842, 2791,
         5310, 2362, 1400, 8024, 1762, 2600, 1378, 2802,  671, 2356, 6413, 8024,
         1440, 4761, 8038, 7444,  123, 6235, 8013,  102,    0,    0,    0,    0,
            0,    0],
        [ 101, 4212, 4275,  680,  752, 2141,  679, 5016, 8024, 2791, 7313, 1922,
         2207, 8024, 3204, 2255, 2141, 7354,  677, 3300, 2523, 1914, 6983, 2421,
         3683, 6821,  702, 1962, 8024,  817, 3419,  738, 3291,  912, 2139, 8024,
         3025, 4923, 4638,  817, 3419, 1762, 6421, 6983, 2421, 3187,  831, 1232,
          511,  102],
        [ 101, 1912, 6225, 2523, 4023,  778, 8024, 3683, 6772, 6768, 8024, 8024,
         6821, 3416, 4638,  817, 3419, 5543,  743, 116

In [4]:
from torch.utils.data import IterableDataset, DataLoader

class MyIterableDataset(IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start
        self.start = start
        self.end = end

    def __iter__(self):
        return iter(range(self.start, self.end))

ds = MyIterableDataset(start=3, end=7) # [3, 4, 5, 6]
# Single-process loading
print(list(DataLoader(ds, num_workers=0)))
# Directly doing multi-process loading
print(list(DataLoader(ds, num_workers=2)))

[tensor([3]), tensor([4]), tensor([5]), tensor([6])]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[tensor([3]), tensor([3]), tensor([4]), tensor([4]), tensor([5]), tensor([5]), tensor([6]), tensor

In [6]:
from torch.utils.data import get_worker_info
import math
def worker_init_fn(worker_id):
    worker_info = get_worker_info()
    dataset = worker_info.dataset  # the dataset copy in this worker process
    overall_start = dataset.start
    overall_end = dataset.end
    # configure the dataset to only process the split workload
    per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
    worker_id = worker_info.id
    dataset.start = overall_start + worker_id * per_worker
    dataset.end = min(dataset.start + per_worker, overall_end)
    
# Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
print(list(DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
# With even more workers
print(list(DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]
huggingface/tokenizers: The current process just got forked, after parallelism has already been us

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]
