Conversation
…r and enhance error handling in TorchrunInferenceWorker
Summary of ChangesHello @GACLove, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on refining the distributed API for parallel processing by establishing a dedicated communication channel for task distribution and bolstering error recovery mechanisms. The changes aim to enhance the stability and reliability of the distributed inference system, particularly in handling task broadcasting and ensuring proper synchronization across processes, even in the face of failures. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces significant improvements to the distributed inference capabilities by creating a dedicated 'gloo' process group for task communication, which should enhance stability and prevent deadlocks. The error handling in the worker process is also made more robust. I've identified a potential crash in an error handling path and suggested several performance optimizations for data serialization. Overall, these are great changes for improving the reliability of the parallel execution.
| if self.rank == 0: | ||
| if has_error: | ||
| return { | ||
| "task_id": task_data.get("task_id", "unknown"), |
There was a problem hiding this comment.
If an exception occurs early in process_request (e.g., if it's called with task_data=None), task_data could be None at this point. Calling .get() on None would cause a crash within the error handling logic, which can mask the original error. This change adds a check to handle this case gracefully.
| "task_id": task_data.get("task_id", "unknown"), | |
| "task_id": task_data.get("task_id", "unknown") if task_data else "unknown", |
| chunk = data_bytes[start_idx:end_idx] | ||
| task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device) | ||
| dist.broadcast(task_tensor, src=0) | ||
| task_tensor = torch.tensor(list(chunk), dtype=torch.uint8) |
There was a problem hiding this comment.
Using torch.tensor(list(chunk)) is inefficient for converting bytes to a tensor, especially for large chunks (1MB). It creates a large intermediate list of integers. A more performant approach is to use numpy.frombuffer which creates a view on the byte buffer without copying data, and then convert that to a tensor. You will need to add import numpy at the top of the file.
| task_tensor = torch.tensor(list(chunk), dtype=torch.uint8) | |
| task_tensor = torch.from_numpy(numpy.frombuffer(chunk, dtype=numpy.uint8)) |
| chunk = data_bytes[-remaining:] | ||
| task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device) | ||
| dist.broadcast(task_tensor, src=0) | ||
| task_tensor = torch.tensor(list(chunk), dtype=torch.uint8) |
There was a problem hiding this comment.
Similar to the loop above, this conversion from bytes to tensor is inefficient. Using numpy.frombuffer is recommended for better performance. You will need to add import numpy at the top of the file if you haven't already.
| task_tensor = torch.tensor(list(chunk), dtype=torch.uint8) | |
| task_tensor = torch.from_numpy(numpy.frombuffer(chunk, dtype=numpy.uint8)) |
| received.extend(task_tensor.cpu().numpy()) | ||
| task_tensor = torch.empty(chunk_length, dtype=torch.uint8) | ||
| dist.broadcast(task_tensor, src=0, group=self.task_pg) | ||
| received.extend(task_tensor.numpy()) |
There was a problem hiding this comment.
bytearray.extend() with a NumPy array iterates through the array and appends each number individually, which is inefficient. Using tobytes() and += provides a more direct and performant way to append the tensor's data.
| received.extend(task_tensor.numpy()) | |
| received += task_tensor.numpy().tobytes() |
No description provided.