Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model weights are being reset to zero on Windows #19537

Closed
BramVanroy opened this issue Feb 27, 2024 · 8 comments
Closed

Model weights are being reset to zero on Windows #19537

BramVanroy opened this issue Feb 27, 2024 · 8 comments
Assignees
Labels
3rd party Related to a 3rd-party bug Something isn't working data handling Generic data-related topic

Comments

@BramVanroy
Copy link

BramVanroy commented Feb 27, 2024

Bug description

In the highly popular evaluation metric for machine translation, COMET, an issue has been raised where on Windows the predictions are always zero. I can confirm this with a new env installation of the library (pip install unbabel-comet) on Windows and this snippet:

from comet import download_model, load_from_checkpoint

def main():
    # Choose your model from Hugging Face Hub
    model_path = download_model("Unbabel/wmt22-comet-da")

    # Load the model checkpoint:
    model = load_from_checkpoint(model_path)

    # Data must be in the following format:
    data = [
        {
            "src": "10 到 15 分钟可以送到吗",
            "mt": "Can I receive my food in 10 to 15 minutes?",
            "ref": "Can it be delivered between 10 to 15 minutes?"
        },
        {
            "src": "Pode ser entregue dentro de 10 a 15 minutos?",
            "mt": "Can you send it for 10 to 15 minutes?",
            "ref": "Can it be delivered between 10 to 15 minutes?"
        }
    ]
    # Call predict method:
    model_output = model.predict(data, batch_size=8, gpus=1)
    print(model_output)
    print(model_output.scores) # sentence-level scores
    print(model_output.system_score) # system-level score


# Not all COMET models return metadata with detected errors.
if __name__ == '__main__':
    main()

The output will be be all zeroes:

Prediction([('scores', [0.0, 0.0]), ('system_score', 0.0)])
[0.0, 0.0]
0.0

I also get the following PyTorch warning but I am not sure if it is relevant:

[W CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

When trying to debug this, I went down a LONG rabbit hole and found that during the prediction loop, the model's weights seem to get set to zero. I do not fully understand how exactly (no time to look into this further atm), but it occurs here when calling iter on the data fetcher:

combined_loader.limits = self.max_batches
data_fetcher.setup(combined_loader)
iter(data_fetcher) # creates the iterator inside the fetcher

To verify, replace those lines with this

# set the per-dataloader limits
combined_loader.limits = self.max_batches
data_fetcher.setup(combined_loader)


print("Before iter")
for name, param in self.trainer.model.named_parameters():
    if "layer.23.output.dense.weight" in name:
        print(param)

iter(data_fetcher)  # creates the iterator inside the fetcher

print("After iter")
for name, param in self.trainer.model.named_parameters():
    if "layer.23.output.dense.weight" in name:
        print(param)

And execute the script above. If you replace those lines with this to check the weights of the final layer, you'll see that the first print contains the normal weights, but the second print after data_fetcher.setup gives all zeroes.

I am stumped as to why this would happen here when setting up a data fetcher, but PL has so many moving parts that are interconnected that it is very hard for me to debug this further. Again, this only seems to be a reported issue on Windows.

Running on Windows 10, Python 3.10, PL 2.2.0.post0.

cc @justusschock @awaelchli

@BramVanroy BramVanroy added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Feb 27, 2024
@awaelchli awaelchli self-assigned this Feb 28, 2024
@awaelchli awaelchli removed the needs triage Waiting to be triaged by maintainers label Feb 28, 2024
@awaelchli
Copy link
Member

@BramVanroy Thanks for reporting. I am not familiar with the COMET library, but if I can reproduce it (hopefully don't need Windows) I can probably help debug this.

Is pip installing and copying this code snippet all I need to repro this?

@BramVanroy
Copy link
Author

Yes, pip install and the code snippet should be enough but I don't think this occurs on non-Windows so I fear you won't be able to reproduce on other OS. (Which might explain why it has flown under the radar.)

@awaelchli
Copy link
Member

awaelchli commented Feb 28, 2024

Thanks. Could you also share the PyTorch version that was installed when you ran into this issue?

@BramVanroy
Copy link
Author

Sure, I ran it with torch==2.2.1+cu118.

@awaelchli
Copy link
Member

Ok no luck trying to reproduce on Linux with

python repro.py
git checkout tags/2.2.0.post0
pip install unbabel-comet
pip install -e .
pip install torch torchvision  --index-url https://download.pytorch.org/whl/cu118
conda activate comet
conda create -n comet python=3.10

I get

Predicting DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  4.09it/s]
Prediction([('scores', [0.8417138457298279, 0.7745391130447388]), ('system_score', 0.8081264793872833)])
[0.8417138457298279, 0.7745391130447388]
0.8081264793872833

It will be day or two before I can check on a Windows machine.

@awaelchli
Copy link
Member

Hey @BramVanroy

I confirm your observations on Windows.
The issue is that COMET uses the model inside the dataloader workers through the collate function here:
https://github.com/Unbabel/COMET/blob/fd3c2d9f72b69ed9035cf778f76721f6996efb35/comet/models/base.py#L604
If we follow the trail:
https://github.com/Unbabel/COMET/blob/fd3c2d9f72b69ed9035cf778f76721f6996efb35/comet/models/base.py#L536
and then here:
https://github.com/Unbabel/COMET/blob/fd3c2d9f72b69ed9035cf778f76721f6996efb35/comet/models/regression/regression_metric.py#L181

How the dataloader workers get created (which is when we call iter() on the dataloader for the first time), is different between Windows and Linux:

  • On Linux, it uses the "fork" method to create processes. This creates an identical view of the main process memory and your loaded weights are accessible in the worker processes.
  • On Windows, "fork" is not supported. PyTorch DataLoaders will use the "spawn" method instead, which essentially pickles and unpickles all objects from your main process into the subprocesses. But this appears to be creating random weights instead of taking the loaded weights from the main process.

There is a small section about this in the PyTorch DataLoader docs:

On Windows or MacOS, spawn() is the default multiprocessing start method. Using spawn(), another interpreter is launched which runs your main script, followed by the internal worker function that receives the dataset, collate_fn and other arguments through pickle serialization.

Since this is a limitation of the operating systems and a PyTorch DataLoader implementation detail, Lightning can't really do anything here. This would happen if the code was written in plain PyTorch.

A possible workaround for you is to set

num_workers=0

in this line here:
https://github.com/Unbabel/COMET/blob/fd3c2d9f72b69ed9035cf778f76721f6996efb35/comet/models/base.py#L605

It would be slower in general but in your specific code example here not noticable.

For COMET in general, if they want to support prediction on Windows, maybe the code could be changed to either not run the model in the collate function, or to load the model checkpoint inside the dataloader workers (inside the collate). Or maybe there is a memory-sharing trick to avoid this problem, I don't know.

I hope this is somewhat clearer now and I hope the workaround will be useful to you. I'm closing the issue because there is no action that can be taken in PyTorch Lightning at this moment (that I am aware of).

@awaelchli awaelchli closed this as not planned Won't fix, can't repro, duplicate, stale Mar 2, 2024
@awaelchli awaelchli added 3rd party Related to a 3rd-party data handling Generic data-related topic labels Mar 2, 2024
@BramVanroy
Copy link
Author

Thanks a lot @awaelchli! I would swear that I tested that, but my memory does not serve me well. Thanks a lot for going further down the rabbit hole. The linked PR should hopefully fix the default behavior on Windows.

Thanks again!

@awaelchli
Copy link
Member

Great, no problem! The PR you created looks very good!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working data handling Generic data-related topic
Projects
None yet
Development

No branches or pull requests

2 participants