Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2 #83

Open
erikqu opened this issue Mar 31, 2024 · 0 comments

Comments

@erikqu
Copy link

erikqu commented Mar 31, 2024

It seems like all my batches have some underlying issue where they're all off by one, I've seen other issues opened about this, but no proper explanation, could I get some help on this?

Failed during forward The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2

Verified text and wavs are both the batch size (16), all wavs are padded in this case to 84480.

RuntimeError: The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2                                                                                                                                                                                
Failed during forward The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2                                                                                                                                                                        
The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2                                                                                                                                                                                              
torch.Size([16, 84480]) 16                                                                                                                  
Traceback (most recent call last):                                                                                                          
  File "/mnt/nvme/programs/qTTS/train_ttv_v1.py", line 184, in train_and_evaluate                                                           
    loss_gen_all = net_g(                                                                                                                   
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1523, in forward                                                                                                                                                                                 
    else self._run_ddp_forward(*inputs, **kwargs)                                                                                           
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward                                                                                                                                                                        
    return self.module(*inputs, **kwargs)  # type: ignore[index]                                                                            
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/audio_diffusion_pytorch/models.py", line 40, in forward                                                                                                                                                                                  
    return self.diffusion(*args, **kwargs)                                                                                                  
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/audio_diffusion_pytorch/diffusion.py", line 93, in forward                                                                                                                                                                               
    v_pred = self.net(x_noisy, sigmas, **kwargs)                                                                                            
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 63, in forward                                                      
    return forward_fn(*args, **kwargs)                                                                                                      
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 594, in forward                                                                                                                                                                                                  
    return net(x, features=features, **kwargs)                                                                                              
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl           
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                                                                                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 63, in forward                                                      
    return forward_fn(*args, **kwargs)                                                                                                      
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 621, in forward                                                     
    return net(x, embedding=text_embedding, **kwargs)                                                                                       
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 63, in forward                                                                                                                                                                                                   
    return forward_fn(*args, **kwargs)                                                                                                      
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 552, in forward                                                                                                                                                                                                  
    return net(x, embedding=embedding, **kwargs)                                                                                            
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/apex.py", line 431, in forward                                                                                                                                                                                                    
    return self.net(x, features, embedding, channels)  # type: ignore                                                                       
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                                    
    x = self.block(x, features, embedding, channels)                                                                                        
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 77, in forward                                                      
    x = block(x, *args)                                                                                                                     
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                              

Followed the example from the README:

  net_g = DiffusionModel(
      net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
      in_channels=1, # U-Net: number of input/output (audio) channels
      channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
      factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
      items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
      attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
      attention_heads=8, # U-Net: number of attention heads per attention item
      attention_features=64, # U-Net: number of attention features per attention item
      diffusion_t=VDiffusion, # The diffusion method used
      sampler_t=VSampler, # The diffusion sampler used
      use_text_conditioning=True, # U-Net: enables text conditioning (default T5-base)
      use_embedding_cfg=True, # U-Net: enables classifier free guidance
      embedding_max_length=64, # U-Net: text embedding maximum length (default for T5-base)
      embedding_features=768, # U-Net: text mbedding features (default for T5-base)
      cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer
  )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant