/
02_test_data.py
42 lines (33 loc) · 1.41 KB
/
02_test_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
import sys
import torch
import paddle
import numpy as np
from PIL import Image
from reprod_log import ReprodLogger, ReprodDiffHelper
from utilities import build_paddle_data_pipeline, build_torch_data_pipeline
def test_data_pipeline():
paddle_dataset, paddle_dataloader = build_paddle_data_pipeline()
torch_dataset, torch_dataloader = build_torch_data_pipeline()
logger_paddle_data = ReprodLogger()
logger_torch_data = ReprodLogger()
logger_paddle_data.add("length", np.array(len(paddle_dataset)))
logger_torch_data.add("length", np.array(len(torch_dataset)))
for idx, (paddle_batch, torch_batch
) in enumerate(zip(paddle_dataloader, torch_dataloader)):
if idx >= 5:
break
logger_paddle_data.add(f"dataloader_{idx}", paddle_batch[0].numpy())
logger_torch_data.add(f"dataloader_{idx}",
torch_batch[0].detach().cpu().numpy())
logger_paddle_data.save("./result/data_paddle.npy")
logger_torch_data.save("./result/data_ref.npy")
if __name__ == "__main__":
test_data_pipeline()
# load data
diff_helper = ReprodDiffHelper()
torch_info = diff_helper.load_info("./result/data_ref.npy")
paddle_info = diff_helper.load_info("./result/data_paddle.npy")
# compare result and produce log
diff_helper.compare_info(torch_info, paddle_info)
diff_helper.report(path="./result/log/data_diff.log")