In [2]:
import numpy as np
import pandas as pd
import pickle
import csv
import os
from torch_geometric.data import Data

import torch
from torch_geometric.data import InMemoryDataset
from tqdm import tqdm

import networkx as nx
#import matplotlib.pyplot as plt
from collections import OrderedDict

In [33]:
#Load syndrome matrices and the corresponding eq_class

PATH = '../data/d7_big/'

pickle_in2 = open(PATH + "dict.eq_distr","rb")
pickle_in1 = open(PATH + "dict.defects","rb")


defects = []
eq_distr = []
counter = 0

latch = True

while latch:
    try:
        eq_distribution = pickle.load(pickle_in2)
        #my_rounded_list = [ round(elem*0.01) for elem in eq_distribution ]
        
        defects.append(pickle.load(pickle_in1))
        eq_distr.append(eq_distribution)

        #latch = False
        
    except EOFError:
        break


print(len(defects))
print(len(eq_distr))

#print(len(defects1))
#print(len(eq_distr1))

pickle_in2.close()
pickle_in1.close()

5833
5833


In [34]:
def combine_defect_matrices(qubit_matrix,size):
    #combining the vertex- and plaquette defect matrices
    a = np.array(qubit_matrix[0])
    b = np.array(qubit_matrix[1])
    #print(a)
    #print(b)
    
    a_reshaped = np.zeros((size,size))
    b_reshaped = np.zeros((size,size))

    a_reshaped[:a.shape[0],:a.shape[1]] = a
    b_reshaped[:b.shape[0],:b.shape[1]] = b

    for ix in range(size-1):
        if ix ==0:
            a_reshaped = np.insert(a_reshaped,ix+1,0,axis=1)
            a_reshaped = np.insert(a_reshaped,ix,0,axis=0)
        
            b_reshaped = np.insert(b_reshaped,ix,0,axis=1)
            b_reshaped = np.insert(b_reshaped,ix+1,0,axis=0)
        else:
            a_reshaped = np.insert(a_reshaped,(ix*2)+1,0,axis=1)
            a_reshaped = np.insert(a_reshaped,ix*2,0,axis=0)
        
            b_reshaped = np.insert(b_reshaped,ix*2,0,axis=1)
            b_reshaped = np.insert(b_reshaped,(ix*2)+1,0,axis=0)
        
    combine = a_reshaped+b_reshaped
    #print(combine) 
    return combine

def manhattan_dist(positions,nbr_of_defects):
    manhattan_distances = np.zeros((nbr_of_defects,nbr_of_defects))
    for ix in range(nbr_of_defects):
        for jx in range(nbr_of_defects):
            if ix == jx:manhattan_distances[ix][jx] = 0
            else:
                manhattan_distances[ix][jx] = 1/(abs(positions[ix,0]-positions[jx,0])+abs(positions[ix,1]-positions[jx,1]))

    return manhattan_distances

def distance_to_boundary(combined_matrix,idx,size):
    
    #size = 5
    defect_positions = np.argwhere(combined_matrix)
    nbr_of_defects=len(defect_positions)
    new_pos= np.zeros((nbr_of_defects,2))
    
    for ix in range(nbr_of_defects):
        if defect_positions[ix][0]%2 == 0 :
            aa = defect_positions[ix,1]*0.5
            new_pos[ix] = [defect_positions[ix,0]*0.5,aa]
        else:
            bb = defect_positions[ix,0]*0.5
            new_pos[ix] = [bb,defect_positions[ix,1]*0.5]

    ##################
    #print(new_pos)
    manhattan_distances = manhattan_dist(new_pos,nbr_of_defects)
    #print(manhattan_distances)
    
    #from adj.matrix to edgelist
    graph = nx.from_numpy_matrix(manhattan_distances)
    #edges to be directed both ways
    graph = nx.to_directed(graph)
    layout= nx.spring_layout(graph)
    
    ####plot one graph
    #if idx == len(defects)-1:
        #f = plt.figure()
        #nx.draw(graph,ax=f.add_subplot(111),with_labels=True,arrowsize=2, arrowstyle='fancy')
        #f.savefig("defect_matrices/Graph.png", format="PNG")
    ####

    file_2 = open(PATH + 'edgelist.txt','ab')        
    nx.write_edgelist(graph, file_2)
    file_2.close()
    
    #new_line = '\n'
    #file_2.write(new_line.encode('utf-8'))

    file_3 = open(PATH + 'graph_info.txt','a')        
    file_3.write(str(nbr_of_defects)+'\n')
    file_3.close()
    
    #file_2.write(new_line.encode('utf-8'))

    
    ##################
    
    order_list_vertex = []
    order_list_plaquette = []
    for ix in range(nbr_of_defects):
        if (round(new_pos[ix,0],2)-int(new_pos[ix,0]) != 0.):
            #print(new_pos[ix,0])
            order_list_vertex.append(ix)
        if (round(new_pos[ix,1],2)-int(new_pos[ix,1]) != 0.):
            #print(new_pos[ix,1])
            order_list_plaquette.append(ix)

    order_array_plaquette = np.array(order_list_plaquette)
    order_array_vertex = np.array(order_list_vertex)

    #print(len(order_array_plaquette))
    #print(len(order_array_vertex))
    #print(order_array_vertex)
    #print(order_array_plaquette)

    #print(new_pos[order_array_vertex])
    #print(new_pos[order_array_plaquette])
    
    ##########
    #special case when there are only vertex defects
    if len(order_array_plaquette) == 0:
        distance_vertex_boundary = np.zeros(len(order_array_vertex))
        for ix in range(len(order_array_vertex)):
            dist = new_pos[order_array_vertex[ix]]
            if dist[0] >= size*0.5:
                dist[0] = -(size-dist[0])
    
            distance_vertex_boundary[ix] = dist[0]
    
    
    
        boundary_vertex = np.around(distance_vertex_boundary+0.1)*0.5
    
        boundary_v = np.ones((len(boundary_vertex),2))
        boundary_v = -1*boundary_v
        boundary_v[:,0] = boundary_vertex

        node_features = np.zeros((nbr_of_defects,2))
        node_features[order_array_vertex] = boundary_v
    
    
    ############
    #special case when there are only plaq. defects
    elif len(order_array_vertex) == 0:
        distance_plaquette_boundary = np.zeros(len(order_array_plaquette))

        for ix in range(len(order_array_plaquette)):
            dist = new_pos[order_array_plaquette[ix]]
            if dist[1] >= size*0.5:
                dist[1] = -(size-dist[1])
    
            distance_plaquette_boundary[ix] = dist[1]
    

        boundary_plaquette = np.around(distance_plaquette_boundary+0.1)*0.5

    
        boundary_p = np.ones((len(boundary_plaquette),2))
        boundary_p[:,0] = boundary_plaquette
        
        
        node_features = np.zeros((nbr_of_defects,2))
        node_features[order_array_plaquette] = boundary_p
        
        #print(node_features)
    
        
    ##########
    #typical case
    else:
        distance_vertex_boundary = np.zeros(len(order_array_vertex))
        distance_plaquette_boundary = np.zeros(len(order_array_plaquette))

        for ix in range(len(order_array_vertex)):
            dist = new_pos[order_array_vertex[ix]]
            if dist[0] >= size*0.5:#2.5:
                dist[0] = -(size-dist[0])
    
            distance_vertex_boundary[ix] = dist[0]
    
        for ix in range(len(order_array_plaquette)):
            dist = new_pos[order_array_plaquette[ix]]
            if dist[1] >= size*0.5:#2.5:
                dist[1] = -(size-dist[1])
    
            distance_plaquette_boundary[ix] = dist[1]
    

        boundary_vertex = np.around(distance_vertex_boundary+0.1)*0.5
        boundary_plaquette = np.around(distance_plaquette_boundary+0.1)*0.5

        
        boundary_v = np.ones((len(boundary_vertex),2))
        boundary_v = -1*boundary_v
        boundary_v[:,0] = boundary_vertex

        boundary_p = np.ones((len(boundary_plaquette),2))
        boundary_p[:,0] = boundary_plaquette
    
        node_features = np.zeros((nbr_of_defects,2))
        node_features[order_array_vertex] = boundary_v
        node_features[order_array_plaquette] = boundary_p
    
    return node_features 


#caveman function to check balance of classes
def conditions(eq_distr,eq_class):
    if np.round(eq_distr[0]*0.01) == 1:
        eq_class[0] = eq_class[0] + 1
    
    elif np.round(eq_distr[1]*0.01) == 1:
        eq_class[1] = eq_class[1] + 1
    
    elif np.round(eq_distr[2]*0.01) == 1:
        eq_class[2] = eq_class[2] + 1
    
    else: eq_class[3] = eq_class[3] + 1
        
    return eq_class


    
    

In [35]:
#clean out old contents
open(PATH + 'node_features.txt', 'w').close()
open(PATH + 'edgelist.txt', 'w').close()
open(PATH + 'graph_info.txt', 'w').close()


file_edge=open(PATH + 'node_features.txt','a')
#pickle_eq_dist = open("../data/test/dict.eq_distr","ab")
#pickle_defect_matrices = open("../data/test/dict.defect","ab")

eq_class_counter = np.zeros(4)

zero_def = 0
single_def = 0

defect_test = []
eq_dist_test = []
_, size = np.array(defects[0][0]).shape
for ix in range(len(defects)):
    if (len(np.argwhere(defects[ix][0]))+len(np.argwhere(defects[ix][1]))) == 0:
        zero_def +=1
        continue
    
    
    
    #### REMOVE FOR REGULAR RUN, ONLY FOR FEW DEFECT SIM.
    #if (len(np.argwhere(defects[ix][0]))+len(np.argwhere(defects[ix][1]))) > 8 :
    #    continue
    
    if (len(np.argwhere(defects[ix][0]))+len(np.argwhere(defects[ix][1]))) == 1:    
        single_def +=1
        continue
    ##################################################
    
    balcount = 1000
    
    #the number of syndromes wanted of each class, eg 3500
    #if np.round(eq_distr[ix][0]*0.01) == 1 and eq_class_counter[0] >= balcount:
    #    continue
    #if np.round(eq_distr[ix][1]*0.01) == 1 and eq_class_counter[1] >= balcount:
    #    continue
    #if np.round(eq_distr[ix][2]*0.01) == 1 and eq_class_counter[2] >= balcount:
    #    continue
    #if np.round(eq_distr[ix][3]*0.01) == 1 and eq_class_counter[3] >= balcount:
    #    continue
    
    
    eq_class_counter = conditions(eq_distr[ix],eq_class_counter)
    
    
    
    eq_dist_test.append(eq_distr[ix])
    #pickle.dump(eq_distr[ix],pickle_eq_dist)
    #pickle.dump(defects[ix],pickle_defect_matrices)
    
    combined_matrix = combine_defect_matrices(defects[ix],size)
    node_features = distance_to_boundary(combined_matrix,ix,size)
    #print(node_features)
    
    np.savetxt(file_edge,node_features,fmt='%.1f')
    file_edge.write("\n")
    

#pickle_eq_dist.close()
#pickle_defect_matrices.close()
file_edge.close()

In [36]:
eq_class_counter

array([1500., 1500., 1500., 1333.])

In [37]:
import torch_geometric.transforms
from torch_geometric.nn import knn_graph
import torch_geometric.data
import torch 
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
import numpy as np
import pickle
from torch_geometric.data import DataLoader
import linecache
import random

In [38]:
eq_distr = eq_dist_test

In [39]:
#Use files from "new_datacreator"

node_feat_load = np.genfromtxt(PATH + 'node_features.txt')
edge_idx_load = np.genfromtxt(PATH + 'edgelist.txt',usecols=(0,1))
edge_feat_load = np.genfromtxt(PATH + 'edgelist.txt',dtype = float,usecols=3,comments = '}')
graph_info = np.genfromtxt(PATH + 'graph_info.txt')

In [40]:
graphs = []

#could probably make it more elegant but for now the holder indices works fine i guess :P
holder=0
edge_holder =0
for ix in range(len(graph_info)):
#for ix in range(5):

    num_nodes = int(graph_info[ix])
    edge_amount = int(num_nodes*(num_nodes-1))
    
    node_feat = node_feat_load[holder:holder+num_nodes]
    edge_idx = edge_idx_load[edge_holder:edge_amount+edge_holder]
    edge_feat = edge_feat_load[edge_holder:edge_amount+edge_holder]
    
    #print(node_feat)
    #print(edge_idx)
    #print(edge_feat)
    
    node_features = torch.from_numpy(node_feat)
    edge_features = torch.from_numpy(edge_feat)
    
    edge_index = torch.from_numpy(np.transpose(edge_idx)).type(torch.LongTensor)
    #edge_index = torch_geometric.utils.to_undirected(edge_index)
    
    #print(node_features)
    #print(edge_index)
    #print(edge_features.shape)
    
    # NOTE that we have the distribution here
    eq_class = np.array(np.around(eq_distr[ix]*0.01,decimals=3))#should change this to float with prec.3 or something
    eq_class = torch.from_numpy(eq_class)
    eq_class = eq_class.unsqueeze(0)
        
    #create of graph one at the time
    graph = Data(node_features,edge_index,edge_features,y=eq_class)
    
    #print(graph)
    graphs.append(graph)
    
    holder = holder + num_nodes
    edge_holder = edge_holder + edge_amount


    

