diff --git a/src/deepquantum/photonic/circuit.py b/src/deepquantum/photonic/circuit.py index 1816e3e3..ff0531d1 100644 --- a/src/deepquantum/photonic/circuit.py +++ b/src/deepquantum/photonic/circuit.py @@ -154,24 +154,25 @@ def forward( self, data: Optional[torch.Tensor] = None, state: Any = None, - is_prob: bool = False, + is_prob: Optional[bool] = None, detector: Optional[str] = None, stepwise: bool = False ) -> Union[torch.Tensor, Dict, List[torch.Tensor]]: - """Perform a forward pass of the photonic quantum circuit and return the final state. + """Perform a forward pass of the photonic quantum circuit and return the final-state-related result. Args: data (torch.Tensor or None, optional): The input data for the ``encoders``. Default: ``None`` state (Any, optional): The initial state for the photonic quantum circuit. Default: ``None`` - is_prob (bool, optional): Whether to return probabilities for Fock basis states or Gaussian backend. - Default: ``False`` + is_prob (bool or None, optional): For Fock backend, whether to return probabilities or amplitudes. + For Gaussian backend, whether to return probabilities or the final Gaussian state. + For Fock backend with ``basis=True``, set ``None`` to return the unitary matrix. Default: ``None`` detector (str or None, optional): For Gaussian backend, use ``'pnrd'`` for the photon-number-resolving detector or ``'threshold'`` for the threshold detector. Default: ``None`` stepwise (bool, optional): Whether to use the forward function of each operator for Gaussian backend. Default: ``False`` Returns: - Union[torch.Tensor, Dict, List[torch.Tensor]]: The final state of the photonic quantum circuit after + Union[torch.Tensor, Dict, List[torch.Tensor]]: The result of the photonic quantum circuit after applying the ``operators``. """ if self.backend == 'fock': @@ -183,19 +184,22 @@ def _forward_fock( self, data: Optional[torch.Tensor] = None, state: Any = None, - is_prob: bool = False - ) -> Union[torch.Tensor, List[torch.Tensor], Dict]: + is_prob: Optional[bool] = None + ) -> Union[torch.Tensor, Dict, List[torch.Tensor]]: """Perform a forward pass based on the Fock backend. Args: data (torch.Tensor or None, optional): The input data for the ``encoders``. Default: ``None`` state (Any, optional): The initial state for the photonic quantum circuit. Default: ``None`` - is_prob (bool, optional): Whether to return probabilities for Fock basis states. Default: ``False`` + is_prob (bool or None, optional): Whether to return probabilities or amplitudes. + When ``basis=True``, set ``None`` to return the unitary matrix. Default: ``None`` Returns: - Union[torch.Tensor, Dict]: The final state of the photonic quantum circuit after - applying the ``operators``. + Union[torch.Tensor, Dict, List[torch.Tensor]]: Unitary matrix, Fock state tensor, + a dictionary of probabilities or amplitudes, or a list of tensors for MPS. """ + if self.mps: + assert not is_prob if state is None: state = self.init_state if isinstance(state, MatrixProductState): @@ -207,10 +211,9 @@ def _forward_fock( state = FockState(state=state, nmode=self.nmode, cutoff=self.cutoff, basis=self.basis).state if data is None: if self.basis: - state_dict = self._forward_helper_basis(state=state, is_prob=is_prob) - self.state = sort_dict_fock_basis(state_dict) + self.state = self._forward_helper_basis(state=state, is_prob=is_prob) else: - self.state = self._forward_helper_tensor(state=state) + self.state = self._forward_helper_tensor(state=state, is_prob=is_prob) if not self.mps and self.state.ndim == self.nmode: self.state = self.state.unsqueeze(0) else: @@ -218,51 +221,56 @@ def _forward_fock( data = data.unsqueeze(0) assert data.ndim == 2 if self.basis: - state_dict = vmap(self._forward_helper_basis, in_dims=(0, None, None))(data, state, is_prob) - self.state = sort_dict_fock_basis(state_dict) + self.state = vmap(self._forward_helper_basis, in_dims=(0, None, None))(data, state, is_prob) else: if self.mps: assert state[0].ndim in (3, 4) if state[0].ndim == 3: - self.state = vmap(self._forward_helper_tensor, in_dims=(0, None))(data, state) + self.state = vmap(self._forward_helper_tensor, in_dims=(0, None, None))(data, state, is_prob) elif state[0].ndim == 4: - self.state = vmap(self._forward_helper_tensor)(data, state) + self.state = vmap(self._forward_helper_tensor)(data, state, is_prob) else: if state.shape[0] == 1: - self.state = vmap(self._forward_helper_tensor, in_dims=(0, None))(data, state) + self.state = vmap(self._forward_helper_tensor, in_dims=(0, None, None))(data, state, is_prob) else: - self.state = vmap(self._forward_helper_tensor)(data, state) + self.state = vmap(self._forward_helper_tensor)(data, state, is_prob) # for plotting the last data self.encode(data[-1]) + if self.basis and is_prob is not None: + self.state = sort_dict_fock_basis(self.state) return self.state def _forward_helper_basis( self, data: Optional[torch.Tensor] = None, state: Optional[torch.Tensor] = None, - is_prob: bool = False - ) -> Dict: + is_prob: Optional[bool] = None + ) -> Union[torch.Tensor, Dict]: """Perform a forward pass for one sample if the input is a Fock basis state.""" self.encode(data) - if state is None: - state = self.init_state.state - out_dict = {} - final_states = self._get_all_fock_basis(state) - sub_mats = self._get_sub_matrices(state, final_states) - per_norms = self._get_permanent_norms(state, final_states) - if is_prob: - rst = vmap(self._get_prob_fock_vmap)(sub_mats, per_norms) + if is_prob is None: + return self.get_unitary() else: - rst = vmap(self._get_amplitude_fock_vmap)(sub_mats, per_norms) - for i in range(len(final_states)): - final_state = FockState(state=final_states[i], nmode=self.nmode, cutoff=self.cutoff, basis=self.basis) - out_dict[final_state] = rst[i] - return out_dict + if state is None: + state = self.init_state.state + out_dict = {} + final_states = self._get_all_fock_basis(state) + sub_mats = self._get_sub_matrices(state, final_states) + per_norms = self._get_permanent_norms(state, final_states) + if is_prob: + rst = vmap(self._get_prob_fock_vmap)(sub_mats, per_norms) + else: + rst = vmap(self._get_amplitude_fock_vmap)(sub_mats, per_norms) + for i in range(len(final_states)): + final_state = FockState(state=final_states[i], nmode=self.nmode, cutoff=self.cutoff, basis=self.basis) + out_dict[final_state] = rst[i] + return out_dict def _forward_helper_tensor( self, data: Optional[torch.Tensor] = None, - state: Union[torch.Tensor, List[torch.Tensor], MatrixProductState, None] = None + state: Union[torch.Tensor, List[torch.Tensor], None] = None, + is_prob: Optional[bool] = None ) -> Union[torch.Tensor, List[torch.Tensor]]: """Perform a forward pass for one sample if the input is a Fock state tensor.""" self.encode(data) @@ -273,33 +281,37 @@ def _forward_helper_tensor( state = MatrixProductState(nsite=self.nmode, state=state, chi=self.chi, qudit=self.cutoff, normalize=self.init_state.normalize) return self.operators(state).tensors - if isinstance(state, FockState): - state = state.state - x = self.operators(self.tensor_rep(state)).squeeze(0) - return x + else: + if isinstance(state, FockState): + state = state.state + x = self.operators(self.tensor_rep(state)).squeeze(0) + if is_prob: + x = abs(x) ** 2 + return x def _forward_gaussian( self, data: Optional[torch.Tensor] = None, state: Any = None, - is_prob: bool = False, + is_prob: Optional[bool] = None, detector: Optional[str] = None, stepwise: bool = False - ) -> List[torch.Tensor]: + ) -> Union[List[torch.Tensor], Dict]: """Perform a forward pass based on the Gaussian backend. Args: data (torch.Tensor or None, optional): The input data for the ``encoders``. Default: ``None`` state (Any, optional): The initial state for the photonic quantum circuit. Default: ``None`` - is_prob (bool, optional): Whether to return probabilities. Default: ``False`` + is_prob (bool or None, optional): Whether to return probabilities or the final Gaussian state. + Default: ``None`` detector (str or None, optional): Use ``'pnrd'`` for the photon-number-resolving detector or ``'threshold'`` for the threshold detector. Only valid when ``is_prob`` is ``True``. Default: ``None`` stepwise (bool, optional): Whether to use the forward function of each operator. Default: ``False`` Returns: - List[torch.Tensor]: The covariance matrix and displacement vector of the final state - of the photonic quantum circuit after applying the ``operators``. + Union[List[torch.Tensor], Dict]: The covariance matrix and displacement vector of the final state + or a dictionary of probabilities. """ if state is None: state = self.init_state @@ -322,9 +334,8 @@ def _forward_gaussian( self.state = vmap(self._forward_helper_gaussian, in_dims=(0, 0, None))(data, state, stepwise) self.encode(data[-1]) if is_prob: - return self._forward_gaussian_prob(detector) - else: - return self.state + self.state = self._forward_gaussian_prob(self.state[0], self.state[1], detector) + return self.state def _forward_helper_gaussian( self, @@ -340,22 +351,22 @@ def _forward_helper_gaussian( else: cov, mean = state if stepwise: - self.state = self.operators([cov, mean]) + return self.operators([cov, mean]) else: sp_mat = self.get_symplectic() cov = sp_mat @ cov @ sp_mat.mT mean = self.get_displacement(mean) - self.state = [cov.squeeze(0), mean.squeeze(0)] - return self.state + return [cov.squeeze(0), mean.squeeze(0)] - def _forward_gaussian_prob(self, detector: Optional[str] = None) -> Dict: + def _forward_gaussian_prob(self, cov: torch.Tensor, mean: torch.Tensor, detector: Optional[str] = None) -> Dict: """Get the probabilities of all possible final states for Gaussian backend by different detectors. Args: + cov (torch.Tensor): The covariance matrix of the Gaussian state. + mean (torch.Tensor): The displacement vector of the Gaussian state. detector (str or None, optional): Use ``'pnrd'`` for the photon-number-resolving detector or ``'threshold'`` for the threshold detector. Default: ``None`` """ - cov, mean = self.state batch = cov.shape[0] if detector is None: detector = self.detector @@ -634,7 +645,8 @@ def measure( shots: int = 1024, with_prob: bool = False, wires: Union[int, List[int], None] = None, - detector: Optional[str] = None + detector: Optional[str] = None, + mcmc: bool = False ) -> Union[Dict, List[Dict], None]: """Measure the final state. @@ -647,72 +659,200 @@ def measure( Default: ``None`` (which means all wires are measured) detector (str or None, optional): For Gaussian backend, use ``'pnrd'`` for the photon-number-resolving detector or ``'threshold'`` for the threshold detector. Default: ``None`` + mcmc (bool, optional): Whether to use MCMC sampling method. Default: ``False`` + + See https://arxiv.org/pdf/2108.01622 for MCMC. """ assert not self.mps, 'Currently NOT supported.' if self.state is None: return if self.backend == 'fock': - return self._measure_fock(shots, with_prob, wires) + return self._measure_fock(shots, with_prob, wires, mcmc) elif self.backend == 'gaussian': return self._measure_gaussian(shots, with_prob, detector) + def _prob_dict_to_measure_result(self, prob_dict: Dict, shots: int, with_prob: bool) -> Dict: + """Get the measurement result from the dictionary of probabilities.""" + samples = random.choices(list(prob_dict.keys()), list(prob_dict.values()), k=shots) + results = dict(Counter(samples)) + if with_prob: + for k in results: + results[k] = results[k], prob_dict[k] + return results + def _measure_fock( self, shots: int = 1024, with_prob: bool = False, - wires: Union[int, List[int], None] = None + wires: Union[int, List[int], None] = None, + mcmc: bool = False ) -> Union[Dict, List[Dict]]: """Measure the final state for Fock backend.""" + if isinstance(self.state, torch.Tensor): + if self.basis: + return self._measure_fock_unitary(shots, with_prob, wires, mcmc) + else: + assert not mcmc, "Final states have been calculated, we don't need mcmc!" + return self._measure_fock_tensor(shots, with_prob, wires) + elif isinstance(self.state, dict): + assert not mcmc, "Final states have been calculated, we don't need mcmc!" + return self._measure_fock_dict(shots, with_prob, wires) + else: + assert False, 'Check your forward function or input!' + + def _measure_fock_unitary( + self, + shots: int = 1024, + with_prob: bool = False, + wires: Union[int, List[int], None] = None, + mcmc: bool = False + ) -> Union[Dict, List[Dict]]: + """Measure the final state according to the unitary matrix for Fock backend.""" if wires is None: wires = self.wires wires = sorted(self._convert_indices(wires)) - amp_dis = self.state + if self.state.ndim == 2: + self.state = self.state.unsqueeze(0) + batch = self.state.shape[0] all_results = [] - if self.basis: - batch = len(amp_dis[list(amp_dis.keys())[0]]) + if mcmc: for i in range(batch): + samples_i = self._sample_mcmc_fock(shots=shots, unitary=self.state[i], num_chain=5) + keys = list(map(FockState, samples_i.keys())) + results = dict(zip(keys, samples_i.values())) + if with_prob: + for k in results: + prob = self._prob_func_fock_unitary(k.state) + results[k] = results[k], prob + all_results.append(results) + else: + state = self.init_state.state + final_states = self._get_all_fock_basis(state) + sub_mats = [] + u = self.state + for fstate in final_states: + sub_mats.append(vmap(sub_matrix, in_dims=(0, None, None))(u, state, fstate)) + sub_mats = torch.stack(sub_mats, dim=1) + per_norms = self._get_permanent_norms(state, final_states) + for j in range(batch): + rst = vmap(self._get_prob_fock_vmap)(sub_mats[j], per_norms) + state_dict = {} prob_dict = defaultdict(list) - for key in amp_dis.keys(): + for i in range(len(final_states)): + final_state = FockState(state=final_states[i]) + state_dict[final_state] = rst[i] + for key in state_dict.keys(): state_b = key.state[wires] state_b = FockState(state=state_b) - prob_dict[state_b].append(abs(amp_dis[key][i]) ** 2) + prob_dict[state_b].append(state_dict[key]) for key in prob_dict.keys(): prob_dict[key] = sum(prob_dict[key]) - samples = random.choices(list(prob_dict.keys()), list(prob_dict.values()), k=shots) - results = dict(Counter(samples)) - if with_prob: - for k in results: - results[k] = results[k], prob_dict[k] + results = self._prob_dict_to_measure_result(prob_dict, shots, with_prob) all_results.append(results) + if batch == 1: + return all_results[0] else: - state_tensor = self.tensor_rep(amp_dis) - batch = state_tensor.shape[0] - combi = list(itertools.product(range(self.cutoff), repeat=len(wires))) - for i in range(batch): - prob_dict = {} - state = state_tensor[i] - probs = abs(state) ** 2 - if wires == self.wires: - ptrace_probs = probs + return all_results + + def _measure_fock_dict( + self, + shots: int = 1024, + with_prob: bool = False, + wires: Union[int, List[int], None] = None + ) -> Union[Dict, List[Dict]]: + """Measure the final state according to the dictionary of amplitudes or probabilities for Fock backend.""" + if wires is None: + wires = self.wires + wires = sorted(self._convert_indices(wires)) + all_results = [] + batch = len(self.state[list(self.state.keys())[0]]) + if any(value.dtype.is_complex for value in self.state.values()): + is_prob = False + else: + is_prob = True + for i in range(batch): + prob_dict = defaultdict(list) + for key in self.state.keys(): + state_b = key.state[wires] + state_b = FockState(state=state_b) + if is_prob: + prob_dict[state_b].append(self.state[key][i]) else: - sum_idx = list(range(self.nmode)) - for idx in wires: - sum_idx.remove(idx) - ptrace_probs = probs.sum(dim=sum_idx) - for p_state in combi: - p_state_b = FockState(list(p_state)) - prob_dict[p_state_b] = ptrace_probs[p_state] - samples = random.choices(list(prob_dict.keys()), list(prob_dict.values()), k=shots) - results = dict(Counter(samples)) - if with_prob: - for k in results: - results[k] = results[k], prob_dict[k] - all_results.append(results) + prob_dict[state_b].append(abs(self.state[key][i]) ** 2) + for key in prob_dict.keys(): + prob_dict[key] = sum(prob_dict[key]) + results = self._prob_dict_to_measure_result(prob_dict, shots, with_prob) + all_results.append(results) if batch == 1: return all_results[0] else: return all_results + def _measure_fock_tensor( + self, + shots: int = 1024, + with_prob: bool = False, + wires: Union[int, List[int], None] = None + ) -> Union[Dict, List[Dict]]: + """Measure the final state according to Fock state tensor for Fock backend.""" + if wires is None: + wires = self.wires + wires = sorted(self._convert_indices(wires)) + all_results = [] + if self.state.is_complex(): + state_tensor = self.tensor_rep(abs(self.state) ** 2) + else: + state_tensor = self.tensor_rep(self.state) + batch = state_tensor.shape[0] + combi = list(itertools.product(range(self.cutoff), repeat=len(wires))) + for i in range(batch): + prob_dict = {} + probs = state_tensor[i] + if wires == self.wires: + ptrace_probs = probs + else: + sum_idx = list(range(self.nmode)) + for idx in wires: + sum_idx.remove(idx) + ptrace_probs = probs.sum(dim=sum_idx) + for p_state in combi: + p_state_b = FockState(list(p_state)) + prob_dict[p_state_b] = ptrace_probs[p_state] + results = self._prob_dict_to_measure_result(prob_dict, shots, with_prob) + all_results.append(results) + if batch == 1: + return all_results[0] + else: + return all_results + + def _sample_mcmc_fock(self, shots: int, unitary: torch.Tensor, num_chain: int): + """Sample the output states for Fock backend via SC-MCMC method.""" + self._unitary = unitary + merged_samples = sample_sc_mcmc(prob_func=self._prob_func_fock_unitary, + proposal_sampler=self._proposal_sampler, + shots=shots, + num_chain=num_chain) + return merged_samples + + def _prob_func_fock_unitary(self, final_state: torch.Tensor, init_state: Optional[FockState] = None) -> torch.Tensor: + """Get the probability of the final state according to the unitary matrix for Fock backend. + + Args: + final_state (torch.Tensor): The final Fock basis state. + init_state (FockState or None, optional): The initial Fock basis state. Default: ``None`` + """ + if init_state is None: + init_state = self.init_state + sub_mat = sub_matrix(self._unitary, init_state.state, final_state) + nphoton = sum(init_state.state) + if nphoton == 0: + amp = torch.tensor(1.) + else: + per = permanent(sub_mat) + amp = per / self._get_permanent_norms(init_state.state, final_state).to(per.dtype).to(per.device) + prob = torch.abs(amp) ** 2 + return prob + def _measure_gaussian(self, shots: int = 1024, with_prob: bool = False, detector: Optional[str] = None) -> Dict: """Measure the final state for Gaussian backend. @@ -742,8 +882,8 @@ def _measure_gaussian(self, shots: int = 1024, with_prob: bool = False, detector def _sample_mcmc_gaussian(self, shots: int, cov: torch.Tensor, mean: torch.Tensor, detector: str, num_chain: int): """Sample the output states for Gaussian backend via SC-MCMC method.""" - self.cov = cov - self.mean = mean + self._cov = cov + self._mean = mean self.detector = detector if detector == 'threshold' and not torch.allclose(mean, torch.zeros_like(mean)): # For the displaced state, aggregate PNRD detector samples to derive threshold detector results @@ -769,12 +909,20 @@ def _prob_func_gaussian(self, state: Any) -> torch.Tensor: """Get the probability of the state for Gaussian backend.""" if not isinstance(state, torch.Tensor): state = torch.tensor(state, dtype=torch.int) - prob = self._get_probs_gaussian_helper(state, cov=self.cov, mean=self.mean, detector=self.detector)[0] + prob = self._get_probs_gaussian_helper(state, cov=self._cov, mean=self._mean, detector=self.detector)[0] return prob def _proposal_sampler(self): """The proposal sampler for MCMC sampling.""" - sample = self._generate_rand_sample(self.detector) + if self.backend == 'fock': + assert self.basis, 'Currently NOT supported.' + if self.basis: + all_fock_basis = self._get_all_fock_basis(self.init_state.state) + # else: + # all_fock_basis = self._get_all_fock_basis(self.init_state.state[0]) + sample = all_fock_basis[torch.randint(0, len(all_fock_basis), (1,))][0] + elif self.backend == 'gaussian': + sample = self._generate_rand_sample(self.detector) return sample def _generate_rand_sample(self, detector: str = 'pnrd'): @@ -782,7 +930,7 @@ def _generate_rand_sample(self, detector: str = 'pnrd'): if detector == 'threshold': sample = torch.randint(0, 2, [self.nmode]) elif detector == 'pnrd': - if torch.allclose(self.mean, torch.zeros_like(self.mean)): + if torch.allclose(self._mean, torch.zeros_like(self._mean)): while True: sample = torch.randint(0, self.cutoff, [self.nmode]) if sample.sum() % 2 == 0: diff --git a/src/deepquantum/photonic/qmath.py b/src/deepquantum/photonic/qmath.py index e053d6ae..8a7a21a6 100644 --- a/src/deepquantum/photonic/qmath.py +++ b/src/deepquantum/photonic/qmath.py @@ -312,7 +312,7 @@ def sample_sc_mcmc(prob_func: Callable, else: prob_i = prob_func(sample_i) cache_prob[tuple(sample_i.tolist())] = prob_i - rand_num = torch.rand(1) + rand_num = torch.rand(1, device= prob_i.device) samples.append(sample_i) # MCMC transfer to new state if prob_i / prob_max > rand_num: