In [None]:
from utilsADCN import mnistLoader, plotPerformance
from ADCNbasic import ADCN
from ADCNmainloop import ADCNmainMT
import numpy as np
import pdb
import torch
import random
from torchvision import datasets, transforms

In [None]:
# random seed control
np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

In [None]:
# convert data to torch.FloatTensor
transform = transforms.ToTensor()

# load the training and test datasets
labeledData   = datasets.MNIST(root='data', train=False,download=True, transform=transform)
unlabeledData = datasets.MNIST(root='data', train=True, download=True, transform=transform)

In [None]:
dataStream = mnistLoader(labeledData, unlabeledData, nEachClassSamples = 500)

In [None]:
device = torch.device('cuda:1')

In [None]:
dataStream.createTask(nTask = 4, taskList = [[0,30],[31,60],[61,90],[91,120]], taskType = 2)

In [None]:
allMetrics = []

In [None]:
nNodeInit  = 96  # 96  # 32
epoch      = 1
n_trials   = 5

In [None]:
for i_trial in range(0, n_trials):
    print('Trial: ', i_trial)
    ADCNnet = ADCN(dataStream.nOutput, nHiddenNode = nNodeInit)
    ADCNnet, performanceHistory0, allPerformance0 = ADCNmainMT(ADCNnet, dataStream, noOfEpoch = epoch, device = device)
    allMetrics.append(allPerformance0)

In [None]:
# all results

# 0: accuracy
# 1: all tasks accuracy
# 2: BWT
# 3: FWT
# 4: ARI
# 5: NMI
# 6: f1_score
# 7: precision_score
# 8: recall_score
# 9: training_time
# 10: testingTime
# 11: nHiddenLayer
# 12: nHiddenNode
# 13: nCluster
# 14: nMemory

meanResults = np.round_(np.mean(allMetrics,0), decimals=2)
stdResults  = np.round_(np.std(allMetrics,0), decimals=2)

print('\n')
print('========== Performance ==========')
print('Preq Accuracy: ', meanResults[0].item(), '(+/-)',stdResults[0].item())
print('All tasks accuracy: ', meanResults[1].item(), '(+/-)',stdResults[1].item())
print('BWT: ', meanResults[2].item(), '(+/-)',stdResults[2].item())
print('FWT: ', meanResults[3].item(), '(+/-)',stdResults[3].item())
print('ARI: ', meanResults[4].item(), '(+/-)',stdResults[4].item())
print('NMI: ', meanResults[5].item(), '(+/-)',stdResults[5].item())
print('F1 score: ', meanResults[6].item(), '(+/-)',stdResults[6].item())
print('Precision: ', meanResults[7].item(), '(+/-)',stdResults[7].item())
print('Recall: ', meanResults[8].item(), '(+/-)',stdResults[8].item())
print('Training time: ', meanResults[9].item(), '(+/-)',stdResults[9].item())
print('Testing time: ', meanResults[10].item(), '(+/-)',stdResults[10].item())

print('\n')
print('========== Network ==========')
print('Number of hidden layers: ', meanResults[11].item(), '(+/-)',stdResults[11].item())
print('Number of features: ', meanResults[12].item(), '(+/-)',stdResults[12].item())
print('Number of clusters: ', meanResults[13].item(), '(+/-)',stdResults[13].item())