测试 https://github.com/pytorch/pytorch/issues/5059 中提到的 numpy 在多 num_workers 下的表现

In [53]:
import numpy as np
import pandas as pd
import random
from torch.utils.data import DataLoader

class Transform(object):
    def __init__(self):
        pass

    def __call__(self, item = None):
        return [np.random.randint(10000, 20000), random.randint(20000,30000)]

class RandomDataset(object):

    def __init__(self):
        pass

    def __getitem__(self, ind):
        item = [ind, np.random.randint(1, 10000), np.random.choice([1, 2, 3]), random.randint(10000, 20000), random.choice([1,2,3])]
        tsfm =Transform()(item)
        return np.array(item + tsfm)
    def __len__(self):
        return 10


def show(num_workers):
    print(f"num_workers: {num_workers}")
    ds = RandomDataset()
    ds = DataLoader(ds, 5, shuffle=True, num_workers=num_workers)
    for epoch in range(2):
        print("epoch {}".format(epoch))
        for i, batch in enumerate(ds):
            print(f'Batch {i}')
            df = pd.DataFrame(batch.data.numpy(), columns=['idx', 'np', 'np choice', 'py', 'random.choice','np in transform', 'py in random'])
            print(df)

In [54]:
show(num_workers=1)

num_workers: 1
epoch 0
Batch 0
   idx    np  np choice     py  random.choice  np in transform  py in random
0    7  9175          1  16283              2            13638         26768
1    2  2651          3  10798              2            15572         28957
2    6  9679          3  13213              2            18532         29854
3    3  9015          1  19733              3            13714         24339
4    8  2908          1  11373              3            11637         26764
Batch 1
   idx    np  np choice     py  random.choice  np in transform  py in random
0    0  5858          3  13115              1            15378         24542
1    1  1075          2  19764              2            16264         26791
2    9  1992          1  14330              3            18887         22242
3    5  8411          2  11212              1            15608         24245
4    4  7574          2  10514              3            10982         21878
epoch 1
Batch 0
   idx    np  np choi

In [55]:
show(num_workers=2)

num_workers: 2
epoch 0
Batch 0
   idx    np  np choice     py  random.choice  np in transform  py in random
0    2  9175          1  18896              3            13638         25388
1    8  2651          3  16625              3            15572         29109
2    0  9679          3  14599              3            18532         25992
3    5  9015          1  14301              1            13714         21669
4    3  2908          1  17056              2            11637         26344
Batch 1
   idx    np  np choice     py  random.choice  np in transform  py in random
0    6  9175          1  18179              2            13638         25369
1    9  2651          3  17896              2            15572         22372
2    7  9679          3  16729              2            18532         25718
3    1  9015          1  17658              2            13714         27335
4    4  2908          1  17633              2            11637         27306
epoch 1
Batch 0
   idx    np  np choi

- 当 num_workers == 1 时，一个 epoch 中，不同 batch 出来的随机数是不一样的，但是每一个 epoch 出来的随机数是一样的
- 当 num_workers >1 时，每一个 epoch 中的每一个 batch 出来的随机数都是一样的

In [62]:
def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

def correct_show(num_workers):
    print(f"num_workers: {num_workers}")
    ds = RandomDataset()
    ds = DataLoader(ds, 5, shuffle=False, num_workers=num_workers, worker_init_fn=worker_init_fn)

    for epoch in range(2):
        print("epoch {}".format(epoch))
        np.random.seed()
        for i, batch in enumerate(ds):
            print(f'Batch {i}')
            df = pd.DataFrame(batch.data.numpy(), columns=['idx', 'np', 'np choice', 'py', 'random.choice','np in transform', 'py in random'])
            print(df)

In [60]:
correct_show(1)

num_workers: 1
epoch 0
Batch 0
   idx    np  np choice     py  random.choice  np in transform  py in random
0    0  1235          2  14038              3            12585         26387
1    1  9441          3  18036              1            12344         29636
2    2  5774          2  10963              3            11575         23882
3    3    69          1  17130              1            12692         24516
4    4  2966          3  12893              1            15722         23180
Batch 1
   idx    np  np choice     py  random.choice  np in transform  py in random
0    5  5015          1  19762              2            14940         24579
1    6  7538          1  17960              2            12027         26896
2    7  7012          2  18341              2            17019         21691
3    8  5564          2  14875              1            14288         29571
4    9  5259          2  12681              3            12178         27491
epoch 1
Batch 0
   idx    np  np choi

In [63]:
correct_show(4)

num_workers: 4
epoch 0
Batch 0
   idx    np  np choice     py  random.choice  np in transform  py in random
0    0   671          1  18971              2            16543         27174
1    1  9181          3  16275              3            16969         24915
2    2  3242          3  17358              3            17795         20894
3    3  5899          3  10655              1            19936         23698
4    4  2391          2  12706              3            14225         24237
Batch 1
   idx    np  np choice     py  random.choice  np in transform  py in random
0    5  4294          2  15085              1            14995         22693
1    6  5885          2  19017              1            15147         29581
2    7  9285          2  17500              2            14946         21699
3    8  9263          3  11823              1            14660         27230
4    9   783          1  13118              1            14107         20160
epoch 1
Batch 0
   idx    np  np choi