Skip to content

Conversation

@ZimengXiong
Copy link
Contributor

@ZimengXiong ZimengXiong commented Nov 14, 2025

MPS (Apple Silicon) support and fixed bugs re multinomial sampling across demos and the Gradio app

  • Shared device/dtype selection

    • Added select_device() + backend→dtype map:
      • CUDA → torch.bfloat16
      • MPS → torch.float16
      • CPU → torch.float32
  • Guarded multinomial sampling

    • Introduced a wrapped torch.multinomial
    • Logit sanitization:
      • Replace NaN/Inf with 0
      • Clamp negatives to 0
      • Renormalize along the last dimension to sum to 1

Accepts tensors of arbitrary rank, flattens to 2D for sampling to satisfy PyTorch’s 1–2D constraint, Reshapes sampled indices back to the original shape

Fixes a runtime crash in the Gradio UI caused by HighlightedText expecting list-of-tuples but getting a string

  • Add highlight_message() helper to wrap error/status strings as a HighlightedText-compatible list of (text, color) tuples.
  • Replace bare string yields in dream_generate_with_visualization error paths with highlight_message(error_message) so gr.HighlightedText no longer throws IndexError.
  • Make UI reset functions and callbacks use an empty list for vis_output_display (instead of "") so the HighlightedText component always receives a consistent type.
  • Return [] for skip response in bot_response_generator to maintain consistent component behavior.

@jiacheng-ye
Copy link
Contributor

Thanks, Zimeng, nice work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants