In [None]:
import logging
import sys
sys.path.append('/your_path/QIAL')

import torch
from torch import nn, optim

from layer.GENERAL_RAM import GENERAL_RAM
from layer.MCNN3x2_RAM import MCNN3x2_RAM

from utils.info import *

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average

from sklearn.preprocessing import normalize

from itertools import product
from utils.dataManager import *

from skimage.metrics import structural_similarity as ssim
from skimage import io

import warnings

import copy
import concurrent.futures

In [None]:
parser = set_params()

args = parser.parse_args(args=['--model', 'GENERAL_RAM',
                               '--seed','500',
                               '--strategy','ALL',
                               '--data','FashionMNIST',                    
                               '--N_WAY','4',
                               '--N_TRAIN','100',
                               '--N_VALIDATE','20',
                               '--N_TEST','100',
                               '--N_SHOT','30',
                               '--N_ACTIVE','30',
                               '--N_ACTIVE_TIMES','0',
                               '--classes','0','1','2','3',
                               '--num_layers','4',
                               '--num_qubits','8',
                               '--learning_rate','1e-1'])

setup_seed(args)

In [None]:
images, labels = sample_data_qac(args)

images_train = torch.cat([data.reshape(1,-1) for data in images['train']], dim=0)
images_valid = torch.cat([data.reshape(1,-1) for data in images['validate']], dim=0)
images_test = torch.cat([data.reshape(1,-1) for data in images['test']], dim=0)

labels_train = torch.stack(labels['train'])
labels_valid = torch.stack(labels['validate'])
labels_test = torch.stack(labels['test'])

plot_images(images['train'], "query images", images_per_row=args.N_TRAIN)

print(images_train.shape)
print(images_valid.shape)
print(images_test.shape)
    

X_train = torch.tensor(normalize(torch.tensor(images_train.view(args.N_TRAIN*args.N_WAY, -1), requires_grad=False),norm='l2'))
label_train = torch.tensor(labels_train, requires_grad=False)
Y_train= torch.tensor(torch.zeros([label_train.shape[0], args.N_WAY]), requires_grad=False)
for i in range(label_train.shape[0]):
    Y_train[i, int(label_train[i])]=1  
    
X_valid =  torch.tensor(normalize(torch.tensor(images_valid.view(args.N_VALIDATE*args.N_WAY, -1), requires_grad=False),norm='l2'))
label_valid = torch.tensor(labels_valid, requires_grad=False)
Y_valid= torch.tensor(torch.zeros([label_valid.shape[0], args.N_WAY]), requires_grad=False)
for i in range(label_valid.shape[0]):
    Y_valid[i, int(label_valid[i])]=1 

X_test =  torch.tensor(normalize(torch.tensor(images_test.view(args.N_TEST*args.N_WAY, -1), requires_grad=False),norm='l2'))
label_test = torch.tensor(labels_test, requires_grad=False)
Y_test= torch.tensor(torch.zeros([label_test.shape[0], args.N_WAY]), requires_grad=False)
for i in range(label_test.shape[0]):
    Y_test[i, int(label_test[i])]=1  

In [None]:
from skimage import io, img_as_float

def SSIM_compute(images, args):
    N = len(images['train'])
    
    mat = torch.Tensor(N,N)
    for i in range(N):
        for j in range(N):
            img1 = img_as_float(images['train'][i]).reshape(args.N_SIZE,args.N_SIZE)
            img2 = img_as_float(images['train'][j]).reshape(args.N_SIZE,args.N_SIZE)
            ssim_value, _ = ssim(img1, img2, multichannel=False, full=True)
            mat[i, j] = ssim_value.item()
    
    return mat

AdjMtx = SSIM_compute(images, args)
topo_list = torch.sum(AdjMtx > 0.7, dim=1)

In [None]:
def semi_loss(predictions_FSL, Y_FSL):
    diff = predictions_FSL - Y_FSL
    error = torch.trace(torch.matmul(diff.T, diff))
    return error

def semi_loss_qcnn(predictions_FSL, predictions_CL1, predictions_CL2, Y_FSL, alpha):
    contrative_loss = CLoss_select('Distance')(predictions_CL1, predictions_CL2)
    diff = 1.0 - torch.sum(predictions_FSL[:, 0]==YSL)/ Y_FSL.shape[0]
    error = torch.trace(torch.matmul(diff.T, diff))
    loss = alpha*contrative_loss + (1-alpha) * error
    return loss

