Skip to content

Support vmap over multiple instances of QumodeCircuit#139

Merged
sansiro77 merged 10 commits intoTuringQ:mainfrom
Jooyuza:vmap
Feb 5, 2026
Merged

Support vmap over multiple instances of QumodeCircuit#139
sansiro77 merged 10 commits intoTuringQ:mainfrom
Jooyuza:vmap

Conversation

@Jooyuza
Copy link
Copy Markdown
Collaborator

@Jooyuza Jooyuza commented Feb 5, 2026

New features:

  • set_fock_basis(): allow users manually set fock basis states before forward
  • get_fock_basis(): check the output fock basis states in current settings
  • optional bool arg sort of forward could decide whether to sort dictionary of Fock basis states in the descending order of probs. If sort=False, the dictionary will be in the lexicographic order.

Miscellaneous:

  • update get_unitary: replacing in-place operations with out-of-place ops index_put to support vmap
  • support return hash value of FockState on meta devices.
  • fix typos

Example use case:

import torch
import torch.nn as nn

class BS(nn.Module):
    def __init__(self):
        super().__init__()
        cir = dq.QumodeCircuit(nmode=3, init_state=[1,1,0])
        cir.bs([0,1], encode=True)
        cir.ps([1], encode=True)
        cir.bs([1,2])
        cir.set_fock_basis() # set all fock basis before vmap
        self.cir = cir

    def forward(self, x):
        d = self.cir(x, is_prob=True, sort=False) # forward without extra sorting
        sorted_items = d.items()
        sorted_values = [v for k, v in sorted_items]
        probs = torch.cat(sorted_values,dim=1)
        return probs

num_models = 5
batch_size = 64

models = [BS() for i in range(num_models)]
data = torch.randn(batch_size, 3)

def wrapper(params, buffers, data):
    return torch.func.functional_call(models[0], (params, buffers), data)

params, buffers = torch.func.stack_module_state(models)
output = vmap(wrapper, (0, 0, None))(params, buffers, data)

print(output.shape)

Comment thread src/deepquantum/photonic/circuit.py Outdated
Comment thread src/deepquantum/photonic/circuit.py Outdated
Comment thread src/deepquantum/photonic/circuit.py Outdated
Comment thread src/deepquantum/photonic/circuit.py Outdated
Comment thread src/deepquantum/photonic/circuit.py
Comment thread src/deepquantum/photonic/circuit.py
@sansiro77 sansiro77 added the enhancement New feature or request label Feb 5, 2026
@sansiro77 sansiro77 changed the title Support vmap over module_list of QumodeCircuit Support vmap over multiple instances of QumodeCircuit Feb 5, 2026
Copy link
Copy Markdown
Contributor

@sansiro77 sansiro77 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sansiro77 sansiro77 merged commit 63cf47c into TuringQ:main Feb 5, 2026
@Jooyuza Jooyuza deleted the vmap branch March 4, 2026 02:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants