In [None]:
import logging
import sys
sys.path.append('/your_path/QAIL')

import torch
from torch import nn, optim

from layer.GENERAL_RAM import GENERAL_RAM

from utils.info import *

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average

from sklearn.preprocessing import normalize

from itertools import product
from utils.dataManager import *

from skimage.metrics import structural_similarity as ssim
from skimage import io

import warnings
warnings.filterwarnings('ignore', category=Warning)

In [None]:
parser = set_params()

args = parser.parse_args(args=['--model', 'GENERAL_RAM',
                               '--seed','500',
                               '--strategy','ALL',
                               '--data','FashionMNIST',                    
                               '--N_WAY','4',
                               '--N_TRAIN','100',
                               '--N_VALIDATE','20',
                               '--N_TEST','100',
                               '--N_SHOT','30',
                               '--N_ACTIVE','2',
                               '--N_ACTIVE_TIMES','30',
                               '--classes','0','1','2','3',
                               '--num_layers','4',
                               '--num_qubits','8',
                               '--learning_rate','1e-1'])

print(args)

setup_seed(args)

In [None]:
images, labels = sample_data_qac(args)

images_train = torch.cat([data.reshape(1,-1) for data in images['train']], dim=0)
images_valid = torch.cat([data.reshape(1,-1) for data in images['validate']], dim=0)
images_test = torch.cat([data.reshape(1,-1) for data in images['test']], dim=0)

labels_train = torch.stack(labels['train'])
labels_valid = torch.stack(labels['validate'])
labels_test = torch.stack(labels['test'])

plot_images(images['train'], "query images", images_per_row=args.N_TRAIN)

print(images_train.shape)
print(images_valid.shape)
print(images_test.shape)
    

X_train = torch.tensor(normalize(torch.tensor(images_train.view(args.N_TRAIN*args.N_WAY, -1), requires_grad=False),norm='l2'))
label_train = torch.tensor(labels_train, requires_grad=False)
Y_train= torch.tensor(torch.zeros([label_train.shape[0], args.N_WAY]), requires_grad=False)
for i in range(label_train.shape[0]):
    Y_train[i, int(label_train[i])]=1  
    
X_test =  torch.tensor(normalize(torch.tensor(images_test.view(args.N_TEST*args.N_WAY, -1), requires_grad=False),norm='l2'))
label_test = torch.tensor(labels_test, requires_grad=False)
Y_test= torch.tensor(torch.zeros([label_test.shape[0], args.N_WAY]), requires_grad=False)
for i in range(label_test.shape[0]):
    Y_test[i, int(label_test[i])]=1  

In [None]:
from skimage import io, img_as_float

def SSIM_compute(images, args):
    N = len(images['train'])
    
    mat = torch.Tensor(N,N)
    for i in range(N):
        for j in range(N):
            img1 = img_as_float(images['train'][i]).reshape(args.N_SIZE,args.N_SIZE)
            img2 = img_as_float(images['train'][j]).reshape(args.N_SIZE,args.N_SIZE)
            ssim_value, _ = ssim(img1, img2, multichannel=False, full=True)
            mat[i, j] = ssim_value.item()
    
    return mat

AdjMtx = SSIM_compute(images, args)
topo_list = torch.sum(AdjMtx > 0.7, dim=1)

In [None]:
def semi_loss(predictions_FSL, Y_FSL):
    diff = predictions_FSL - Y_FSL
    error = torch.trace(torch.matmul(diff.T, diff))
    return error

def semi_loss_qcnn(predictions_FSL, predictions_CL1, predictions_CL2, Y_FSL, alpha):
    contrative_loss = CLoss_select('Distance')(predictions_CL1, predictions_CL2)
    diff = 1.0 - torch.sum(predictions_FSL[:, 0]==YSL)/ Y_FSL.shape[0]
    error = torch.trace(torch.matmul(diff.T, diff))
    loss = alpha*contrative_loss + (1-alpha) * error
    return loss

def cost(qnet, weights, X_FSL_FRQI, Y_FSL):
    loss = 0
    pred = qnet(weights, X_FSL_FRQI)
    loss = semi_loss(qnet(weights, X_FSL_FRQI), Y_FSL)
    return qnet(weights, X_FSL_FRQI), loss

In [None]:
zeros = 0.0 * torch.randn(X_train.shape[1], requires_grad=False)

if args.model=='GENERAL_RAM':
    weights_init = torch.normal(mean=0.0, std=1, size=(args.num_layers,args.num_qubits,3),requires_grad=True)
    quantum_neural_network = GENERAL_RAM(args.num_qubits)
    
X_train_state = []
X_test_state = []

for i in range(X_train.shape[0]):
    X_train_state.append(quantum_neural_network.qnode_amplitude(X_train[i,:], zeros))

X_train_state = torch.stack(X_train_state)


for i in range(X_test.shape[0]):
    X_test_state.append(quantum_neural_network.qnode_amplitude(X_test[i,:], zeros))
    
X_test_state = torch.stack(X_test_state)
images_fsl, labels_fsl, state_fsl, y_fsl, images_query, labels_query, state_query, y_query, AdjMtx = split_fsl_query_random(X_train_state, Y_train, images_train, labels_train, AdjMtx, args)

In [None]:
from sklearn.metrics import accuracy_score, normalized_mutual_info_score
import time
import numpy as np
import threading
    
def active_learning(weights_init, args, quantum_neural_network, state_fsl, y_fsl, state_query, y_query, X_test_state, Y_test, AdjMtx):
    logger = get_logger(str(args.strategy)+'_'+str(args.data) + '_' + str(args.data_method) + '_' + str(args.N_SHOT) + '_' + str(args.num_layers) + '.log')
    accs = []
    weight_history = []

    start = time.time()

    weight_history.append(weights_init)
    
    for it_ac in range(args.N_ACTIVE_TIMES+1):
        iter_start_ac = time.time()
        weights = weights_init
        opt = torch.optim.Adam([weights], lr = args.learning_rate)
        for it in range(100):
            iter_start = time.time()

            opt.zero_grad()
            train_res, loss = cost(quantum_neural_network.qnode_qnn, weights, state_fsl, y_fsl)
            loss.backward()
            opt.step()

            train_acc = accuracy(train_res, y_fsl)

        
        state_fsl, y_fsl, state_query, y_query, AdjMtx = add_sample(quantum_neural_network, weights, state_fsl, y_fsl, state_query, y_query, AdjMtx, args)
        
        res = quantum_neural_network.qnode_qnn(weights, X_test_state) 
        test_acc = accuracy(res, Y_test)
        
        iter_end_ac = time.time()
        logger.info('{}\t Epoch_AC:[{}/{}]\t TRAIN ACCURACY={:.6f}\t  TEST ACCURACY={:.6f} \t time={:.3f}'.format(args.strategy, it_ac, args.N_ACTIVE_TIMES, train_acc, test_acc, iter_end_ac-iter_start_ac))
        
        print(torch.sum(y_fsl, dim=0))
        accs.append(test_acc)
        weight_history.append(weights)

    all_end = time.time()
    print('epoch_time='+str(all_end-start))
    
    print(state_fsl.shape)
    print(y_fsl.shape)
    print(state_query.shape)
    print(y_query.shape)
    
    return accs, weight_history

In [None]:
import concurrent.futures
import copy

args_rand = copy.deepcopy(args)
args_rand.strategy = 'RAND'
accs_rand, weight_history_rand = active_learning(weights_init, args_rand, quantum_neural_network, state_fsl, y_fsl, state_query, y_query, X_test_state, Y_test, AdjMtx)

In [None]:
args_entro = copy.deepcopy(args)
args_entro.strategy = 'ENTRO'
accs_entro, weight_history_entro = active_learning(weights_init, args_entro, quantum_neural_network, state_fsl, y_fsl, state_query, y_query, X_test_state, Y_test, AdjMtx)

In [None]:
args_quantum = copy.deepcopy(args)
args_quantum.strategy = 'QUANTUM'
accs_quantum, weight_history_quantum = active_learning(weights_init, args_quantum, quantum_neural_network, state_fsl, y_fsl, state_query, y_query, X_test_state, Y_test, AdjMtx)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

nums = np.array([it for it in range(args.N_ACTIVE_TIMES+1)] )
y_quantum = accs_quantum
y_entro = accs_entro
y_rand = accs_rand

fig, ax = plt.subplots()

line1 = ax.plot(nums, y_quantum, marker='o', linestyle='-', color='red', markersize=2, linewidth=2, label='QUANTUM')
line2 = ax.plot(nums, y_entro, marker='x', linestyle='--', color='blue', markersize=2, linewidth=2, label='ENTROPY')
line3 = ax.plot(nums, y_rand, marker='s', linestyle=':', color='green', markersize=2, linewidth=2, label='RANDOM')


ax.set_ylabel('Classification ACC ($\%$)', fontsize=16)
ax.set_xlabel('Active learning rounds', fontsize=16)
ax.set_title(args.data)

ax.set_ylim([0.5, 0.97])

legend = ax.legend()

plt.savefig(r'../Figure/QIAL_'+'all_'+str(args.N_ACTIVE)+'-'+str(args.seed)+'-'+str(args.num_layers)+'.pdf', format='pdf', dpi=300, bbox_inches='tight')


In [None]:
import pickle 

variables = {
    'weight_history_quantum': weight_history_quantum,
    'weight_history_entro': weight_history_entro,
    'weight_history_rand': weight_history_rand,
    'X_train': X_train,
    'X_test': X_test,
    'X_train_state': X_train_state,
    'X_test_state': X_test_state,
    'Y_train': Y_train,
    'Y_test': Y_test,
    'accs_quantum': accs_quantum,
    'accs_entro': accs_entro,
    'accs_rand': accs_rand,
    'args':args
} 

file_path = (
    '../results/' + args.data + '/' +
    args.data + '_all' + '_'+str(args.seed)+'.pkl'
)

with open(file_path, 'wb') as file:
    pickle.dump(variables, file)