Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ def update_production_status(
max_sample_idx = max(global_indices) if global_indices else -1
required_samples = max_sample_idx + 1

# Ensure we have enough rows
with self.data_status_lock:
# Ensure we have enough rows
self.ensure_samples_capacity(required_samples)

# Register new fields if needed
Expand All @@ -415,10 +415,11 @@ def update_production_status(
with self.data_status_lock:
self.ensure_fields_capacity(required_fields)

# Update production status
if self.production_status is not None and global_indices and field_names:
field_indices = [self.field_name_mapping.get(field) for field in field_names]
self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1
with self.data_status_lock:
Comment on lines 403 to +418
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lock is released between ensuring samples capacity (line 405) and ensuring fields capacity (line 416), creating a potential race condition window. If another thread calls ensure_samples_capacity or ensure_fields_capacity during this window, it could reassign self.production_status to a new tensor object. The subsequent field indexing at line 422 would then operate on this new tensor, which might not have been properly expanded for the new fields registered in lines 411-412.

Consider holding the lock continuously from line 403 through the production status update at line 422, especially since the field registration (lines 408-412) also modifies shared state that should be protected.

Copilot uses AI. Check for mistakes.
# Update production status
if self.production_status is not None and global_indices and field_names:
field_indices = [self.field_name_mapping.get(field) for field in field_names]
self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1

# Update field metadata
self._update_field_metadata(global_indices, dtypes, shapes, custom_backend_meta)
Expand Down
Loading