In [None]:
from time import time
import logging
import os
import os.path as osp
import numpy as np
import time

import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import BA2MotifDataset
from torch_geometric.data import DataLoader
from torch_geometric.utils import degree
from torch.autograd import Variable

import random
from torch.optim.lr_scheduler import StepLR


from utils import stat_graph, split_class_graphs, align_graphs
from utils import two_graphons_mixup, universal_svd
from graphon_estimator import universal_svd
from models import GIN,GCN
from tensorboardX import SummaryWriter

import argparse
logdir=''
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(levelname)s: - %(message)s', datefmt='%Y-%m-%d')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tensorboard_writer = SummaryWriter(log_dir=logdir)



def prepare_dataset_x(dataset):
    if dataset[0].x is None:
        max_degree = 0
        degs = []
        for data in dataset:
            degs += [degree(data.edge_index[0], dtype=torch.long)]
            max_degree = max( max_degree, degs[-1].max().item() )
            data.num_nodes = int( torch.max(data.edge_index) ) + 1

        if max_degree < 2000:
            # dataset.transform = T.OneHotDegree(max_degree)

            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = F.one_hot(degs, num_classes=max_degree+1).to(torch.float)
        else:
            deg = torch.cat(degs, dim=0).to(torch.float)
            mean, std = deg.mean().item(), deg.std().item()
            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = ( (degs - mean) / std ).view( -1, 1 )
    return dataset



def prepare_dataset_onehot_y(dataset):

    y_set = set()
    for data in dataset:
        y_set.add(int(data.y))
    num_classes = len(y_set)

    for data in dataset:
        data.y = F.one_hot(data.y, num_classes=num_classes).to(torch.float)[0]
    return dataset


def mixup_cross_entropy_loss(input, target, size_average=True):
    """Origin: https://github.com/moskomule/mixup.pytorch
    in PyTorch's cross entropy, targets are expected to be labels
    so to predict probabilities this loss is needed
    suppose q is the target and p is the input
    loss(p, q) = -\sum_i q_i \log p_i
    """
    assert input.size() == target.size()
    assert isinstance(input, Variable) and isinstance(target, Variable)
    loss = - torch.sum(input * target)
    return loss / input.size()[0] if size_average else loss




def train(model, train_loader):
    model.train()
    loss_all = 0
    graph_all = 0
    for data in train_loader:
        # print( "data.y", data.y )
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x, data.edge_index, data.batch)
        y = data.y.view(-1, num_classes)
        loss = mixup_cross_entropy_loss(output, y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        graph_all += data.num_graphs
        optimizer.step()
    loss = loss_all / graph_all
    return model, loss


def test(model, loader):
    model.eval()
    correct = 0
    total = 0
    loss = 0
    for data in loader:
        data = data.to(device)
        _,output = model(data.x, data.edge_index, data.batch)
        pred = output.max(dim=1)[1]
        y = data.y.view(-1, num_classes)
        loss += mixup_cross_entropy_loss(output, y).item() * data.num_graphs
        y = y.max(dim=1)[1]
        correct += pred.eq(y).sum().item()
        total += data.num_graphs
    acc = correct / total
    loss = loss / total
    return acc, loss

originaldataset=TUDataset(root="data",name='MUTAG')
dataset=list(originaldataset)
random.shuffle(dataset)
for graph in dataset:
        graph.y = graph.y.view(-1)

dataset = prepare_dataset_onehot_y(dataset)
dataset = prepare_dataset_x( dataset )

train_nums = int(len(dataset) * 0.8)
train_val_nums = int(len(dataset) * 0.9)

train_dataset = dataset[:train_nums]
val_dataset = dataset[train_nums:train_val_nums]
test_dataset = dataset[train_val_nums:]
batch_size=32
learning_rate=0.01
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
num_features = dataset[0].x.shape[1]
num_classes = dataset[0].y.shape[0] 

print("Num features",num_features)
print("num_classes",num_classes)
model = GIN(num_features=num_features, num_classes=num_classes, num_hidden=64).to(device)
    


optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
scheduler = StepLR(optimizer, step_size=100, gamma=0.5)




In [None]:
print(len(train_dataset))

In [None]:
from torch_geometric.utils import to_networkx
import matplotlib.colors as colors
import matplotlib.cm as cmx
import networkx as nx
import matplotlib.pyplot as plt
import networkx as nx
def plotmutag(data):
  cmap = colors.ListedColormap(['blue', 'black','red','yellow','orange','green','purple'])
  ColorLegend = {'Carbon': 0,'Nitrogen': 1,'Oxygen': 2,'Fluorine': 3,'Iodine':4,'Chlorine':5,'Bromine':6}
  cNorm  = colors.Normalize(vmin=0, vmax=6)
  scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cmap)
  #print(cmap.colors)

  exampledata=data
  exfeatures=exampledata.x
  #exlabel=exampledata.y
  examplelabels=torch.argmax(exfeatures,dim=1)
  #print(exlabel)
  examplegraph=to_networkx(exampledata,to_undirected=True)
  examplegraph.remove_edges_from(nx.selfloop_edges(examplegraph))
  for component in list(nx.connected_components(examplegraph)):
    if len(component)<7:
        for node in component:
            examplegraph.remove_node(node)
  f = plt.figure(2,figsize=(8,8))
  ax = f.add_subplot(1,1,1)
  for label in ColorLegend:
      ax.plot([0],[0],color=scalarMap.to_rgba(ColorLegend[label]),label=label)
  nx.draw_networkx(examplegraph, node_size=150,node_color=examplelabels,cmap=cmap,vmin=0,vmax=6,with_labels=False,ax=ax)
  plt.legend(fontsize=12,loc='best')
  plt.show()



In [None]:
def graph_draw(self, graph):
    attr = nx.get_node_attributes(graph, "label")
    labels = {}
    color = ''
    for n in attr:
        labels[n]= self.dict[attr[n]]
        color = color+ self.color[attr[n]]
        
    #   labels=dict((n,) for n in attr)
    nx.draw(graph,labels=labels, node_color='red')

In [None]:
dataset=list(originaldataset)
torch.manual_seed(12345)
random.shuffle(dataset)

train_dataset = dataset[:150]
test_dataset = dataset[:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data.x)
    print()

In [None]:
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class GCNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCNEncoder, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(7, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.bn = torch.nn.BatchNorm1d(hidden_channels)
        self.dropout = torch.nn.Dropout(0.5)
        self.leaky_relu = torch.nn.LeakyReLU(0.2)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = self.leaky_relu(x)
        x = self.conv2(x, edge_index)
        x = self.leaky_relu(x)
        x = self.conv3(x, edge_index)

        x = global_mean_pool(x, batch)
        x = self.bn(x)
        x = F.dropout(x, p=0.5, training=self.training)
        return x

class LinearClassifier(torch.nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearClassifier, self).__init__()
        self.linear = Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(x)
class CombinedModel(torch.nn.Module):
    def __init__(self, hidden_channels, num_classes):
        super(CombinedModel, self).__init__()
        self.encoder = GCNEncoder(hidden_channels)
        self.classifier = LinearClassifier(input_dim=hidden_channels, num_classes=num_classes)

    def forward(self, x, edge_index, batch):
        # Get the embeddings from the encoder
        embeddings = self.encoder(x, edge_index, batch)

        # Get the logits from the classifier
        logits = self.classifier(embeddings)

        return embeddings, logits


In [None]:
import torch
from torch.nn import Linear, Dropout
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, BatchNorm
from torch_geometric.nn import JumpingKnowledge

class GCNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCNEncoder, self).__init__()
        torch.manual_seed(12345)
        # Adjust the number of input features (7) if necessary
        self.conv1 = GCNConv(7, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels, heads=2, concat=False)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.residual = GCNConv(7, hidden_channels)
        
        # Update BatchNorm to handle concatenated features
        self.bn = BatchNorm(hidden_channels * 3)
        
        self.dropout = Dropout(0.5)
        self.leaky_relu = torch.nn.LeakyReLU(0.2)
        self.jk = JumpingKnowledge(mode='cat')

    def forward(self, x, edge_index, batch):
        residual = self.residual(x, edge_index)
        
        x1 = self.conv1(x, edge_index)
        x1 = self.leaky_relu(x1)
        
        x2 = self.conv2(x1, edge_index)
        x2 = self.leaky_relu(x2)

        x3 = self.conv3(x2, edge_index) + residual

        # Concatenate features from all layers
        x = self.jk([x1, x2, x3])
        x = global_mean_pool(x, batch)
        x = self.bn(x)  # Normalize the concatenated features
        x = self.dropout(x)
        return x

class LinearClassifier(torch.nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearClassifier, self).__init__()
        self.linear = Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(x)

class CombinedModel(torch.nn.Module):
    def __init__(self, hidden_channels, num_classes):
        super(CombinedModel, self).__init__()
        self.encoder = GCNEncoder(hidden_channels)
        self.classifier = LinearClassifier(input_dim=hidden_channels * 3, num_classes=num_classes)

    def forward(self, x, edge_index, batch):
        # Get the embeddings from the encoder
        embeddings = self.encoder(x, edge_index, batch)

        # Get the logits from the classifier
        logits = self.classifier(embeddings)

        return embeddings, logits

# Example usage
# hidden_channels = 32  # Choose a suitable hidden size
# num_classes = 2  # Number of classes in MUTAG

# model = CombinedModel(hidden_channels=hidden_channels, num_classes=num_classes)

# # Example input data (replace with actual data)
# x = torch.rand((num_nodes, 7))  # Features
# edge_index = torch.tensor([[0, 1], [1, 0]])  # Edge index (replace with actual edges)
# batch = torch.tensor([0, 0])  # Batch tensor

# output = model(x, edge_index, batch)


In [None]:
model=CombinedModel(hidden_channels=64,num_classes=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

# Add a learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

def train():
    model.train()

    for epoch in range(2000):
        for data in train_loader:  # Iterate in batches over the training dataset.
            embedding,  out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
            #print(out)
            loss = criterion(out, data.y)  # Compute the loss.
            loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.
            optimizer.zero_grad()  # Clear gradients.

        # Update the learning rate scheduler
        scheduler.step()

        # Print the current learning rate every epoch (optional)
        print(f"Epoch {epoch + 1}/{num_epochs}, Learning Rate: {scheduler.get_last_lr()[0]}",loss)
        # train_acc = test(train_loader)
        # test_acc = test(test_loader)
        # print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

# Set the number of epochs
num_epochs = 700

# Call the training loop
train()



In [None]:
def test(model,dataset):
    model.eval()
    acc=0
    for data in dataset:
        emb,out=model(data.x,data.edge_index,data.batch)
        if(out.argmax(dim=1)==data.y):
            acc+=1
    return acc/len(dataset)
acc=test(model,dataset)
print(acc)

In [None]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

def plot_confusion_matrix(model, dataset, class_dict):
    """
    Evaluate the model on the provided dataset, compute the confusion matrix,
    and plot it with class names.

    Parameters:
    - model: Trained GNN model
    - dataset: List of data objects
    - class_dict: Dictionary mapping class labels to class names, e.g., {0: 'Class A', 1: 'Class B'}
    """

    # Step 1: Evaluate the model and get predictions and true labels
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for data in dataset:
            _, out = model(data.x, data.edge_index, data.batch)
            pred = out.argmax(dim=1)
            all_preds.append(pred.cpu().numpy())
            all_labels.append(data.y.cpu().numpy())

    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    # Step 2: Compute the confusion matrix
    conf_matrix = confusion_matrix(all_labels, all_preds)

    # Step 3: Plot the confusion matrix
    class_names = [class_dict[i] for i in range(len(class_dict))]
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", 
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

# Example usage:
# Assuming the class labels are {0: 'Mutagenic', 1: 'Non-Mutagenic'}
#class_dict = {0: 'Mutagenic', 1: 'Non-Mutagenic'}

# Example dataset (assuming it's a list of data objects)
# dataset = [...]

# Call the function with the model, dataset (as a list), and class dictionary
#plot_confusion_matrix(model, dataset, class_dict)


In [None]:
class_dict={0:'Non-Mutagenic',1:'Mutagenic'}
plot_confusion_matrix(model,dataset,class_dict)

In [None]:
for epoch in range(1, 200):
    model, train_loss = train(model, train_loader)
    train_acc = 0
    val_acc, val_loss = test(model, val_loader)
    test_acc, test_loss = test(model, test_loader)
    scheduler.step()
    print("Train loss",train_loss, "Epoch",epoch)
    print("Test Accuracy",test_acc,"Test_loss",test_loss)
    tensorboard_writer.add_scalar('Train Loss', train_loss, epoch)
    tensorboard_writer.add_scalar('Validation Loss', val_loss, epoch)
    tensorboard_writer.add_scalar('Test Loss', test_loss, epoch)
    tensorboard_writer.add_scalar('Validation Accuracy', val_acc, epoch)
    tensorboard_writer.add_scalar('Test Accuracy', test_acc, epoch)

    logger.info('Epoch: {:03d}, Train Loss: {:.6f}, Val Loss: {:.6f}, Test Loss: {:.6f},  Val Acc: {: .6f}, Test Acc: {: .6f}'.format(
        epoch, train_loss, val_loss, test_loss, val_acc, test_acc))

In [None]:
test_acc,_=test(model,test_loader)
print(test_acc)

In [None]:
 save_path='model/mutag.pth'
 torch.save(model.state_dict(), save_path)

In [None]:
load_path = 'model/mutag.pth'



# Initialize the model architecture

# Load the saved model weights
model.load_state_dict(torch.load(load_path))
model.eval()

In [None]:
model.eval()
dataset2=list(originaldataset)
classifieraccuracy=0
random.shuffle(dataset2)
# for graph in dataset2:
#      graph.y = graph.y.view(-1)

# dataset = prepare_dataset_onehot_y(dataset2)
dataset2 = prepare_dataset_x( dataset2 )
num_features = dataset2[0].x.shape[1]
num_classes = dataset2[0].y.shape[0]
print(num_features)
print(num_classes)
#explain_loader= DataLoader(dataset2[:30], batch_size=1, shuffle=True)
newdataset=[]
latentdata1=[]
latentdata2=[]
model.to('cpu')
for data in dataset2: 
    #data=data.to(device)
    emb,output = model(data.x, data.edge_index,data.batch)
    #print("Output is",output)
    #print(output)
    #output=output.to("cpu")

    pred = output.argmax(dim=1)
    if(pred==data.y):
        classifieraccuracy+=1
    #print(pred)
    if (pred==0):
        data.y=torch.zeros_like(data.y)
        latentdata1.append(emb)
        #newdataset.append(data)
        
    if (pred==1):
        data.y=torch.ones_like(data.y)
        latentdata2.append(emb)
   
    newdataset.append(data)
    #print("pred is",pred)
    #y = data.y.view(-1, num_classes)

In [None]:
dataset=newdataset
classgraphs=split_class_graphs(dataset)

avg_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density = stat_graph(dataset)
resolution = int(median_num_nodes)
#print("resolution is",resolution)
graphons=[]
for label,graphs in classgraphs:
    #print("Label is",label)
    #print("graph is",graphs[0])
    align_graphs_list, normalized_node_degrees, max_num, min_num = align_graphs(
                    graphs, padding=True, N=resolution)
    #print("Aligned adj",align_graphs_list[8].shape,align_graphs_list[56].shape)
    graphon = universal_svd(align_graphs_list, threshold=0.2)
    #print("Graphon is ",graphon.shape)

    graphons.append((label, graphon))
#two_graphons = random.sample(graphons, 2)
print(graphons)
two_graphons= [graphons[0] , graphons[1]]
print(graphons[0][0], graphons[1][0])
new_graph = two_graphons_mixup(two_graphons, la=1.0, num_sample=1,show=True)



# ng=two_graphons_mixup(two_graphons,la=1.0,num_sample=1)
# print(new_graph)
# print(ng)


Plotting the graphons

In [None]:

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from torch_geometric.nn import GCNConv,GINConv
from torch.distributions import Bernoulli,Categorical
import matplotlib.cm as cmxplt
#print(graphons[1][1])
maxval=0.028
plt.figure(1)
plt.axis('off')
print(graphons[0][0])
plt.imshow(graphons[0][1], cmap="inferno")
plt.figure(2)
print(graphons[1][0])
plt.axis('off')
plt.imshow(graphons[1][1], cmap="inferno")

Ground truth examples

In [None]:
dataset=newdataset
classgraphs=split_class_graphs(dataset)

avg_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density = stat_graph(dataset)
resolution = int(median_num_nodes)
#print("resolution is",resolution)
graphons=[]
for label,graphs in classgraphs:
    #print("Label is",label)
    #print("graph is",graphs[0])
    align_graphs_list, normalized_node_degrees, max_num, min_num = align_graphs(
                    graphs, padding=True, N=resolution)
    #print("Aligned adj",align_graphs_list[8].shape,align_graphs_list[56].shape)
    graphon = universal_svd(align_graphs_list, threshold=0.2)
    #print("Graphon is ",graphon.shape)

    graphons.append((label, graphon))
#two_graphons = random.sample(graphons, 2)
print(graphons)
two_graphons= [graphons[0] , graphons[1]]
print(graphons[0][0], graphons[1][0])
new_graph = two_graphons_mixup(two_graphons, la=1.0, num_sample=1,show=True)



# ng=two_graphons_mixup(two_graphons,la=1.0,num_sample=1)
# print(new_graph)
# print(ng)


In [None]:
from collections import Counter
import numpy as np
import copy
from typing import List, Tuple

def align_x_graphscorrected2(
    graphs: List[np.ndarray], 
    node_x: List[np.ndarray], 
    padding: bool = False, 
    N: int = None
) -> Tuple[List[np.ndarray], np.ndarray, List[np.ndarray], int, int]:
    """
    Align multiple graphs by sorting their nodes by descending node degrees
    and perform max pooling over aligned node features.

    :param graphs: A list of binary adjacency matrices
    :param node_x: A list of node feature matrices (one-hot encoded)
    :param padding: Whether to pad graphs to the same size
    :param N: Target number of nodes for alignment (if specified)
    :return:
        aligned_graphs: A list of aligned adjacency matrices
        final_node_features: The pooled node feature matrix
        normalized_node_degrees: A list of sorted normalized node degrees (as node distributions)
        max_num: Maximum number of nodes after alignment
        min_num: Minimum number of nodes before alignment
    """
    num_nodes = [graph.shape[0] for graph in graphs]
    max_num = max(num_nodes)
    min_num = min(num_nodes)

    if N is None:
        N = max_num  # Use the maximum number of nodes if N is not specified

    aligned_graphs = []
    aligned_node_features = []  # Accumulate node features
    normalized_node_degrees = []
    
    for i in range(len(graphs)):
        num_i = graphs[i].shape[0]

        node_degree = 0.5 * np.sum(graphs[i], axis=0) + 0.5 * np.sum(graphs[i], axis=1)
        node_degree /= np.sum(node_degree)

        idx = np.argsort(node_degree)[::-1]  # Sort indices by descending node degree

        sorted_node_degree = node_degree[idx].reshape(-1, 1)
        sorted_graph = graphs[i][np.ix_(idx, idx)]

        sorted_node_x = node_x[i][idx]

        if padding:
            normalized_node_degree = np.zeros((N, 1))
            normalized_node_degree[:min(N, num_i), :] = sorted_node_degree[:min(N, num_i)]

            aligned_graph = np.zeros((N, N))
            aligned_graph[:min(N, num_i), :min(N, num_i)] = sorted_graph[:min(N, num_i), :min(N, num_i)]

            aligned_node_x = np.zeros((N, node_x[i].shape[1]))
            aligned_node_x[:min(N, num_i), :] = sorted_node_x[:min(N, num_i)]
        else:
            normalized_node_degree = sorted_node_degree[:N]
            aligned_graph = sorted_graph[:N, :N]
            aligned_node_x = sorted_node_x[:N]

        aligned_graphs.append(aligned_graph)
        aligned_node_features.append(aligned_node_x)  # Add aligned node features to the list
        normalized_node_degrees.append(normalized_node_degree)

    # Max pooling over the aligned node features
    num_features = aligned_node_features[0].shape[1]
    pooled_node_features = np.zeros((N, num_features))

    for node_pos in range(N):
        # Gather features at this node position across all graphs
        features_at_pos = [
            features[node_pos] 
            for features in aligned_node_features 
            if node_pos < features.shape[0]
        ]
        
        # Count frequency of each feature
        feature_counts = Counter(map(tuple, features_at_pos))

        # Sort features by frequency, prioritizing non-zero features
        most_common_features = sorted(
            feature_counts.items(), 
            key=lambda item: (-item[1], item[0] == (0,) * num_features)
        )

        # Select the most common non-zero feature
        for feature, _ in most_common_features:
            if feature != (0,) * num_features:  # Ignore all-zero vector
                pooled_node_features[node_pos] = feature
                break

    return aligned_graphs, pooled_node_features, normalized_node_degrees, N, min_num


In [None]:
from utils import stat_graph, split_class_x_graphs, align_x_graphs,align_x_graphscorrected
from utils import two_x_graphons_mixup, universal_svd
classgraphs=split_class_x_graphs(dataset)
Alignedfeatures=[]
#print("Classgraphs is", classgraphs)


avg_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density = stat_graph(dataset)
resolution = int(median_num_nodes)+6
print("resolution is",resolution)
graphons=[]
#print(classgraphs)
for label,graphs,nodes in classgraphs:
    print(len(graphs))
    print("Label is",label)

    print("graph is",np.shape(graphs[0]))

    print("Nodes is",np.shape(nodes[0]))



    align_graphs_list,alignx, normalized_node_degrees, max_num, min_num = align_x_graphscorrected2(
                    graphs,nodes, padding=True, N=resolution)
    
    #print("Aligned adj",align_graphs_list[8].shape,align_graphs_list[56].shape)
    print("Alignx is",np.shape(alignx))

    graphon = universal_svd(align_graphs_list, threshold=0.2)
    print("Graphon is ",graphon.shape)

    graphons.append((label, graphon))
    Alignedfeatures.append(alignx)
two_graphons= [graphons[0] , graphons[1]]

In [None]:
#print(Alignedfeatures[0])
print(np.shape(Alignedfeatures))
print(Alignedfeatures[1])

In [None]:
#groundtruthset=list(originaldataset)
from torch_geometric.utils import to_networkx
import networkx as nx
groundtruthset=list(originaldataset)
random.shuffle(groundtruthset)
#print(groundtruthset)
count=0
for i in range(len(groundtruthset)):
    data=groundtruthset[i]
    if(data.y==0):
        #print(data.x)
        class0graph=to_networkx(data,to_undirected=True)
        class0graph.remove_edges_from(nx.selfloop_edges(class0graph))
        count+=1
        #plt.figure(1)
        # plotmutag(data)
        # break
print(count)       

In [None]:
def assign_same_features_to_data(data_list, features_list):
    """
    Assigns the same features to the x attribute of each Data object in data_list.

    :param data_list: List of PyTorch Geometric Data objects
    :param features_list: List of features to assign to the x attribute of each Data object
    :return: List of Data objects with updated x attributes
    """
    features_tensor = torch.tensor(features_list,dtype=torch.float)
    
    for data in data_list:
        data.x = features_tensor

    return data_list

In [None]:
from torch_geometric.utils import to_networkx
import matplotlib.colors as colors
import matplotlib.cm as cmx
import networkx as nx
import matplotlib.pyplot as plt
import torch

def plotmutag2(data):
    cmap = colors.ListedColormap(['blue', 'black', 'red', 'yellow', 'orange', 'green', 'purple'])
    ColorLegend = {
        'Carbon': 0,
        'Nitrogen': 1,
        'Oxygen': 2,
        'Fluorine': 3,
        'Iodine': 4,
        'Chlorine': 5,
        'Bromine': 6
    }
    cNorm = colors.Normalize(vmin=0, vmax=6)
    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cmap)

    # Extract node features and get the labels
    exfeatures = data.x
    examplelabels = torch.argmax(exfeatures, dim=1)

    # Convert to NetworkX graph
    examplegraph = to_networkx(data, to_undirected=True)
    examplegraph.remove_edges_from(nx.selfloop_edges(examplegraph))

    # Debug: Print initial number of nodes and labels
    # print(f"Initial number of nodes: {examplegraph.number_of_nodes()}")
    # print(f"Initial number of labels: {len(examplelabels)}")

    # Remove small components and update labels
    nodes_to_keep = set()
    for component in list(nx.connected_components(examplegraph)):
        if len(component) >= 7:
            nodes_to_keep.update(component)
    
    examplegraph = examplegraph.subgraph(nodes_to_keep).copy()

    # Filter examplelabels to only include nodes in the subgraph
    nodes_list = list(examplegraph.nodes())
    examplelabels = examplelabels[nodes_list]

    # Debug: Print number of nodes and labels after filtering
    # print(f"Filtered number of nodes: {examplegraph.number_of_nodes()}")
    # print(f"Filtered number of labels: {len(examplelabels)}")

    # Plot the graph
    f = plt.figure(2, figsize=(8, 8))
    ax = f.add_subplot(1, 1, 1)
    for label in ColorLegend:
        ax.plot([0], [0], color=scalarMap.to_rgba(ColorLegend[label]), label=label)
    nx.draw_networkx(
        examplegraph,
        node_size=150,
        node_color=examplelabels,
        cmap=cmap,
        vmin=0,
        vmax=6,
                # plotmutag(data)
        # breakels=False,
        ax=ax,with_labels=False
    )
    #plt.legend(fontsize=12, loc='best')
    plt.show()

# Example usage
# plotmutag(data)  # Assuming 'data' is a PyTorch Geometric data object


In [None]:
#new_graph=[]
import time
start_time=time.time()
import networkx as nx
from torch_geometric.utils import to_networkx
for y in range(1):
    new_graph = two_graphons_mixup(two_graphons, la=0.0, num_sample=20)
    print("Label of new graph is",new_graph[0].y)
    new_graph= assign_same_features_to_data(new_graph,Alignedfeatures[0])
    # dataset3=list(originaldataset)

    # newlist=list(originaldataset)+new_graph
    # # print(len(newlist))
    # # print(newlist[-1].x)
    # newlist=prepare_dataset_x(newlist)
    explainergraph=new_graph
    count=0
    # print(explainergraph[0].x)
    #print(new_graph[0].x,new_graph[0].y,new_graph[0].edge_index)
    #max=0
    maxprob=0
    for data in explainergraph:
        _,targetoutput=model(data.x,data.edge_index,None)
        #print(targetoutput)
        soft=torch.nn.Softmax(dim=1)
        problitites=soft(targetoutput)
        targetpred = targetoutput.max(dim=1)[1]
        
        #print("Probabilities are", problitites[0][0])
        if (maxprob<problitites[0][1]):
            maxprob= problitites[0][1]
            print(problitites[0][1])
            bestdata=data

        

            #print("Label of explainer graph is",targetpred)

    # examplegraph=to_networkx(bestdata,to_undirected=True)
    # examplegraph.remove_edges_from(nx.selfloop_edges(examplegraph))
    # for component in list(nx.connected_components(examplegraph)):
    #     if len(component)<7:
    #         for node in component:
    #             examplegraph.remove_node(node)
    # pos = nx.spring_layout(examplegraph, scale=20.0)
    plt.figure(y+1)
    #print("Bestdata is",bestdata.x)
    plotmutag2(bestdata)
endtime=time.time()
executiontime=endtime-start_time
print("executiontime is",executiontime)
    #newlist=[]




In [None]:
la= 0.3
accuracybound=[]
stdbound=[]
lalist=[]
boundaryembeddings=[]
while(la<=0.7):
    ratio= la/(1-la)

    
    boundary_graph = two_graphons_mixup(two_graphons, la=la, num_sample=100,show=True)
    boundary_graph= assign_same_features_to_data(boundary_graph,Alignedfeatures[0])
    #print("Label of new graph is",torch.argmax(new_graph[1].y,dim=-1))
    label=torch.argmax(boundary_graph[1].y,dim=-1)
    print(label)
    #from torch_geometric.utils import to_networkx
    
    boundaryaccuracy=[]
    for numexp in range(20):
        min1=1
        for data in boundary_graph:
            num_nodes = int( torch.max(data.edge_index) ) + 1
            #new_graph= assign_same_features_to_data(new_graph,Alignedfeatures[0])
            #data.x= torch.ones(num_nodes,1)
            embedding,out=model(data.x,data.edge_index,data.batch)
            soft=torch.nn.Softmax(dim=1)
            problities=soft(out)
            if( abs(problities[0][0]-0.5)<min1):
                #print("if")
                min1= abs(problities[0][0]-0.5)
                boundaryprobs = problities[0][0]
                latentboundary=embedding
                bestdata=data
        #plotmutag2(bestdata)
        boundaryaccuracy.append(boundaryprobs)
        boundaryembeddings.append(latentboundary)
    boundaryaccuracy=torch.stack(boundaryaccuracy)
    accuracybound.append(boundaryaccuracy.mean(dim=0))
    stdbound.append(boundaryaccuracy.std(dim=0))
    lalist.append(la)
    la=la+0.05
    

    
        
        
  
    #print("Label of new graph is",new_graph[1].y)



# ng=two_graphons_mixup(two_graphons,la=1.0,num_sample=1)

In [None]:
print(latentdata1)

In [None]:
latentclass=latentdata1
margin=boundary_margin(latentclass,boundaryembeddings[:len(latentclass)])
print("margin is",margin)
classifier=model.classifier

thickness=boundary_thickness(latentclass ,boundaryembeddings[:len(latentclass)],classifier,0,1)
print(thickness)
print("thickness is",thickness)
complexity=boundary_complexity(boundaryembeddings,64)
print("complexity is",complexity)

In [None]:
# Computing boundary metrics
import torch
import torch.nn.functional as F

def boundary_margin(embeddings_c1, embeddings_c2):
    """
    Compute the boundary margin.
    
    Args:
    - embeddings_c1 (torch.Tensor): Embeddings of class c1 graphs.
    - embeddings_c2 (torch.Tensor): Embeddings of boundary graphs between class c1 and c2.
    
    Returns:
    - margin (float): The boundary margin.

    """
    embeddings_c1=torch.cat(embeddings_c1,dim=0)
    embeddings_c2=torch.cat(embeddings_c2,dim=0)
    distances = torch.norm(embeddings_c1 - embeddings_c2, dim=1)
    margin = torch.min(distances).item()
    return margin

def boundary_thickness(embeddings_c1, embeddings_c1_c2, model, c1, c2, gamma=0.75, num_points=100):
    thickness_values = []

    for emb_c1, emb_c1_c2 in zip(embeddings_c1, embeddings_c1_c2):
        t_values = torch.linspace(0, 1, num_points)
        h_t = (1 - t_values).unsqueeze(1) * emb_c1 + t_values.unsqueeze(1) * emb_c1_c2
        print(model(h_t).size())

        # Compute the logits
        logits_h_t = model(h_t)  # Assuming `model` is your classifier
        probs_h_t = F.softmax(logits_h_t, dim=1)

        # Compute the integrand
        integrand = (gamma > (probs_h_t[:, c1] - probs_h_t[:, c2])).float()

        # Approximate the integral using the trapezoidal rule
        integral = torch.trapz(integrand, t_values)

        # Compute the thickness value
        thickness_value = (emb_c1 - emb_c1_c2).norm() * integral.mean()
        thickness_values.append(thickness_value.item())

    return sum(thickness_values) / len(thickness_values)

# def boundary_complexity(embeddings, D):
#     """
#     Compute the boundary complexity.
    
#     Args:
#     - embeddings (torch.Tensor): Embeddings of the boundary graphs with shape (num_graphs, embedding_dim).
#     - D (int): Dimensionality of the embeddings.
    
#     Returns:
#     - complexity (float): The boundary complexity.
#     """
#     # Compute the covariance matrix of the embeddings
#     embeddings=torch.cat(embeddings,dim=0)
#     covariance_matrix = torch.cov(embeddings.T)
    
#     # Compute the eigenvalues of the covariance matrix
#     eigenvalues = torch.linalg.eigvalsh(covariance_matrix)
#     print(eigenvalues)
    
#     # Normalize the eigenvalues
#     eigenvalues_normalized = eigenvalues / eigenvalues.sum()
#     print(eigenvalues_normalized)
    
#     # Compute the entropy of the normalized eigenvalues
#     entropy = -torch.sum(eigenvalues_normalized * torch.log(eigenvalues_normalized + 1e-7))
#     print(entropy)
    
#     # Normalize the entropy by dividing it by log(D)
#     complexity = entropy / torch.log(torch.tensor(D, dtype=torch.float32))
    
#     return complexity.item()
def boundary_complexity(embeddings, D, epsilon=1e-7):
    """
    Compute the boundary complexity.
    
    Args:
    - embeddings (torch.Tensor): Embeddings of the boundary graphs with shape (num_graphs, embedding_dim).
    - D (int): Dimensionality of the embeddings.
    - epsilon (float): Small value added to eigenvalues to prevent log(0).
    
    Returns:
    - complexity (float): The boundary complexity.
    """
    # Flatten and concatenate embeddings
    embeddings = torch.cat(embeddings, dim=0)
    
    # Compute the covariance matrix of the embeddings
    covariance_matrix = torch.cov(embeddings.T)
    
    # Add a small value to the diagonal for regularization
    covariance_matrix += epsilon * torch.eye(covariance_matrix.size(0))
    
    # Compute the eigenvalues of the covariance matrix
    eigenvalues = torch.linalg.eigvalsh(covariance_matrix)
    
    # Clamp eigenvalues to avoid very small negative values due to numerical errors
    eigenvalues = torch.clamp(eigenvalues, min=epsilon)
    
    # Normalize the eigenvalues
    eigenvalues_normalized = eigenvalues / eigenvalues.sum()
    
    # Compute the entropy of the normalized eigenvalues
    entropy = -torch.sum(eigenvalues_normalized * torch.log(eigenvalues_normalized + epsilon))
    
    # Normalize the entropy by dividing it by log(D)
    complexity = entropy / torch.log(torch.tensor(D, dtype=torch.float32))
    
    return complexity.item()

