In [1]:
import pennylane as qml
import torch
import torch.nn.functional as F
from pennylane import numpy as np

In [4]:
n_qubits = 4
dev = qml.device("default.qubit", wires=[i for i in range(n_qubits+1)])

def ansatz_flatten(state, flat_weights, n_qubits, n_layers=1, change_of_basis=False, entanglement="all2all"):
    #flat_weights = weights.flatten()
    num_weights_per_layer = n_qubits * 2

    if change_of_basis is True:
        for l in range(n_layers):
            for i in range(n_qubits):
                index = l * num_weights_per_layer + i * 2
                qml.Rot(flat_weights[index], flat_weights[index + 1], wires=i)
    else:          
        for l in range(n_layers):
            for i in range(n_qubits):
                index = l * num_weights_per_layer + i * 2
                qml.RZ(flat_weights[index], wires=i)
                qml.RY(flat_weights[index + 1], wires=i)


            if entanglement == "all2all":
                for q1 in range(n_qubits-1):    
                    for q2 in range(q1+1, n_qubits):
                        qml.CNOT(wires=[q1,q2])

            elif entanglement == "mod":
                for q1 in range(n_qubits):
                    qml.CNOT(wires=[q1, (q1+l+1)%n_qubits])

            elif entanglement == "linear":
                for q1 in range(n_qubits-1):    
                    qml.CNOT(wires=[q1, q1+1])

            elif entanglement == "circular":
                for q1 in range(n_qubits):
                    qml.CNOT(wires=[q1, (q1+1)%n_qubits])

            elif entanglement == "nn":
                qml.CNOT(wires=[0, 1])
                qml.CNOT(wires=[2, 3])
                qml.CNOT(wires=[1, 2])

            else:
                for q in range(1, n_qubits):
                    qml.CNOT(wires=[q, 0])
                for q in range(2, n_qubits):
                    qml.CNOT(wires=[q, 1])

            if l < n_layers-1:
                qml.AngleEmbedding(state, wires=range(n_qubits),rotation="Y")
                qml.AngleEmbedding(state, wires=range(n_qubits),rotation="Z")


@qml.qnode(dev,interface="torch")
def qcircuit_fisher(s,params):

    for i in range(n_qubits):
        qml.Hadamard(wires=i)
    
    ansatz_flatten(s, params, n_qubits, n_layers=1, change_of_basis=False, entanglement="all2all")

    #for q in range(n_qubits-1):
        #qml.CNOT(wires=[q,n_qubits])
        #qml.CNOT(wires=[q,q+1])
        
    return qml.probs(wires=range(n_qubits)) 


@qml.qnode(dev,interface="torch")
def qcircuit_fisher_2(s,params):

    for i in range(n_qubits):
        qml.Hadamard(wires=i)
    
    ansatz_flatten(s, params, n_qubits, n_layers=1, change_of_basis=False, entanglement="all2all")

    for q in range(n_qubits-1):
        #qml.CNOT(wires=[q,n_qubits])
        qml.CNOT(wires=[q,q+1])
        
    return qml.probs(wires=n_qubits-1) 

In [19]:
def FIM(s,weights):
    num_params = len(weights)
    fisher_info_matrix = np.zeros((num_params, num_params))
    #for s in range(n_samples):
    
    outs = qcircuit_fisher(s,weights)

    for i in range(2**n_qubits):
        if weights.grad is not None:
            print("grad not none")
            weights.grad.zero_()
        
        #outs = qcircuit_fisher(s,weights)
        log_prob = outs[i]

        log_prob.backward(retain_graph=True)
        grad = weights.grad.view(-1)
        grad_np = grad.detach().numpy()  # Detach the gradients and convert to NumPy
        fisher_info_matrix += (1/outs[i].detach().numpy()) * np.outer(grad_np, grad_np)
    #fisher_info_matrix /= n_samples

    regularization_constant = 0.1
    fisher_info_matrix += regularization_constant * np.eye(num_params)
    fisher_info_matrix = fisher_info_matrix.real

    return fisher_info_matrix

In [12]:
batch_size = 1000
ss = np.random.random((batch_size,n_qubits))

avg_state = np.mean(ss,axis=0)

n_layers=4
w = np.random.random((n_layers,n_qubits,2)).flatten()


In [13]:
cfim = qml.qinfo.classical_fisher(qcircuit_fisher)(torch.tensor(avg_state,requires_grad=False), torch.tensor(w, requires_grad=True))[1].detach().numpy() 
regularization_constant = 0.1
cfim += regularization_constant * np.eye(len(w))
print(cfim)

[[ 1.00906904e-01 -2.78438309e-02  2.17085289e-10 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [-2.78438309e-02  9.54862997e-01 -3.06414034e-09 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [ 2.17085289e-10 -3.06414034e-09  1.94307898e-01 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 ...
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  1.00000000e-01
   0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  0.00000000e+00
   1.00000000e-01  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  0.00000000e+00
   0.00000000e+00  1.00000000e-01]]


In [14]:
cfim = qml.qinfo.classical_fisher(qcircuit_fisher_2)(torch.tensor(avg_state,requires_grad=False), torch.tensor(w, requires_grad=True))[1].detach().numpy() 
regularization_constant = 0.1
cfim += regularization_constant * np.eye(len(w))
print(cfim)

[[ 1.00906904e-01 -2.78438309e-02  2.17085289e-10 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [-2.78438309e-02  9.54862997e-01 -3.06414034e-09 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [ 2.17085289e-10 -3.06414034e-09  1.94307898e-01 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 ...
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  1.00000000e-01
   0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  0.00000000e+00
   1.00000000e-01  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  0.00000000e+00
   0.00000000e+00  1.00000000e-01]]


In [20]:
cfim_man = FIM(torch.tensor(avg_state,requires_grad=False), torch.tensor(w, requires_grad=True)) 
print(cfim_man)

grad not none
grad not none
grad not none
grad not none
grad not none
grad not none
grad not none
grad not none
grad not none
grad not none
grad not none
grad not none
grad not none
grad not none
grad not none
[[ 1.00906904e-01 -2.78438309e-02  2.17085289e-10 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [-2.78438309e-02  9.54862997e-01 -3.06414033e-09 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [ 2.17085289e-10 -3.06414033e-09  1.94307898e-01 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 ...
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  1.00000000e-01
   0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  0.00000000e+00
   1.00000000e-01  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  0.00000000e+00
   0.00000000e+00  1.00000000e-01]]


: 

In [16]:
#frobenious norm
print(np.linalg.norm(cfim-cfim_man))

2.748766145904134e-16
