Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace lambdas with partials to allow multi-gpu training #160

Closed
wants to merge 2 commits into from
Closed

Replace lambdas with partials to allow multi-gpu training #160

wants to merge 2 commits into from

Conversation

BramVanroy
Copy link
Contributor

Currently, the CometModel has lambda functions defined. These cannot be pickled and therefore multiprocessing (~multi-GPU training) is not possible.

This PR replaces the lambda functions with partials, that should be picklable.

closes #159

@BramVanroy BramVanroy marked this pull request as draft August 8, 2023 11:48
@BramVanroy
Copy link
Contributor Author

BramVanroy commented Aug 8, 2023

This PR is currently not ready for integration. I have done some work on making the metric RegressionMetrics(Metric) (not the model) compatible with distributed computing and this seems to work for training. However, I can't get this to work with prediction.
As far as I can tell, the issue is that the metrics are not gathered correctly over the different processes. So in this piece of code, RegressionMetrics should get dist_sync_on_step=True but only in distributed settings.

def init_metrics(self):
"""Initializes train/validation metrics."""
self.train_metrics = RegressionMetrics(prefix="train")
self.val_metrics = nn.ModuleList(
[RegressionMetrics(prefix=d) for d in self.hparams.validation_data]
)

PyTorch Lightning has so many hoops (subclasses, properties) to jump through that I lost my patience to figure out how we can do something if multi_gpu or if distributed. For someone who knows PyTorch Lightning well, this is perhaps an easy fix so feel free to chime in.

A test case for distributed scenario has also been added. Note that for dev'ing this, I updated torchmetrics to the newest version to be sure to avoid underlying issues.

@BramVanroy BramVanroy mentioned this pull request Aug 8, 2023
@ricardorei ricardorei closed this Sep 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi-GPU training
2 participants