In [None]:
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from tqdm import tnrange
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [None]:
measurements = np.load('measurements.npy', allow_pickle=True).all()
measurements_torch = {}
for base, item in measurements.items():
    measurements_torch[base] = (torch.tensor(item[0],dtype=torch.float32),
                          torch.tensor(item[1],dtype=torch.float32),
                          torch.tensor(item[2],dtype=torch.float32))
measurements = measurements_torch

In [None]:
class Net(nn.Module):
    
    lr = 0.01
    N_epochs = 300
     
    def __init__(self,dim,measurements): 
        super(Net,self).__init__()
        # dim is a number of qubits
        self.dim = dim
        self.measurements = measurements
        
        # Use nn.Linear to define layers
        # For example, self.linear = nn.Linear(dim_in,dim_out),
        # where dim_in, dim_out are input and output dimensions of a layer
        # Write your code here: 
        # ---------------
        
        #Define activation functions
        # Write your code here: 
        # ---------------
        
        freq = measurements[0][2]/(torch.sum(measurements[0][2]))
        keys = measurements.keys()
        for i in keys:
            if(i!=0):
                freq_cur = measurements[i][2]/(torch.sum(measurements[i][2]))
                freq = torch.cat((freq,freq_cur))
        
        self.freq = freq
    
    def forward(self,x):
        # x --- batch of input vectors
        # Apply linear layers and activation functions sequentially
        # The size of the input = Number of qubits
        # The size of the output = 2 neurons, first defines an amplitude, 2nd defines a phase  
        # Write your code here: 
        # ---------------
        return x
    
    
    # Method transforms tensor of decimal numbers to the tensor of bit strings
    # The shape of the final tensor is (Number of decimal numbers, Number of qubits)
    def int10tobase_torch(self,num,base=2):
        c = num
        max_digit = self.dim-1
        res = torch.zeros((max_digit+1,num.shape[0]),dtype = torch.long)
        i = max_digit
        while(torch.sum(c)!=0):
            res[i,:] = torch.remainder(c,base)
            c = (c/base).type(torch.long)
            i-=1
        return res.t()
    
    # Write a method which returns unitary matrix 2x2  
    # This matrix rotates qubit state on a Bloch sphere
    def rotation(self,theta,phi):
        U = torch.empty((2,2),dtype=torch.cfloat)
        # Write your code here: 
        # ---------------
        return U
    
    # Write a method to rotate a qubit state around X,Y and Z axes on a pi/2 angle
    # Qubit state is initially oriented along Z-axis 
    # Hint, rotation around X axis -> theta = pi/2, phi = 0; around Y axis -> theta = pi/2, phi = -pi/2;
    # around Z axis -> theta = 0, phi = 0
    # Current convention: base=0 -> X, base=1 -> Y, base=2 -> Z
    def rot_canonical(self,base):
        # Write your code here: 
        # ---------------
    
    
    # Write a method to obtain a state vector from a batch of bit strings  
    # v_in is a batch of bit strings of the shape (Number of base vectors, Number of qubits) 
    def state(self,v_in):
        # Make forward pass using v_in to obtain tensor of amplitudes and phases
        # Write your code here: 
        # ---------------
        
        # Normalize your amplitudes
        # Write your code here: 
        # --------------- 
        return
    
    # This method rotate state psi to the base
    # "base" is a tensor of 0,1 and 2, where 0,1 and 2 to define a base of a particular qubit
    def state_rot(self,psi,base):
        
        dimsq = tuple(2*np.ones(self.dim,dtype=np.long))
        psi_cur = torch.clone(psi).reshape(dimsq)
        dims = np.arange(self.dim)

        for i in range(self.dim):
            if (int(base[i])!=2):
                U = self.rot_canonical(int(base[i]))
                dims_cur = tuple(np.concatenate((np.array([i]),dims[:i],dims[i+1:])))
                psi_cur = psi_cur.permute(dims_cur).reshape(2,2**(self.dim-1))
                psi_cur = U@psi_cur
                psi_cur = psi_cur.reshape(dimsq)
                psi_cur = psi_cur.permute(tuple(np.argsort(dims_cur)))
        return psi_cur.reshape(2**self.dim)

    
    # Compute value of a negative loglikelihood for a particular base 
    def negloglik_base(self,psi,base,projs,freqs,eps = 1e-8):
        # Rotate psi using method "state_rot"
        # Write your code here: 
        psi_rot = 
        
        # Probabilities of the corresponding measured outcomes in a particular base
        probs = (torch.abs(psi_rot)**2)[projs.type(torch.long)] 
        
        # Compute value of the negative loglikelihood
        # Write your code here: 
        negloglik = 
        
        return negloglik
    
    def loss(self,psi):
        res = 0.
        for b in self.measurements.keys():
            base = self.measurements[b][0]
            projs = self.measurements[b][1]
            freqs = self.measurements[b][2]
            freqs = freqs/torch.sum(freqs)
            res += self.negloglik_base(psi,base,projs,freqs)
        return res

    # Your final goal is to complete this method
    # This method returns a list of loss values and a final state
    def evolution(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, betas=(0.9, 0.999))
        loss_ar = []
        
        count = 0
        # Use method "int10tobase_torch" to get the batch of input bit strings  
        # Write your code here: 
        v_in = 
        
        for i in tnrange(self.N_epochs):
            # Using v_in define psi
            # Write your code here:
            psi = 
            
            # Set the gradients to zero
            # Write your code here:
            optimizer.zero_grad()
            
            # Calculate current loss
            loss_cur = 
            
            # Use method .backward() to compute gradients
            # Write your code here:
    
            optimizer.step()
            loss_ar.append(loss_cur.detach().cpu().numpy())
            count+=1
            if (count==5):
                clear_output()
                plt.plot(loss_ar)
                plt.show()
                count = 0
        return loss_ar,psi