Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the versions of PyTorch, CUDA, and other dependencies used in the implementation of the Monarch Mixer #25

Open
yingxuanhi opened this issue Feb 20, 2024 · 20 comments

Comments

@yingxuanhi
Copy link

Hello, I'm interested in applying the MLP layer of your Monarch Mixer in my research.
I'm unsure about the versions of packages used in the implementation of Monarch Mixer, including PyTorch, CUDA, torchvision, cudatoolkit, and others.
Could you please provide information on the versions of these packages used? Thank you very much for your assistance!

@DanFu09
Copy link
Collaborator

DanFu09 commented Feb 26, 2024

I recommend using the NVIDIA PyTorch Docker containers: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch

We ran our experiments on version 23.05: https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html

PyTorch 2.0.0, CUDA 12.1.1.

However, the MLP layers in Monarch Mixer are not tied to any particular version of PyTorch or CUDA - the code is vanilla PyTorch and should work with any modern version.

@yingxuanhi
Copy link
Author

Hello, if I run your experiment directly in the Ubuntu terminal without using Nvidia PyTorch Docker containers, can I reproduce the experiment? Also, if I only want to use the Monarch mixer's MLP layer without using the provided FlashFFTConv, meaning without executing the following two lines of code:

pip install git+https://github.com/HazyResearch/flash-fft-conv.git#subdirectory=csrc/flashfftconv
pip install git+https://github.com/HazyResearch/flash-fft-conv.git

Can I still use the Monarch mixer MLP layer as usual? I apologize for taking up your valuable time!

@DanFu09
Copy link
Collaborator

DanFu09 commented Feb 27, 2024 via email

@yingxuanhi
Copy link
Author

Hello, I've installed the relevant packages and am ready to start experimenting. I have a question regarding the Monarch Mixer MLP layer that I referenced in the following code:

python
Copy code
import torch
from torch import nn
from src.mm.blockdiag_linear import BlockdiagLinear

class M2MLP(nn.Module):
"""Applies the MLP."""

def __init__(self, config):
    super().__init__()
    self.config = config

    if self.config.use_monarch_mlp:
        linear_cls = partial(BlockdiagLinear, nblocks=self.config.monarch_mlp_nblocks)
    else:
        linear_cls = nn.Linear

    self.linear = linear_cls(config.hidden_size,
                                  config.intermediate_size,
                                  bias=False)
    self.act = nn.GELU(approximate='none')
    self.wo = linear_cls(config.intermediate_size, config.hidden_size)

    self.layernorm = nn.LayerNorm(config.hidden_size,
                                  eps=config.layer_norm_eps)
    
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    """Compute new hidden states from current hidden states.

    Args:
        hidden_states (torch.Tensor): The (unpadded) hidden states from
            the attention layer [nnz, dim].
    """
    
    residual_connection = hidden_states
    hidden_states = self.linear(hidden_states)
    hidden_states = self.act(hidden_states)
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.wo(hidden_states)
    hidden_states = self.layernorm(hidden_states + residual_connection)
    return hidden_states

My input dimension dim (hidden_size) is 48, and the ffn_expansion_factor is 2.66 (normally 4). The intermediate_size is calculated as hidden_size * 4, so in my case, intermediate_size is hidden_size * 2.66. My question is, how should I set the nblocks in this code? Or, based on my input, how should I configure this MLP layer correctly? I apologize for the inconvenience and appreciate your assistance!

@DanFu09
Copy link
Collaborator

DanFu09 commented Feb 27, 2024 via email

@yingxuanhi
Copy link
Author

Hello, I have modified the MLP layer you provided as follows:

python
Copy code
class FeedForward(nn.Module):
"""Applies the MLP."""

def __init__(self, dim, ffn_expansion_factor, bias):
    super().__init__()

    hidden_size = int(dim * 4)

    linear_cls = partial(BlockdiagLinear, nblocks=4)

    self.linear = linear_cls(dim, hidden_size, bias=bias)
    self.act = nn.GELU(approximate='none')
    self.wo = linear_cls(hidden_size, dim, bias=bias)

def forward(self, x):
    x = self.linear(x)
    x = self.act(x)
    x = self.wo(x)
    return x

When compiling, I encountered the following error, and I have extracted relevant parts of the error message:

File "/home/dcmc/Data/kyx111m/DRSformer/basicsr/models/archs/DRSformer_arch.py", line 193, in forward
x = self.linear(x)
File "/home/dcmc/anaconda3/envs/kyxm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/dcmc/Data/kyx111m/DRSformer/src/mm/structured_linear.py", line 70, in forward
output = self.forward_matmul(x)
File "/home/dcmc/Data/kyx111m/DRSformer/src/mm/blockdiag_linear.py", line 49, in forward_matmul
output = blockdiag_multiply(x, self.weight)
File "/home/dcmc/anaconda3/envs/kyxm2/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 105, in decorate_fwd
return fwd(*args, **kwargs)
File "/home/dcmc/Data/kyx111m/DRSformer/src/mm/blockdiag_multiply.py", line 64, in forward
assert nblocks * p == n
AssertionError

