You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importnumpyasnpimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFfromtorch.utils.dataimportDataset, BatchSampler, DataLoaderBATCH_NUM=20BATCH_SIZE=16EPOCH_NUM=4IMAGE_SIZE=784CLASS_NUM=10# define a random datasetclassRandomDataset(Dataset):
def__init__(self, num_samples):
self.num_samples=num_samplesdef__getitem__(self, idx):
# image = np.random.random([IMAGE_SIZE]).astype('float32')# label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')image="This is a string format data"label="Also a string format data"returnimage, labeldef__len__(self):
returnself.num_samplesdataset=RandomDataset(BATCH_NUM*BATCH_SIZE)
classSimpleNet(nn.Module):
def__init__(self):
super().__init__()
self.fc=nn.Linear(IMAGE_SIZE, CLASS_NUM)
defforward(self, image, label=None):
returnself.fc(image)
simple_net=SimpleNet()
opt=torch.optim.SGD(lr=1e-3,params=simple_net.parameters())
loader=DataLoader(dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=2)
foreinrange(EPOCH_NUM):
importpdbpdb.set_trace()
fori, (image, label) inenumerate(loader):
out=simple_net(image)
loss=F.cross_entropy(out, label)
avg_loss=paddle.mean(loss)
avg_loss.backward()
opt.minimize(avg_loss)
simple_net.clear_gradients()
print("Epoch {} batch {}: loss = {}".format(e, i, np.mean(loss.numpy())))
输出:
(exp) coco@coco:~$ pythontorch_mytest.py>/home/coco/torch_mytest.py(52)<module>()
->fori, (image, label) inenumerate(loader):
(Pdb) n>/home/coco/torch_mytest.py(53)<module>()
->out=simple_net(image)
(Pdb) l4849foreinrange(EPOCH_NUM):
50importpdb51pdb.set_trace()
52fori, (image, label) inenumerate(loader):
53->out=simple_net(image)
54loss=F.cross_entropy(out, label)
55avg_loss=paddle.mean(loss)
56avg_loss.backward()
57opt.minimize(avg_loss)
58simple_net.clear_gradients()
(Pdb) image
('This is a string format data', 'This is a string format data', 'This is a string format data', 'This is a string format data', 'This is a string format data', 'This is a string format data', 'This is a string format data', 'This is a string format data', 'This is a string format data', 'This is a string format data', 'This is a string format data', 'This is a string format data', 'This is a string format data', 'This is a string format data', 'This is a string format data', 'This is a string format data')
(Pdb) label
('Also a string format data', 'Also a string format data', 'Also a string format data', 'Also a string format data', 'Also a string format data', 'Also a string format data', 'Also a string format data', 'Also a string format data', 'Also a string format data', 'Also a string format data', 'Also a string format data', 'Also a string format data', 'Also a string format data', 'Also a string format data', 'Also a string format data', 'Also a string format data')
(Pdb)
Pytorch可以正常load进数据。
上面的复现过程中仅针对数据load的部分,不考虑后续计算
但是如果把数据部分改成这样:
def__getitem__(self, idx):
# image = np.random.random([IMAGE_SIZE]).astype('float32')label=np.random.randint(0, CLASS_NUM-1, (1, )).astype('int64')
image="This is a string format data"# label = "Also a string format data"returnimage, label
bug描述 Describe the Bug
Paddle的 Dataloader 遇到纯文本数据时报错 StopIteration
看上去Paddle的Dataloader并不支持纯str类型的数据,而Pytorch则不然
下面是Dataloader文档中的样例代码,后面代码都是基于这份代码来做复现。
当dataset中的数据不是tensor,而是str类型时,无法成功load进数据
在内层 for 去读取 dataloader 的数据的时候,报错:
单纯把上面 paddle 的例子中相关api换成 torch 中的 api
输出:
Pytorch可以正常load进数据。
但是如果把数据部分改成这样:
也就是有文本,也有ndarray,又可以正常读到数据。
所以看上去Paddle的dataloader遇到纯str类型数据会有有问题,而且输出仅仅是一个
StopIteration
,难定位问题,对用户不友好。感觉会是这里的原因吗?:
Paddle/python/paddle/io/dataloader/dataloader_iter.py
Lines 219 to 277 in 0359685
中的
_flatten_batch
函数里面:Paddle/python/paddle/io/dataloader/flat.py
Lines 37 to 43 in 0359685
这里用了自己递归的方式,对最内层数据
(np.ndarray, paddle.Tensor, paddle.base.core.eager.Tensor)
这些类型提取并存在了flat_batch
,也就是说,遇到例如纯str类型的数据时,而structure
则保存了tensor和ndarray之外的信息以及给tensor和ndarray的值留了占位符,但是此时为纯文本,导致flat_batch
一直都会为空,不知道是不是这里有影响emm在一些LLM模型中,它可能在Pytorch实现过程中没有很规范地去写processor或者collator(尤其是tokenize的过程),直接把这些模块放在了组网的过程中,这就意味着dataloader迭代得到数据会是纯文本,这就导致Paddle的dataloader出现上述问题,Paddle中必须写collator把纯文本做一下tokenize,让input_ids、attention_mask等最内层都是tensor或者np array,这样才能顺利被迭代读取。
但是至于为什么会抛出
StopIteration
,我还是比较困惑,希望后续如果可以成功复现,定位问题后可以在这个issue下告知呀~非常感谢!!!其他补充信息 Additional Supplementary Information
No response
The text was updated successfully, but these errors were encountered: