In [None]:
#Compatible versions
!pip install 'jax == 0.4.3' 
!pip install 'jaxlib == 0.4.3'
!pip install 'pennylane == 0.29.1' #latest stable version as of 7th March 2023

In [None]:
import matplotlib.pyplot as plt 
from tqdm.auto import tqdm 
from joblib import Parallel, delayed
import jax.numpy as jnp 
import pennylane as qml 
import numpy as np 
import time
from datetime import datetime 
from io import StringIO 
import pandas as pd 
import os 
import jax 
from jax import vmap, grad, jit 
from jax import random 
from jax.scipy.special import logsumexp
from jax.config import config 
config.update("jax_enable_x64", True) 

In [None]:
#Imports from repo 
from data.datareader import datareader 
from data.datahandler import datahandler 
from .qconfig import *

In [None]:
#TODO: See if this can be improved - e.g. embedded in class
#def unwrap_self(image, c, j, i, qubits, ksize, nlayers, circuit):. 
  #return QCNN.qconv2D2(image, c, j, i, qubits, ksize, filters, nlayers, seed, circuit)


Try to generate a neural network (fully connected for now) utilizing only jax (and jit) inside the QCNN class - for the trainig phase: 

In [None]:
class QCNN: 

  ''' Quantum Convolutional NN class. 
    This class implements: 
        - 1 quantum convolutional layer along te y-axis - default: Ry 
        - n classical convolutional layers 
        - n classical dense layers.  
    The dataset is meant to be composed by RGB data. 
    The quanvolutional layer extracts the feature maps obtained by means of quantum processing.  
''' 
  def __init__(self, qubits, filters, kernel_size, stride, img_shape, n_classes, circuits='ry', parallelize=0, nlayers=1, seed=0, name=None) 
    
    #Set up quantum layer params - for image processing
    self.qubits = qubits
    self.filters = filters 
    self.kernel_size = kernel_size 
    self.stride = stride 
    self.img_shape = img_shape 
    self.n_classes = n_classes 
    self.circuits = circuits
    self.parallelize = parallelize 
    self.nlayers = nlayers 
    self.seed = seed
    self.name = name 

    if self.name == 'None': self.name == 'QCNN' 
    #Set up training parsmeters with confug file - conv and/or dends layers
    self.loss = qcnnv2s['loss'] 
    self.metrics = qcnnv2s['metrics']
    self.learning_rate = qcnnv2s['learning_rate'] 
    self.dropout = qcnnv2s['dropout'] 
    self.batch_size = qcnnv2s['batch_size'] 
    self.epochs = qcnnv2s['epochs'] 
    self.es_rounds = qcnnv2s['early_stopping'] 
    self.dense = qcnnv2s['dense'] #vector of n neurons for each dense layer 
    self.conv = qcnnv2s['conv'] #vector of n filters for each convolutional layer 
    self.convolutional_kernel_size = qcnnv2s['kernel'] 
    self.convolution_stride = qcnnv2s['stride']
    self.avg_pool_size = qcnnv2s['pool_size'] 
    self.avg_pool_stride = qcnnv2s['pool_stride'] 

  def apply(self, image, verbose=True): 
    results = []
    if self.parallelize == 0: 
      results. append(self.__qconv2D(image, verbose))

    results = np.moveaxis(results, 0, -1) 
    s = np.shape(results) 
    return np.reshape(results, (s[0], s[1], s[-2]*s[-1]))

  @staticmethod
  def quanvolutional_layer(self, image, verbose): #input: rgb data: output: feature maps. Computed b.m.o. quantum conv layer along ry 
    #non-parallelized a.t.m. 
    h, w, ch = image.shape 
    h_out = (h - self.kernel_size) // self.stride + 1) 
    w_out = (w - self.kernel_size) // self.stride + 1) 
    out = np.zeros((h_out, w_out, ch, self.filters)) 

    ctx = 0 
    cty = 0
    for c in tqdm(range(ch), desc='Channel', disable=not(verbose), colour='black'): 
      for j in tqdm(range(0, h-self.kernel_size, self.stride), desc='Column', leave=False, disable=not(verbose), colour='black'):
        for i in tqdm(range(0, w-self.kernel_size, self.stride), desc='Row',leave=False, disable=not(verbose), colour='black'): 
          p = image[j_j+self.kerne_size, i:i+self.kernel_stride, c] 
          if self.circuits == 'ry': 
            q_results = ry_random(jnp.array(p.reshape(-1)), self.qubits, self.kernel_size, self.filters, self.nlayers, self.seed) 
          elif self.circuits == 'rx': 
            q_results = rx_random(jnp.array(p.reshape(-1)), self.qubits, self.kernel_size, self.filters, self.nlayers, self.seed) 
          elif self.circuits == 'rz': 
            q_results = rz_random(jnp.array(p.reshape(-1)), self.qubits, self.kernel_size, self.filters, self.nlayers, self.seed) 
          else: 
            q_results = rxyz_custom(jnp.array(p.reshape(-1)), self.qubits, self.kernel_size, self.filters, self.nlayers, self.seed) 
          q_results = np.array(q_results)

          for k in range(self.filters): 
            out[cty, ctx, c, k] = q_results[k] 

          ctx += 1 
        ctx = 0
        cty += 1
      ctx = 0
      cty = 0   
    out = np.mean(out, -2, keepdims=False)   

    return out 

  #function to randomly initialize weights and biases
  def random_layer_params(m, n, key, scale=1e-2): 
    w_key, b_key = random.split(key) 
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,)) 

  def init_network_params(sizes, key): 
    keys = random.split(key, len(sizes)) 
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)] 

  def relu(x): 
    return jnp.maximum(0, x) 

  def predict(params, image): 
    #per-image preds 
    activations = image 
    for w, b in params[:-1]: 
      outputs = jnp.dot(w, activations) + b 
      activations = relu(outputs) 

    final_w, final_b = params[-1] 
    logits = jnp.dot(final_w, activations) + final_b 
    return logits 

  def batched_predict(predict): 
    batched_predicts = vmap(predict, in_axes=(None, 0))
    return batched_predicts 

  def one_hot(x, k, dtype=jnp.float32): 
    return jnp.array(x:[:, None] == jnp.arange(k), dtype) 

  def accuracy(params, images, targets): 
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1) 
    return jnp.mean(predicted_class == target_class) 

  def loss(params, images, targets): 
    preds = batched_predict(params, images) 
    return jnp.mean(preds * targets) 

  @jit 
  def update(params, x, y): 
    grads = grad(loss)(params, x, y) 
    return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
  
  def get_train_batches(): 
    #TODO: get batches as feature maps 

  def train():
    for epoch in range(self.epochs): 
      start_time = time.time() 
      for x, y in get_train_batches(): 
        x = jnp.reshape(x, (len(x), num_pixels)) 
        y = one_hot(y, num_labels) 
        params = update(params, x, y) 
      epoch_time = time.time() - start_time 
      train_acc = accuracy(params, train_images, train_labels) 
      test_acc = accuracy(params, test_images, test_labels) 
      print("epoch {} in {:0.2f} sec".format(epoch, epoch_time)) 
      print("training set accuracy {}".format(train_acc)) 
      print("test set accuracy {}".format(test_acc))



    
