In [None]:
# python 3.9
import numpy as np
import QCL_simu_on_18bits
from Adam import adam_log
import random
import pickle
from matplotlib import pyplot as plt
# %matplotlib qt

In [None]:
# !pip install PyQt6

In [None]:
import os,sys
file_path_now = os.getcwd()
print(file_path_now)
plt.rcParams['font.family'] = 'Arial'

In [None]:
cml = QCL_simu_on_18bits.cml(save_paht=file_path_now)

In [None]:
def mult_learning(
        cml,
        iter_idx=0,
        max_iter=200,
        ask_stop=False,
        test=True,
        batch_size=25,
        test_num=100,
        tasks=['FashionMNIST_09','medical','spt'],
        theta2data=[2,2,2],
        lr=0.05,
    ):
        tasks = tasks if isinstance(tasks, list) else [tasks]
        theta2data = theta2data if isinstance(theta2data,
                                              list) else [theta2data]
        test_num = cml.x_train[tasks[0]].shape[0] if test_num is None else test_num
        cml.adam_log = adam_log( file_path_now+"/files/adam_log.pkl",
                                lr=lr)
        random.seed(1)
        losses={}
        accs={}
        for t in tasks:
            losses[t] = []
            accs[t] = []
        while iter_idx <= max_iter:
            cml.save_params()
            for t in range(len(tasks)):
                amp = True if tasks[t].startswith('spt') else False
                random_perm=True if tasks[t]=='medical'  else False
                
                cml.train_idx = random.sample(
                    range(0, len(cml.x_train[tasks[t]])), batch_size)
                
                loss, accuracy_rate = cml.loss_and_accuracy(
                    cml.params,
                    theta2data[t],
                    amp=amp,
                    random_perm=random_perm,
                    task=tasks[t],
                    )


            if test:
                for t in range(len(tasks)):
                    amp = True if tasks[t].startswith('spt') else False
                    random_perm=True if tasks[t]=='medical'  else False
                    loss, accuracy_rate = cml.loss_and_accuracy(
                        cml.params,
                        theta2data[t],
                        amp=amp,
                        random_perm=random_perm,
                        test=True,
                        task=tasks[t],
                        )
                    losses[tasks[t]].append(loss)
                    accs[tasks[t]].append(accuracy_rate)
                    print(f"{'iter_idx:':<9}{iter_idx:<5} {'task:':<5}{tasks[t]:<16} {'loss:':<5}{loss:.4f} {'accuracy:':<9}{accuracy_rate:.2%}")
            gs=0
            for t in range(len(tasks)):
                amp = True if tasks[t].startswith('spt') else False
                random_perm=True if tasks[t]=='medical'  else False
                gs += cml.gradient(
                                params=cml.params,
                                task=tasks[t],
                                theta2data=theta2data[-1],
                                amp=amp,
                                random_perm=random_perm,
                                noise_level=0.0,)
            gs = gs/len(tasks)
            delta_theta = cml.adam_log.log(gs, "", iter_idx)
            iter_idx += 1
            cml.update_params(delta_theta)
            if ask_stop:
                stop = input("stop?[y/n]")
                if stop == "y":
                    break
            with open(file_path_now+'/files/Results.pkl', 'wb') as f:
                pickle.dump({'losses':losses, 'accs':accs}, f)
        return  losses, accs

In [None]:
tasks=['FashionMNIST_09','medical','spt']
batch_size = 25
lr = 0.05
max_epoch = 2
test_num = 100
losses,accs=mult_learning(cml=cml,max_iter=max_epoch,batch_size=batch_size,lr=lr,tasks=tasks,test_num=test_num)

In [None]:

PAPERFIGURE = np.array([180, 247]) / 25.4  
fontsize = 7
save = True
cs=['#29af7f','#396db1','#c44072']
marks = ['o','s','*']
tasks = ['FashionMNIST_09','medical','spt']

with open(file_path_now+'/files/Results.pkl', 'rb') as f:
        data=pickle.load( f)
losses = data['losses']
accs = data['accs']
steps = np.arange(len(accs[tasks[0]]))
#plot accuracy rate 
fig = plt.figure(figsize=(PAPERFIGURE[0]/2,PAPERFIGURE[1]/4))
ax = fig.add_axes([0.12, 0.15, 0.85, 0.82])
markersize= [2,1.5,3]
for task in tasks:
    ax.plot(steps,np.array(accs[task])[steps], \
        label=fr'$\mathcal{{T}}_\mathcal{{{tasks.index(task)+1}}}$',\
            linewidth =0.8,markersize=markersize[tasks.index(task)], c=cs[tasks.index(task)],marker=marks[tasks.index(task)])
ax.set_xlabel('Epochs',fontsize=fontsize)
ax.set_ylabel('Accuracy',fontsize=fontsize)
# plt.title('multiple_learning',fontsize=7)
ax.legend(fontsize=fontsize)
ax.tick_params(labelsize=fontsize)
# plt.tight_layout(pad=0.15,w_pad=0.5,h_pad=0.5)
# plt.grid()
# plt.show()
if save:
    file_name =file_path_now+f'/multiple_learning.pdf'
    plt.savefig(
        file_name,
        dpi=300,
        facecolor='w',
        edgecolor='w',
        orientation='portrait',
        format='pdf',
    )
