In [15]:
from utilsADCN import cifarLoaderAllChannels, plotPerformance
from ADCNbasic import ADCN
from ADCNmainloop import ADCNmain
from model import ConvAeCIFAR
import numpy as np
import pdb
import torch
import random
from torchvision import datasets, transforms

In [16]:
# random seed control
np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

In [17]:
# convert data to torch.FloatTensor
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5,0.5,0.5))])
transform = transforms.ToTensor()

# load the training and test datasets
labeledData   = datasets.CIFAR10(root='data', train=False,download=True, transform=transform)
unlabeledData = datasets.CIFAR10(root='data', train=True, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [18]:
dataStream = cifarLoaderAllChannels(labeledData, unlabeledData, nEachClassSamples = 500)

Number of output:  10
Number of labeled data:  5000
Number of unlabeled data:  55000
Number of unlabeled data batch:  55


In [19]:
dataStream.createDrift(nDrift = 4, taskList = [[0,5],[6,10],[11,15],[0,15]], taskType = 2)

In [20]:
# Check for cude or mps or cpu
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')


In [21]:
allMetrics = []

In [22]:
nNodeInit  = 96  
nIn        = 768  
n_trials   = 5
batch_size = 16 #

In [23]:
for i_trial in range(0, n_trials):
    print('Trial: ', i_trial)
    ADCNnet         = ADCN(dataStream.nOutput, nInput = nIn, nHiddenNode = nNodeInit)
    ADCNnet.ADCNcnn = ConvAeCIFAR()
    ADCNnet, performanceHistory, allPerformance = ADCNmain(ADCNnet, dataStream, trainingBatchSize = batch_size, device = device)
    allMetrics.append(allPerformance)

Trial:  0
Network initialization phase is started
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node +++
+++ Grow node 

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/python_arg_parser.cpp:1630.)
  loss.add_(self.regStrClusteringLoss/2,clusteringLoss(latentAE, oneHotClusters, centroids))


Accuracy:  26.9
1 -th batch
A cut point is detected cut:  500
Status: STABLE
2 -th batch
A cut point is detected cut:  500
Status: STABLE
3 -th batch
A cut point is detected cut:  500
Status: STABLE
4 -th batch
A cut point is detected cut:  500
Status: STABLE
5 -th batch
A cut point is detected cut:  500
Status: STABLE
6 -th batch
A cut point is detected cut:  500
Status: STABLE
7 -th batch
A cut point is detected cut:  500
Status: STABLE
8 -th batch
A cut point is detected cut:  500
Status: STABLE
9 -th batch
A cut point is detected cut:  500
Status: STABLE
10 -th batch
A cut point is detected cut:  500
Status: STABLE
Accuracy:  26.636363636363637
11 -th batch
A cut point is detected cut:  500
Status: STABLE
12 -th batch
A cut point is detected cut:  500
Status: STABLE
13 -th batch
A cut point is detected cut:  500
Status: STABLE
14 -th batch
A cut point is detected cut:  500
Status: STABLE
15 -th batch
A cut point is detected cut:  500
Status: STABLE
16 -th batch
A cut point is detec

KeyboardInterrupt: 

In [None]:
# all results

# 0: accuracy
# 1: ARI
# 2: NMI
# 3: f1_score
# 4: precision_score
# 5: recall_score
# 6: training_time
# 7: testingTime
# 8: nHiddenLayer
# 9: nHiddenNode
# 10: nCluster

meanResults = np.round_(np.mean(allMetrics,0), decimals=2)
stdResults  = np.round_(np.std(allMetrics,0), decimals=2)

print('\n')
print('========== Performance CIFAR10 ==========')
print('Preq Accuracy: ', meanResults[0].item(), '(+/-)',stdResults[0].item())
print('ARI: ', meanResults[1].item(), '(+/-)',stdResults[1].item())
print('NMI: ', meanResults[2].item(), '(+/-)',stdResults[2].item())
print('F1 score: ', meanResults[3].item(), '(+/-)',stdResults[3].item())
print('Precision: ', meanResults[4].item(), '(+/-)',stdResults[4].item())
print('Recall: ', meanResults[5].item(), '(+/-)',stdResults[5].item())
print('Training time: ', meanResults[6].item(), '(+/-)',stdResults[6].item())
print('Testing time: ', meanResults[7].item(), '(+/-)',stdResults[7].item())

print('\n')
print('========== Network ==========')
print('Number of hidden layers: ', meanResults[8].item(), '(+/-)',stdResults[8].item())
print('Number of features: ', meanResults[9].item(), '(+/-)',stdResults[9].item())
print('Number of clusters: ', meanResults[10].item(), '(+/-)',stdResults[10].item())



Preq Accuracy:  26.85 (+/-) 0.88
ARI:  0.05 (+/-) 0.01
NMI:  0.08 (+/-) 0.01
F1 score:  0.26 (+/-) 0.01
Precision:  0.27 (+/-) 0.01
Recall:  0.27 (+/-) 0.01
Training time:  669.47 (+/-) 125.35
Testing time:  1.0 (+/-) 0.2


Number of hidden layers:  1.0 (+/-) 0.0
Number of features:  352.0 (+/-) 97.31
Number of clusters:  2739.4 (+/-) 109.91
