Skip to content

Conversation

@oscarkey
Copy link
Contributor

Currently we set the global pytorch random seed. When multi-device inference is enabled, this can result in a race condition because the global seed is shared between threads. Use a Generator instead to avoid this.

Testing:

  • Save the sampled embeddings from the old and the new versions and check they are equal.
  • Existing consistency checks, plus the additional two I have added to verify multi-device inference.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a potential race condition in multi-device inference by replacing the global PyTorch random seed with a thread-safe torch.Generator object. The changes are well-implemented and the accompanying tests for multi-device consistency are a great addition. I've found one minor issue: a leftover debugging statement that should be removed.

Currently we set the global pytorch random seed. When multi-device
inference is enabled, this can result in a race condition because
the global seed is shared between threads. Use a Generator instead to
avoid this.

Exporting the generator is not supported by onnx, so don't use it during
tracing. This means that the random embeddings will be different to
those used during training, which may effect the performance of the
model. However, the previous method for fixing the seed was silently
ignored by onnx export, so it doesn't make things worse.

Testing:
- Save the sampled embeddings from the old and the new versions and
  check they are equal.
- Existing consistency checks, plus the additional two I have added to
  verify multi-device inference.
@oscarkey
Copy link
Contributor Author

oscarkey commented Oct 8, 2025

hey @LeoGrin , I updated the implementation a bit to disable the seed during onnx export. This fixes the tests, and adds a warning that the seed won't be set correctly. What do you think?

Copy link
Collaborator

@LeoGrin LeoGrin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just one question to check I'm understanding things correctly: this should give use the exact same random output as before the change right? (I think yes because the consistency tests are passing?)

@oscarkey
Copy link
Contributor Author

oscarkey commented Oct 9, 2025

yes, and I saved + compared the actual embedding tensors before and after the change as well. But I'll do that one more time before merging just in case haha

@oscarkey
Copy link
Contributor Author

oscarkey commented Oct 9, 2025

I did the following on macos mps, cpu, cuda:

  • checked that the value of the embeddings here was equal before and after the change, when performing classification on the iris dataset
  • used the script below
from contextlib import contextmanager
from typing import Generator
import torch

device = torch.device("cuda")


@contextmanager
def isolate_torch_rng(seed: int, device: torch.device) -> Generator[None, None, None]:
    torch_rng_state = torch.get_rng_state()
    if torch.cuda.is_available():
        torch_cuda_rng_state = torch.cuda.get_rng_state(device=device)
    torch.manual_seed(seed)
    try:
        yield
    finally:
        torch.set_rng_state(torch_rng_state)
        if torch.cuda.is_available():
            torch.cuda.set_rng_state(torch_cuda_rng_state, device=device)


def global_rng():
    with isolate_torch_rng(seed=0, device=device):
        a = torch.randn(10_000, 10, device=device)
    with isolate_torch_rng(seed=0, device=device):
        b = torch.randn(10_000, 10, device=device)
    with isolate_torch_rng(seed=1, device=device):
        c = torch.randn(10_000, 10, device=device)
    return a, b, c


def generator_rng():
    a = torch.randn(
        10_000,
        10,
        generator=torch.Generator(device=device).manual_seed(0),
        device=device,
    )
    b = torch.randn(
        10_000,
        10,
        generator=torch.Generator(device=device).manual_seed(0),
        device=device,
    )
    c = torch.randn(
        10_000,
        10,
        generator=torch.Generator(device=device).manual_seed(1),
        device=device,
    )
    return a, b, c


global_a, global_b, global_c = global_rng()
generator_a, generator_b, generator_c = generator_rng()

print(torch.allclose(global_a, generator_a))
print(torch.allclose(global_b, generator_b))
print(torch.allclose(global_c, generator_c))

It all passed, so seems safe to merge.

@oscarkey oscarkey merged commit b39d4b6 into main Oct 9, 2025
10 checks passed
@oscarkey oscarkey deleted the ok-local-rng branch October 9, 2025 12:36
oscarkey added a commit that referenced this pull request Nov 12, 2025
… random embeddings. (#175)

* Record copied public PR 525

* Use a Generator object to specify the seed for the random embeddings. (#525)

Currently we set the global pytorch random seed. When multi-device inference is enabled, this can result in a race condition because the global seed is shared between threads. Use a Generator instead to avoid this.

Testing:
- Existing consistency checks, plus the additional two I have added to verify multi-device inference.
- Save the sampled embeddings from the old and the new versions and check they are equal.
- Script to compare global/generator rng output: see comment on PR

(cherry picked from commit b39d4b6)

* Port change to v2 and v2.1 single-file models.

---------

Co-authored-by: mirror-bot <mirror-bot@users.noreply.github.com>
Co-authored-by: Oscar Key <oscar@priorlabs.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants