#Library Installation

In [None]:
!pip --version

In [None]:
!pip install torch_geometric
!pip install prettytable
!pip install torch
!pip install pandas
!pip install matplotlib
!pip install scikit-learn
!pip install plotly

#Library Import

In [3]:
import os
import torch
import math
import networkx as nx
import pandas as pd
from torch_geometric.utils import to_networkx
import torch.nn.functional as F
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
from collections import defaultdict
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.nn import GATConv
from prettytable import PrettyTable
from pprint import pprint
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import time
import os
import platform
import psutil
import torch
import torch_geometric
import numpy
import subprocess

In [None]:
print(torch.__version__)
print(torch_geometric.__version__)
print(numpy.__version__)
print(os)

#Project Configuration

In [5]:
Configuration = {
    "Dataset": {
        "root": "data/Planetoid",
        "name": "Cora", # Cora, CiteSeer, PubMed
        "normalization": False,
        "getSummaryLevel": 0
    },
    "BuildModel": {
        "learningRate": 0.05, #Cora: 0.05, Citeseer: 0.01, PubMed: 0.04
        "patience": 10,
        "maxEpochs": 23, #Cora: 150
        "plot": False,
        "plotName": None,
        "scaleMin": -127,
        "scaleMax": 127,
        "printLearnableParameters": False
    },
    "GAT": {
        "hiddenChannel": 16,
        "head": 1
    },
    "GCSR": {
        "printInfoShape": True,

        "colIndexRequiredBit": None,
        "valueRequiredBit": 8,
        "rowLengthRequiredBit": None,
        "numOfNodesRequiredBit": None,
        "flagRequiredBit": None,

        "colIndexSigned": False,
        "valueSigned": True,
        "rowLengthSigned": False,
        "numOfNodesSigned": False,
        "flagSigned": False,
    },
    "LParamQuan": {
        "scaleMin": -127,
        "scaleMax": 127,
        "printParamShape": True,

        "conv1AttSrcRequiredBit": None,
        "conv1AttDstRequiredBit": None,
        "conv1BiasRequiredBit": None,
        "conv1WeightRequiredBit": None,

        "conv2AttSrcRequiredBit": None,
        "conv2AttDstRequiredBit": None,
        "conv2BiasRequiredBit": None,
        "conv2WeightRequiredBit": None,

        "conv1AttSrcSigned": True,
        "conv1AttDstSigned": True,
        "conv1BiasSigned": True,
        "conv1WeightSigned": True,

        "conv2AttSrcSigned": True,
        "conv2AttDstSigned": True,
        "conv2BiasSigned": True,
        "conv2WeightSigned": True,
    },
    "RawDataResult": {
        "showData": False
    },
    "PCOO": {
        "sorRequiredBit": None,
        "eorRequiredBit": None,
        "vldRequiredBit": None,
        "colRequiredBit": None,
        "valRequiredBit": 11,

        "sorSigned": False,
        "eorSigned": False,
        "vldSigned": False,
        "colSigned": False,
        "valSigned": True,
    }

}

# Common Utility

In [6]:
def count_elements_tensor(tensor, value):
  return (tensor == value).sum().item()

def scale_tensor(tensor, scale_min, scale_max, to_dtype=torch.int8):
    v_max = tensor.max() if tensor.max() != 0 else 1  # Avoid division by zero

    # Scale the tensor
    scaled_tensor = (tensor / v_max) * scale_max
    scaled_tensor = scaled_tensor.clamp(scale_min, scale_max)
    scaled_tensor = scaled_tensor.to(to_dtype)

    # Define a function to scale back to the original range
    def scale_back_fn(scaled_tensor):
        scaled_tensor = scaled_tensor.to(torch.float32)  # Ensure float for computation
        return (scaled_tensor / scale_max) * v_max

    return scaled_tensor, scale_back_fn

def printLearnableParameters(model):
  print("Learnable Parameters of the Model:")
  for name, param in model.named_parameters():
      if param.requires_grad:
          print(f"{name} ({param.shape}): {param}")

def int_to_n_bit_binary(number, n_bits):
    # Handle two's complement for negative numbers
    if number < 0:
        number = (1 << n_bits) + number

    # Convert the number to binary with zero-padding to n bits
    binary_str = format(number, f'0{n_bits}b')
    return binary_str

def int_to_n_bit_binary_with_flags(number, n_bits, isStart=False, isEnd=False):
    binary_str = int_to_n_bit_binary(number, n_bits)
    if isStart:
        binary_str = '1' + binary_str
    else:
        binary_str = '0' + binary_str

    if isEnd:
        binary_str = binary_str + '1'
    else:
        binary_str = binary_str + '0'

    return binary_str


def get_nonzero_features(node_features):
    nonzero_indices = torch.nonzero(node_features, as_tuple=False).squeeze(1)  # Nonzero feature indices
    return nonzero_indices

def bits_required(min_value, max_value, have_sign_bit = False):
    max_abs_value = max(abs(min_value), abs(max_value))
    if have_sign_bit:
      return math.floor(math.log2(max_abs_value) + 1) + 1
    return math.floor(math.log2(max_abs_value) + 1)

def intToBinaryArray(arr, n_bits):
    binary_arrays = []
    for num in arr:
      binary_arrays.append(int_to_n_bit_binary(int(num), n_bits))
    return binary_arrays

def intToBinaryMatrix(matrix, n_bits):
    binary_matrix = []
    for row in matrix:
      binary_row = []
      for col in row:
        binary_row.append(int_to_n_bit_binary(int(col), n_bits))
      binary_matrix.append(binary_row)
    return binary_matrix

def resolveNodeInfo(arr, row_length_n_bits, num_of_nodes_n_bits, flag_n_bits):
    binary_matrix = []
    for row in arr:
      binary_row = []
      for i, col in enumerate(row):
        if i == 0: binary_row.append(int_to_n_bit_binary(int(col), row_length_n_bits))
        if i == 1: binary_row.append(int_to_n_bit_binary(int(col), num_of_nodes_n_bits))
        if i == 2: binary_row.append(int_to_n_bit_binary(int(col), flag_n_bits))
      concat_binary_row = ''.join(binary_row)
      binary_row.append(concat_binary_row)
      binary_matrix.append(binary_row)
    return binary_matrix

def visualize_2d(h, color, name = '2D_dist_plot'):
    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())

    fig = px.scatter(x=z[:, 0], y=z[:, 1], color=color, color_continuous_scale="magma")
    fig.update_layout(
        xaxis_title="Dimension 1",
        yaxis_title="Dimension 2",
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        coloraxis_showscale=False,
        width=800,
        height=800
    )
    fig.show()

def visualize_3d(h, color, name = '3D_dist_plot'):
    z = TSNE(n_components=3).fit_transform(h.detach().cpu().numpy())

    fig = px.scatter_3d(x=z[:, 0], y=z[:, 1], z=z[:, 2], color=color, color_continuous_scale="magma")
    fig.update_layout(
        scene=dict(
            xaxis_title="Dimension 1",
            yaxis_title="Dimension 2",
            zaxis_title="Dimension 3"
        ),
        coloraxis_showscale=False,
        width=800,
        height=800
    )
    fig.show()

def currentOption():
  if torch.cuda.is_available():
    print("GPU is available.")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    try:
        output = subprocess.check_output(
            ["nvidia-smi", "--query-gpu=clocks.gr,clocks.sm,clocks.mem", "--format=csv,noheader,nounits"],
            encoding='utf-8'
        )
        gr, sm, mem = output.strip().split(', ')
        print(f"GPU Graphics Clock: {gr} MHz")
        print(f"GPU SM Clock: {sm} MHz")
        print(f"GPU Memory Clock: {mem} MHz")
    except Exception as e:
        print("Could not fetch GPU frequency. Make sure 'nvidia-smi' is installed.")
  else:
      print("Running on CPU.")

  # CPU Information
  print(f"Processor: {platform.processor()}")
  print(f"CPU Count: {os.cpu_count()}")

  # RAM Information
  virtual_memory = psutil.virtual_memory()
  print(f"Total RAM: {virtual_memory.total / 1e9:.2f} GB")
  print(f"Available RAM: {virtual_memory.available / 1e9:.2f} GB")

  #Frequency
  cpu_freq = psutil.cpu_freq()
  if cpu_freq:
      print(f"CPU Frequency: {cpu_freq.current:.2f} MHz (Max: {cpu_freq.max:.2f} MHz)")


  # Disk Information
  disk_usage = psutil.disk_usage('/')
  print(f"Total Disk Space: {disk_usage.total / 1e9:.2f} GB")
  print(f"Used Disk Space: {disk_usage.used / 1e9:.2f} GB")
  print(f"Free Disk Space: {disk_usage.free / 1e9:.2f} GB")

  # Operating System Information
  print(f"Operating System: {platform.system()} {platform.release()}")
  print(f"Python Version: {platform.python_version()}")

def write_array_to_file(data, filename):
    with open(filename, 'w') as file:
        for element in data:
            file.write(f"{element}\n")
    print(f"Array data written to {filename}")

def write_matrix_to_file(data, filename):
    with open(filename, 'w') as file:
        for row in data:
            for element in row:
                file.write(f"{element}\n")
    print(f"Matrix data written to {filename}")

def append_array_to_file(data, filename):
    with open(filename, 'a') as file:  # Use 'a' for append mode
        for element in data:
            file.write(f"{element}\n")
    print(f"Array data appended to {filename}")

def append_matrix_to_file(data, filename):
    with open(filename, 'a') as file:
        for row in data:
            for element in row:
                file.write(f"{element}\n")
    print(f"Matrix data written to {filename}")

def createPath(dir, filename):
    return os.path.join(dir, filename)

#Current Resource

In [None]:
currentOption()

# Dataset Loader

In [8]:
class DatasetLoader:
  GetDataOption ={
    "ALL" : 0,
    "SHORT" : 1
  }
  def __init__(self,
               root: str = Configuration["Dataset"]["root"],
               name: str = Configuration["Dataset"]["name"],
               normalize: int = Configuration["Dataset"]["normalization"]):
      self.root = root
      self.name = name
      self.normalize = normalize
      self.dataset = self._load_dataset()

  def _load_dataset(self):
      transform = NormalizeFeatures() if self.normalize else None
      return Planetoid(root=self.root, name=self.name, transform=transform)

  def get_summary(self,
                  option: str = Configuration["Dataset"]["getSummaryLevel"]):
      print(f'Dataset: {self.dataset}')
      print('======================')
      print(f'Number of graphs: {len(self.dataset)}')
      print(f'Number of features: {self.dataset.num_features}')
      print(f'Number of classes: {self.dataset.num_classes}')
      print(f'Overview data:  {self.dataset[0]}')
      print('======================')

      if(option == self.GetDataOption["ALL"]):
        data = self.dataset[0]
        print(f'Feature matrix: {data.x} \n Shape: {data.x.shape}')
        print('======================')
        print(f'Classification output: {data.y} \nShape: {data.y.shape}')
        print('======================')
        print(f'Data mask (Train): {data.train_mask} \nShape: {data.train_mask.shape} \nsize: {count_elements_tensor(data.train_mask, True)}')
        print('======================')
        print(f'Data mask (Validation): {data.val_mask} \nShape: {data.val_mask.shape} \nsize: {count_elements_tensor(data.val_mask, True)}')
        print('======================')
        print(f'Data mask (Test): {data.test_mask} \nShape: {data.test_mask.shape} \nsize: {count_elements_tensor(data.test_mask, True)}')
        print('======================')
        print(f'Edge Pairs: {data.edge_index} \nShape: {data.edge_index.shape}')
        print('======================')
        print(f'Is Contain Isolated Node: {data.has_isolated_nodes()}')
        print('======================')
        print(f'Is Contain Self Loop: {data.has_self_loops()}')
        print('======================')
        print(f'Graph Direction: {"Undirected" if data.is_undirected() else "Directed"}')
        print('======================')

  def get_data(self, index: int = 0):
      return self.dataset[index]

  def get_dataset(self):
      return self.dataset

# DataSet Loader Intialization

In [None]:
loader = DatasetLoader()
loader.get_summary()
data = loader.get_data()
dataset = loader.get_dataset()

# visualize_2d(data.x, color=data.y, name = 'gat_2d_dist_plot')
# visualize_3d(data.x, color=data.y, name = 'gat_3d_dist_plot')

In [None]:
print(data.x)
zero_row_count = (data.x == 0).all(dim=1).sum().item()
print(zero_row_count)

# Model building

In [11]:
class BuildModel():
    def __init__(self, model, lr = Configuration["BuildModel"]["learningRate"], save_path="model_params.pth"):
        self.model = model
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
        self.criterion = torch.nn.CrossEntropyLoss()
        self.save_path = save_path

    def save_model_params(self):
        """Save model parameters to the file."""
        torch.save(self.model.state_dict(), self.save_path)
        print(f"Model parameters saved to {self.save_path}")

    def load_model_params(self):
        """Load model parameters from the file if it exists."""
        if os.path.exists(self.save_path):
            self.model.load_state_dict(torch.load(self.save_path))
            print(f"Model parameters loaded from {self.save_path}")
            return True
        else:
            print(f"No saved model parameters found at {self.save_path}")
            return False

    def single_train(self):
          self.model.train()
          self.optimizer.zero_grad()

          # Quantize model
          a = self.model.state_dict()
          for k, v in a.items():
            scaled_tensor, scale_back_fn = scale_tensor(v,
                                                       Configuration["BuildModel"]["scaleMin"],
                                                       Configuration["BuildModel"]["scaleMax"],
                                                       torch.int8)
            converted_tensor = scale_back_fn(scaled_tensor)
            a[k] = converted_tensor
          self.model.load_state_dict(a);
          out = self.model(data.x, data.edge_index)
          loss = self.criterion(out[data.train_mask], data.y[data.train_mask])
          loss.backward()
          self.optimizer.step()
          return loss

    def test(self, visualization_2D = False, visualization_3D = False):
          self.model.eval()
          a = self.model.state_dict()
          out = self.model(data.x, data.edge_index)
          pred = out.argmax(dim=1)
          test_correct = pred[data.test_mask] == data.y[data.test_mask]
          test_acc = int(test_correct.sum()) / int(data.test_mask.sum())
          # if visualization_2D:
          #   visualize_2d(out, color=data.y, name = 'gat_2d_dist_plot')
          # if visualization_3D:
          #   visualize_3d(out, color=data.y, name = 'gat_2d_dist_plot')
          return test_acc

    def train_with_early_stopping(self,
                                  epochs=Configuration["BuildModel"]["maxEpochs"],
                                  patience=Configuration["BuildModel"]["patience"],
                                  plot = Configuration["BuildModel"]["plot"],
                                  plot_name = Configuration["BuildModel"]["plotName"],
                                  printAllParams = Configuration["BuildModel"]["printLearnableParameters"]):
        history = {
            'epoch': [],
            'loss': [],
            'test_acc': []
        }

        best_test_acc = 0.0
        epochs_without_improvement = 0

        for epoch in range(1, epochs + 1):
            loss = self.single_train()
            test_acc = self.test()

            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {test_acc:.4f}')

            history['epoch'].append(epoch)
            history['loss'].append(loss.item())
            history['test_acc'].append(test_acc)

            if test_acc > best_test_acc:
                best_test_acc = test_acc
                epochs_without_improvement = 0
            else:
                epochs_without_improvement += 1

            if epochs_without_improvement >= patience:
                print(f'Early stopping triggered at epoch {epoch}.')
                break

        if plot:
            self.history_plot(history, plot_name)

        if printAllParams:
            print('======================')
            printLearnableParameters(self.model)

        return history

    def train(self, epochs= Configuration["BuildModel"]["maxEpochs"]):
        if self.load_model_params():
            print("Skipping training as model parameters are already available.")
            return
        for epoch in range(1, epochs + 1):
            loss = self.single_train()
            test_acc = self.test()
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {test_acc:.4f}')
            # TODO: Comment this
            hardcode_acc = 0.8;
            if Configuration['Dataset']['name'] == 'Cora':
              hardcode_acc = 0.83
            if Configuration['Dataset']['name'] == 'CiteSeer':
              # hardcode_acc = 0.704
              hardcode_acc = 0.69
            if Configuration['Dataset']['name'] == 'PubMed':
              hardcode_acc = 0.775

            if test_acc >= hardcode_acc: #Cora 0.82, Citeseer: 0.704
              break
        self.save_model_params()


    def history_plot(self, history, plot_name):
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=history['epoch'], y=history['loss'], mode='lines', name='Training Loss'))
        fig.add_trace(go.Scatter(x=history['epoch'], y=history['test_acc'], mode='lines', name='Test Accuracy'))
        fig.update_layout(
            title='Training History',
            xaxis_title='Epoch',
            yaxis_title='Value',
            legend=dict(x=0, y=1),
            template='plotly_dark'
        )
        fig.show()

#GAT Algorithm

In [12]:
class GAT(torch.nn.Module):
    def __init__(self,
                 hidden_channels = Configuration["GAT"]["hiddenChannel"],
                 heads = Configuration["GAT"]["head"]):
        super().__init__()
        torch.manual_seed(1234567)
        self.conv1 = GATConv(dataset.num_features, hidden_channels,heads, True)
        self.conv2 = GATConv(heads*hidden_channels, dataset.num_classes,1, False)

    def forward(self, x, edge_index):
        p_default = 0.6
        if Configuration['Dataset']['name'] == 'Cora':
              p_default = 0.6
        if Configuration['Dataset']['name'] == 'CiteSeer':
              p_default = 0.9
        if Configuration['Dataset']['name'] == 'PubMed':
              p_default = 0.6
        x = F.dropout(x, p=p_default, training=self.training) #Cora: 0.6, Citeseer: 0.9
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=p_default, training=self.training) #Cora: 0.6, Citeseer: 0.9
        x = self.conv2(x, edge_index)
        return x


# GAT Initialization

In [13]:
GATModel = GAT()

# Train Model

In [None]:
buildGATModel = BuildModel(GATModel)
# history = buildGATModel.train_with_early_stopping()
history = buildGATModel.train()

# Test Capture & Visualization

In [None]:
start_time = time.time()
test_accuracy = buildGATModel.test(visualization_2D=False, visualization_3D=False)
end_time = time.time()
print(f'Test Accuracy with Quantization Aware Training (QAT): {test_accuracy}')
print(f"Time to run test: {end_time - start_time:.6f} seconds")

# GCSR Data Compression Builder
Base on subgraph:
- col_index: array of column index of non-zero feature
- value: array of that non-zero feature
- node_info: array of {row_length, num_of_nodes, flag}
  + row_length: # of non-zero element of a node
  + num_of_nodes: # of nodes of that subgraph
  + flag: 1 -> src node, 0 -> other node


Code Implementation:
- node_info_gcsr: idx -> tensor of non-zero features idx
+ Example:
      1: {
        'source_node_nonzero_indices': tensor[3,9,10,...],
        'neighbors': {
            2: tensor[3,9,10,...],
            3: tensor[3,9,10,...]
        }
        'totalLength': 3
      }
   


In [16]:
class GCSR_Data_Compression_Builder():
  def getInfoGCSR(self, printInfoShape = Configuration["GCSR"]["printInfoShape"]):
    edge_index_gcsr = data.edge_index
    subgraph_info_gcsr = {}

    col_index_tensor = []
    value_tensor = []
    node_info_tensor = []

    for node_idx in torch.unique(edge_index_gcsr[0]):
        neighbors_idx_arr = edge_index_gcsr[1][edge_index_gcsr[0] == node_idx]

        # Create an entry for the source node
        subgraph_info_gcsr[int(node_idx)] = {
            'source_node_nonzero_indices': get_nonzero_features(data.x[node_idx]),
            'neighbors': {},
            'totalLength': 0
        }

        # For each neighbor, get their nonzero feature indices
        for neighbor_idx in neighbors_idx_arr:
            subgraph_info_gcsr[int(node_idx)]['neighbors'][int(neighbor_idx)] = get_nonzero_features(data.x[neighbor_idx])
        subgraph_info_gcsr[int(node_idx)]['totalLength'] = 1 + len(neighbors_idx_arr)

    for source_node_idx, subgraph_info in subgraph_info_gcsr.items():
        num_of_nodes = subgraph_info['totalLength']

        # Source node data
        source_node_nonzero_idx_arr = subgraph_info['source_node_nonzero_indices']
        source_node_nonzero_feature_arr = data.x[source_node_idx, source_node_nonzero_idx_arr]
        source_node_row_length = len(source_node_nonzero_feature_arr)
        # Assign value to source_node_non_zero_idx_arr & source_node_non_zero_feature = 0
        if source_node_row_length == 0:
          source_node_nonzero_idx_arr = torch.tensor([0], dtype=torch.int64)
          source_node_nonzero_feature_arr = torch.tensor([0])

        tmp_col_index_subgraph = [source_node_nonzero_idx_arr]
        tmp_value_subgraph = [source_node_nonzero_feature_arr]
        tmp_node_info_subgraph = [(source_node_row_length, num_of_nodes, 1)]

      # Neighbor node data
        for neighbor_node_idx, neighbor_node_nonzero_idx_arr in subgraph_info['neighbors'].items():
          neighbor_node_nonzero_feature_arr = data.x[neighbor_node_idx, neighbor_node_nonzero_idx_arr]
          neighbor_node_row_length = len(neighbor_node_nonzero_feature_arr)
          # Assign value to source_node_non_zero_idx_arr & source_node_non_zero_feature = 0
          if neighbor_node_row_length == 0:
            neighbor_node_nonzero_feature_arr = torch.tensor([0])
            neighbor_node_nonzero_idx_arr = torch.tensor([0], dtype=torch.int64)

          tmp_col_index_subgraph.append(neighbor_node_nonzero_idx_arr)
          tmp_value_subgraph.append(neighbor_node_nonzero_feature_arr)
          tmp_node_info_subgraph.append((neighbor_node_row_length, num_of_nodes, 0))

        # Merge multiple tensors of each array to 1 tensor
        tmp_col_index_subgraph_streamline = torch.cat(tmp_col_index_subgraph)
        tmp_value_subgraph_streamline = torch.cat(tmp_value_subgraph)
        tmp_node_info_subgraph_streamline = torch.tensor(tmp_node_info_subgraph)

        col_index_tensor.append(tmp_col_index_subgraph_streamline)
        value_tensor.append(tmp_value_subgraph_streamline)
        node_info_tensor.append(tmp_node_info_subgraph_streamline)

    col_index_tensor_streamline = torch.cat(col_index_tensor)
    value_tensor_streamline = torch.cat(value_tensor)
    node_info_tensor_streamline = torch.cat(node_info_tensor)

    col_index_result = col_index_tensor_streamline.tolist()
    value_result = value_tensor_streamline.tolist()
    node_info_result = node_info_tensor_streamline.tolist()

    bits_required_col_index = bits_required(max(col_index_result), min(col_index_result), Configuration["GCSR"]["colIndexSigned"]) if Configuration["GCSR"]["colIndexRequiredBit"] == None else Configuration["GCSR"]["colIndexRequiredBit"]
    bits_required_value = bits_required(max(value_result), min(value_result), Configuration["GCSR"]["valueSigned"]) if Configuration["GCSR"]["valueRequiredBit"] == None else Configuration["GCSR"]["valueRequiredBit"]
    bits_required_node_info_row_length = bits_required(max(row[0] for row in node_info_result), min(row[0] for row in node_info_result), Configuration["GCSR"]["rowLengthSigned"]) if Configuration["GCSR"]["rowLengthRequiredBit"] == None else Configuration["GCSR"]["rowLengthRequiredBit"]
    bits_required_node_info_num_of_nodes = bits_required(max(row[1] for row in node_info_result), min(row[1] for row in node_info_result), Configuration["GCSR"]["numOfNodesSigned"]) if Configuration["GCSR"]["numOfNodesRequiredBit"] == None else Configuration["GCSR"]["numOfNodesRequiredBit"]
    bits_required_node_info_flag = bits_required(max(row[2] for row in node_info_result), min(row[2] for row in node_info_result), Configuration["GCSR"]["flagSigned"]) if Configuration["GCSR"]["flagRequiredBit"] == None else Configuration["GCSR"]["flagRequiredBit"]

    if printInfoShape:
      print(f"Length of col_index: {len(col_index_result)}, Max: {max(col_index_result)}, Min: {min(col_index_result)}, Bits Required: {bits_required_col_index}")
      print(f"Length of value_result: {len(value_result)}, Max: {max(value_result)}, Min: {min(value_result)}, Bits Required: {bits_required_value}")
      print(f"Length of node_info_result: {len(node_info_result)}, Max: {max(max(row) for row in node_info_result)}, Min: {min(min(row) for row in node_info_result)}, Bits Required: {bits_required_node_info_row_length} - {bits_required_node_info_num_of_nodes} - {bits_required_node_info_flag}")

    return col_index_result, value_result, node_info_result, bits_required_col_index, bits_required_value, bits_required_node_info_row_length, bits_required_node_info_num_of_nodes, bits_required_node_info_flag


In [None]:
GCSRBuilder = GCSR_Data_Compression_Builder()
col_index_raw, value_raw, node_info_raw, bits_required_col_index, bits_required_value, bits_required_node_info_row_length, bits_required_node_info_num_of_nodes, bits_required_node_info_flag = GCSRBuilder.getInfoGCSR()

# Learnable Param Quantization

In [18]:
class LearnableParamQuanBuilder():
  def getInfoParamQuan(self, printParamShape = Configuration["LParamQuan"]["printParamShape"]):
    a = GATModel.state_dict()
    for k, v in a.items():
      scaled_tensor, scale_back_fn = scale_tensor(v,
                                                  Configuration["LParamQuan"]["scaleMin"],
                                                  Configuration["LParamQuan"]["scaleMax"],
                                                  torch.int8)
      if scaled_tensor.ndim == 3:
        scaled_tensor = scaled_tensor.reshape(scaled_tensor.shape[0], -1)
      a[k] = scaled_tensor
    conv1_att_src_result =  [element[0] for element in a['conv1.att_src'].t().tolist()]
    conv1_att_dst_result = [element[0] for element in a['conv1.att_dst'].t().tolist()]
    conv1_weight_result = a['conv1.lin.weight'].tolist()

    conv2_att_src_result = [element[0] for element in a['conv2.att_src'].t().tolist()]
    conv2_att_dst_result = [element[0] for element in a['conv2.att_dst'].t().tolist()]
    conv2_weight_result = a['conv2.lin.weight'].tolist()

    bits_required_conv1_att_src = bits_required(max(conv1_att_src_result), min(conv1_att_src_result), Configuration["LParamQuan"]["conv1AttSrcSigned"]) if Configuration["LParamQuan"]["conv1AttSrcRequiredBit"] == None else Configuration["LParamQuan"]["conv1AttSrcRequiredBit"]
    bits_required_conv1_att_dst = bits_required(max(conv1_att_dst_result), min(conv1_att_dst_result), Configuration["LParamQuan"]["conv1AttDstSigned"]) if Configuration["LParamQuan"]["conv1AttDstRequiredBit"] == None else Configuration["LParamQuan"]["conv1AttDstRequiredBit"]
    bits_required_conv1_weight = bits_required(max(max(row) for row in conv1_weight_result), min(min(row) for row in conv1_weight_result), Configuration["LParamQuan"]["conv1WeightSigned"]) if Configuration["LParamQuan"]["conv1WeightRequiredBit"] == None else Configuration["LParamQuan"]["conv1WeightRequiredBit"]

    bits_required_conv2_att_src = bits_required(max(conv2_att_src_result), min(conv2_att_src_result), Configuration["LParamQuan"]["conv2AttSrcSigned"]) if Configuration["LParamQuan"]["conv2AttSrcRequiredBit"] == None else Configuration["LParamQuan"]["conv2AttSrcRequiredBit"]
    bits_required_conv2_att_dst = bits_required(max(conv2_att_dst_result), min(conv2_att_dst_result), Configuration["LParamQuan"]["conv2AttDstSigned"]) if Configuration["LParamQuan"]["conv2AttDstRequiredBit"] == None else Configuration["LParamQuan"]["conv2AttDstRequiredBit"]
    bits_required_conv2_weight = bits_required(max(max(row) for row in conv2_weight_result), min(min(row) for row in conv2_weight_result), Configuration["LParamQuan"]["conv2WeightSigned"]) if Configuration["LParamQuan"]["conv2WeightRequiredBit"] == None else Configuration["LParamQuan"]["conv2WeightRequiredBit"]

    if printParamShape:
      print(f"Length of conv1_att_src: {len(conv1_att_src_result)}, Max: {max(conv1_att_src_result)}, Min: {min(conv1_att_src_result)}, Bits Required: {bits_required_conv1_att_src}")
      print(f"Length of conv1_dst_src: {len(conv1_att_dst_result)}, Max: {max(conv1_att_dst_result)}, Min: {min(conv1_att_dst_result)}, Bits Required: {bits_required_conv1_att_dst}")
      print(f"Length of conv1_weight: {len(conv1_weight_result)} x {len(conv1_weight_result[0])}, Max: {max(max(row) for row in conv1_weight_result)}, Min: {min(min(row) for row in conv1_weight_result)}, Bits Required: {bits_required_conv1_weight}")

      print(f"Length of conv2_att_src: {len(conv2_att_src_result)}, Max: {max(conv2_att_src_result)}, Min: {min(conv2_att_src_result)}, Bits Required: {bits_required_conv2_att_src}")
      print(f"Length of conv2_dst_src: {len(conv2_att_dst_result)}, Max: {max(conv2_att_dst_result)}, Min: {min(conv2_att_dst_result)}, Bits Required: {bits_required_conv2_att_dst}")
      print(f"Length of conv2_weight: {len(conv2_weight_result)} x {len(conv2_weight_result[0])}, Max: {max(max(row) for row in conv2_weight_result)}, Min: {min(min(row) for row in conv2_weight_result)}, Bits Required: {bits_required_conv2_weight}")

    return conv1_att_src_result, conv1_att_dst_result, conv1_weight_result, conv2_att_src_result, conv2_att_dst_result,conv2_weight_result, bits_required_conv1_att_src, bits_required_conv1_att_dst, bits_required_conv1_weight, bits_required_conv2_att_src, bits_required_conv2_att_dst, bits_required_conv2_weight


In [None]:
LParamBuilder = LearnableParamQuanBuilder()
conv1_att_src, conv1_att_dst, conv1_weight, conv2_att_src, conv2_att_dst,conv2_weight, bits_required_conv1_att_src, bits_required_conv1_att_dst, bits_required_conv1_weight, bits_required_conv2_att_src, bits_required_conv2_att_dst, bits_required_conv2_weight = LParamBuilder.getInfoParamQuan()

# Raw Data Result

In [20]:
class RawDataResult:
  def getRawDataResult(self):
    raw =  {
        'col_index': {
            'data': col_index_raw if Configuration["RawDataResult"]["showData"] else None,
            'max': max(col_index_raw),
            'min': min(col_index_raw),
            'bits': bits_required_col_index,
            'signed': Configuration["GCSR"]["colIndexSigned"],
            'shape': len(col_index_raw),
            'count': len(col_index_raw)
        },
        'value': {
            'data': value_raw if Configuration["RawDataResult"]["showData"] else None,
            'max': max(value_raw),
            'min': min(value_raw),
            'bits': bits_required_value,
            'signed': Configuration["GCSR"]["valueSigned"],
            'shape': len(value_raw),
            'count': len(value_raw)
        },
        'node_info': {
            'row_length':{
                'data': [row[0] for row in node_info_raw] if Configuration["RawDataResult"]["showData"] else None,
                'max': max([row[0] for row in node_info_raw]),
                'min': min([row[0] for row in node_info_raw]),
                'bits': bits_required_node_info_row_length,
                'signed': Configuration["GCSR"]["rowLengthSigned"],
                'shape': len([row[0] for row in node_info_raw]),
                'count': len([row[0] for row in node_info_raw])
            },
            'num_of_nodes': {
                'data': [row[1] for row in node_info_raw] if Configuration["RawDataResult"]["showData"] else None,
                'max': max([row[1] for row in node_info_raw]),
                'min': min([row[1] for row in node_info_raw]),
                'bits': bits_required_node_info_num_of_nodes,
                'signed': Configuration["GCSR"]["numOfNodesSigned"],
                'shape': len([row[1] for row in node_info_raw]),
                'count': len([row[1] for row in node_info_raw])
            },
            'flag': {
                'data': [row[2] for row in node_info_raw] if Configuration["RawDataResult"]["showData"] else None,
                'max': max([row[2] for row in node_info_raw]),
                'min': min([row[2] for row in node_info_raw]),
                'bits': bits_required_node_info_flag,
                'signed': Configuration["GCSR"]["flagSigned"],
                'shape': len([row[2] for row in node_info_raw]),
                'count': len([row[2] for row in node_info_raw])
            }
        },
        'conv1': {
            'att_src': {
                'data': conv1_att_src if Configuration["RawDataResult"]["showData"] else None,
                'max': max(conv1_att_src),
                'min': min(conv1_att_src),
                'bits': bits_required_conv1_att_src,
                'signed': Configuration["LParamQuan"]["conv1AttSrcSigned"],
                'shape': len(conv1_att_src),
                'count': len(conv1_att_src)
            },
            'att_dst': {
                'data': conv1_att_dst if Configuration["RawDataResult"]["showData"] else None,
                'max': max(conv1_att_dst),
                'min': min(conv1_att_dst),
                'bits': bits_required_conv1_att_dst,
                'signed': Configuration["LParamQuan"]["conv1AttDstSigned"],
                'shape': len(conv1_att_dst),
                'count': len(conv1_att_dst)
            },
            'weight': {
                'data': conv1_weight if Configuration["RawDataResult"]["showData"] else None,
                'max': max(max(row) for row in conv1_weight),
                'min': min(min(row) for row in conv1_weight),
                'bits': bits_required_conv1_weight,
                'signed': Configuration["LParamQuan"]["conv1WeightSigned"],
                'shape': f"{len(conv1_weight)} x {len(conv1_weight[0])}",
                'count': len(conv1_weight) * len(conv1_weight[0])
            }
        },
        'conv2': {
            'att_src': {
                'data': conv2_att_src if Configuration["RawDataResult"]["showData"] else None,
                'max': max(conv2_att_src),
                'min': min(conv2_att_src),
                'bits': bits_required_conv2_att_src,
                'signed': Configuration["LParamQuan"]["conv2AttSrcSigned"],
                'shape': len(conv2_att_src),
                'count': len(conv2_att_src)
            },
            'att_dst': {
                'data': conv2_att_dst if Configuration["RawDataResult"]["showData"] else None,
                'max': max(conv2_att_dst),
                'min': min(conv2_att_dst),
                'bits': bits_required_conv2_att_dst,
                'signed': Configuration["LParamQuan"]["conv2AttDstSigned"],
                'shape': len(conv2_att_dst),
                'count': len(conv2_att_dst)
            },
            'weight': {
                'data': conv2_weight if Configuration["RawDataResult"]["showData"] else None,
                'max': max(max(row) for row in conv2_weight),
                'min': min(min(row) for row in conv2_weight),
                'bits': bits_required_conv2_weight,
                'signed': Configuration["LParamQuan"]["conv2WeightSigned"],
                'shape': f"{len(conv2_weight)} x {len(conv2_weight[0])}",
                'count': len(conv2_weight) * len(conv2_weight[0])
            }
        }
    }
    return raw

  def getTableResult(self):
    table = PrettyTable()
    raw = self.getRawDataResult()

    # Define the table headers
    table.field_names = ["Category", "Subcategory", "Max", "Min", "Bits", "Signed", "Shape", "Count"]

    # Add rows to the table for each key in `raw`
    for category, subcategories in raw.items():
        if isinstance(subcategories, dict):
            for subcategory, attributes in subcategories.items():
                if isinstance(attributes, dict):  # Nested dictionary
                    table.add_row([
                        category,
                        subcategory,
                        attributes.get('max', '-'),
                        attributes.get('min', '-'),
                        attributes.get('bits', '-'),
                        attributes.get('signed', '-'),
                        attributes.get('shape', '-'),
                        attributes.get('count', '-')
                    ])
                else:
                    table.add_row([
                        category,
                        "-",
                        subcategories.get('max', '-'),
                        subcategories.get('min', '-'),
                        subcategories.get('bits', '-'),
                        subcategories.get('signed', '-'),
                        subcategories.get('shape', '-'),
                        subcategories.get('count', '-')
                    ])
                    break;
        else:  # Leaf-level value
            table.add_row([category, "-", "-", "-", "-", "-", "-", "-"])
    return table

In [None]:
rawDataResult = RawDataResult()
print(rawDataResult.getTableResult())

# PCOO Data Compression Builder

PE (Processing elements) -> Header

For the convenience of parallel computation, there may be more than one PE created. With n PEs created, the r-th row of the matrix will be assigned to the PE corresponding to the remainder when r is divided by n.
If n = 2:
  PE 0 -> control row 0, 2, 4, 6, 8 of feature matrix
  PE 1 -> control row 1, 3, 5, 7, 9 of feature matrix

Header:
- Start-of-row (SOR): Indicates whether this is the first non-zero element in the row.
- End-of-row (EOR): Indicates whether this is the last non-zero element in the row.
- Valid (VLD): Indicates whether the element participates in the computation.

- For rows without any non-zero elements, this 3-bit set will have the value 110, indicating that from the start to the end of the row, no elements participate in the computation.


Body:
- Column (col): Indicates column
- Value (val): Indicates value of feature

Example:
Header{1,0,1}Body{1,a}

Code Implementation:
- node_info_gcsr: idx -> tensor of non-zero features idx
+ Example:
      1: {
        'source_node_nonzero_indices': tensor[3,9,10,...],
        'neighbors': {
            2: tensor[3,9,10,...],
            3: tensor[3,9,10,...]
        }
        'totalLength': 3
      }
   


In [22]:
class PCOO_Data_Compression_Builder():
  def getInfoPCOO(self):
    #Pre-processing subgraph
    featureSubgraph = []
    tmp = []

    for node_idx in torch.unique(data.edge_index[0]):
      neighbors = data.edge_index[1][data.edge_index[0] == node_idx]
      tmp.append(int(node_idx.item()))
      for neighbor_idx in neighbors:
        tmp.append(int(neighbor_idx.item()))

    for node_idx in tmp:
      featureSubgraph.append(data.x[node_idx].tolist())

    #PCOO Processing
    pcooData = []
    numRows, numCols = len(featureSubgraph), len(featureSubgraph[0])

    for row_idx in range(numRows):
        non_zero_elements = [(col_idx, featureSubgraph[row_idx][col_idx]) for col_idx in range(numCols) if featureSubgraph[row_idx][col_idx] != 0]

        if not non_zero_elements:
            pcooData.append({'header': {'sor': 1, 'eor': 1, 'vld': 0}, 'body': {'col': 0, 'val': 0}})
            continue

        for i, (col_idx, value) in enumerate(non_zero_elements):
            # Determine header bits for the element
            sor = 1 if i == 0 else 0  # Start of row
            eor = 1 if i == len(non_zero_elements) - 1 else 0  # End of row
            vld = 1  # Set valid to 1 (assuming all values are used for computation)

            if torch.is_tensor(value):
                value = int(value.item())
            else:
                value = int(value)

            # Combine header and body info for this element
            header = {'sor': sor, 'eor': eor, 'vld': vld}
            body = {
                'col': col_idx,
                'val': value
            }
            pcooData.append({'header': header, 'body': body})

    return pcooData

  def getConstraint(self):
    pcooData = self.getInfoPCOO();
    max_min_values = {field: {'max': float('-inf'), 'min': float('inf')} for field in ['sor', 'eor', 'vld', 'col', 'val']}

    # Iterate through the array
    for entry in pcooData:
        header = entry['header']
        body = entry['body']

        # Update max and min values
        for field in max_min_values:
            value = header.get(field, body.get(field))
            if value is not None:
                max_min_values[field]['max'] = max(max_min_values[field]['max'], value)
                max_min_values[field]['min'] = min(max_min_values[field]['min'], value)

    bits_required_sor = bits_required(max_min_values['sor']['max'], max_min_values['sor']['min'], Configuration["PCOO"]["sorSigned"]) if Configuration["PCOO"]["sorRequiredBit"] == None else Configuration["PCOO"]["sorRequiredBit"]
    bits_required_eor = bits_required(max_min_values['eor']['max'], max_min_values['eor']['min'], Configuration["PCOO"]["eorSigned"]) if Configuration["PCOO"]["eorRequiredBit"] == None else Configuration["PCOO"]["eorRequiredBit"]
    bits_required_vld = bits_required(max_min_values['vld']['max'], max_min_values['vld']['min'], Configuration["PCOO"]["vldSigned"]) if Configuration["PCOO"]["vldRequiredBit"] == None else Configuration["PCOO"]["vldRequiredBit"]
    bits_required_col = bits_required(max_min_values['col']['max'], max_min_values['col']['min'], Configuration["PCOO"]["colSigned"]) if Configuration["PCOO"]["colRequiredBit"] == None else Configuration["PCOO"]["colRequiredBit"]
    bits_required_val = bits_required(max_min_values['val']['max'], max_min_values['val']['min'], Configuration["PCOO"]["valSigned"]) if Configuration["PCOO"]["valRequiredBit"] == None else Configuration["PCOO"]["valRequiredBit"]

    # Table Processing

    raw =  {
        'header': {
            'sor':{
                'max': max_min_values['sor']['max'],
                'min': max_min_values['sor']['min'],
                'bits': bits_required_sor,
                'signed': Configuration["PCOO"]["sorSigned"],
                'shape': len(pcooData),
                'count': len(pcooData)
            },
            'eor': {
                'max': max_min_values['eor']['max'],
                'min': max_min_values['eor']['min'],
                'bits': bits_required_eor,
                'signed': Configuration["PCOO"]["eorSigned"],
                'shape': len(pcooData),
                'count': len(pcooData)
            },
            'vld': {
                'max': max_min_values['vld']['max'],
                'min': max_min_values['vld']['min'],
                'bits': bits_required_vld,
                'signed': Configuration["PCOO"]["vldSigned"],
                'shape': len(pcooData),
                'count': len(pcooData)
            }
        },
        'body': {
            'col':{
                'max': max_min_values['col']['max'],
                'min': max_min_values['col']['min'],
                'bits': bits_required_col,
                'signed': Configuration["PCOO"]["colSigned"],
                'shape': len(pcooData),
                'count': len(pcooData)
            },
            'val': {
                'max': max_min_values['val']['max'],
                'min': max_min_values['val']['min'],
                'bits': bits_required_val,
                'signed': Configuration["PCOO"]["valSigned"],
                'shape': len(pcooData),
                'count': len(pcooData)
            },
        }
    }

    table = PrettyTable()

    # Define the table headers
    table.field_names = ["Category", "Subcategory", "Max", "Min", "Bits", "Signed", "Shape", "Count"]

    # Add rows to the table for each key in `raw`
    for category, subcategories in raw.items():
        if isinstance(subcategories, dict):
            for subcategory, attributes in subcategories.items():
                if isinstance(attributes, dict):  # Nested dictionary
                    table.add_row([
                        category,
                        subcategory,
                        attributes.get('max', '-'),
                        attributes.get('min', '-'),
                        attributes.get('bits', '-'),
                        attributes.get('signed', '-'),
                        attributes.get('shape', '-'),
                        attributes.get('count', '-')
                    ])
                else:
                    table.add_row([
                        category,
                        "-",
                        subcategories.get('max', '-'),
                        subcategories.get('min', '-'),
                        subcategories.get('bits', '-'),
                        subcategories.get('signed', '-'),
                        subcategories.get('shape', '-'),
                        subcategories.get('count', '-')
                    ])
                    break;
        else:  # Leaf-level value
            table.add_row([category, "-", "-", "-", "-", "-", "-", "-"])
    return table

In [None]:
PCOODataCompression = PCOO_Data_Compression_Builder()
pcooRawData = PCOODataCompression.getInfoPCOO()
print(PCOODataCompression.getConstraint())

# Manual Calculation

In [None]:
#Start time
start_time = time.time()

a = GATModel.state_dict()
# Using quantized parameter
for k, v in a.items():
  scaled_tensor, _ = scale_tensor(v,
                                Configuration["BuildModel"]["scaleMin"],
                                Configuration["BuildModel"]["scaleMax"],
                                torch.int8)
  a[k] = scaled_tensor

# Wh
# tmp_feature_torch = feature_matrix_1 * weight_1
tmp_feature_torch= torch.matmul(data.x.to(dtype=torch.float32) , a['conv1.lin.weight'].to(dtype=torch.float32).t())
print(f"tmp_feature_matrix shape: {tmp_feature_torch.shape}")

# e_i_j = a_src.W.h_i || a_dst.W.h_j (TODO: DMVM + Calculate according to subgraph + COEF(e_i_j))
tmp_info = []
for i in range(data.edge_index.size(1)):
  src_node = data.edge_index[0, i]
  dst_node = data.edge_index[1, i]
  feature_src_node = tmp_feature_torch[src_node] # W.h_i
  feature_dst_node = tmp_feature_torch[dst_node] # W.h_j
  att_src_node = torch.matmul(a['conv1.att_src'].to(dtype=torch.float32), feature_src_node.to(dtype=torch.float32)) # a_src.W.h_i
  att_dst_node = torch.matmul(a['conv1.att_dst'].to(dtype=torch.float32), feature_dst_node.to(dtype=torch.float32)) # a_dst.W.h_j
  tmp_info.append([src_node.item(), dst_node.item(), (att_src_node + att_dst_node).item()])

tmp_info = sorted(tmp_info, key=lambda x: x[0])
bit_required_e_i_j = bits_required(max([row[2] for row in tmp_info]), min([row[2] for row in tmp_info]), True)
print(f"tmp_info (a(Wh1||Wh2)): {len(tmp_info)}, max: {max([row[2] for row in tmp_info])}, min: {min([row[2] for row in tmp_info])}, bits_required: {bit_required_e_i_j}")

# Group by src_node
group_by_src_node = defaultdict(list)
for row in tmp_info:
  group_by_src_node[row[0]].append(row)

# Softmax + LeakyRelu
max_numerators = 0
max_denominators = 0
for i, rows in group_by_src_node.items():
  numerators = [2**((max(row[2], 0) /(2**(bit_required_e_i_j - 8)))) for row in rows] #TODO
  # print("Divide by", bit_required_e_i_j - 8)
  denominator = sum(numerators) #TODO
  max_numerators = max(max_numerators, max(numerators))
  max_denominators = max(max_denominators, denominator)
  for idx, row in enumerate(rows):
    alpha_i_j = numerators[idx] / denominator #TODO
    row.append(alpha_i_j) #[src_node, dst_node, a(Wh1||Wh2), alpha_i_j] -> group_src_node

print(f"max_numerators of softmax: {max_numerators}, bits_required: {bits_required(max_numerators, 0)}")
print(f"max_denominators of softmax: {max_denominators}, bits_required: {bits_required(max_denominators, 0)}")


group_to_matrix = [row for group in group_by_src_node.values() for row in group] #[[0, 633, -186704.0, 0.3333333333333333], [0, 1862, -182407.0, 0.3333333333333333], [0, 2582, -144662.0, 0.3333333333333333], [1, 2, -100041.0, 0.040421365723215065],...]
result_feature_torch = torch.zeros(tmp_feature_torch.shape[0], tmp_feature_torch.shape[1]).to(dtype=torch.float32) #h'


for row in group_to_matrix:
   result_feature_torch[row[0]] = result_feature_torch[row[0]] + torch.tensor(row[3]).to(dtype=torch.float32) * tmp_feature_torch[row[1]].to(dtype=torch.float32).unsqueeze(0)

print(f"h' after first layer: {result_feature_torch.shape}, max: {torch.max(result_feature_torch)}, min: {torch.min(result_feature_torch)}")
print(result_feature_torch)
print(f"result_feature shape: {result_feature_torch.shape}, in range: {((result_feature_torch >= -127) & (result_feature_torch <= 127)).sum().item()}")

# Quantized
result_feature_torch, _ = scale_tensor(result_feature_torch,
                                                       Configuration["BuildModel"]["scaleMin"],
                                                       Configuration["BuildModel"]["scaleMax"],
                                                       torch.int8)
print("After quantized: ", result_feature_torch)
print(f"result_feature shape: {result_feature_torch.shape}, in range: {((result_feature_torch >= -127) & (result_feature_torch <= 127)).sum().item()}")


# Wh
# tmp_feature_torch = feature_matrix_2 * weight_2
tmp_feature_torch= torch.matmul(result_feature_torch.to(dtype=torch.float32) , a['conv2.lin.weight'].to(dtype=torch.float32).t())
print(f"tmp_feature_matrix shape: {tmp_feature_torch.shape}")


# e_i_j = a_src.W.h_i || a_dst.W.h_j
tmp_info = []
for i in range(data.edge_index.size(1)):
  src_node = data.edge_index[0, i]
  dst_node = data.edge_index[1, i]
  feature_src_node = tmp_feature_torch[src_node] # W.h_i
  feature_dst_node = tmp_feature_torch[dst_node] # W.h_j
  att_src_node = torch.matmul(a['conv2.att_src'].to(dtype=torch.float32), feature_src_node.to(dtype=torch.float32)) # a_src.W.h_i
  att_dst_node = torch.matmul(a['conv2.att_dst'].to(dtype=torch.float32), feature_dst_node.to(dtype=torch.float32)) # a_dst.W.h_j
  tmp_info.append([src_node.item(), dst_node.item(), (att_src_node + att_dst_node).item()])

tmp_info = sorted(tmp_info, key=lambda x: x[0])
bit_required_e_i_j = bits_required(max([row[2] for row in tmp_info]), min([row[2] for row in tmp_info]), True)
print(f"tmp_info (a(Wh1||Wh2)): {len(tmp_info)}, max: {max([row[2] for row in tmp_info])}, min: {min([row[2] for row in tmp_info])}, bits_required: {bit_required_e_i_j}")

# Group by src_node
group_by_src_node = defaultdict(list)
for row in tmp_info:
  group_by_src_node[row[0]].append(row)

# Softmax + LeakyRelu
max_numerators = 0
max_denominators = 0
for i, rows in group_by_src_node.items():
  numerators = [2**((max(row[2], 0) /(2**(bit_required_e_i_j - 8)))) for row in rows]
  denominator = sum(numerators)
  max_numerators = max(max_numerators, max(numerators))
  max_denominators = max(max_denominators, denominator)
  for idx, row in enumerate(rows):
    alpha_i_j = numerators[idx] / denominator
    row.append(alpha_i_j) #[src_node, dst_node, a(Wh1||Wh2), alpha_i_j] -> group_src_node

print(f"max_numerators of softmax: {max_numerators}, bits_required: {bits_required(max_numerators, 0)}")
print(f"max_denominators of softmax: {max_denominators}, bits_required: {bits_required(max_denominators, 0)}")


group_to_matrix = [row for group in group_by_src_node.values() for row in group] #[[0, 633, -186704.0, 0.3333333333333333], [0, 1862, -182407.0, 0.3333333333333333], [0, 2582, -144662.0, 0.3333333333333333], [1, 2, -100041.0, 0.040421365723215065],...]
result_feature_torch = torch.zeros(tmp_feature_torch.shape[0], tmp_feature_torch.shape[1]).to(dtype=torch.float32) #h'


for row in group_to_matrix:
   result_feature_torch[row[0]] = result_feature_torch[row[0]] + torch.tensor(row[3]).to(dtype=torch.float32) * tmp_feature_torch[row[1]].to(dtype=torch.float32).unsqueeze(0)

print(f"h' after second layer: {result_feature_torch.shape}, max: {torch.max(result_feature_torch)}, min: {torch.min(result_feature_torch)}")
print(result_feature_torch)

# Output comparison
my_output = torch.argmax(result_feature_torch, dim=1)
print("Classification: ", my_output)
print("Correct result", data.y)

correct_count = (my_output == data.y).sum().item()
print(f"Correct count: {correct_count} / {tmp_feature_torch.shape[0]}, acc: {correct_count / tmp_feature_torch.shape[0]}")

end_time = time.time()
print(f"Time to run test: {end_time - start_time:.6f} seconds")



#Export File Util

In [25]:
def format_number(value):
    # Convert the number to a string
    value_str = str(value)
    # Check if the string ends with '.0' and remove it
    if value_str.endswith('.0'):
        value_str = value_str[:-2]
    return value_str


def initFileConfig(isLayer1 = False):
  option = 2
  if isLayer1:
    option = 1
  FirstLayerStructure = {
      "input_dir" : {
        "name": f"exports/layer_{option}/input/",
        "child": {
          "weight": "weight.txt",
          "attention": "a.txt",
          "col_idx_and_value": "h_data.txt",
          "node_info": "node_info.txt",
          "input_value": "h_data.txt",
          "graph_index": "graph_index.txt",
          "merge_weight": "merge_weight.txt",
        }
      },
      "output_dir": {
        "name": f"exports/layer_{option}/output/",
        "child": {
          "SPMM_dir": {
            "name": "SPMM/",
            "child": {
                "wh": "WH.txt"
            }
          },
          "DMVM_dir": {
            "name": "DMVM/",
            "child": {
                "coef": "COEF.txt",
                "DMVM": "DMVM.txt"
            }
          },
          "softmax_dir": {
            "name": "softmax/",
            "child": {
                "num_nodes": "num_nodes.txt",
                "alpha": "ALPHA.txt",
                "dividend": "DIVIDEND.txt",
                "divisor": "DIVISOR.txt"
            }
          },
          "aggregator_dir": {
            "name": "aggregator/",
            "child": {
                "new_feature": "new_feature.txt"
            }
          }
        }
      }
    }

  FLFileConfig = {
      1:{
        "input_dir": FirstLayerStructure['input_dir']['name'],
        "attention_file": FirstLayerStructure['input_dir']['child']['attention'],
        "weight_file": FirstLayerStructure['input_dir']['child']['weight'],
        "col_idx_and_value_file": FirstLayerStructure['input_dir']['child']['col_idx_and_value'],
        "node_info_file": FirstLayerStructure['input_dir']['child']['node_info'],
        "graph_index_file": FirstLayerStructure['input_dir']['child']['graph_index'],
        "merge_weight_file": FirstLayerStructure['input_dir']['child']['merge_weight'],
        "output_spmm_dir": f"{FirstLayerStructure['output_dir']['name']}{FirstLayerStructure['output_dir']['child']['SPMM_dir']['name']}",
        "output_wh_file": FirstLayerStructure['output_dir']['child']['SPMM_dir']['child']['wh'],
        "output_dmvm_dir": f"{FirstLayerStructure['output_dir']['name']}{FirstLayerStructure['output_dir']['child']['DMVM_dir']['name']}",
        "output_coef_file": FirstLayerStructure['output_dir']['child']['DMVM_dir']['child']['coef'],
        "output_dmvm_file": FirstLayerStructure['output_dir']['child']['DMVM_dir']['child']['DMVM'],
        "output_softmax_dir": f"{FirstLayerStructure['output_dir']['name']}{FirstLayerStructure['output_dir']['child']['softmax_dir']['name']}",
        "output_num_nodes_file": FirstLayerStructure['output_dir']['child']['softmax_dir']['child']['num_nodes'],
        "output_alpha_file": FirstLayerStructure['output_dir']['child']['softmax_dir']['child']['alpha'],
        "output_dividend_file": FirstLayerStructure['output_dir']['child']['softmax_dir']['child']['dividend'],
        "output_divisor_file": FirstLayerStructure['output_dir']['child']['softmax_dir']['child']['divisor'],
        "output_aggregator_dir": f"{FirstLayerStructure['output_dir']['name']}{FirstLayerStructure['output_dir']['child']['aggregator_dir']['name']}",
        "output_new_feature_file": FirstLayerStructure['output_dir']['child']['aggregator_dir']['child']['new_feature'],
      },
      2: {
        "input_dir": FirstLayerStructure['input_dir']['name'],
        "attention_file": FirstLayerStructure['input_dir']['child']['attention'],
        "weight_file": FirstLayerStructure['input_dir']['child']['weight'],
        "input_value_file": FirstLayerStructure['input_dir']['child']['input_value'],
        "merge_weight_file": FirstLayerStructure['input_dir']['child']['merge_weight'],
        "output_spmm_dir": f"{FirstLayerStructure['output_dir']['name']}{FirstLayerStructure['output_dir']['child']['SPMM_dir']['name']}",
        "output_wh_file": FirstLayerStructure['output_dir']['child']['SPMM_dir']['child']['wh'],
        "output_dmvm_dir": f"{FirstLayerStructure['output_dir']['name']}{FirstLayerStructure['output_dir']['child']['DMVM_dir']['name']}",
        "output_coef_file": FirstLayerStructure['output_dir']['child']['DMVM_dir']['child']['coef'],
        "output_dmvm_file": FirstLayerStructure['output_dir']['child']['DMVM_dir']['child']['DMVM'],
        "output_softmax_dir": f"{FirstLayerStructure['output_dir']['name']}{FirstLayerStructure['output_dir']['child']['softmax_dir']['name']}",
        "output_num_nodes_file": FirstLayerStructure['output_dir']['child']['softmax_dir']['child']['num_nodes'],
        "output_alpha_file": FirstLayerStructure['output_dir']['child']['softmax_dir']['child']['alpha'],
        "output_dividend_file": FirstLayerStructure['output_dir']['child']['softmax_dir']['child']['dividend'],
        "output_divisor_file": FirstLayerStructure['output_dir']['child']['softmax_dir']['child']['divisor'],
        "output_aggregator_dir": f"{FirstLayerStructure['output_dir']['name']}{FirstLayerStructure['output_dir']['child']['aggregator_dir']['name']}",
        "output_new_feature_file": FirstLayerStructure['output_dir']['child']['aggregator_dir']['child']['new_feature'],
      }
  }
  return FLFileConfig[option]


def exportToFilesFirstLayer(input_subgraph, layer_idx, notLayer1InputTensor= None):
  # FilesConfiguration
  FLFileConfig = None
  if notLayer1InputTensor != None:
    FLFileConfig = initFileConfig(False)
  else:
    FLFileConfig = initFileConfig(True)


  # Input
  a = GATModel.state_dict()
  for k, v in a.items():
    scaled_tensor, scale_back_fn = scale_tensor(v,
                                                Configuration["LParamQuan"]["scaleMin"],
                                                Configuration["LParamQuan"]["scaleMax"],
                                                torch.int8)
    if scaled_tensor.ndim == 3:
      scaled_tensor = scaled_tensor.reshape(scaled_tensor.shape[0], -1)
    a[k] = scaled_tensor

  input_dir = FLFileConfig['input_dir']
  os.makedirs(input_dir, exist_ok=True)

  # Attention File
  att_src = [element[0] for element in a[f"conv{layer_idx+1}.att_src"].t().tolist()]
  att_dst = [element[0] for element in a[f"conv{layer_idx+1}.att_dst"].t().tolist()]
  att_total = att_src + att_dst
  att_total_binary = intToBinaryArray(att_total, 8)
  print("Att total binary", att_total_binary)
  write_array_to_file(att_total_binary, createPath(input_dir, FLFileConfig['attention_file']))

  # Weight
  weight = a[f"conv{layer_idx+1}.lin.weight"].t().tolist()
  weight_binary = intToBinaryMatrix(weight, 8)
  print("Weight binary", weight_binary)
  write_matrix_to_file(weight_binary, createPath(input_dir, FLFileConfig['weight_file']))
  append_array_to_file(att_total_binary, createPath(input_dir, FLFileConfig['weight_file']))

  # # Weight Merge
  # if notLayer1InputTensor == None:
  #   att_src_value_1 = [element[0] for element in a[f"conv1.att_src"].t().tolist()]
  #   att_dst_value_1 = [element[0] for element in a[f"conv1.att_dst"].t().tolist()]
  #   att_total_value_1 = att_src_value_1 + att_dst_value_1
  #   att_total_binary_value_1 = intToBinaryArray(att_total_value_1, 8)
  #   print("Att binary 1", att_total_binary_value_1)

  #   att_src_value_2 = [element[0] for element in a[f"conv2.att_src"].t().tolist()]
  #   att_dst_value_2 = [element[0] for element in a[f"conv2.att_dst"].t().tolist()]
  #   att_total_value_2 = att_src_value_2 + att_dst_value_2
  #   att_total_binary_value_2 = intToBinaryArray(att_total_value_2, 8)
  #   print("Att binary 2", att_total_binary_value_2)

  #   weight_value_1 = a[f"conv1.lin.weight"].t().tolist()
  #   weight_binary_value_1 = intToBinaryMatrix(weight_value_1, 8)
  #   weight_binary_value_1_list = [item for row in weight_binary_value_1 for item in row]

  #   weight_value_2 = a[f"conv2.lin.weight"].t().tolist()
  #   weight_binary_value_2 = intToBinaryMatrix(weight_value_2, 8)
  #   weight_binary_value_2_list = [item for row in weight_binary_value_2 for item in row]
  #   print("Weight binary 2", weight_binary_value_2)
  #   print()

  #   write_array_to_file(weight_binary_value_1_list + att_total_binary_value_1 + weight_binary_value_2_list + att_total_binary_value_2, createPath(input_dir, FLFileConfig['merge_weight_file']))

  # Layer 1 only
  if notLayer1InputTensor == None:

    # Col Index & Value
    col_index_binary = intToBinaryArray(col_index_raw, bits_required_col_index)
    value_binary = intToBinaryArray(value_raw, bits_required_value)
    col_index_value_binary = [a + b for a, b in zip(col_index_binary, value_binary)]
    write_array_to_file(col_index_value_binary, createPath(input_dir, FLFileConfig['col_idx_and_value_file']))

    # Node Info
    # TODO: FIX THIS change bits_required_col_index to bits_required_node_info_row_length
    node_info_binray = [row[-1] for row in resolveNodeInfo(node_info_raw, bits_required_col_index, bits_required_node_info_num_of_nodes, bits_required_node_info_flag)]
    write_array_to_file(node_info_binray, createPath(input_dir, FLFileConfig['node_info_file']))


    # Subgraph Index
    print(f"Check by huynguyenn subgraph_index: {len(node_info_binray)}, bits_required: {bits_required(0, len(node_info_binray)-1, False)}")
    # subgraph_index_info = []
    # maximum_subgraph_index = len(node_info_binray)
    # bits_required_subgraph_index = bits_required(0, len(node_info_binray)-1, False)
    # for src_idx in torch.unique(data.edge_index[0]):
    #   neighbors_idx_arr = data.edge_index[1][data.edge_index[0] == src_idx]
    #   subgraph_index_info.append(src_idx.item())
    #   for neighbor_idx in neighbors_idx_arr:
    #     subgraph_index_info.append(neighbor_idx.item())
    # incremental_count_index_info = list(range(maximum_subgraph_index))
    # print("Check length subgraph index:", len(subgraph_index_info),len(incremental_count_index_info))
    # # print("Subgraph_index_info", subgraph_index_info)
    # # print("Incremental_count_index_info", incremental_count_index_info)
    # subgraph_matrix_info = [subgraph_index_info, incremental_count_index_info]
    # print("Subgraph Matrix Info", subgraph_matrix_info)
    # subgraph_index_binary = []
    # for src_idx in torch.unique(data.edge_index[0]):
    #   index_arr = [subgraph_matrix_info[1][i] for i in range(len(subgraph_matrix_info[0])) if subgraph_matrix_info[0][i] == src_idx]
    #   # print("Index arr", len(index_arr), index_arr)
    #   for i, value in enumerate(index_arr):
    #     if i == 0:
    #       subgraph_index_binary.append(int_to_n_bit_binary_with_flags(value, bits_required_subgraph_index, True, False))
    #     elif i == len(index_arr) - 1:
    #       subgraph_index_binary.append(int_to_n_bit_binary_with_flags(value, bits_required_subgraph_index, False, True))
    #     else:
    #       subgraph_index_binary.append(int_to_n_bit_binary_with_flags(value, bits_required_subgraph_index, False, False))

    # print("Subgraph binary", len(subgraph_index_binary))
    # write_array_to_file(subgraph_index_binary, createPath(input_dir, FLFileConfig['graph_index_file']))


    print("Bits required col_index: ", bits_required_col_index)
    print("Bits required value: ", bits_required_value)
    # TODO: FIX THIS
    print("Bits required row_length", bits_required_col_index)
    print("Bits required num_of_nodes", bits_required_node_info_num_of_nodes)
    print("Bits required node_info_flag", bits_required_node_info_flag)
    print("Bits required h_data: ", bits_required_col_index + bits_required_value)
    print("Bits required node_info: ", bits_required_col_index + bits_required_node_info_num_of_nodes + bits_required_node_info_flag)
    # print("Bits required subgraph_index: ", bits_required_subgraph_index + 2)
  else: #Layer n (with n >= 2)
    h_matrix = []
    for _, value in input_subgraph.items():
        h_matrix.append(value['src_h'].tolist())
        for _, neighbor_value in value['neighbors'].items():
            h_matrix.append(neighbor_value['nb_h'].tolist())
    h_matrix_binary = intToBinaryMatrix([[format_number(value) for value in row] for row in h_matrix], 8)
    write_matrix_to_file(h_matrix_binary, createPath(input_dir, FLFileConfig['input_value_file']))

  # Output
  #SPMM
  output_dir_SPMM = FLFileConfig['output_spmm_dir']
  os.makedirs(output_dir_SPMM, exist_ok=True)
  wh = []
  for _, value in input_subgraph.items():
      wh.append(value['src_z'])
      for _, neighbor_value in value['neighbors'].items():
          wh.append(neighbor_value['nb_z'])

  write_array_to_file([format_number(x) for x in torch.cat(wh, dim=0).tolist()], createPath(output_dir_SPMM, FLFileConfig['output_wh_file']))
  print("=> Max SPMM", max(torch.cat(wh, dim=0).tolist()))
  print("=> Min SPMM", min(torch.cat(wh, dim=0).tolist()))

  #DMVM
  output_dir_DMVM = FLFileConfig['output_dmvm_dir']
  os.makedirs(output_dir_DMVM, exist_ok=True)
  coef = []
  dmvm = []
  for _, value in input_subgraph.items():
      dmvm.append(value['src_e_src'].squeeze(0))
      dmvm.append(value['src_e_dst'].squeeze(0))
      coef.append(value['src_e_i_j'].squeeze(0))
      for _, neighbor_value in value['neighbors'].items():
          dmvm.append(neighbor_value['nb_e'].squeeze(0))
          coef.append(neighbor_value['nb_e_i_j'].squeeze(0))

  write_array_to_file([format_number(x) for x in torch.cat(coef, dim=0).tolist()], createPath(output_dir_DMVM, FLFileConfig['output_coef_file']))
  write_array_to_file([format_number(x) for x in torch.cat(dmvm, dim=0).tolist()], createPath(output_dir_DMVM, FLFileConfig['output_dmvm_file']))
  print("=> Max DMVM", max(torch.cat(dmvm, dim=0).tolist()))
  print("=> Min DMVM", min(torch.cat(dmvm, dim=0).tolist()))
  print("=> Max COEF", max(torch.cat(coef, dim=0).tolist()))
  print("=> Min COEF", min(torch.cat(coef, dim=0).tolist()))

  #SOFTMAX
  output_dir_softmax = FLFileConfig['output_softmax_dir']
  os.makedirs(output_dir_softmax, exist_ok=True)
  alpha = []
  dividend = []
  divisor = []
  nodes = []
  for _, value in input_subgraph.items():
      nodes.append(value['nodes'])
      alpha.append(value['src_alpha'].squeeze(0))
      dividend.append(value['src_dividend'].squeeze(0))
      divisor.append(value['divisor'].squeeze(0))
      for _, neighbor_value in value['neighbors'].items():
          alpha.append(neighbor_value['nb_alpha'].squeeze(0))
          dividend.append(neighbor_value['nb_dividend'].squeeze(0))

  write_array_to_file(nodes, createPath(output_dir_softmax, FLFileConfig['output_num_nodes_file']))
  write_array_to_file(torch.cat(alpha, dim=0).tolist(), createPath(output_dir_softmax, FLFileConfig['output_alpha_file']))
  write_array_to_file([format_number(x) for x in torch.cat(dividend, dim=0).tolist()], createPath(output_dir_softmax, FLFileConfig['output_dividend_file']))
  write_array_to_file([format_number(x) for x in torch.cat(divisor, dim=0).tolist()], createPath(output_dir_softmax, FLFileConfig['output_divisor_file']))
  print("=> Max alpha", max(torch.cat(alpha, dim=0).tolist()))
  print("=> Min alpha", min(torch.cat(alpha, dim=0).tolist()))
  print("=> Max dividend", max(torch.cat(dividend, dim=0).tolist()))
  print("=> Min dividend", min(torch.cat(dividend, dim=0).tolist()))
  print("=> Max divisor", max(torch.cat(divisor, dim=0).tolist()))
  print("=> Min divisor", min(torch.cat(divisor, dim=0).tolist()))

  #aggregator
  output_dir_aggregator = FLFileConfig['output_aggregator_dir']
  os.makedirs(output_dir_aggregator, exist_ok=True)
  new_feature = []
  for _, value in input_subgraph.items():
      new_feature.append(value['new_h'].squeeze(0))
  write_array_to_file(torch.cat(new_feature, dim=0).tolist(), createPath(output_dir_aggregator, FLFileConfig['output_new_feature_file']))
  print("=> Max new_feature", max(torch.cat(new_feature, dim=0).tolist()))
  print("=> Min new_feature", min(torch.cat(new_feature, dim=0).tolist()))


In [26]:
test_new_feature = None
test_output = None

In [None]:
print(511.99999/4)

#**Simulation & Export to Files**

In [None]:
import torch.nn as nn
import json

NUM_OF_LAYERS = 2

input_layer = data.x.to(dtype=torch.float32)
a = GATModel.state_dict()

IS_HARDWARE_HANDLE_OUTPUT_LAYER_1 = True

# Getting quantized parameters
for k, v in a.items():
  scaled_tensor, _ = scale_tensor(v,
                                Configuration["BuildModel"]["scaleMin"],
                                Configuration["BuildModel"]["scaleMax"],
                                torch.int8)
  a[k] = scaled_tensor

for layer_idx in range(NUM_OF_LAYERS):
  print(f"----------------------------------------Start Layer {layer_idx + 1}-------------------------------------------")
  #SPMM: Calculate WH(= z) according to subgraph
  print(a[f"conv{layer_idx+1}.lin.weight"].to(dtype=torch.float32).t())
  print(a[f"conv{layer_idx+1}.lin.weight"].to(dtype=torch.float32).t().shape)
  z_layer1= torch.matmul(input_layer , a[f"conv{layer_idx+1}.lin.weight"].to(dtype=torch.float32).t()) #2708 * 16
  e_idx = data.edge_index
  subgraph = {}

  #DMVM: Calculate DMVM + e
  for node_idx in torch.unique(e_idx[0]):
      neighbors_idx_arr = e_idx[1][e_idx[0] == node_idx]
      src_dmvm_src = torch.matmul(a[f"conv{layer_idx+1}.att_src"].to(dtype=torch.float32), z_layer1[node_idx].to(dtype=torch.float32)) # z1 x a1 = e1
      src_dmvm_dst = torch.matmul(a[f"conv{layer_idx+1}.att_dst"].to(dtype=torch.float32), z_layer1[node_idx].to(dtype=torch.float32)) # z1 x a2 = e1

      subgraph[int(node_idx)] = {
          'src_h': input_layer[node_idx],
          'src_z': z_layer1[node_idx],
          'src_e_src': src_dmvm_src,
          'src_e_dst': src_dmvm_dst,
          'src_e_i_j': torch.relu(src_dmvm_src + src_dmvm_dst), # e1 + e1' = e11
          # 'src_e_i_j': torch.nn.functional.leaky_relu(src_dmvm_src + src_dmvm_dst, negative_slope=0.01), # e1 + e1' = e11
          'neighbors': {},
      }

      for neighbor_idx in neighbors_idx_arr:
          neighbor_dmvm = torch.matmul(a[f"conv{layer_idx+1}.att_dst"].to(dtype=torch.float32), z_layer1[neighbor_idx].to(dtype=torch.float32)) #z2 x a2 = e2, z3 x a2 = e3,....
          subgraph[int(node_idx)]['neighbors'][int(neighbor_idx)] = {
              'nb_h': input_layer[neighbor_idx],
              'nb_z': z_layer1[neighbor_idx],
              'nb_e': neighbor_dmvm,
              'nb_e_i_j': torch.relu(src_dmvm_src + neighbor_dmvm) #e1 + e2 = e12, e1 + e3 = e13,....
              # 'nb_e_i_j': torch.nn.functional.leaky_relu(src_dmvm_src + neighbor_dmvm, negative_slope=0.01), #e1 + e2 = e12, e1 + e3 = e13,....
          }

  # Check max, min and calculate bits to reduce
  values = []
  for key, value in subgraph.items():
      values.append(value['src_e_i_j'])
      for neighbor_key, neighbor_value in value['neighbors'].items():
          values.append(neighbor_value['nb_e_i_j'])
  all_values = torch.cat(values)
  min_value = torch.min(all_values)
  max_value = torch.max(all_values)
  #Important dont delete
  bits_e_i_j = bits_required(max_value, min_value, True)
  # print(min_value, max_value, bits_e_i_j, bits_e_i_j-8, max_value / 2**(bits_e_i_j-8)) #Citeseer: /2^12
  print("Max COE:", max_value)
  print("Min COE:", min_value)
  print("Remove # of bits", bits_e_i_j-8)

  for _, value in subgraph.items():
      value['src_e_i_j'] = value['src_e_i_j'] // (2**(bits_e_i_j - 8))
      for _, neighbor_value in value['neighbors'].items():
          neighbor_value['nb_e_i_j'] = neighbor_value['nb_e_i_j'] // (2**(bits_e_i_j - 8))

  #Softmax: Calculate DIVIDEND + DIVISOR + ALPHA + num_nodes
  for key, value in subgraph.items():
      divisor = 0
      num_of_node = 1
      src_dividend = 2**value['src_e_i_j']
      divisor += src_dividend.double()
      value['src_dividend'] = src_dividend
      for _, neighbor_value in value['neighbors'].items():
          nb_dividend = 2**neighbor_value['nb_e_i_j']
          neighbor_value['nb_dividend'] = nb_dividend
          num_of_node += 1
          divisor += nb_dividend.double()
      value['divisor'] = divisor
      value['nodes'] = num_of_node

  for _, value in subgraph.items():
      src_alpha = value['src_dividend'] / value['divisor']
      value['src_alpha'] = src_alpha
      for _, neighbor_value in value['neighbors'].items():
          nb_alpha = neighbor_value['nb_dividend'] / value['divisor']
          neighbor_value['nb_alpha'] = nb_alpha

  for _, value in subgraph.items():
      total = 0
      total = value['src_alpha'] * value['src_z']
      for _, neighbor_value in value['neighbors'].items():
          total += neighbor_value['nb_alpha'] * neighbor_value['nb_z']
      value['new_h'] = torch.relu(total)
      # value['new_h'] = torch.nn.functional.leaky_relu(total, negative_slope=0.01)

  # Export to file here
  # Handle input of next layer
  feature_next_layer = []
  # Use for isolated node
  print(f"Subgraph Len: {len(subgraph)} / {data.x.shape[0]}")
  for idx in range(data.x.shape[0]):
    if idx in subgraph:
      feature_next_layer.append(subgraph[idx]['new_h'].squeeze(0))
    else:
      feature_next_layer.append(z_layer1[idx].to(dtype=torch.float64))
  feature_next_layer_tensor = torch.stack(feature_next_layer)
  print(feature_next_layer_tensor, feature_next_layer_tensor.shape)
  # To remove just test only
  if layer_idx == 0:
    test_new_feature = feature_next_layer_tensor
  else:
    test_output = feature_next_layer_tensor
  # Export data
  if layer_idx == 0:
    exportToFilesFirstLayer(subgraph, layer_idx, None)
  else:
    exportToFilesFirstLayer(subgraph, layer_idx, input_layer)

  # Break when calculation is done
  if layer_idx + 1 == NUM_OF_LAYERS:
    classification_output = torch.argmax(feature_next_layer_tensor, dim=1)
    classification_correct = classification_output[data.test_mask] == data.y[data.test_mask]
    test_acc = int(classification_correct.sum()) / int(data.test_mask.sum())
    print("Classification: ", classification_output)
    print("Correct result", data.y)

    correct_count = (classification_output == data.y).sum().item()
    print(f"Correct count final: {int(classification_correct.sum())} / {int(data.test_mask.sum())}, acc: {int(classification_correct.sum()) / int(data.test_mask.sum())}")
    break;
  # Quantize and assign to input
  start_quantized_time = time.time()
  if IS_HARDWARE_HANDLE_OUTPUT_LAYER_1:
    feature_next_layer_float32 = feature_next_layer_tensor.to(dtype=torch.float64)
    max_value = torch.max(feature_next_layer_float32)
    min_value = torch.min(feature_next_layer_float32)
    rounded_max_value = max_value.round().int()
    bits_required_feature_next_layer = bits_required(min_value, rounded_max_value, False)
    result_feature_next_layer_float32 = feature_next_layer_float32 // (2**(bits_required_feature_next_layer - 8))

    input_layer = result_feature_next_layer_float32.to(dtype=torch.float32)
    print("Check specific value", feature_next_layer_float32[1941][10])
    print("Check specific value // 4", result_feature_next_layer_float32[1941][10])
    print("Check input value // 4", input_layer[1941][10])
    print("Check by huynguyenn", rounded_max_value, min_value, bits_required_feature_next_layer, torch.max(input_layer), torch.min(input_layer), input_layer.shape)
  else:
    feature_next_layer_quantized,_ = scale_tensor(feature_next_layer_tensor,
                                              Configuration["BuildModel"]["scaleMin"],
                                              Configuration["BuildModel"]["scaleMax"],
                                              torch.int8)
    input_layer = feature_next_layer_quantized.to(dtype=torch.float32)
  end_quantized_time = time.time()
  print(input_layer, input_layer.shape, data.x.shape[0])
  print(f"Time to quantized: {end_quantized_time - start_quantized_time:.6f} seconds")

**Loading Model & Param quantization**

In [None]:
!pip install torch_geometric

In [30]:
import os
import torch
from torch_geometric.nn import GATConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

GlobalConfiguration = {
    "GAT": {
        "hiddenChannel": 16,
        "head": 1
    },
    "model": {
        "savePath": "model_params.pth",
        "scaleMin": -127,
        "scaleMax": 127,
    },
    "dataset": {
        "root": "data/Planetoid",
        "name": "Cora", # Cora, CiteSeer, PubMed
        "normalization": False,
    },
}

MappingModelParam = {
    "conv1.att_src" : "a_src_1",
    "conv1.att_dst" : "a_dst_1",
    "conv1.bias" : "b_1",
    "conv1.lin.weight" : "w_1",
    "conv2.att_src" : "a_src_2",
    "conv2.att_dst" : "a_dst_2",
    "conv2.bias" : "b_2",
    "conv2.lin.weight" : "w_2"
}

def tensor_to_list(tensor):
    if not isinstance(tensor, torch.Tensor):
        raise TypeError("Input must be a torch.Tensor")

    # Flatten the tensor and convert to a list
    return tensor.flatten().tolist()

def tensor_to_matrix(tensor):
    return tensor.tolist()

def list_or_matrix_to_tensor(data):
    return torch.tensor(data)

def format_number(value):
    # Convert the number to a string
    value_str = str(value)
    # Check if the string ends with '.0' and remove it
    if value_str.endswith('.0'):
        value_str = value_str[:-2]
    return value_str

def int_to_n_bit_binary(number, n_bits):
    # Handle two's complement for negative numbers
    if number < 0:
        number = (1 << n_bits) + number

    # Convert the number to binary with zero-padding to n bits
    binary_str = format(number, f'0{n_bits}b')
    return binary_str

def int_to_n_bit_binary_list(arr, n_bits):
    binary_arrays = []
    for num in arr:
      binary_arrays.append(int_to_n_bit_binary(int(num), n_bits))
    return binary_arrays

def int_to_n_bit_binary_matrix(matrix, n_bits):
    binary_matrix = []
    for row in matrix:
      binary_row = []
      for col in row:
        binary_row.append(int_to_n_bit_binary(int(col), n_bits))
      binary_matrix.append(binary_row)
    return binary_matrix

def list_to_matrix(lst, rows, cols):
    if len(lst) != rows * cols:
        raise ValueError("List length must match rows * columns")

    return [lst[i * cols:(i + 1) * cols] for i in range(rows)]

def matrix_to_list(matrix):
    return [item for row in matrix for item in row]

def quantized(tensor, scale_min, scale_max, to_dtype=torch.int8):
    v_max = tensor.max() if tensor.max() != 0 else 1  # Avoid division by zero

    # Scale the tensor
    quantized_tensor = (tensor / v_max) * scale_max
    quantized_tensor = quantized_tensor.clamp(scale_min, scale_max)
    quantized_tensor = quantized_tensor.to(to_dtype)

    # Define a function to scale back to the original range
    def dequantized(quantized_tensor):
        quantized_tensor = quantized_tensor.to(torch.float32)  # Ensure float for computation
        return (quantized_tensor / scale_max) * v_max
    return quantized_tensor, dequantized

class DatasetLoaderV2:
  def __init__(self,
               root: str = GlobalConfiguration["dataset"]["root"],
               name: str = GlobalConfiguration["dataset"]["name"],
               normalize: int = GlobalConfiguration["dataset"]["normalization"]):
      self.root = root
      self.name = name
      self.normalize = normalize
      self.dataset = self._load_dataset()

  def _load_dataset(self):
      transform = NormalizeFeatures() if self.normalize else None
      return Planetoid(root=self.root, name=self.name, transform=transform)

  def get_data(self, index: int = 0):
      return self.dataset[index]

  def get_dataset(self):
      return self.dataset

  def get_edges(self):
      return self.dataset[0].edge_index

  def get_isolated(self):
      edges = self.get_edges()
      edges_src = edges[0]
      edges_dst = edges[1]
      all_nodes = torch.unique(torch.cat([edges_src, edges_dst]))
      total_nodes = self.get_data().x.shape[0]
      isolated_nodes = [node for node in range(total_nodes) if node not in all_nodes]
      isolated_map = {}
      print(self.get_data().x.shape)
      for node_idx in isolated_nodes:
        isolated_map[node_idx] = self.get_data().x[node_idx]
      return isolated_nodes, isolated_map

class GATV2(torch.nn.Module):
    def __init__(self,
                 data_loader,
                 hidden_channels = GlobalConfiguration["GAT"]["hiddenChannel"],
                 heads = GlobalConfiguration["GAT"]["head"]):
        super().__init__()
        torch.manual_seed(1234567)
        self.conv1 = GATConv(data_loader.get_dataset().num_features, hidden_channels, heads, True)
        self.conv2 = GATConv(heads * hidden_channels, data_loader.get_dataset().num_classes, 1, False)

    def forward(self, x, edge_index):
        p_default = 0.6
        x = F.dropout(x, p=p_default, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=p_default, training=self.training)
        x = self.conv2(x, edge_index)
        return x

class BuildModelV2():
    def __init__(self, model, data_loader, save_path = GlobalConfiguration["model"]["savePath"]):
        self.model = model
        self.save_path = save_path
        self.data_loader = data_loader
        self.load_model_params()

    def load_model_params(self):
        """Load model parameters from the file if it exists."""
        if os.path.exists(self.save_path):
            self.model.load_state_dict(torch.load(self.save_path))
            print(f"Model parameters loaded from {self.save_path}")
            return True
        else:
            print(f"No saved model parameters found at {self.save_path}")
            return False

    # Return:
    # - Dictionary of parameters
    # - Example: {'a_src_1': [...], 'a_dst_1': [...], 'a_1': [...], 'b_1': [...], 'a_src_2': [...], 'a_dst_2': [...], 'a_2': [...], 'b_2': [...]}
    def get_model_params(self):
        result = {}
        param = self.model.state_dict();
        for k, v in param.items():
          quantized_v, _ = quantized(v,
                                        GlobalConfiguration["model"]["scaleMin"],
                                        GlobalConfiguration["model"]["scaleMax"],
                                        torch.int8)
          if quantized_v.ndim == 3:
            quantized_v = quantized_v.reshape(quantized_v.shape[0], -1)
          if k == "conv1.lin.weight" or k == "conv2.lin.weight":
            quantized_v = quantized_v.t()
          result[MappingModelParam[k]] = tensor_to_list(quantized_v)

        result['a_1'] = result['a_src_1'] + result['a_dst_1']
        result['a_2'] = result['a_src_2'] + result['a_dst_2']
        return result

    def get_raw_model(self):
      param = self.model.state_dict()
      for k, v in param.items():
        quantized_v, _ = quantized(v,
                                     GlobalConfiguration["model"]["scaleMin"],
                                     GlobalConfiguration["model"]["scaleMax"],
                                     torch.int8)
        a[k] = quantized_v
      return a

    def test(self, visualization_2D = False, visualization_3D = False):
        self.model.eval()
        data = self.data_loader.get_data()
        a = self.model.state_dict()
        out = self.model(data.x, data.edge_index)
        # pred = []
        # for row in out:
        #     max_value = row[0]
        #     max_index = 0
        #     for i in range(1, len(row)):
        #         if row[i] > max_value:
        #             max_value = row[i]
        #             max_index = i
        #     pred.append(max_index)
        pred = out.argmax(dim=1)
        # test_correct = []
        # for i in range(len(data.y)):
        #     if data.test_mask[i]:
        #         is_correct = pred[i] == data.y[i]
        #         test_correct.append(is_correct)
        # test_acc = int(sum(test_correct)) / int(data.test_mask.sum())
        test_correct = pred[data.test_mask] == data.y[data.test_mask]
        test_acc = int(test_correct.sum()) / int(data.test_mask.sum())
        return test_acc


In [31]:
# Params:
  # - new_feature: List = [...]
  # - gat_model = BuildModelV2
# Return:
  # - {'w_a': [1,2,3,4,...], 'h': [1,2,3,4,...]}
def handle_new_feature(new_feature, gat_model, data_loader): #Todo: handle case isolated node
  raw_data = data_loader.get_data()
  raw_edge_data = raw_data.edge_index
  raw_init_feature_data = raw_data.x
  isolated_node, isolated_map = data_loader.get_isolated()
  raw_model = gat_model.get_raw_model()

  new_feature_matrix = []
  curr_flat_index = 0

  for row_idx in range(raw_init_feature_data.shape[0]):
      if row_idx in isolated_map:
          new_feature_matrix.extend(torch.matmul(isolated_map[row_idx] , raw_model[f"conv1.lin.weight"].to(dtype=torch.float32).t()).tolist())
      else:
          new_feature_matrix.extend(new_feature[curr_flat_index : curr_flat_index + GlobalConfiguration["GAT"]["hiddenChannel"]])
          curr_flat_index += GlobalConfiguration["GAT"]["hiddenChannel"]

  new_feature_matrix = list_to_matrix(new_feature_matrix, raw_init_feature_data.shape[0], GlobalConfiguration["GAT"]["hiddenChannel"])
  start_time = time.time()
  new_feature_tensor_quantized,_ = quantized(list_or_matrix_to_tensor(new_feature_matrix),
                                                  GlobalConfiguration["model"]["scaleMin"],
                                                  GlobalConfiguration["model"]["scaleMax"],
                                                  torch.int8)
  new_feature_matrix_quantized = tensor_to_matrix(new_feature_tensor_quantized)
  h_matrix = []
  for src_idx in torch.unique(raw_edge_data[0]):
      neighbors_idx_arr = raw_edge_data[1][raw_edge_data[0] == src_idx]
      h_matrix.append(new_feature_matrix_quantized[src_idx])
      for neighbor_idx in neighbors_idx_arr:
        h_matrix.append(new_feature_matrix_quantized[neighbor_idx])
  h_matrix_format_binary = [[int_to_n_bit_binary(int(format_number(value)), 8) for value in row] for row in h_matrix]
  end_time = time.time()
  print(f"Time to run test: {end_time - start_time:.6f} seconds")
  return {
      'weight': int_to_n_bit_binary_list(gat_model.get_model_params()['w_2'] + gat_model.get_model_params()['a_2'], 8),
      'h_data': matrix_to_list(h_matrix_format_binary)
  }

def handle_classification(result_list, data_loader): #Todo: handle case isolated node
  raw_dataset = data_loader.get_dataset()
  raw_data = data_loader.get_data()
  raw_init_feature_data = raw_data.x
  isolated_node, isolated_map = data_loader.get_isolated()

  curr_flat_index = 0
  new_result_list = []
  for row_idx in range(raw_init_feature_data.shape[0]):
      if row_idx in isolated_map:
          new_result_list.extend([0] * raw_dataset.num_classes)
      else:
          new_result_list.extend(result_list[curr_flat_index : curr_flat_index + raw_dataset.num_classes])
          curr_flat_index += raw_dataset.num_classes

  result_matrix = list_to_matrix(new_result_list, raw_init_feature_data.shape[0], raw_dataset.num_classes)
  result_tensor = list_or_matrix_to_tensor(result_matrix)
  result_classification = torch.argmax(result_tensor, dim=1)
  print("Classification: ", result_classification)
  print("Correct result", raw_data.y)
  test_indices = torch.where(raw_data.test_mask)[0]

  match_count = 0;
  for idx in range(raw_init_feature_data.shape[0]):
    if idx in test_indices:
      if idx in isolated_node:
        match_count = match_count + 1
      else:
        if result_classification[idx].item() == raw_data.y[idx].item():
          match_count = match_count + 1

  # print(f"Correct count: {match_count} / {raw_init_feature_data.shape[0]}, acc: {match_count / raw_init_feature_data.shape[0]}")
  print(f"Correct count: {match_count} / {len(test_indices)}, acc: {match_count / len(test_indices)}")


In [None]:
data_loader_instance = DatasetLoaderV2()
gat_instance = GATV2(data_loader_instance)
model_instance = BuildModelV2(gat_instance, data_loader_instance)


# Prepare to receive output of layer 1
print(test_new_feature, test_new_feature.shape)
simulate_output_layer_1 = tensor_to_list(test_new_feature)
simulate_output_layer_2 = tensor_to_list(test_output)
handle_new_feature(simulate_output_layer_1, model_instance, data_loader_instance)['h_data']

print()
handle_classification(simulate_output_layer_2, data_loader_instance)
# print(data_loader_instance.get_edges())
# print(data_loader_instance.get_isolated())

In [None]:
start_test_time = time.time()
test_acc = model_instance.test()
end_test_time = time.time()
print(f"Initial Model Accuracy: {test_acc}, time to run test: {(end_test_time - start_test_time)*1000:.6f} ms")

In [None]:
currentOption()