测试 https://github.com/pytorch/pytorch/issues/5059 中提到的 numpy 在多 num_workers 下的表现

In [31]:
import numpy as np
import pandas as pd
import random
from torch.utils.data import DataLoader
import torch


class Transform(object):
    def __init__(self):
        pass

    def __call__(self, item = None):
        return [np.random.randint(10000, 20000), random.randint(20000,30000)]

class RandomDataset(object):

    def __init__(self):
        self.data = list(range(0, 10))
        pass

    def __getitem__(self, ind):
        item = [ind, np.random.randint(1, 10000), np.random.choice([1, 2, 3]), random.randint(10000, 20000), random.choice([1,2,3])]
        tsfm =Transform()(item)
        return np.array(item + tsfm)
    
    def __len__(self):
        return 10


def show(num_workers):
    print(f"num_workers: {num_workers}")
    ds = RandomDataset()
    torch.manual_seed(2)
    ds = DataLoader(ds, 5, shuffle=True, num_workers=num_workers)
    for epoch in range(2):
        print("epoch {}".format(epoch))
        
        for i, batch in enumerate(ds):
            print(f'Batch {i}')
            df = pd.DataFrame(batch.data.numpy(), columns=['idx', 'np', 'np choice', 'py', 'random.choice','np in transform', 'py in random'])
            print(df)

In [18]:
show(num_workers=1)

num_workers: 1
epoch 0
Batch 0
   idx    np  np choice     py  random.choice  np in transform  py in random
0    1  1997          3  11266              2            18214         21671
1    0   245          1  11599              2            17290         22097
2    8  7721          1  17929              3            10315         24997
3    3  8776          2  15311              1            18892         22562
4    4  1533          3  11679              1            12073         21155
Batch 1
   idx    np  np choice     py  random.choice  np in transform  py in random
0    5  7432          1  12898              1            18095         29931
1    6  8259          1  11257              2            13308         26077
2    7  5583          1  11812              3            14231         21360
3    2  4178          3  16004              2            15067         26752
4    9  7532          2  13354              3            19838         26194
epoch 1
Batch 0
   idx    np  np choi

In [17]:
show(num_workers=2)

num_workers: 2
epoch 0
Batch 0
   idx    np  np choice     py  random.choice  np in transform  py in random
0    1  1997          3  11266              2            18214         21671
1    0   245          1  11599              2            17290         22097
2    8  7721          1  17929              3            10315         24997
3    3  8776          2  15311              1            18892         22562
4    4  1533          3  11679              1            12073         21155
Batch 1
   idx    np  np choice     py  random.choice  np in transform  py in random
0    5  1997          3  12182              1            18214         24722
1    6   245          1  19078              3            17290         29963
2    7  7721          1  19968              1            10315         20393
3    2  8776          2  12542              2            18892         29081
4    9  1533          3  16491              3            12073         20872
epoch 1
Batch 0
   idx    np  np choi

- 当 num_workers == 1 时，一个 epoch 中，不同 batch 出来的随机数是不一样的，但是每一个 epoch 出来的随机数是一样的
- 当 num_workers >1 时，每一个 epoch 中的每一个 batch 出来的随机数都是一样的

In [32]:
def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)
    
def correct_show(num_workers):
    print(f"num_workers: {num_workers}")
    ds = RandomDataset()
    torch.manual_seed(2)
    ds = DataLoader(ds, 5, shuffle=True, num_workers=num_workers, worker_init_fn=worker_init_fn, pin_memory=True)
    for epoch in range(2):
        print("epoch {}".format(epoch))
        np.random.seed()
        for i, batch in enumerate(ds):
            print(f'Batch {i}')
            df = pd.DataFrame(batch.data.numpy(), columns=['idx', 'np', 'np choice', 'py', 'random.choice','np in transform', 'py in random'])
            print(df)

In [42]:
correct_show(1)

num_workers: 1
epoch 0
Batch 0
   idx    np  np choice     py  random.choice  np in transform  py in random
0    1  3646          1  11266              2            11342         21671
1    0  6080          1  11599              2            15887         22097
2    8  5933          3  17929              3            11302         24997
3    3  5299          1  15311              1            13479         22562
4    4  9998          3  11679              1            11202         21155
Batch 1
   idx    np  np choice     py  random.choice  np in transform  py in random
0    5   710          2  12898              1            10128         29931
1    6  6762          2  11257              2            11496         26077
2    7  6215          3  11812              3            19242         21360
3    2  4534          2  16004              2            14940         26752
4    9  5746          1  13354              3            14040         26194
epoch 1
Batch 0
   idx    np  np choi

In [60]:
correct_show(4)

num_workers: 4
epoch 0
Batch 0
   idx    np  np choice     py  random.choice  np in transform  py in random
0    5  9385          3  12768              2            16205         29849
1    7  6114          2  15404              3            18732         21038
2    0  9287          2  13753              2            19262         24205
3    8  2603          1  15010              2            12931         25615
4    4  4883          2  12705              3            14169         28304
Batch 1
   idx    np  np choice     py  random.choice  np in transform  py in random
0    2  3639          3  15206              2            17072         22890
1    3   265          3  17331              1            17994         24616
2    1  6947          3  13220              2            11242         23144
3    6  7818          2  18245              3            14338         25104
4    9  5395          3  12724              1            16545         24369
epoch 1
Batch 0
   idx    np  np choi