In [29]:
from inspector import Metric
from inspector.utility import plot_distmat

from netrep.metrics.stochastic import EnergyStochasticMetric
from netrep.multiset import pairwise_distances

import numpy as np
import sklearn.datasets

In [30]:
class NetRep(Metric):
    '''Base class for NetRep wrapper classes'''

    def __init__(self):
        super().__init__()

    def score(self, verbose=False):
        '''Compute the netrep score'''

        metric = EnergyStochasticMetric()
        distmat_energy = pairwise_distances(metric, self.data['digested_networks'], verbose=verbose)

        self.data['distance_matrix'] = distmat_energy

        return distmat_energy

In [42]:
class ESM(NetRep):
    '''Energy Stochastic Metric'''

    def __init__(self):
        super().__init__()

    def digest(self):
        '''Preprocess data for RSA analysis
        They are segrigated by class (by using ground truth y), then by sample, then by feature
        example shapes (Xi, Xj): ((4, 225, 2), (4, 225, 2))
        '''

        n_netowrks = len(self.data['X'])

        # find the lowest number of samples of y class in X
        # this is the number of samples we will use for each class
        min_samples_per_class = float('inf')
        for i in range(n_netowrks):
            y_t = self.data['y'][i]
            min_samples_per_class = min([len(y_t[y_t==c]) for c in np.unique(y_t)])

            # replace the min_samples_per_class if the current network has a lower number of samples
            if min_samples_per_class < min_samples_per_class:
                min_samples_per_class = min_samples_per_class

        print(f'found the minimum samples to be: {min_samples_per_class} samples per class')

        all_digested_networks = []

        for i in range(n_netowrks):

            # get the data for the ith network
            X = self.data['X'][i]
            y = self.data['y'][i]

            n_classes = len(np.unique(y))

            digested_x = [X[y==i] for i in range(n_classes)] # Meaning: digested_x is a list of the data for each class
            
            print(len(digested_x))
            print(f"digested_x shape: {digested_x.shape}")

            digested_x = np.stack([x[:min_samples_per_class] for x in digested_x], 0)

            all_digested_networks.append(digested_x)
        
        print(f"succesfully digested {n_netowrks} networks to {n_classes} classes with {len(digested_x[0])} samples each")
        
        self.data['digested_networks'] = all_digested_networks

        # TO DO LATER: return a report of the preprocessing

        return "TBD Digestion Report"

    def plot(self):
        '''Plot the RSA score'''
        plot_distmat(self.data['distance_matrix'])
        pass

In [74]:

class SyntheticNetworks():
    def __init__(self) -> None:
        pass

    def generate(self, networks = 3, n_samples = 500, n_classes = 4, n_features = 2):
        # seed = 42069
        # rng = np.random.default_rng(seed)

        X , y = [], []

        for i in range(networks):
            # Generate a new dataset
            _X, _y = sklearn.datasets.make_classification(
                n_samples=n_samples, 
                n_features=n_features, 
                n_informative=n_features, 
                n_redundant=0, 
                n_repeated=0, 
                n_classes=n_classes, 
                n_clusters_per_class=1, # Meaning: each class is a single cluster
                # random_state=seed,
            )
            X.append(_X)
            y.append(_y)

        return X, y

In [75]:
# create 10 synthetic networks
network_generator = SyntheticNetworks()
networks_X, networks_y = network_generator.generate(networks=10)

# create an RSA object
netrep_metric = ESM()
ingestion_report = netrep_metric.ingest(networks_X, networks_y)
digestion_report = netrep_metric.digest()
energy_distance = netrep_metric.score(verbose=True)

netrep_metric.plot()

found the minimum samples to be: 122 samples per class
4


AttributeError: 'list' object has no attribute 'shape'

In [76]:
network_generator = SyntheticNetworks()
networks_X, networks_y = network_generator.generate(networks=1)

In [77]:
X1, Y1 = networks_X[0], networks_y[0]

In [78]:
X1.shape, Y1.shape

((500, 2), (500,))

In [79]:
import numpy as np

# Assume Y1 and Y2 are arrays of shape (1000,) containing class labels
num_classes = len(np.unique(Y1))  # assuming Y1 and Y2 have the same number of classes

# Create an empty array of shape (num_classes, 0, 2) to store the reshaped X1
X1_reshaped = np.empty((num_classes, 0, 2))

# Loop over each class and extract the corresponding samples from X1
for class_label in range(num_classes):
    samples_in_class = np.where(Y1 == class_label)[0]
    X1_class = X1[samples_in_class, :]
    X1_class = np.expand_dims(X1_class, axis=0)
    
    # Create a temporary array with the correct shape
    new_X1_reshaped = np.empty((num_classes, X1_reshaped.shape[1]+X1_class.shape[1], 2))
    new_X1_reshaped[:, :X1_reshaped.shape[1], :] = X1_reshaped
    new_X1_reshaped[:, X1_reshaped.shape[1]:, :] = X1_class
    X1_reshaped = new_X1_reshaped

# The resulting shape of X1_reshaped will be (num_classes, num_samples_per_class, 2)

In [80]:
X1_reshaped[1]

array([[-1.49277669e+00, -7.38383417e-01],
       [ 3.20359070e-01, -1.75861368e+00],
       [ 6.41662751e-03, -1.04350578e+00],
       [-1.12526249e+00, -1.37466872e+00],
       [ 3.11555538e-01, -1.21288405e+00],
       [-3.56200421e-01, -9.63472569e-01],
       [-9.80047597e-01, -1.79482427e+00],
       [-5.58407574e-01, -5.29697257e-01],
       [ 9.47106487e-01,  1.19984722e+00],
       [-1.18729211e+00, -3.84768602e-01],
       [ 1.68303098e-01, -9.81170536e-01],
       [-2.68983786e+00, -1.12080930e+00],
       [-1.82265291e+00, -6.69900957e-01],
       [-2.33139866e+00, -1.41046626e+00],
       [-1.24761701e+00, -1.26003532e+00],
       [-1.55841789e+00, -3.59710162e-01],
       [-1.30880868e+00, -1.90963968e+00],
       [-1.49499484e+00, -1.22881833e+00],
       [-2.10015535e+00, -1.13163491e+00],
       [-1.30541435e+00, -4.79428814e-01],
       [-2.89384689e-01, -6.87374099e-01],
       [ 1.83188597e+00, -1.15880311e+00],
       [-2.35672505e+00, -1.16971509e+00],
       [ 8.