In [None]:
import logging
import sys
sys.path.append('/your_path/QIAL')

import torch
from torch import nn, optim

from layer.GENERAL_RAM import GENERAL_RAM

from utils.info import *

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average

from sklearn.preprocessing import normalize

from itertools import product
from utils.dataManager import *

from skimage.metrics import structural_similarity as ssim
from skimage import io

import warnings

from sklearn.metrics import accuracy_score, normalized_mutual_info_score
import time
import numpy as np
import threading

from skimage import io, img_as_float

import concurrent.futures
import copy
import pickle

import matplotlib.pyplot as plt
import numpy as np

import concurrent.futures

In [None]:
parser = set_params()

args = parser.parse_args()
args = parser.parse_args(args=['--model', 'GENERAL_RAM',
                               '--seed','500',
                               '--strategy','ALL',
                               '--data','FashionMNIST',                    
                               '--N_WAY','4',
                               '--N_TRAIN','100',
                               '--N_VALIDATE','20',
                               '--N_TEST','100',
                               '--N_SHOT','30',
                               '--N_ACTIVE','30',
                               '--N_ACTIVE_TIMES','30',
                               '--classes','0','1','2','3',
                               '--num_layers','4',
                               '--num_qubits','8',
                               '--learning_rate','1e-1'])



setup_seed(args)

In [None]:
images, labels = sample_data_qac(args)

images_train = torch.cat([data.reshape(1,-1) for data in images['train']], dim=0)
images_valid = torch.cat([data.reshape(1,-1) for data in images['validate']], dim=0)
images_test = torch.cat([data.reshape(1,-1) for data in images['test']], dim=0)

labels_train = torch.stack(labels['train'])
labels_valid = torch.stack(labels['validate'])
labels_test = torch.stack(labels['test'])

plot_images(images['train'], "query images", images_per_row=args.N_TRAIN)

X_train = torch.tensor(normalize(torch.tensor(images_train.view(args.N_TRAIN*args.N_WAY, -1), requires_grad=False),norm='l2'))
label_train = torch.tensor(labels_train, requires_grad=False)
Y_train= torch.tensor(torch.zeros([label_train.shape[0], args.N_WAY]), requires_grad=False)
for i in range(label_train.shape[0]):
    Y_train[i, int(label_train[i])]=1  

X_test =  torch.tensor(normalize(torch.tensor(images_test.view(args.N_TEST*args.N_WAY, -1), requires_grad=False),norm='l2'))
label_test = torch.tensor(labels_test, requires_grad=False)
Y_test= torch.tensor(torch.zeros([label_test.shape[0], args.N_WAY]), requires_grad=False)
for i in range(label_test.shape[0]):
    Y_test[i, int(label_test[i])]=1  

In [None]:
def SSIM_compute(images, args):
    N = len(images['train'])
    
    mat = torch.Tensor(N,N)
    for i in range(N):
        for j in range(N):
            img1 = img_as_float(images['train'][i]).reshape(args.N_SIZE,args.N_SIZE)
            img2 = img_as_float(images['train'][j]).reshape(args.N_SIZE,args.N_SIZE)
            ssim_value, _ = ssim(img1, img2, multichannel=False, full=True)
            mat[i, j] = ssim_value.item()
    
    return mat

AdjMtx = SSIM_compute(images, args)
topo_list = torch.sum(AdjMtx > 0.7, dim=1)

In [None]:
def semi_loss(predictions_FSL, Y_FSL):
    diff = predictions_FSL - Y_FSL
    error = torch.trace(torch.matmul(diff.T, diff))
    return error

def semi_loss_qcnn(predictions_FSL, predictions_CL1, predictions_CL2, Y_FSL, alpha):
    contrative_loss = CLoss_select('Distance')(predictions_CL1, predictions_CL2)
    diff = 1.0 - torch.sum(predictions_FSL[:, 0]==YSL)/ Y_FSL.shape[0]
    error = torch.trace(torch.matmul(diff.T, diff))
    loss = alpha*contrative_loss + (1-alpha) * error
    return loss

def cost(qnet, weights, X_FSL_FRQI, Y_FSL):
    pred = qnet(weights, X_FSL_FRQI)
    loss = semi_loss(qnet(weights, X_FSL_FRQI), Y_FSL)
    return qnet(weights, X_FSL_FRQI), loss
    
def active_learning(weights_init, args, quantum_neural_network, state_fsl, y_fsl, state_query, y_query, X_test_state, Y_test, AdjMtx):
    logger = get_logger(str(args.strategy)+'_'+str(args.data) + '_' + str(args.data_method) + '_' + str(args.N_SHOT) + '_' + str(args.num_layers) + '.log')
    accs = []
    weight_history = []
    weight_history.append(weights_init)
    
    for it_ac in range(args.N_ACTIVE_TIMES+1):
        weights = weights_init
        opt = torch.optim.Adam([weights], lr = args.learning_rate)
        for it in range(100):

            opt.zero_grad()
            train_res, loss = cost(quantum_neural_network.qnode_qnn, weights, state_fsl, y_fsl)
            loss.backward()
            opt.step()

            train_acc = accuracy(train_res, y_fsl)

        state_fsl, y_fsl, state_query, y_query, AdjMtx = add_sample(quantum_neural_network, weights, state_fsl, y_fsl, state_query, y_query, AdjMtx, args)
        
        res = quantum_neural_network.qnode_qnn(weights, X_test_state) 
        test_acc = accuracy(res, Y_test)
        
        logger.info('{}\t Epoch_AC:[{}/{}]\t TRAIN ACCURACY={:.6f}\t  TEST ACCURACY={:.6f}'.format(args.strategy, it_ac, args.N_ACTIVE_TIMES, train_acc, test_acc))
        
        print(torch.sum(y_fsl, dim=0))
        accs.append(test_acc)
        weight_history.append(weights)
    
    return accs, weight_history

