Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StopIteration Exception in Conformer Encoder because of next(self.parameters()) #4430

Closed
grazder opened this issue Jun 23, 2022 · 4 comments
Closed
Assignees

Comments

@grazder
Copy link

grazder commented Jun 23, 2022

Hi, I get StopIteration exception here:

device = next(self.parameters()).device

I call this method from forward_for_export here:

self.set_max_audio_length(max_audio_length)

I found that self.parameters() is zero length. Found this happens when I use multiple GPUs in DP mode. When I run the code on one GPU I get self.parameters() > 0.

Issue about it: pytorch/pytorch#40457

Maybe it would be better if the device was passed into the method, like so:

def set_max_audio_length(self, max_audio_length, device='cpu'):
    ....

Because it's strange to get such problems when getting a device, when it can be easily passed as a parameter.

@grazder grazder changed the title StopIteration Exception in Conformer Encoder because of next(self.parameters()) StopIteration Exception in Conformer Encoder because of next(self.parameters()) Jun 23, 2022
@ericharper ericharper assigned ericharper and VahidooX and unassigned ericharper Jun 23, 2022
@ericharper
Copy link
Collaborator

@grazder could you give an example script of what you're trying to do? Also, we typically use DDP for multi-gpu.

@titu1994
Copy link
Collaborator

titu1994 commented Jun 23, 2022

According to pytorch ligntinings design, it is not advised to pass around device parameter. It also invites more errors due to tensors created on device not belonging to correct device of parameters under DDP.

It is to be noted that DDP is the only recommended way to perform distributed computing in pytorch, and Nemo ASR does not support DP at all (most of our classes are unpickleable)

@grazder
Copy link
Author

grazder commented Jun 24, 2022

Here is repro script:

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class RandomDataset(Dataset):
    def __init__(self, length):
        self.len = length
        self.data = [torch.randn(64, 256) for _ in range(length)]
        self.lens = [torch.randint(256, (1, )).squeeze().long() 
                     for _ in range(length)]

    def __getitem__(self, index):
        return self.data[index], self.lens[index]

    def __len__(self):
        return self.len


rand_loader = DataLoader(dataset=RandomDataset(8),
                         batch_size=8, shuffle=True)

# pos_emb_max_len = 32 < seq_len = 256
model = ConformerEncoder(feat_in=64, n_layers=2, d_model=512, 
                         pos_emb_max_len=32)

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

model.to(device)

for data in rand_loader:
    x, xlen = data
    output = model(audio_signal=x.to(device), length=xlen.to(device))
    print("Outside: input size", input.size(),
          "output_size", output.size())

Output Exception:

File "/home/user/.local/lib/python3.9/site-packages/nemo/collections/asr/modules/conformer_encoder.py", line 230, in set_max_audio_length
    device = next(self.parameters()).device
StopIteration

torch.__version__ - 1.10.1

@titu1994
Copy link
Collaborator

titu1994 commented Jul 3, 2022

Yes, as stated we don't encourage use of DP. Please use DDP.

@titu1994 titu1994 closed this as completed Jul 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants