<a href="https://colab.research.google.com/github/TJSun009/test_categorisation/blob/main/Test_Categorisation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Plan

In [None]:
# Plan of Action

# Source-Test Mapping
# Get Graph Representation of Programs
# Create a Graph Neural Network Classfier to map src and test graph
# (Optional) Enhance graph tokens using GraphCodeBERT, CodeBERT or TREEBERT embeddings

# Test Generation
# Create a GraphTransformer using Graph Representations and Encoder-Decoder Architecture
# Prior Embeddings may be useful
# See GraphBERT - https://arxiv.org/abs/2001.05140
# Encoder - convert Graph nodes to a node embedding Representation based on surrounding nodes and edges

# Use Masked Node Modelling to Mask a Node in the AST and Generate it based on it's connected nodes and edges

# Google Cloud Storage Setup

In [1]:
!gcloud auth login --no-launch-browser

Go to the following link in your browser:

    https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=32555940559.apps.googleusercontent.com&redirect_uri=https%3A%2F%2Fsdk.cloud.google.com%2Fauthcode.html&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fappengine.admin+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fsqlservice.login+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcompute+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Faccounts.reauth&state=7XrsEvQwmbyIGEkKe9DvarWDdiTgdF&prompt=consent&access_type=offline&code_challenge=GnhVDoNciVs6r7fDLqtzE_HAA8PxEsiF0-RJLpbucYk&code_challenge_method=S256

Enter authorization code: 

Command killed by keyboard interrupt

^C


In [1]:
project_id = "southern-camera-367511"
bucket_name = "dissertation-data-bucket-1"
!gcloud config set project {project_id}

Updated property [core/project].


In [2]:
# GCS Helpers

# copy file to GCS
def copy_to_gcs(file_path, options=""):
  file_path = file_path.replace("/content", "")
  command = ! gsutil -m cp{options} {file_path} gs://{bucket_name}/{file_path}

# copy file to GCS
def copy_from_gcs(src_file_path, dest_file_path="", options=""):
  gcs_file_path = src_file_path.replace("/content/", "")

  dest_file_path = dest_file_path if dest_file_path != "" else src_file_path

  if not os.path.exists(dest_file_path):
    os.makedirs(os.path.dirname(dest_file_path), exist_ok=True)

  command = ! gsutil -m cp{options} gs://{bucket_name}/{gcs_file_path} {dest_file_path}
  print("copied succesfully")

def list_gcs_files(file_path):
  file_path = file_path.replace("/content/", "")
  files = ! gsutil ls gs://{bucket_name}/{file_path}
  return files

def gcs_file_path_to_colab(gcs_file_path):
  return gcs_file_path.replace(f"gs://{bucket_name}/", "/content/")

# check for file in GCS
def is_in_gcs(file_path):
  file_path = file_path.replace("/content/", "")
  output = ! gsutil -q stat gs://{bucket_name}/{file_path}; echo $?
  return output[0] == '0'

# Source-Test Mapping

## Import libraries

In [3]:
! pip install -Uqqq scipy networkx

In [4]:
import os
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
# import tensorflow as tf
# from tensorflow import keras
# from tensorflow.keras import layers
# import pandas as pd
import networkx as nx
from glob import iglob
import importlib

### python-graphs dependency

In [5]:
# New Graph Generator Approach
# https://arxiv.org/pdf/2208.07461v1.pdf

# install python-graphs on startup
! echo {SUDO} | sudo -S apt-get -qq -y install graphviz graphviz-dev
# ! pip install -Uqqq python-graphs gast==0.3.2

In [49]:
! git clone https://github.com/google-research/python-graphs.git

