In [1]:
# Importing standard Qiskit libraries
from qiskit import QuantumCircuit, transpile
from qiskit.tools.jupyter import *
from qiskit.visualization import *
from ibm_quantum_widgets import *
from qiskit_aer import AerSimulator

# qiskit-ibmq-provider has been deprecated.
# Please see the Migration Guides in https://ibm.biz/provider_migration_guide for more detail.
from qiskit_ibm_runtime import QiskitRuntimeService, Sampler, Estimator, Session, Options

# Loading your IBM Quantum account(s)
service = QiskitRuntimeService(channel="ibm_quantum")

# Invoke a primitive inside a session. For more details see https://qiskit.org/documentation/partners/qiskit_ibm_runtime/tutorials.html
# with Session(backend=service.backend("ibmq_qasm_simulator")):
#     result = Sampler().run(circuits).result()

In [2]:
import torch
from torch.optim.optimizer import Optimizer

class CustomOptimizer(Optimizer):
    def __init__(self, params, lr, momentum, weight_decay, d_dim, k, alpha):
        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, d_dim=d_dim, k=k, alpha=alpha)
        super().__init__(params, defaults)

    def step(self, closure=None):
        for group in self.param_groups:
            lr = group['lr']
            momentum = group['momentum']
            weight_decay = group['weight_decay']
            d_dim = group['d_dim']
            k = group['k']
            alpha = group['alpha']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data
                shape = grad.shape
                d = min(d_dim, len(shape))
                flatten_dim = 1
                for i in range(d):
                    flatten_dim *= shape[i]
                grad_2d = grad.view(flatten_dim, -1)

                U, S, V = torch.svd(grad_2d)
                Uk = U[:, :k]
                Uk_norm = torch.norm(Uk, p='fro')
                Wk = Uk / Uk_norm
                W2k = U[:, :2*k] @ torch.diag(S[:2*k]) @ V[:2*k, :]
                W2k_norm = torch.norm(W2k, p='fro')
                Wk_norm = momentum * W2k_norm + (1 - momentum) * Uk_norm
                Wk_normalized = Wk / Wk_norm

                alpha_percentile = torch.kthvalue(torch.abs(Wk_normalized), int(alpha * flatten_dim))[0]
                mask = torch.abs(Wk_normalized) >= alpha_percentile
                Wk_truncated = Wk_normalized * mask

                if 'momentum_buffer' not in state:
                    state['momentum_buffer'] = torch.zeros_like(p.data)
                momentum_buffer = state['momentum_buffer']

                momentum_buffer.mul_(momentum).add_(Wk_truncated, alpha=1 - momentum)
                p.data.add_(momentum_buffer, alpha=-lr)
                p.data.add_(-lr * weight_decay, p.data)
