Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UnitaryHack] Added Rotosolve optimizer #93

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions lambeq/__init__.py
Expand Up @@ -85,6 +85,7 @@

'Optimizer',
'SPSAOptimizer',
'RotosolveOptimizer',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re-order (Rotosolve before SPSA)


'Model',
'NumpyModel',
Expand Down Expand Up @@ -126,8 +127,9 @@
stairs_reader, word_sequence_reader)
from lambeq.tokeniser import Tokeniser, SpacyTokeniser
from lambeq.training import (Checkpoint, Dataset, Optimizer, SPSAOptimizer,
Model, NumpyModel, PennyLaneModel, PytorchModel,
QuantumModel, TketModel, Trainer, PytorchTrainer,
RotosolveOptimizer, Model, NumpyModel,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re-order

PennyLaneModel, PytorchModel, QuantumModel,
TketModel, Trainer, PytorchTrainer,
QuantumTrainer, BinaryCrossEntropyLoss,
CrossEntropyLoss, LossFunction, MSELoss)
from lambeq.version import (version as __version__,
Expand Down
2 changes: 2 additions & 0 deletions lambeq/training/__init__.py
Expand Up @@ -24,6 +24,7 @@

'Optimizer',
'SPSAOptimizer',
'RotosolveOptimizer',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re-order


'Trainer',
'PytorchTrainer',
Expand All @@ -47,6 +48,7 @@

from lambeq.training.optimizer import Optimizer
from lambeq.training.spsa_optimizer import SPSAOptimizer
from lambeq.training.rotosolve_optimizer import RotosolveOptimizer

from lambeq.training.trainer import Trainer
from lambeq.training.pytorch_trainer import PytorchTrainer
Expand Down
171 changes: 171 additions & 0 deletions lambeq/training/rotosolve_optimizer.py
@@ -0,0 +1,171 @@
# Copyright 2021-2023 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
RotosolveOptimizer
=============
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some =

Module implementing the Rotosolve optimizer.

"""
from __future__ import annotations

from collections.abc import Callable, Iterable, Mapping
from typing import Any

import numpy as np
from numpy.typing import ArrayLike

from lambeq.training.optimizer import Optimizer
from lambeq.training.quantum_model import QuantumModel


class RotosolveOptimizer(Optimizer):
"""An Optimizer using the Rotosolve algorithm.

See https://quantum-journal.org/papers/q-2021-01-28-391/pdf/ for details.

"""

model : QuantumModel

def __init__(self, model: QuantumModel,
hyperparams: dict[str, float],
loss_fn: Callable[[Any, Any], float],
bounds: ArrayLike | None = None) -> None:
"""Initialise the Rotosolve optimizer.

Parameters
----------
model : :py:class:`.QuantumModel`
A lambeq quantum model.
hyperparams : dict of str to float.
A dictionary containing the models hyperparameters.
loss_fn : Callable
A loss function of form `loss(prediction, labels)`.
bounds : ArrayLike, optional
The range of each of the model parameters.

Raises
------
ValueError
If the length of `bounds` does not match the number
of the model parameters.

"""
if bounds is None:
bounds = [[-np.pi, np.pi]]*len(model.weights)

super().__init__(model, hyperparams, loss_fn, bounds)

self.project: Callable[[np.ndarray], np.ndarray]

bds = np.asarray(bounds)
if len(bds) != len(self.model.weights):
raise ValueError('Length of `bounds` must be the same as the '
'number of the model parameters')
self.project = lambda x: x.clip(bds[:, 0], bds[:, 1])

def backward(
self,
batch: tuple[Iterable[Any], np.ndarray]) -> float:
"""Calculate the gradients of the loss function.

The gradients are calculated with respect to the model
parameters.

Parameters
----------
batch : tuple of Iterable and numpy.ndarray
Current batch. Contains an Iterable of diagrams in index 0,
and the targets in index 1.

Returns
-------
float
The calculated loss.

"""
diagrams, targets = batch

# The new model weights
self.gradient = np.copy(self.model.weights)

old_model_weights = self.model.weights

for i, _ in enumerate(self.gradient):
# Let phi be 0

# M_phi
self.gradient[i] = 0.0
self.model.weights = self.gradient
m_phi = self.model(diagrams)

# M_phi + pi/2
self.gradient[i] = np.pi / 2
self.model.weights = self.gradient
m_phi_plus = self.model(diagrams)

# M_phi - pi/2
self.gradient[i] = -np.pi / 2
self.model.weights = self.gradient
m_phi_minus = self.model(diagrams)

# Update weight
self.gradient[i] = -(np.pi / 2) - np.arctan2(
2*m_phi - m_phi_plus - m_phi_minus,
m_phi_plus - m_phi_minus
)

# Calculate loss
self.model.weights = self.gradient
y1 = self.model(diagrams)
loss = self.loss_fn(y1, targets)

self.model.weights = old_model_weights

return loss

def step(self) -> None:
"""Perform optimisation step."""
self.model.weights = self.gradient
self.model.weights = self.project(self.model.weights)

self.update_hyper_params()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to call this here

self.zero_grad()

def update_hyper_params(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to define this method

"""Update the hyperparameters of the Rotosolve algorithm."""
return

def state_dict(self) -> dict[str, Any]:
"""Return optimizer states as dictionary.

Returns
-------
dict
A dictionary containing the current state of the optimizer.

"""
raise NotImplementedError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return empty dict


def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
"""Load state of the optimizer from the state dictionary.

Parameters
----------
state_dict : dict
A dictionary containing a snapshot of the optimizer state.

"""
raise NotImplementedError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just do a pass here

123 changes: 123 additions & 0 deletions tests/training/test_rotosolve_optimizer.py
@@ -0,0 +1,123 @@
import pytest

import numpy as np

from discopy import Cup, Word
from discopy.quantum.circuit import Id

from lambeq import AtomicType, IQPAnsatz, RotosolveOptimizer

N = AtomicType.NOUN
S = AtomicType.SENTENCE

ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1, n_single_qubit_params=1)

diagrams = [
ansatz((Word("Alice", N) @ Word("runs", N >> S) >> Cup(N, N.r) @ Id(S))),
ansatz((Word("Alice", N) @ Word("walks", N >> S) >> Cup(N, N.r) @ Id(S)))
]

from lambeq.training.model import Model


class ModelDummy(Model):
def __init__(self) -> None:
super().__init__()
self.initialise_weights()
def from_checkpoint():
pass
def _make_lambda(self, diagram):
return diagram.lambdify(*self.symbols)
def initialise_weights(self):
self.weights = np.array([1.,2.,3.])
def _clear_predictions(self):
pass
def _log_prediction(self, y):
pass
def get_diagram_output(self):
pass
def _make_checkpoint(self):
pass
def _load_checkpoint(self):
pass
def forward(self, x):
return self.weights.sum()

loss = lambda yhat, y: np.abs(yhat-y).sum()**2

def test_init():
model = ModelDummy.from_diagrams(diagrams)
model.initialise_weights()
optim = RotosolveOptimizer(model,
hyperparams={},
loss_fn= loss,
bounds=[[-np.pi, np.pi]]*len(model.weights))

assert optim.project

def test_backward():
np.random.seed(3)
model = ModelDummy.from_diagrams(diagrams)
model.initialise_weights()
optim = RotosolveOptimizer(model,
hyperparams={},
loss_fn= loss,
bounds=[[-np.pi, np.pi]]*len(model.weights))

optim.backward(([diagrams[0]], np.array([0])))

assert np.array_equal(optim.gradient.round(5), np.array([-1.5708] * len(model.weights)))
assert np.array_equal(model.weights, np.array([1.,2.,3.]))

def test_step():
np.random.seed(3)
model = ModelDummy.from_diagrams(diagrams)
model.initialise_weights()
optim = RotosolveOptimizer(model,
hyperparams={},
loss_fn= loss,
bounds=[[-np.pi, np.pi]]*len(model.weights))
optim.backward(([diagrams[0]], np.array([0])))
optim.step()

assert np.array_equal(model.weights.round(4), np.array([-1.5708] * len(model.weights)))

def test_bound_error():
model = ModelDummy()
model.initialise_weights()
with pytest.raises(ValueError):
_ = RotosolveOptimizer(model=model,
hyperparams={},
loss_fn=loss,
bounds=[[0, 10]]*(len(model.weights)-1))

def test_none_bound_error():
model = ModelDummy()
model.initialise_weights()
optim = RotosolveOptimizer(model=model,
hyperparams={},
loss_fn=loss)

assert optim.bounds == [[-np.pi, np.pi]] * len(model.weights)

def test_load_state_dict():
model = ModelDummy()
model.from_diagrams(diagrams)
model.initialise_weights()
optim = RotosolveOptimizer(model,
hyperparams={},
loss_fn= loss)

with pytest.raises(NotImplementedError):
optim.load_state_dict({})

def test_state_dict():
model = ModelDummy()
model.from_diagrams(diagrams)
model.initialise_weights()
optim = RotosolveOptimizer(model,
hyperparams={},
loss_fn= loss)

with pytest.raises(NotImplementedError):
optim.state_dict()