In [None]:
la= 0.3
accuracybound=[]
stdbound=[]
lalist=[]
count=100
while(la<=0.7):
    ratio= la/(1-la)

    

    #print("Label of new graph is",torch.argmax(new_graph[1].y,dim=-1))
    # label=torch.argmax(boundary_graph[1].y,dim=-1)
    # print(label)
    #from torch_geometric.utils import to_networkx
    
    boundaryaccuracy=[]
    for numexp in range(100):
        min=1
        boundary_graph = two_graphons_mixup(two_graphons, la=la, num_sample=90)
        #boundary_graph=list(originaldataset)+boundary_graph
        boundary_graph=assign_same_features_to_data(boundary_graph,Alignedfeatures[0])
        for data in boundary_graph:
            # num_nodes = int( torch.max(data.edge_index) ) + 1
            # data.x= torch.ones(num_nodes,1)
            emb,out=model(data.x,data.edge_index,None)
            soft=torch.nn.Softmax(dim=1)
            problities=soft(out)
            if( abs(problities[0][0]-0.5)<min):
                #print("if")
                min= abs(problities[0][0]-0.5)
                boundaryprobs = problities[0][0]
 
                bestdata=data
        #print("boundaryaccuracy being appended is",boundaryprobs )
        boundaryaccuracy.append(boundaryprobs)
    # examplegraph=to_networkx(bestdata,to_undirected=True)
    # examplegraph.remove_edges_from(nx.selfloop_edges(examplegraph))
    # plt.figure(count)
    # nx.draw_networkx(examplegraph, node_size=150, node_color='red',with_labels=False)
    boundaryaccuracy=torch.stack(boundaryaccuracy)
    accuracybound.append(boundaryaccuracy.mean(dim=0))
    print("Mean is",boundaryaccuracy.mean(dim=0))
    print("Std is",boundaryaccuracy.std(dim=0))
    stdbound.append(boundaryaccuracy.std(dim=0))
    lalist.append(la)
    la=la+0.05
    count=count+1
    

    
        
        
  
    #print("Label of new graph is",new_graph[1].y)



# ng=two_graphons_mixup(two_graphons,la=1.0,num_sample=1)

In [None]:
# print(accuracybound)
# print(stdbound)
#stdbound2=[torch.zeros_like(stdbound[i]) for i in range(len(stdbound))]
print(accuracybound)
print(stdbound)
#stdbound1=[stdbound[i]/10 for i in range(len(stdbound))]
plot_mean_with_error(accuracybound,stdbound,lalist)

In [None]:
#new_graph=[]
import networkx as nx
from torch_geometric.utils import to_networkx
new_graph = two_graphons_mixup(two_graphons, la=0.5, num_sample=100)
print("Label of new graph is",new_graph[0].y)
dataset3=list(originaldataset)

newlist=list(originaldataset)+new_graph
# print(len(newlist))
# print(newlist[-1].x)
newlist=prepare_dataset_x(newlist)
explainergraph=newlist[len(dataset3):]
count=0
# print(explainergraph[0].x)
#print(new_graph[0].x,new_graph[0].y,new_graph[0].edge_index)
#max=0
maxprob=0
for data in explainergraph:
    targetoutput=model(data.x,data.edge_index,None)
    #print(targetoutput)
    soft=torch.nn.Softmax(dim=1)
    problitites=soft(targetoutput)
    
    #print("Probabilities are", problitites[0][0])
    if (abs(maxprob -  0.5)>abs(problitites[0][1] - 0.5)):
        maxprob= problitites[0][1]
        bestdata=data
        print(problitites)

    
    targetpred = targetoutput.max(dim=1)[1]
    #print("Label of explainer graph is",targetpred)
examplegraph=to_networkx(bestdata,to_undirected=True)
plt.figure(1)
nx.draw_networkx(examplegraph, node_size=20, node_color='lightblue',with_labels=False)
#newlist=[]




In [None]:
import torch

def convert_to_one_hot(data_list, num_classes=2):
    """
    Convert the 'data.y' attribute of a list of PyTorch Geometric data objects into one-hot vectors.

    Args:
        data_list (list): A list of PyTorch Geometric data objects.
        num_classes (int): The number of classes. Default is 2.

    Returns:
        list: A list of PyTorch Geometric data objects with 'data.y' attribute converted into one-hot vectors.
    """
    for data in data_list:
        # Convert labels to one-hot encoding
        one_hot = torch.zeros((len(data.y), num_classes))
        one_hot.scatter_(1, data.y.view(-1, 1).long(), 1)
        # Replace data.y with one-hot vectors
        data.y = one_hot.float()

    return data_list


In [None]:
from utils import stat_graph, split_class_x_graphs, align_x_graphs,align_x_graphscorrected
from utils import two_x_graphons_mixup, universal_svd
classgraphs=split_class_x_graphs(dataset)
Alignedfeatures=[]
#print("Classgraphs is", classgraphs)


avg_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density = stat_graph(dataset)
resolution = int(median_num_nodes)
print("resolution is",resolution)
graphons=[]
#print(classgraphs)
for label,graphs,nodes in classgraphs:
    print(len(graphs))
    print("Label is",label)
    print("graph is",np.shape(graphs[0]))

    print("Nodes is",np.shape(nodes[1]))



    align_graphs_list,alignx, normalized_node_degrees, max_num, min_num = align_x_graphscorrected(
                    graphs,nodes, padding=True, N=resolution)
    
    #print("Aligned adj",align_graphs_list[8].shape,align_graphs_list[56].shape)
    print("Alignx is",len(alignx))

    graphon = universal_svd(align_graphs_list, threshold=0.2)
    print("Graphon is ",graphon.shape)

    graphons.append((label, graphon))
    Alignedfeatures.append(alignx)
two_graphons= [graphons[0] , graphons[1]]

In [None]:
#explain_loader= DataLoader(dataset2[:30], batch_size=1, shuffle=True)
#newdataset=dataset

classgraphs=split_class_x_graphs(newdataset)
avg_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density = stat_graph(newdataset)
resolution = int(median_num_nodes)-10 # This parameter controls the number of nodes in the generated explanations
mean_accuracy1=[]
std_accuracy1=[]
mean_accuracy2=[]
std_accuracy2=[]
ExplanationNodes=[]

