测试 https://github.com/pytorch/pytorch/issues/5059 中提到的 numpy 在多 num_workers 下的表现

In [61]:
import numpy as np
import pandas as pd
import random
from torch.utils.data import DataLoader
import torch


class Transform(object):
    def __init__(self):
        pass

    def __call__(self, item = None):
        return [np.random.randint(10000, 20000), random.randint(20000,30000)]

class RandomDataset(object):

    def __init__(self):
        self.data = list(range(0, 10))
        pass

    def __getitem__(self, ind):
        item = [ind, np.random.randint(1, 10000), np.random.choice([1, 2, 3]), random.randint(10000, 20000), random.choice([1,2,3])]
        tsfm =Transform()(item)
        return np.array(item + tsfm)
    
    def __len__(self):
        return 10


def show(num_workers):
    print(f"num_workers: {num_workers}")
    ds = RandomDataset()
    torch.manual_seed(2)
    ds = DataLoader(ds, 5, shuffle=True, num_workers=num_workers)
    for epoch in range(2):
        print("epoch {}".format(epoch))
        
        for i, batch in enumerate(ds):
            print(f'Batch {i}')
            df = pd.DataFrame(batch.data.numpy(), columns=['idx', 'np', 'np choice', 'py', 'random.choice','np in transform', 'py in random'])
            print(df)

In [67]:
show(num_workers=1)

num_workers: 1
epoch 0
Batch 0
   idx    np  np choice     py  random.choice  np in transform  py in random
0    1   236          1  11266              2            15192         21671
1    0   906          2  11599              2            12895         22097
2    8  5057          1  17929              3            14225         24997
3    3  7752          2  15311              1            13462         22562
4    4  9395          1  11679              1            15374         21155
Batch 1
   idx    np  np choice     py  random.choice  np in transform  py in random
0    5  2963          1  12898              1            18444         29931
1    6  3563          1  11257              2            18093         26077
2    7  6543          3  11812              3            18151         21360
3    2  3050          2  16004              2            19719         26752
4    9  1889          3  13354              3            16285         26194
epoch 1
Batch 0
   idx    np  np choi

In [23]:
show(num_workers=2)

num_workers: 2
epoch 0
Batch 0
   idx    np  np choice     py  random.choice  np in transform  py in random
0    4  6399          3  14191              1            18848         24341
1    9  4116          2  15644              1            14743         26141
2    7  6367          1  13437              3            17860         24379
3    3  4943          1  19455              2            14707         28436
4    6  5683          2  15887              3            11056         23981
Batch 1
   idx    np  np choice     py  random.choice  np in transform  py in random
0    2  6399          3  19606              1            18848         29114
1    1  4116          2  18526              3            14743         26015
2    0  6367          1  18283              2            17860         24964
3    8  4943          1  14430              1            14707         20084
4    5  5683          2  13660              1            11056         21449
epoch 1
Batch 0
   idx    np  np choi

- 当 num_workers == 1 时，一个 epoch 中，不同 batch 出来的随机数是不一样的，但是每一个 epoch 出来的随机数是一样的
- 当 num_workers >1 时，每一个 epoch 中的每一个 batch 出来的随机数都是一样的

In [13]:
def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)
    
def correct_show(num_workers):
    print(f"num_workers: {num_workers}")
    ds = RandomDataset()
    ds = DataLoader(ds, 5, shuffle=True, num_workers=num_workers, worker_init_fn=worker_init_fn, pin_memory=True)

    for epoch in range(2):
        print("epoch {}".format(epoch))
        np.random.seed()
        for i, batch in enumerate(ds):
            print(f'Batch {i}')
            df = pd.DataFrame(batch.data.numpy(), columns=['idx', 'np', 'np choice', 'py', 'random.choice','np in transform', 'py in random'])
            print(df)

In [16]:
correct_show(1)

num_workers: 0
epoch 0
Batch 0
   idx    np  np choice     py  random.choice  np in transform  py in random
0    1  8042          3  10649              1            18327         20172
1    8  7622          1  14842              3            19989         29774
2    5  8045          3  15246              2            17955         26410
3    6  7560          3  15132              2            10421         21031
4    9  3042          2  11051              2            17559         29854
Batch 1
   idx    np  np choice     py  random.choice  np in transform  py in random
0    4  3289          1  17468              1            12840         24097
1    7  3859          1  13525              3            18430         28895
2    0  7048          3  17682              3            10116         25829
3    3  4985          1  14244              1            13607         28873
4    2  5371          2  13405              2            11273         23263
epoch 1
Batch 0
   idx    np  np choi

In [60]:
correct_show(4)

num_workers: 4
epoch 0
Batch 0
   idx    np  np choice     py  random.choice  np in transform  py in random
0    5  9385          3  12768              2            16205         29849
1    7  6114          2  15404              3            18732         21038
2    0  9287          2  13753              2            19262         24205
3    8  2603          1  15010              2            12931         25615
4    4  4883          2  12705              3            14169         28304
Batch 1
   idx    np  np choice     py  random.choice  np in transform  py in random
0    2  3639          3  15206              2            17072         22890
1    3   265          3  17331              1            17994         24616
2    1  6947          3  13220              2            11242         23144
3    6  7818          2  18245              3            14338         25104
4    9  5395          3  12724              1            16545         24369
epoch 1
Batch 0
   idx    np  np choi