In [None]:
from utilsADCN import dataLoader
from ADCNbasic import ADCN
from ADCNmainloop import ADCNmain
from model import simpleMPL
import numpy as np
import pdb
import torch
import random
from torchvision import datasets, transforms

In [None]:
# random seed control
np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

In [None]:
dataStream = dataLoader('./data/hyperplane2.mat')

In [None]:
device = torch.device('cuda:0')

In [None]:
nHidNodeExtractor = dataStream.nInput*4
nExtractedFeature = dataStream.nInput*4
nFeaturClustering = dataStream.nInput*2

In [None]:
allMetrics = []

In [None]:
n_trials   = 5

In [None]:
for i_trial in range(0, n_trials):
    print('Trial: ', i_trial)
    ADCNnet         = ADCN(dataStream.nOutput, nInput = nExtractedFeature, nHiddenNode = nFeaturClustering)
    ADCNnet.ADCNcnn = simpleMPL(dataStream.nInput, nNodes = nHidNodeExtractor, nOutput = nExtractedFeature)
    ADCNnet.desiredLabels = [0,1]
    ADCNnet, performanceHistory, allPerformance = ADCNmain(ADCNnet, dataStream, device = device)
    allMetrics.append(allPerformance)

In [None]:
# all results

# 0: accuracy
# 1: ARI
# 2: NMI
# 3: f1_score
# 4: precision_score
# 5: recall_score
# 6: training_time
# 7: testingTime
# 8: nHiddenLayer
# 9: nHiddenNode
# 10: nCluster

meanResults = np.round_(np.mean(allMetrics,0), decimals=2)
stdResults  = np.round_(np.std(allMetrics,0), decimals=2)

print('\n')
print('========== Performance SEA ==========')
print('Preq Accuracy: ', meanResults[0].item(), '(+/-)',stdResults[0].item())
print('ARI: ', meanResults[1].item(), '(+/-)',stdResults[1].item())
print('NMI: ', meanResults[2].item(), '(+/-)',stdResults[2].item())
print('F1 score: ', meanResults[3].item(), '(+/-)',stdResults[3].item())
print('Precision: ', meanResults[4].item(), '(+/-)',stdResults[4].item())
print('Recall: ', meanResults[5].item(), '(+/-)',stdResults[5].item())
print('Training time: ', meanResults[6].item(), '(+/-)',stdResults[6].item())
print('Testing time: ', meanResults[7].item(), '(+/-)',stdResults[7].item())

print('\n')
print('========== Network ==========')
print('Number of hidden layers: ', meanResults[8].item(), '(+/-)',stdResults[8].item())
print('Number of features: ', meanResults[9].item(), '(+/-)',stdResults[9].item())
print('Number of clusters: ', meanResults[10].item(), '(+/-)',stdResults[10].item())