def cost(qnet, weights, X_FSL_FRQI, Y_FSL):
    pred = qnet(weights, X_FSL_FRQI)
    loss = semi_loss(qnet(weights, X_FSL_FRQI), Y_FSL)
    return qnet(weights, X_FSL_FRQI), loss

from sklearn.metrics import accuracy_score, normalized_mutual_info_score
import numpy as np
import threading
    
def active_learning(weights_init, args, quantum_neural_network, state_fsl, y_fsl, state_query, y_query, X_test_state, Y_test, AdjMtx):
    logger = get_logger(str(args.strategy)+'_'+str(args.data) + '_' + str(args.data_method) + '_' + str(args.N_SHOT) + '_' + str(args.num_layers) + '.log')
    accs = []
    weight_history = []

    weight_history.append(weights_init)
    for it_ac in range(args.N_ACTIVE_TIMES+1):
        weights = weights_init
        opt = torch.optim.Adam([weights], lr = args.learning_rate)
        for it in range(100):
            opt.zero_grad()
            train_res, loss = cost(quantum_neural_network.qnode_qnn, weights, state_fsl, y_fsl)
            loss.backward()
            opt.step()
            train_acc = accuracy(train_res, y_fsl)

        args.strategy = 'ENTRO'
        state_fsl_, y_fsl_, state_query_, y_query_, AdjMtx_, shannon_entropy = return_entropy(quantum_neural_network, weights, state_fsl, y_fsl, state_query, y_query, AdjMtx, args)
        args.strategy = 'QUANTUM'
        state_fsl_, y_fsl_, state_query_, y_query_, AdjMtx_, von_entropy = return_entropy(quantum_neural_network, weights, state_fsl, y_fsl, state_query, y_query, AdjMtx, args)        

        res = quantum_neural_network.qnode_qnn(weights, X_test_state) 
        test_acc = accuracy(res, Y_test)
        
        logger.info('{}\t Epoch_AC:[{}/{}]\t TRAIN ACCURACY={:.6f}\t  TEST ACCURACY={:.6f}'.format(args.strategy, it_ac, args.N_ACTIVE_TIMES, train_acc, test_acc))
        
        accs.append(test_acc)
        weight_history.append(weights)

    return accs, weight_history, shannon_entropy, von_entropy

In [None]:
zeros = 0.0 * torch.randn(X_train.shape[1], requires_grad=False)

if args.model=='GENERAL_RAM':
    weights_init = torch.normal(mean=0.0, std=1, size=(args.num_layers,args.num_qubits,3),requires_grad=True)
    quantum_neural_network = GENERAL_RAM(args.num_qubits)
    
X_train_state = []
X_test_state = []

for i in range(X_train.shape[0]):
    X_train_state.append(quantum_neural_network.qnode_amplitude(X_train[i,:], zeros))

X_train_state = torch.stack(X_train_state)

for i in range(X_test.shape[0]):
    X_test_state.append(quantum_neural_network.qnode_amplitude(X_test[i,:], zeros))
    
X_test_state = torch.stack(X_test_state)
images_fsl, labels_fsl, state_fsl, y_fsl, images_query, labels_query, state_query, y_query, AdjMtx_ = split_fsl_query_random(X_train_state, Y_train, images_train, labels_train, AdjMtx, args)

In [None]:
args_quantum = copy.deepcopy(args)
args.strategy = 'QUANTUM'
accs_quantum, weight_history_quantum,shannon_entropy, von_entropy = active_learning(weights_init, args_quantum, quantum_neural_network, state_fsl, y_fsl, state_query, y_query, X_test_state, Y_test, AdjMtx_)

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

shannon_entropy_np = shannon_entropy.numpy()[:100]
von_entropy_np = von_entropy.detach().numpy()[:100]

bar_width = 0.35

index = np.arange(len(shannon_entropy_np))


offset = bar_width / 2

fig, ax = plt.subplots()

ax.bar(index, shannon_entropy_np, bar_width, label='Shannon', alpha=0.6)
ax.bar(index + bar_width, von_entropy_np, bar_width, label='Von-Neumann', alpha=0.6)

ax.plot(index + offset, shannon_entropy_np, 'o-', label='Trend of Shannon')
ax.plot(index + bar_width + offset, von_entropy_np, 's-', label='Trend of Von-Neumann')


ax.set_xlabel('Samples')
ax.set_ylabel('Value')

ax.set_title(args.data)

ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=2)

plt.show()