In [None]:

zeros = 0.0 * torch.randn(X_train.shape[1], requires_grad=False)
noises = 0.0 * torch.randn(X_train.shape[1], requires_grad=False)

if args.model=='GENERAL_RAM':
    weights_init = torch.normal(mean=0.0, std=1, size=(args.num_layers,args.num_qubits,3),requires_grad=True)
    quantum_neural_network = GENERAL_RAM(args.num_qubits)
    
X_train_state = []
X_test_state = []

for i in range(X_train.shape[0]):
    X_train_state.append(quantum_neural_network.qnode_amplitude(X_train[i,:], zeros))

X_train_state = torch.stack(X_train_state)

for i in range(X_test.shape[0]):
    X_test_state.append(quantum_neural_network.qnode_amplitude(X_test[i,:], zeros))
    
X_test_state = torch.stack(X_test_state)


In [None]:
def trace_distance(s1,s2):
    tr_distances = []
    n_rows = s1.shape[0]
    
    for i in range(n_rows):
        rho1 = np.outer (np.conjugate(s1[i,:]), s1[i,:])
        rho2 = np.outer (np.conjugate(s2[i,:]), s2[i,:])
        difference = rho1 - rho2
        eigenvalues = np.linalg.eigvals(difference)
        tr_distance = np.sum(np.abs(eigenvalues)) / 2
        tr_distances.append(tr_distance)
    
    return tr_distances
    

dic_dists = {}
dic_accs = {}
dic_X_train_state = {}
dic_weights = {}

noise_weight = [0, 0.005, 0.01, 0.015, 0.02]
for w in noise_weight:
    noise = w * torch.randn(X_train.shape[1], requires_grad=False)
    X_train_state_noise = []
    for i in range(X_train.shape[0]):
        state_noise= quantum_neural_network.qnode_amplitude(X_train[i,:]+noise, zeros)
        X_train_state_noise.append(state_noise)

    X_train_state_noise = torch.stack(X_train_state_noise)
    
    dists =  trace_distance(X_train_state, X_train_state_noise)
    dic_dists[w] = sum(dists)/len(dists)

In [None]:
import math
categories = list(dic_dists.keys())
values = list(dic_dists.values())
x_labels = range(len(values))

colors = ['red', 'blue', 'green', 'black', 'orange']

bar_width = 0.003

plt.clf()

plt.bar(categories, values, width=bar_width, alpha=0.6, color=colors)


for i in range(len(values)):
    plt.text(x=categories[i], y=values[i]+0.005, s=f"{values[i]:.4f}", ha='center')

plt.axhline(y=1/math.e, color='r', linewidth=1, linestyle='--')

plt.title(args.data)
plt.xlabel('w')
plt.ylabel('Trace Distance')

plt.savefig(r'../Figure/QIAL_'+'all_'+str(args.N_ACTIVE)+'-'+str(args.seed)+'-'+str(args.num_layers)+'_noise_bar.pdf', format='pdf', dpi=300, bbox_inches='tight')


In [None]:
nums = np.array([it for it in range(args.N_ACTIVE_TIMES+1)] )

y_0 = dic_accs[noise_weight[0]]
y_1 = dic_accs[noise_weight[1]]
y_2 = dic_accs[noise_weight[2]]
y_3 = dic_accs[noise_weight[3]]
y_4 = dic_accs[noise_weight[4]]


fig, ax = plt.subplots()


line1 = ax.plot(nums, y_0, marker='o', linestyle='-', color='red', 
                markersize=2, linewidth=2, label='0.0')
line2 = ax.plot(nums, y_1, marker='o', linestyle='-', color='blue', 
                markersize=2, linewidth=2, label='0.005')
line3 = ax.plot(nums, y_2, marker='o', linestyle='-', color='green', 
                markersize=2, linewidth=2, label='0.01')
line4 = ax.plot(nums, y_3, marker='o', linestyle='-', color='black', 
                markersize=2, linewidth=2, label='0.015')
line5 = ax.plot(nums, y_4, marker='o', linestyle='-', color='orange', 
                markersize=2, linewidth=2, label='0.02')

ax.set_ylabel('Classification ACC ($\%$)', fontsize=16)
ax.set_xlabel('Active learning rounds', fontsize=16)
ax.set_title(args.data)

ax.set_ylim([0.6, 1.0])

legend = ax.legend()

plt.savefig(r'../Figure/QIAL_'+'all_'+str(args.N_ACTIVE)+'-'+str(args.seed)+'-'+str(args.num_layers)+'_noise.pdf', format='pdf', dpi=300, bbox_inches='tight')


In [None]:
variables = {
    'dists_list':dic_dists,
    'weight_history_quantum': dic_weights,
    'X_train': X_train,
    'X_test': X_test,
    'X_train_state': X_train_state,
    'X_test_state': X_test_state,
    'Y_train': Y_train,
    'Y_test': Y_test,
    'accs_quantum': dic_accs,
    'args':args
} 


file_path = (
    '../results/' + args.data + '/' +
    args.data + '_all' + '_'+str(args.seed)+'_noise.pkl'
)

with open(file_path, 'wb') as file:
    pickle.dump(variables, file)