The input dimension dim is set to 48, and nblocks is set to 4. However, n is 128, and p is 12.

I'm sorry to bother you, but may I ask what could be the reason for this? Thank you!

@DanFu09
Copy link
Collaborator

DanFu09 commented Feb 27, 2024

Can you print out the shape of the weight parameter of linear?

@yingxuanhi
Copy link
Author

Something like that?

dim = 48
hidden_size = dim * 4 = 192

self.linear = linear_cls(dim, hidden_size)

self.act = nn.GELU(approximate='none')

self.wo = linear_cls(hidden_size, dim)

@DanFu09
Copy link
Collaborator

DanFu09 commented Feb 27, 2024 via email

@yingxuanhi
Copy link
Author

Excuse me, this is my shape of linear.weight :torch.Size([4, 48, 12])

@DanFu09
Copy link
Collaborator

DanFu09 commented Feb 28, 2024 via email

@yingxuanhi
Copy link
Author

yingxuanhi commented Feb 28, 2024

Hello, my deraining model architecture is based on an autoencoder. Attached is my model structure. Due to down-sampling and up-sampling at each level, the input dimensions vary. I would like to inquire if this has any impact.
model.txt

Before I started using the Monarch Mixer MLP Layer, I was using a regular FFN network, and I could train the model without any issues.

@DanFu09
Copy link
Collaborator

DanFu09 commented Feb 28, 2024 via email

@yingxuanhi
Copy link
Author

Hello, in my situation, I believe that at least the linear layer of the first level should compile successfully without errors. My linear layer has input_feature = dim = 48, and hidden_feature is set to default as input_feature * 2.66 (I have also tried input_feature * 4 without success). Regarding my bug, I would like to inquire about the meanings and relationships represented by nblocks, p, and n in the error message. What do they individually stand for and how are they related?

@DanFu09
Copy link
Collaborator

DanFu09 commented Feb 28, 2024 via email

@yingxuanhi
Copy link
Author

Here is the relevant information about the provided tensor x and some parameters:

x.shape: torch.Size([2, 48, 128, 128])
Shape of the weight matrix, linear.size: torch.Size([4, 32, 12])
Shape of the weight matrix, wo.size: torch.Size([4, 12, 32])
x.shape: torch.Size([2, 48, 128, 128])
weight.shape: torch.Size([4, 32, 12])
batch_shape: torch.Size([2, 48, 128])
batch_dim: 12288
weight.shape: torch.Size([4, 32, 12])
n: 128
p: 12
nblocks * p: 48

Here, linear and wo represent my first and second Monarch Mixer MLP Layer, respectively. The question I would like to ask is, for the task of removing rain patterns from a single image, where my x has a shape of ([2, 48, 128, 128]) corresponding to ([Batch size, dim , length, width]), should I modify the following code snippet:

def forward(ctx, x, weight):
ctx.save_for_backward(x, weight)
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = np.prod(batch_shape)
nblocks, q, p = weight.shape
assert nblocks * p == n

Should I take the second dimension of x (dim = 48) instead of the last dimension (128) ? Or do you have any recommendations for a better approach?

@DanFu09
Copy link
Collaborator

DanFu09 commented Feb 29, 2024 via email

@yingxuanhi
Copy link
Author

yingxuanhi commented Mar 2, 2024

Hello, thanks to your assistance, I have been able to train the model successfully!

Here are the questions I would like to ask: My single image deraining model originally used a depth-wise convolutions feed-forward network.
If I replace it with a Monarch Mixer MLP layer, what are the potential reasons for a decrease in performance?

Additionally, I would like to inquire about the impact of batch size on the Monarch Mixer MLP layer.

I apologize and appreciate your help!

@DanFu09
Copy link
Collaborator

DanFu09 commented Mar 3, 2024

Great!

A couple things that could be happening here:

  • A depthwise convolution only mixes along the sequence (or H/W) dimensions, and not over the channels in the image. For something like (B, H, L) it will process the B and H dimensions completely independently of each other, and only interact between the values in the L's.
  • An MLP only mixes along the hidden dimension. So for (B, L, H), it processes the B and L dimensions completely independently, and mixes information along the H dimension.

If you've completely replaced your depthwise convolutions with Monarch MLP layers, then you may not be mixing along the sequence dimensions anymore. This is like trying to predict an image by processing each pixel on its own without the context of what is around it. You'll need some mix of mixing along both sequence and hidden dimensions.

I haven't seen much impact of batch size on the MLP layer.

@yingxuanhi
Copy link
Author

Hello, I will further study what you mentioned above! Additionally,

I would like to inquire about the class BertGatedLinearUnitMLP in m2/bert/src/bert_layers.py.

What are the distinctive features of this class compared to the class BertMLP?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants