In [1]:
import torch
import numpy as np

In [2]:
import torchhd
import HierGraph

In [3]:
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

In [4]:
#function for calculating performance metrics

def perf_measure(y_actual, y_hat):
    '''
    Compute precision, recall, and F1 score between predicted and actual labels.
    '''
    # Generate a classification report with precision, recall, and F1 score
    # for each class. Store the report as a dictionary.
    report = classification_report(y_actual,y_hat, output_dict=True)
    precision = report['macro avg']['precision']
    recall = report['macro avg']['recall']
    f1 = report['macro avg']['f1-score']

    # Return the precision, recall, and F1 score as a tuple
    return precision, recall, f1

In [5]:
np.random.seed(42)
def generate_random_data(shape):
    return np.random.random(shape)

def generate_sinusoidal_data(shape):
    time = np.linspace(0, 2 * np.pi, shape[0]*shape[1]*shape[2])
    return np.sin(time).reshape(shape)

def generate_consecutive_data(shape):
    total_elements = np.prod(shape)
    return np.arange(total_elements).reshape(shape)

# Parameters
num_samples = 500
num_time_step = 10
shape = (num_samples, 64, num_time_step)


# Generate data using the three functions
data_random = generate_random_data(shape)
data_sinusoidal = generate_sinusoidal_data(shape)
data_consecutive = generate_consecutive_data(shape)

# Create labels
labels_random = np.zeros(num_samples, dtype=int)
labels_sinusoidal = np.ones(num_samples, dtype=int)
labels_consecutive = np.full(num_samples, 2, dtype=int)

# Combine data and labels
data = np.concatenate((data_random, data_sinusoidal, data_consecutive), axis=0)
labels = np.concatenate((labels_random, labels_sinusoidal, labels_consecutive), axis=0)

print(data.shape)
print(labels.shape)


(1500, 64, 10)
(1500,)


In [6]:
channel_mat = np.array([6,6,6,5,5,6,3,5,6,4,3,3,6]) #num of channels per parameter
parameter_adj_mat = np.array([[0, 1, 2, 3, 4, 5], #which channel corresponds to which parameter
                         [6, 7, 8, 9, 10, 11],
                         [12, 13, 14, 15, 16, 17],
                         [18, 19, 20, 21, 22],
                         [23, 24, 25, 26, 27],
                         [28, 29, 30, 31, 32, 33],
                         [34, 35, 36],
                         [37, 38, 39, 40, 41],
                         [42, 43, 44, 45, 46, 47],
                         [48, 49, 50, 51],
                         [52, 53, 54],
                         [55, 56, 57],
                         [58, 59, 60, 61, 62, 63]])


  parameter_adj_mat = np.array([[0, 1, 2, 3, 4, 5], #which channel corresponds to which parameter


In [7]:
num_classes = np.unique(labels).size
parameters={'dim': 1000, 'alpha': 20.0, 'lr': .5, 'epoch': 100, 'T': 0.1}

In [8]:
x_train, x_test, y_train, y_test = train_test_split(data, labels, stratify=labels, test_size = .2) #stratify=yy

y_test = torch.from_numpy(y_test).long()

test = HierGraph.hiergraph(num_classes,channel_mat,parameter_adj_mat,num_time_step,embedding_type='density',dim=parameters['dim'],VSA='MAP')
test = test.fit(x_train,y_train,lr=parameters['lr'],alpha=parameters['alpha'],epochs=parameters['epoch'],T=parameters['T'],iter=1)

sim_parameter = torchhd.cos(test.parameter_hv,test.class_hv).numpy()
sim_channel = torchhd.cos(test.channel_hv,test.class_hv).numpy()
sim_channel_para = torchhd.cos(test.channel_hv,test.parameter_hv).numpy()

y_hat = test(x_test)
acc_test_all = np.array((y_test == y_hat).float().mean())
cur_precision_all, cur_recall_all, cur_f1_all = perf_measure(y_test,y_hat)

In [9]:
print(acc_test_all)

0.88


In [10]:
print(classification_report(y_test,y_hat))

              precision    recall  f1-score   support

           0       0.95      0.70      0.80       100
           1       0.76      0.94      0.84       100
           2       0.98      1.00      0.99       100

    accuracy                           0.88       300
   macro avg       0.89      0.88      0.88       300
weighted avg       0.89      0.88      0.88       300