for i in range(10):

    #print("resolution is",resolution)
    stddataset=list(originaldataset)
    graphons=[]
    Alignedfeatures=[]
    for label,graphs,nodes in classgraphs:
        #print("Label is",label)
        #print("graph is",graphs[0])
        align_graphs_list,alignx, normalized_node_degrees, max_num, min_num = align_x_graphscorrected(
                    graphs,nodes, padding=True, N=resolution)
        #print("Aligned adj",align_graphs_list[8].shape,align_graphs_list[56].shape)
        graphon = universal_svd(align_graphs_list, threshold=0.2)
        Alignedfeatures.append(alignx)
        #print("Graphon is ",graphon.shape)

        graphons.append((label, graphon))
    #two_graphons = random.sample(graphons, 2)
    two_graphons= [graphons[0] , graphons[1]]
    #print("Label of graphon 0 is",graphons[0][0], graphons[0])
    explainer_graph1 = two_graphons_mixup(two_graphons, la=0.0, num_sample=10)
    explainer_graph2 = two_graphons_mixup(two_graphons,la=1.0, num_sample=10)
    explainer_graph1=convert_to_one_hot(explainer_graph1)
    explainer_graph2=convert_to_one_hot(explainer_graph2)

    label1=torch.argmax(explainer_graph1[0].y,dim=-1)

    label2=torch.argmax(explainer_graph2[0].y,dim=-1)
    # print(label1,explainer_graph1[0].y)
    # print(label2,explainer_graph2[0].y)


    explainer_graph1=assign_same_features_to_data(explainer_graph1,Alignedfeatures[1])
    explainer_graph2=assign_same_features_to_data(explainer_graph2,Alignedfeatures[0])

    # explainer_graph1=explainer_graph1[len(list(originaldataset)):]
    # explainer_graph2=explainer_graph2[len(list(originaldataset)):]

    accuracy1=[]
    accuracy2=[]

    #print("Label of new graph is",new_graph[1].y)


    for numexplanations in range(10):
        max1=0
        max2=0
        for data in explainer_graph1:
            # num_nodes = int( torch.max(data.edge_index) ) + 1
            # data.x= torch.ones(num_nodes,1)
            emb,out=model(data.x,data.edge_index,None)
            soft=torch.nn.Softmax(dim=1)
            
            problities=soft(out)
            print("Probabilities are",problities)
            if(max1<problities[0][label1]):
                max1= problities[0][label1]
           
        for data in explainer_graph2:
            # num_nodes = int( torch.max(data.edge_index) ) + 1
            # data.x= torch.ones(num_nodes,1)
            emb,out=model(data.x,data.edge_index,None)
            soft=torch.nn.Softmax(dim=1)
            problities=soft(out)
            if (max2<problities[0][label2]):
                max2= problities[0][label2]
        accuracy1.append(max1)
        accuracy2.append(max2)
    accuracy1=torch.stack(accuracy1)
    accuracy2=torch.stack(accuracy2)
    mean1=accuracy1.mean(dim=0)
    #print("Mean1 is", mean1)
    mean2=accuracy2.mean(dim=0)
    std1=accuracy1.std(dim=0)
    std2=accuracy2.std(dim=0)
    mean_accuracy1.append(mean1)
    mean_accuracy2.append(mean2)
    std_accuracy1.append(std1)
    std_accuracy2.append(std2)
    ExplanationNodes.append(resolution)
    resolution= resolution+1
                
            
  
            

In [None]:
print(label1)
print(mean_accuracy1)
print(mean_accuracy2)
print(std_accuracy1)

In [None]:
print(std_accuracy2)

In [None]:
plot_mean_with_error(mean_accuracy1,std_accuracy1,ExplanationNodes)
print(label1)

In [None]:
plot_mean_with_error(mean_accuracy2,std_accuracy2,ExplanationNodes)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import ScalarFormatter

def plot_mean_with_error(mean, std, threshold,title=None, ax=None):
    """
    Plot mean with error bars.

    Parameters:
        mean (array_like): Array containing mean values.
        std (array_like): Array containing standard deviation values.
        threshold (array_like): Array containing threshold values.
        label (str): Label for the data.
        color (str): Color of the line.
        numsample (int): Sample number.
        ax (matplotlib.axes.Axes, optional): Axes object to plot on. If not provided, a new figure will be created.
    """
    # Flatten the arrays
    mean=torch.tensor(mean,dtype=torch.float32)
    
    std=torch.tensor(std,dtype=torch.float32)
    mean = np.array(mean).flatten()
    std = np.array(std).flatten()
    threshold = np.array(threshold).flatten()
    print("Threshold is",threshold)
    print("mean, std , threshold", mean,std,threshold)
    # # Select color automatically
    # colors = plt.cm.tab10(np.linspace(0, 1, 10))
    # color = colors[numsample % 10]  # Cycle through colors

    # Plotting
    if ax is None:
        fig, ax = plt.subplots()
    ax.errorbar(threshold, mean, yerr=std, fmt='-')  # '-' for line

    # Adding labels and title
    ax.set_xlabel('Lambda')
    ax.set_ylabel('Mean Class Score')
    ax.set_title(title)
    #ax.set_ylim(0.99975, 1.000051)

# Optionally, set the number of ticks or their locations
    ax.yaxis.set_major_locator(plt.MaxNLocator(5))  # Set max 5 ticks

    # ax.legend(loc='lower right',fontsize='small')  # Show legend
    # ax.grid(True)  # Add grid
# # Create a figure outside the function
# fig, ax = plt.subplots()
# plot_mean_with_error(Mean1,Std1,Threshold,label='class1',numsample=1,ax=ax)
# plot_mean_with_error(Mean2,Std2,Threshold,label='class1',numsample=2,ax=ax)
# plt.show()


In [None]:
print(len(list(originaldataset)))