Cloning into 'python-graphs'...
remote: Enumerating objects: 198, done.[K
remote: Counting objects: 100% (198/198), done.[K
remote: Compressing objects: 100% (124/124), done.[K
remote: Total 198 (delta 109), reused 144 (delta 63), pack-reused 0[K
Receiving objects: 100% (198/198), 82.55 KiB | 10.32 MiB/s, done.
Resolving deltas: 100% (109/109), done.


In [52]:
! cd ./python-graphs && python setup.py develop

running develop
running egg_info
creating python_graphs.egg-info
writing python_graphs.egg-info/PKG-INFO
writing dependency_links to python_graphs.egg-info/dependency_links.txt
writing requirements to python_graphs.egg-info/requires.txt
writing top-level names to python_graphs.egg-info/top_level.txt
writing manifest file 'python_graphs.egg-info/SOURCES.txt'
adding license file 'LICENSE'
writing manifest file 'python_graphs.egg-info/SOURCES.txt'
running build_ext
Creating /usr/local/lib/python3.8/dist-packages/python-graphs.egg-link (link to .)
python-graphs 1.3.0 is already the active version in easy-install.pth

Installed /content/python-graphs
Processing dependencies for python-graphs==1.3.0
Searching for six==1.15.0
Best match: six 1.15.0
Adding six 1.15.0 to easy-install.pth file

Using /usr/local/lib/python3.8/dist-packages
Searching for pygraphviz==1.10
Best match: pygraphviz 1.10
Adding pygraphviz 1.10 to easy-install.pth file

Using /usr/local/lib/python3.8/dist-packages
Search

#### Understanding python-graphs Representation

In [6]:
# import gast as ast
# from python_graphs import program_graph

# # example file
# file_path = DIR_PREFIX + "Year 3/Dissertation/Projects/Datasets/data/minified/src/unittest_utils.py"

# with open(file_path, "r") as f:
#   # graph = program_graph.get_program_graph(f.read())
  
#   # read ast head
#   graph = program_graph.get_program_graph(ast.parse(f.read()))

In [7]:
# # nodes are stored in a dictionary representing a key-value pair of the node id and node itself
# example_node_dict = graph.nodes

# # nodes are represented as strings by joining the node id and node ast_type if it has one 
# example_node_dict_item = list(example_node_dict.items())[0]

# # item 0 is the node id and item 1 the node representation
# example_node_dict_item

In [8]:
# # as above the node is the value
# example_node = example_node_dict_item[1]

# # we can view the nodes properties as well
# # the ast_value is of particular interest as well for retrieving tokens
# # not all nodes will have a value though
# print(example_node.__dict__)
# print(example_node.__dict__["ast_node"].__dict__)

In [9]:
# # show what ast_values look like
# nodes = list(graph.all_nodes())

# node_values = []
# for node in nodes:
#   if node.ast_value:
#     node_values.append(node.ast_value)
# # some ast_values are long strings so will need subtokens which can be combined
# # may require CodeBERT embeddings
# node_values

In [10]:
# # check edge
# example_edge = graph.edges[0]
# example_edge

### CodeBERT dependency

In [11]:
# imports for tokenising code values
! pip install -Uqqq transformers

#### Investigating CodeBERT Tokeniser for code

In [12]:
# from transformers import AutoTokenizer, AutoModel
# import torch
# tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
# model = AutoModel.from_pretrained("microsoft/codebert-base")

In [13]:
# # join node values into single string
# code = (' ').join([str(val) for val in node_values])
# # code
# code_tokens = tokenizer.tokenize(code)
# # code_tokens
# tokens = [tokenizer.cls_token] + code_tokens + [tokenizer.sep_token]
# # tokens
# tokens_ids = tokenizer.convert_tokens_to_ids(tokens)

# # this is the final output from the model
# # it consists of 1 vector for each token across 768 distinct features 
# context_embeddings = model(torch.tensor(tokens_ids)[None,:])[0]
# # context_embeddings.shape
# len(code_tokens), context_embeddings.shape

# # in the graph embedding the node values could be represented as each value padded
# # by N others on either side, N would be calibrated for best results

## Prepare Dataset Helpers

produces graph_list (list of python_graphs)

In [14]:
CODE_MINI_DIR =  "/content/data/minified"
CODE_LARGE_DIR = "/content/data/large"
CODE_DIR = CODE_MINI_DIR

In [15]:
# adapted from typilus for monitoring
class Monitoring:
    def __init__(self):
        self.count = 0  # type: int
        self.errors = []
        self.file = ""  # type: str
        self.empty_files = []

    def increment_count(self) -> None:
        self.count += 1

    def found_error(self, err, trace) -> None:
        self.errors.append([self.file, err, trace])

    def enter_file(self, filename: str) -> None:
        self.file = filename

In [16]:
# adapted from typilus for generating graphs with python_graphs

from typing import Tuple, List, Optional, Set, Iterator
from python_graphs import program_graph

# TODO progress bar
# from tqdm.notebook import tqdm_notebook


# needed to use gast opposed to standard ast
import gast, ast
def explore_files(root_dir: str, monitoring: Monitoring) -> Iterator[Tuple]:
    """
    Walks through the root_dir and process each file.
    """
    for file_path in iglob(os.path.join(root_dir, '**', '*.py'), recursive=True):
        file_name = file_path.split('/').pop()
        if not os.path.isfile(file_path):
            continue
        # print(file_path)
        with open(file_path, encoding="utf-8", errors='ignore') as f:
            monitoring.increment_count()
            monitoring.enter_file(file_path)
            
            # difficulty parsing some files so had to be skipped
            try:
              graph = program_graph.get_program_graph(gast.parse(f.read()))
              
              # identify graph by file_name
              graph.filename = file_path[len(root_dir):]
              
              yield graph
            except:
              continue

In [17]:
! pip install -Uqqq dill
import dill as pickle

# handle saving of graphs
def save_graphs(graph_list, dir = CODE_DIR):
  graph_dir = os.path.join(dir, "graphs", "")
  
  if not os.path.exists(graph_dir):
    os.makedirs(graph_dir)
  
  file = os.path.join(graph_dir, "graphs.pickle")

  pickle.dump(graph_list, open(file, "wb"), protocol = pickle.HIGHEST_PROTOCOL)
  copy_to_gcs(file)
  return file

In [18]:
graph_pickle_path = os.path.join(CODE_DIR, "graphs", "graphs.pickle")

def create_graph_list(graph_pickle_path=graph_pickle_path):
  if not is_in_gcs(graph_pickle_path):
    outputs = explore_files(CODE_DIR, monitoring)
    graph_list = [graph for graph in outputs]
    monitoring = Monitoring()
    save_graphs(graph_list)
    del(outputs)
    del(monitoring)
  elif not os.path.exists(graph_pickle_path):
    # copy from gcs before reading pickle
    copy_from_gcs(graph_pickle_path)
  else:
    graph_list = pickle.load(open(graph_pickle_path, "rb"))
  
  return graph_list

## Feed Data to Graph Network

### Creating CodeGraph Class

This class makes use of networkx a popular graph representation library

In [19]:
# each edge should be weighted differently based on its type, edge should contain types
from python_graphs import program_graph_dataclasses

# for ast class list
import sys, inspect

# imports for tokenising code values
from transformers import AutoTokenizer, AutoModel
import torch

tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
model = AutoModel.from_pretrained("microsoft/codebert-base")

#### Imports

In [20]:
!pip install -Uqqq torch-scatter torch-sparse torch-geometric -f https://pytorch-geometric.com/whl/torch-1.13.0+cu116.html

In [56]:
code_tokens = tokenizer.tokenize("print")

#### Node Feature Helpers

In [72]:
import pdb
# use enum value to set node type
# dict updates only keys contained therein
# automatic enum conversion
def node_types_to_ints(G):
  node_type_dict = nx.get_node_attributes(G, "node_type")
  int_dict = {k: {"node_type": v.value} for k, v in node_type_dict.items()}
  nx.set_node_attributes(G, int_dict)

# use enum class to convert value back
def ints_to_node_types(G):
  node_type_dict = nx.get_node_attributes(G, "node_type")
  node_type_dict = {k: {"node_type": program_graph_dataclasses.NodeType(v)} for k, v in node_type_dict.items()}
  nx.set_node_attributes(G, node_type_dict)

# ast type can be dealt with using string byte encoding
def ast_types_to_ints(G, ast_types):
  ast_type_dict = nx.get_node_attributes(G, "ast_type")
  int_dict = {k: {"ast_type": ast_types.index(v)} if v in ast_types else {"ast_type": -1} for k, v in ast_type_dict.items()}
  nx.set_node_attributes(G, int_dict)

def ints_to_ast_types(G, ast_types):
  ast_type_dict = nx.get_node_attributes(G, "ast_type")
  int_dict = {k: {"ast_type": ast_types[v]} for k, v in ast_type_dict.items()}
  nx.set_node_attributes(G, int_dict)

# ast_value embeddings done using CodeBERT embeddings
# N equates to context padding how many subsequent and following tokens are used in embedding
def ast_values_to_context_embeddings(G, vocab, N = 1):
  ast_value_dict = nx.get_node_attributes(G, "ast_value")

  embedding_dict = {}

  vocab = list(vocab)

  for k, v in ast_value_dict.items():

    v = str(v)

    idx = vocab.index(v)

    start, end = idx - N, idx + N

    if start > -1:
      if end - 1 > len(vocab):
        end = len(vocab) - 1
    else:
      start = 0
      end = 2
    
    # use prior and subsequent words for context
    code = ''.join(vocab[start : end + 1])

    # always return a code_token of length 3
    code_tokens = tokenizer.tokenize(code)[start : end + 1]

    if len(code_tokens) == 0:
      context_embeddings = torch.zeros(1, 5, 768)
    else:
      tokens = [tokenizer.cls_token] + code_tokens + [tokenizer.sep_token]

      tokens_ids = tokenizer.convert_tokens_to_ids(tokens)

      # this is the final output from the model
      # it consists of 1 vector for each token across 768 distinct features 
      context_embeddings = model(torch.tensor(tokens_ids)[None,:])[0]

    embedding_dict[k] = {"ast_value": context_embeddings}

  nx.set_node_attributes(G, embedding_dict)


#### Edge Feature Helpers

In [108]:
def edge_types_to_ints(G):
  edge_type_dict = nx.get_edge_attributes(G, "type")
  int_dict = {(node1, node2, dir): {"type": v.value} for (node1, node2, dir), v in edge_type_dict.items()}
  nx.set_edge_attributes(G, int_dict)

#### Implementation

In [102]:
from torch_geometric.utils.convert import from_networkx

# This code graph class represents a combination of all the graphs generated in the code corpus
# This code graph class represents a single graph generated in the code corpus
class CodeGraph:

  def __init__(self):
    self.G = nx.MultiDiGraph()
    self.vocab = set()
    self.identifier = ""
    self.is_pair = None
    self.types = {
        "edge" : program_graph_dataclasses.EdgeType._member_names_,
        "ast" : [cls.__name__ for _, cls in inspect.getmembers(sys.modules["ast"], inspect.isclass)]
        }

  def read(self, graph):

    # set identifier to file_name of graph if graph is a module
    if (graph.root.ast_type == "Module"):
      self.identifier = graph.filename
    # TODO implement identifier for functions instead

    # add nodes to graph along with their attributes
    # dict comprehension deduplicates node id
    # we can exclude the ast_node as this info should be encoded in the graphs and edges
    # exclude instruction temporarily due to complexity
    nodes = graph.all_nodes()
    self.G.add_nodes_from([(node.id, {k: v for k, v in node.__dict__.items() if k not in ["id", "ast_node", "instruction"]}) for node in nodes])
    
    # append edges to the graph along with their attributes
    # dict comprehension deduplicates node ids for edge
    self.G.add_edges_from([(edge.id1, edge.id2, {k: v for k, v in edge.__dict__.items() if k.find("id") == -1 }) for edge in graph.edges])
    
    # add ast values to vocab
    self.vocab.update([str(token) for token in nx.get_node_attributes(self.G, "ast_value").values()])

  def node_feature_vector_graph(self, H):
    
    # some node features have been discarded as they are too complex to be used by pytorch
    # or replicate info stored elsewhere in the graph

    node_type = list(H.nodes(data="node_type"))[0][1]
    if not isinstance(node_type, int):
      node_types_to_ints(H)
      ast_types_to_ints(H, self.types["ast"])
      ast_values_to_context_embeddings(H, self.vocab)
    
    return H
  
  def edge_feature_vector_graph(self, H):
    
    # converts edge type to an integer

    edge_type = list(H.edges(data="type"))[0][2]
    if not isinstance(edge_type, int):
      edge_types_to_ints(H)
    
    return H

    

    """{'node_type': <NodeType.AST_NODE: 1>, 
      # ignoring instruction due to it being another complex graph
      'instruction': <python_graphs.instruction.Instruction object at 0x7ff853d756d0>, 
      'ast_type': 'Expr', 
      'ast_value': '', 
      'syntax': ''}"""
    
    # ast_value encoding
  
  def draw(self):
    if len(self.G.nodes) > 0:
      # create normalizer for colours
      norm = plt.Normalize()

      # use vocab and edge_types to generate colours for plot
      # edges are mapped to their position in types
      token_colors = [self.vocab.index(val) for val in list(nx.get_node_attributes(self.G, "ast_value").values())]
      edge_type_colors = [edge_type.value for edge_type in list(nx.get_edge_attributes(self.G, "type").values())]
      
      # normalize the colors between [0, 1]
      node_color, edge_color = norm(token_colors), norm(edge_type_colors)

      fig, ax = plt.subplots(1, 1, figsize=(10, 10));

      nx.draw_networkx(self.G, edge_color = edge_color, node_color = node_color, with_labels=True, ax = ax)
  
  def pytorch_graph(self):
    H = self.node_feature_vector_graph(self.G)

    P = self.edge_feature_vector_graph(H)

    pyg = from_networkx(P)
    
    if (self.is_pair):
      pyg.y = torch.tensor([int(self.is_pair)])
    
    return 

### Create a Code Graph List

In [23]:
def create_code_graph_list(graph_list):
  code_graph_list = []

  for graph in graph_list:
    cg = CodeGraph()
    cg.read(graph)
    code_graph_list.append(cg)

  # cleanup graph list
  del(graph_list)

### Remove Transformers model and tokenizer from memory

In [29]:
del(model)

In [30]:
del(tokenizer)

### Pairing Graphs

In [24]:
# dataset will consist of:
# graph - a graph containing  candidate src graph and test graph
# label - a 0 or 1 corresponding to whether the test and src are a valid pairing

# filter out combination results that contain two source files or two test files
def is_src_test_pair(pair):
  graph1, graph2 = pair

  # checks both are not the same type of file
  return graph1.identifier.find("test") != graph2.identifier.find("test")

# function for combining code_graphs
def combine_code_graphs(pair):
  code_graph1, code_graph2 = pair
  # check if the node vectorisation has already happened
  node_type_list = [
      list(code_graph1.G.nodes(data="node_type"))[0][1],
      list(code_graph2.G.nodes(data="node_type"))[0][1],
  ]
  
  if any([isinstance(node_type, int) for node_type in node_type_list]):
    raise Exception("Cannot combine code graphs that have already been vectorised")

  # uses number of nodes to verify combination graph worked correctly
  graph1, graph2 = code_graph1.G, code_graph2.G
  
  H = nx.disjoint_union(graph1, graph2)

  code_graph1.G = H
  code_graph1.vocab.update(list(code_graph2.vocab))

  # add a property to the code_graph checking whether or not there are a code, test pair
  # False for pair, True for non pair
  # dealing with file at present
  # mapping functions it will require looking at AST calls etc.
  code_graph1.is_pair = code_graph1.identifier.replace("_test.py", "") == code_graph2.identifier.replace("_test.py", "")

  # cleanup vars for concurrency
  del(graph1)
  del(graph2)
  del(code_graph2)
  del(node_type_list)
  del(H)
  return code_graph1

In [25]:
def save_paired_code_graphs(graph_list, dir = CODE_DIR):
  graph_dir = os.path.join(dir, "graphs", "")
  
  if not os.path.exists(graph_dir):
    os.makedirs(graph_dir)
  
  file = os.path.join(graph_dir, "code_graphs.pickle")

  pickle.dump(graph_list, open(file, "wb"), protocol = pickle.HIGHEST_PROTOCOL)
  copy_to_gcs(file)
  # cleanup
  del(graph_dir)
  del(file)

In [26]:
! pip install bounded-pool-executor
# get all possible combinations of code_graphs
from itertools import combinations
from bounded_pool_executor import BoundedProcessPoolExecutor
from tqdm.notebook import tqdm

# code_graph_pickle_path = os.path.join(CODE_DIR, "graphs", "code_graphs.pickle")

def create_combined_code_graphs(code_graph_list):
  # only regenerated paired graphs if not already pickled
  # if not is_in_gcs(code_graph_pickle_path):
    # update combined_code_graph_list to contain the combined code graphs with parallelisation
    # inline to reduce stored vars
  pairs = list(filter(is_src_test_pair, list(combinations(code_graph_list, 2))))
    # del(code_graph_list)
    
    # with BoundedProcessPoolExecutor(max_workers = 10) as executor:
    #   combined_code_graph_list = list(executor.map(combine_code_graphs, pairs, chunksize = 1))

  combined_code_graph_list = []

    # parallelistation hitting RAM limits, using standard loop with dedicated Cloud Instance
  for pair in tqdm(pairs):
    combined_code_graph_list.append(combine_code_graphs(pair))

    # cleanup code_graph_list
  del(pairs)
    # save_paired_code_graphs(combined_code_graph_list)
  # elif not os.path.exists(code_graph_pickle_path):
  #   copy_from_gcs(code_graph_pickle_path)
  # combined_code_graph_list = pickle.load(open(code_graph_pickle_path, "rb"))
  # del(code_graph_list)
  return combined_code_graph_list

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


### PyTorch Conversion

In [27]:
from torch_geometric.data import Dataset
import glob
from torch_geometric.data.makedirs import makedirs
from itertools import product
from tqdm.notebook import tqdm

In [28]:
source_files = [os.path.basename(gcs_file_path_to_colab(file)) for file in list_gcs_files(os.path.join(CODE_DIR, "src", ""))]
RAW_FILES = []
for file in source_files:
  RAW_FILES.append(os.path.join("src", file))
  RAW_FILES.append(os.path.join("test", file.replace(".py", "_test.py")))

PROCESSED_FILES = [file.replace(".py", ".pt") for file in source_files]

In [53]:
import gast
from python_graphs import program_graph
from contextlib import suppress

def save_data(raw_paths, processed_dir):
  src_paths, test_paths = [], []

  for path in raw_paths:
    test_paths.append(path) if path.find("_test.py") != -1 else src_paths.append(path)

  source_test_pairs = list(product(src_paths, test_paths))

  idx = 0

  source_test_pairs = source_test_pairs

  for i, (src_path, test_path) in enumerate(pbar := tqdm(source_test_pairs)):
    try:
      with open(src_path, encoding="utf-8") as f:
          src_graph = program_graph.get_program_graph(gast.parse(f.read()))
          src_graph.filename = os.path.basename(src_path)

      with open(src_path, encoding="utf-8") as f:
        test_graph = program_graph.get_program_graph(gast.parse(f.read()))
        test_graph.filename = os.path.basename(src_path)
    except (TypeError, SyntaxError):
      pbar.set_description(f"Could not parse either {os.path.basename(src_path)} or {os.path.basename(test_path)}")
      continue
    
    pbar.set_description(f"pairing [{os.path.basename(src_path)}, {os.path.basename(test_path)}]")
    
    src_code_graph = CodeGraph()
    src_code_graph.read(src_graph)

    test_code_graph = CodeGraph()
    test_code_graph.read(test_graph)

    combined_code_graph = combine_code_graphs((src_code_graph, test_code_graph))

    data = combined_code_graph.pytorch_graph()

    torch.save(data, os.path.join(processed_dir, f"data_{idx}.pt"))

    pbar.set_description(f"saved data_{idx}.pt")
    
    idx += 1

class SourceTestDataset(Dataset):
  def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
      super().__init__(root, transform, pre_transform, pre_filter)

  @property
  def raw_file_names(self):
    return RAW_FILES

  @property
  def processed_file_names(self):
    return PROCESSED_FILES

  def download(self):
    # Download to `self.raw_dir`.
    copy_from_gcs(os.path.join(self.root, "src", ''), os.path.join(self.root, "raw", ''), " -r")
    copy_from_gcs(os.path.join(self.root, "test", ''), os.path.join(self.root, "raw", ''), " -r")

  def process(self):
    # Read data into huge `Data` list.
    pass

  def len(self):
    return len(self.processed_file_names)

  def get(self, idx):
    data = torch.load(os.path.join(self.processed_dir, f"data_{idx}.pt"))
    return data

In [None]:
print("PyTorch has version {}".format(torch.__version__))

PyTorch has version 1.13.0+cu116


In [41]:
raw_paths = [file_path for file_path in iglob(os.path.join(CODE_DIR, "raw", '**', '*.py'), recursive=True)]

In [None]:
save_data(raw_paths, os.path.join(CODE_DIR, "processed", ""))

  0%|          | 0/4761 [00:00<?, ?it/s]

In [None]:
%debug

In [None]:
dataset = SourceTestDataset(root=CODE_DIR)

In [None]:
!ls data/minified/processed

pre_filter.pt  pre_transform.pt


### Inspect Dataset

In [None]:
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')


Dataset: SourceTestDataset(10):
Number of graphs: 10


FileNotFoundError: ignored

In [None]:
%debug

> [0;32m/usr/local/lib/python3.8/dist-packages/torch/serialization.py[0m(251)[0;36m__init__[0;34m()[0m
[0;32m    249 [0;31m[0;32mclass[0m [0m_open_file[0m[0;34m([0m[0m_opener[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    250 [0;31m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mname[0m[0;34m,[0m [0mmode[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 251 [0;31m        [0msuper[0m[0;34m([0m[0m_open_file[0m[0;34m,[0m [0mself[0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0mopen[0m[0;34m([0m[0mname[0m[0;34m,[0m [0mmode[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    252 [0;31m[0;34m[0m[0m
[0m[0;32m    253 [0;31m    [0;32mdef[0m [0m__exit__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> self.processed_dir
*** AttributeError: '_open_file' object has no attribute 'pr

### Training Parameters

In [None]:
split_ratio = 0.8
batch_size = 64
hidden_channels = 64
learning_rate = 0.01

### Train/Test split

In [None]:
torch.manual_seed(12345)
dataset = dataset.shuffle()

split_idx = int(len(dataset)*split_ratio)

train_dataset = dataset[:split_idx]
test_dataset = dataset[split_idx:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

### Prepare Dataset Loader

In [None]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

### neptune.ai Integration

In [None]:
! pip install -Uqqq neptune-client
import neptune.new as neptune

neptune_api_token = getpass("Enter your Neptune API token: ")

project = "tjsun009/test-src-classifier"

run = neptune.init_run(
    api_token=neptune_api_token,
    project=project,
)


KeyboardInterrupt



## Training a Graph Neural Network (GNN)

copied from: https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb

Training a GNN for graph classification usually follows a simple recipe:

1. Embed each node by performing multiple rounds of message passing
2. Aggregate node embeddings into a unified graph embedding (**readout layer**)
3. Train a final classifier on the graph embedding

There exists multiple **readout layers** in literature, but the most common one is to simply take the average of node embeddings:

$$
\mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v
$$

PyTorch Geometric provides this functionality via [`torch_geometric.nn.global_mean_pool`](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.global_mean_pool), which takes in the node embeddings of all nodes in the mini-batch and the assignment vector `batch` to compute a graph embedding of size `[batch_size, hidden_channels]` for each graph in the batch.

The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training:

In [None]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

model = GCN(hidden_channels=hidden_channels)
print(model)

GCN(
  (conv1): GCNConv(7, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)


Here, we again make use of the [`GCNConv`](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv) with $\mathrm{ReLU}(x) = \max(x, 0)$ activation for obtaining localized node embeddings, before we apply our final classifier on top of a graph readout layer.

Let's train our network for a few epochs to see how well it performs on the training as well as test set:

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(1, 171):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

<IPython.core.display.Javascript object>

Epoch: 001, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 002, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 003, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 004, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 005, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 006, Train Acc: 0.6533, Test Acc: 0.7368
Epoch: 007, Train Acc: 0.7467, Test Acc: 0.7632
Epoch: 008, Train Acc: 0.7267, Test Acc: 0.7632
Epoch: 009, Train Acc: 0.7200, Test Acc: 0.7632
Epoch: 010, Train Acc: 0.7133, Test Acc: 0.7895
Epoch: 011, Train Acc: 0.7200, Test Acc: 0.7632
Epoch: 012, Train Acc: 0.7200, Test Acc: 0.7895
Epoch: 013, Train Acc: 0.7200, Test Acc: 0.7895
Epoch: 014, Train Acc: 0.7133, Test Acc: 0.8421
Epoch: 015, Train Acc: 0.7133, Test Acc: 0.8421
Epoch: 016, Train Acc: 0.7533, Test Acc: 0.7368
Epoch: 017, Train Acc: 0.7400, Test Acc: 0.7632
Epoch: 018, Train Acc: 0.7133, Test Acc: 0.8421
Epoch: 019, Train Acc: 0.7400, Test Acc: 0.7895
Epoch: 020, Train Acc: 0.7533, Test Acc: 0.7368
Epoch: 021, Train Acc: 0.7467, Test Acc:

As one can see, our model reaches around **76% test accuracy**.
Reasons for the fluctations in accuracy can be explained by the rather small dataset (only 38 test graphs), and usually disappear once one applies GNNs to larger datasets.



## Stop Neptune

In [None]:
# stop neptune
run.stop()

In [None]:
# idea for doing tested source MLM
# provide the source graph as input and mask a node in the test graph randomly
# predict what the node is, including node_type, node_value if applicable 