torch.save(graphs,PATH + 'graphs.pt')#change namnes  


[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[0. 0. 0. 1.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 0. 1.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[0.001 0.    0.999 0.   ]
[0. 0. 0. 1.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0

[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0.999 0.    0.001 0.   ]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 0. 1.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 0. 1.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[0. 0. 0. 1.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[0.    0.    0.999 0.001]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 

[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 0. 1.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0.999 0.    0.001 0.   ]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0

[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1.

[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0.017 0.    0.983 0.   ]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0.001 0.999 0.    0.   ]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0.025 0.    0.975 0.   ]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.

[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0.001 0.999 0.    0.   ]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0.001 0.    0.999 0.   ]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0.001 0.    0.999 0.   ]
[0. 0. 1. 0.

[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0.    0.001 0.    0.999]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0.    0.017 0.    0.983]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0.    0.    0.006 0.994]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0.    0.    0.001 0.999]
[0. 0. 0. 1.]
[0.    0.035 0.    0.965]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0. 0. 0. 1.]
[0

In [24]:
import torch 
import pickle
import numpy as np


#pickle_in1 = open("data/stdc/dmixed(579)_p1/dict.defect","rb")
#pickle_in2 = open("data/stdc/dmixed(579)_p1/dict.eq_distr","rb")


dataset = torch.load(PATH + 'graphs.pt')
#dataset = torch.load('data/stdc/dmixed(57)_p1/graphs_clean.pt')
print(len(dataset))

1539


In [25]:
num_nodes=0
max_nodes=0
for ix in range(len(dataset)):
    num_nodes += len(dataset[ix].x)
    if len(dataset[ix].x) > max_nodes:
        max_nodes = len(dataset[ix].x)
        #print(ix)
print(num_nodes/len(dataset))
print(max_nodes)
#print(len(dataset[2].x.flatten()))
#print(dataset[2].edge_attr[:4].flatten().shape)


9.803118908382066
27


In [26]:
flat_eq = np.zeros((len(dataset),(max_nodes*2)+6))
print(flat_eq.shape)
for ix in range(len(dataset)):
    flat_eq[ix,:len(dataset[ix].x.flatten())] = dataset[ix].x.flatten()
    #print(dataset[ix].edge_attr[:2].flatten())
    flat_eq[ix,-6:-(6-len(dataset[ix].edge_attr[:5].flatten()))] = dataset[ix].edge_attr[:5].flatten()
u, idx, inv,counts = np.unique(flat_eq,return_index=True,return_inverse=True,return_counts=True,axis=0)
clean_data = []
for ix in range(len(idx)):
    clean_data.append(dataset[idx[ix]])
print(len(clean_data))

# WE may loose some here because of the non 1-1 representation of syndroms

(1539, 60)
1519


In [27]:
torch.save(clean_data,PATH + 'graphs_clean.pt')

In [28]:
counts = torch.zeros(4)
for ix in range(len(clean_data)):
    counts = counts + clean_data[ix].y
print(torch.round(counts))

tensor([[984., 236., 228.,  71.]], dtype=torch.float64)


In [29]:
dataset = torch.load(PATH + 'graphs_clean.pt')
#print()
#print(f'Dataset: {dataset}:')
#print('====================')
print(f'Number of graphs: {len(dataset)}')
print()
#print(f'Number of features: {dataset.num_features}')
#print(f'Number of classes: {dataset.num_classes}')
print(f'First graph attributes:')
data = dataset[10]  # Get the first graph object.
#print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')


Number of graphs: 1519

First graph attributes:
Data(edge_attr=[30], edge_index=[2, 30], x=[6, 2], y=[1, 4])
Number of nodes: 6
Number of edges: 30
Average node degree: 5.00


In [30]:
torch.manual_seed(1234)
#dataset = dataset.shuffle()
random.shuffle(dataset)

train_dataset = dataset[:8]
test_dataset = dataset[8:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')



Number of training graphs: 8
Number of test graphs: 1511
