In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [1]:
from SourceCodeTools.models.training_config import load_config
from SourceCodeTools.code.data.dataset.Dataset import SourceGraphDataset
from SourceCodeTools.models.graph.train.sampling_multitask2 import SamplingMultitaskTrainer, \
    select_device
from SourceCodeTools.models.graph.train.objectives.NodeClassificationObjective import NodeClassifierObjective, \
    ClassifierTargetMapper
from SourceCodeTools.models.graph import RGGAN, RGCN

from copy import copy
from pathlib import Path

import pandas as pd
from tqdm import tqdm

import torch
from torch import nn
from os.path import join

Using backend: pytorch


# Prepare parameters and options

Full list of options that can be added can be found in `SourceCodeTools/models/training_options.py`. They are ment to be used as arguments for cli trainer. Trainer script can be found in `SourceCodeTools/scripts/train.py`.

For the task of subgraph classification the important options are:
- `subgraph_partition` is path to subgraph-based train/val/test sets. Storead as Dataframe with subgraph id and partition mask
- `subgraph_id_column` is a column is `common_edges` file that stores subgraph id.
- For variable misuse task (same will apply to authorship attribution) subgraphs are created for individual functions (files for SCAA). The label is stored in `common_filecontent`.

In [2]:
tokenizer_path = "sentencepiece_bpe.model"

data_path = "v2_subsample_v4_new_ast2_fixed_distinct_types/with_ast"
partition = join(data_path, "partition_type_prediction.json.bz2")
type_annotations_path = join(data_path, "type_annotations.json.bz2")

# checkpoint_path = "v2_subsample_v4_new_ast2_fixed_distinct_types/with_ast/RGCN-2023-01-29-18-19-15-468135"
# checkpoint_path = "v2_subsample_v4_new_ast2_fixed_distinct_types/with_ast/RGCN-2023-01-30-19-47-28-861157"
# checkpoint_path = "v2_subsample_v4_new_ast2_fixed_distinct_types/with_ast/RGCN-2023-01-30-14-08-35-988001_without_subwords"
checkpoint_path = "v2_subsample_v4_new_ast2_fixed_distinct_types/with_ast/RGCN-2023-01-31-17-58-43-579368_subword_masking"

In [3]:
# config = get_config(
#     data_path=data_path,
#     model_output_dir=data_path,
#     partition=partition,
#     tokenizer_path=tokenizer_path,
#     objectives="node_clf",
#     batch_size=1
# )

config = load_config(join(checkpoint_path, "config.yaml"))
config["TRAINING"]["restore_state"] = True
config["TRAINING"]["batch_size"] = 1

In [4]:
config

{'DATASET': {'custom_reverse': None,
  'data_path': '/Users/LTV/Downloads/NitroShare/v2_subsample_v4_new_ast2_fixed_distinct_types/with_ast/',
  'filter_edges': None,
  'min_count_for_objectives': 3,
  'no_global_edges': True,
  'partition': '/Users/LTV/Downloads/NitroShare/v2_subsample_v4_new_ast2_fixed_distinct_types/with_ast/partition_type_prediction.json.bz2',
  'random_seed': 42,
  'remove_reverse': False,
  'restricted_id_pool': None,
  'self_loops': False,
  'subgraph_id_column': None,
  'subgraph_partition': None,
  'train_frac': 0.9,
  'use_edge_types': True,
  'use_node_types': False},
 'MODEL': {'activation': 'relu',
  'dropout': 0.0,
  'h_dim': None,
  'n_layers': 5,
  'node_emb_size': 100,
  'num_bases': 10,
  'use_att_checkpoint': False,
  'use_gcn_checkpoint': False,
  'use_gru_checkpoint': False,
  'use_self_loop': True},
 'TOKENIZER': {'tokenizer_path': '/Users/LTV/Dropbox (Personal)/sentencepiece_bpe.model'},
 'TRAINING': {'batch_size': 1,
  'dilate_scores': 200,
  'e

# Load dataset

In [5]:
dataset = SourceGraphDataset(
    **{**config["DATASET"], **config["TOKENIZER"]}
)
ntypes, etypes = dataset.get_graph_types()
config["TRAINING"]['ntypes'] = ntypes
config["TRAINING"]['etypes'] = etypes

# Declare objective

In [31]:
class NodeClassifierObjectiveWithSaliency(NodeClassifierObjective):
    def __init__(self, **kwargs):
        super(NodeClassifierObjectiveWithSaliency, self).__init__(**kwargs)

    def saliency_for_batch(
            self, batch_ind, batch, optimizer
    ):
        optimizer.zero_grad()
        
        outputs = self(
            **batch
        )

        # collect gradients
        outputs.gnn_output.node_embeddings["node_"].retain_grad()
        outputs.gnn_output.input_embeddings["node_"].retain_grad()
        outputs.loss.backward()

        # get input node ids
        def get_original_node_ids(block):
            return block.srcnodes["node_"].data["original_id"].tolist()

        original_node_ids = get_original_node_ids(batch["blocks"][0])
        graph_node_ids = get_original_node_ids(batch["blocks"][-1])

        # compute saliency scores for nodes
        def compute_saliency_score(tensor_):
            return tensor_.abs().mean(dim=-1).numpy()

        input_node_saliency = compute_saliency_score(
            outputs.gnn_output.input_embeddings["node_"].grad.data
        )

        # retrieve additional information for nodes for display
        def get_information_for_nodes(node_ids):
            node_info = dataset._graph_storage.database.query(f"""
            select nodes.id, name, string, type_desc as type from nodes
            left join node_strings on node_strings.id = nodes.id
            left join node_types on node_types.type_id = nodes.type
            where nodes.id in ({",".join(set(map(str, node_ids)))})
            """)

            return {
                "names": dict(zip(node_info["id"], node_info["name"])),
                "strings": dict(zip(node_info["id"], node_info["string"])),
                "types": dict(zip(node_info["id"], node_info["type"]))
            }

        def get_function_for_node(node_id):
            function = dataset._graph_storage.database.query(f"""
            select node_strings.string
            from node_hierarchy
            join node_strings on node_hierarchy.mentioned_in = node_strings.id
            where node_hierarchy.id = {node_id}
            """)

            if len(function) == 0:
                return None
            else:
                return function.iloc[0,0]

        node_info = get_information_for_nodes(original_node_ids + graph_node_ids)
        function_text = get_function_for_node(graph_node_ids[0])

        def make_saliency_summary_table(node_ids, saliency, node_info):
            # sort input nodes by saliency
            id_saliency = sorted(
                zip(node_ids, saliency),
                key=lambda x: x[1], reverse=True
            )

            input_nodes_summary = {
                "ids": list(map(lambda x: x[0], id_saliency)),
                "score": list(map(lambda x: x[1], id_saliency)),
                "name": list(map(lambda x: node_info["names"][x[0]], id_saliency)),
                "type": list(map(lambda x: node_info["types"][x[0]], id_saliency)),
                "string": list(map(lambda x: node_info["strings"].get(x[0], pd.NA), id_saliency)),
            }

            return pd.DataFrame(input_nodes_summary)

        input_node_saliency_summary = make_saliency_summary_table(original_node_ids, input_node_saliency, node_info)

        decoder = self.dataloader.label_encoder.get_original_targets()
        # there is only one item in graph_node_ids
        print("Function:\n", function_text)
        print("Investigating: ", node_info["names"][graph_node_ids[0]])

        prediction = outputs.prediction[0][0]
        labels = outputs.labels[0][0]

        # print top 3 predicted classes
        print(f"Top 3 predicted classes:")
        for ind, c in enumerate(torch.argsort(prediction, descending=True)):
            if ind >= 3:
                break
            print(f"{decoder[c.item()]}\t{prediction[c].item():.4f}")

        # print the true class
        print(f"True class: {decoder[labels.item()]}")

        # print saliency in decreasing order
        print("Input saliency")
        print(input_node_saliency_summary[["ids", "score", "name", "type"]].head(10).to_string())
        print("\n\n\n")

    def saliency(self, data_split, optimizer):
        for batch_ind, batch in enumerate(tqdm(
                self.get_iterator(data_split), total=getattr(self, f"num_{data_split}_batches")
        )):
            self.saliency_for_batch(batch_ind, batch, optimizer)
            if batch_ind == 10:
                break


# Declare trainer

In [32]:
class SamplingMultitaskTrainerWithSaliency(SamplingMultitaskTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def train_step_for_objective(self, step, objective, objective_iterator, longterm_metrics):
        batch = next(objective_iterator)

        objective_output, scores = objective.make_step(
            step, batch, "train", longterm_metrics, scorer=None
        )

        if scores is None:
            return None

        objective_output["loss"].backward()


        return scores

    def create_objectives(self, dataset, tokenizer_path):
        self.objectives = nn.ModuleList()
        
        self.objectives.append(
            self._create_node_level_objective(
                objective_name="TypeAnnPrediction",
                objective_class=NodeClassifierObjectiveWithSaliency,
                label_loader_class=ClassifierTargetMapper,
                dataset=dataset,
                labels_fn=dataset.load_type_prediction,
                tokenizer_path=tokenizer_path,
                masker_fn=dataset.create_subword_masker,
                preload_for="package",
            )
        )

    def saliency(self):
        objective = self.objectives[0]
        objective.eval()
        objective.saliency("train", self.optimizer)

In [33]:
# training_procedure(
#     dataset,
#     model_name=RGGAN,
#     model_params=config["MODEL"],
#     trainer_params=config["TRAINING"],
#     model_base_path=get_model_base(config["TRAINING"], get_name(RGGAN, str(datetime.now()))),
#     trainer=SamplingMultitaskTrainerWithSaliency
# )

In [34]:
def compute_saliency(
        dataset, model_name, model_params, trainer_params, model_base_path,
        tokenizer_path=None, trainer=None, load_external_dataset=None
):
    model_params = copy(model_params)
    trainer_params = copy(trainer_params)

    if trainer is None:
        trainer = SamplingMultitaskTrainer

    device = select_device(trainer_params)

    trainer_params['model_base_path'] = model_base_path

    trainer = trainer(
        dataset=dataset,
        model_name=model_name,
        model_params=model_params,
        trainer_params=trainer_params,
        restore=trainer_params["restore_state"],
        device=device,
        pretrained_embeddings_path=trainer_params["pretrained"],
        tokenizer_path=tokenizer_path,
    )

    trainer.saliency()

In [35]:
compute_saliency(
    dataset,
    model_name=RGCN,
    model_params=config["MODEL"],
    trainer_params=config["TRAINING"],
    model_base_path=checkpoint_path,
    trainer=SamplingMultitaskTrainerWithSaliency
)

Model parameter `h_dim` is not provided, setting it to: 100


  0%|          | 1/1864 [00:06<3:18:22,  6.39s/it]

Function:
 def script_for_render_items(docs_json_or_id, render_items: List[RenderItem],
                            app_path: Optional[str] = None, absolute_url: Optional[str] = None) -> str:
    ''' Render an script for Bokeh render items.
    Args:
        docs_json_or_id:
            can be None

        render_items (RenderItems) :
            Specific items to render from the document and where

        app_path (str, optional) :

        absolute_url (Theme, optional) :

    Returns:
        str
    '''
    if isinstance(docs_json_or_id, str):
        docs_json = "document.getElementById('%s').textContent" % docs_json_or_id
    else:
        # XXX: encodes &, <, > and ', but not ". This is because " is used a lot in JSON,
        # and encoding it would significantly increase size of generated files. Doing so
        # is safe, because " in strings was already encoded by JSON, and the semi-encoded
        # JSON string is included in JavaScript in single quotes.
        docs_json

  0%|          | 2/1864 [00:06<2:23:07,  4.61s/it]

Function:
 def _base64_decode(encoded: Union[bytes, str], encoding: Optional[str] = None) -> Union[bytes, str]:
    # base64 lib both takes and returns bytes, we want to work with strings
    encoded_as_bytes = codecs.encode(encoded, 'ascii') if isinstance(encoded, str) else encoded
    # put the padding back
    mod = len(encoded_as_bytes) % 4
    if mod != 0:
        encoded_as_bytes = encoded_as_bytes + (b"=" * (4 - mod))
    assert (len(encoded_as_bytes) % 4) == 0
    result = base64.urlsafe_b64decode(encoded_as_bytes)
    if encoding:
        return codecs.decode(result, 'utf-8')
    return result
Investigating:  encoding@FunctionDef_0x16e393f83b4b2d8b
Predicted class:
bool	0.9997
Union	0.0002
int	0.0001
True class: Optional
Input saliency
       ids     score                                     name         type
0   140757  0.275501                    If_0x16e393f83b9d6884           If
1   493821  0.011624                Return_0x16e393f83bb0e1b2       Return
2    11083  0.010450

  0%|          | 3/1864 [00:07<1:43:55,  3.35s/it]

Function:
 class ExtensionEmbed(NamedTuple):
    artifact_path: str
    server_url: str
    cdn_url: Optional[str] = None
Investigating:  cdn_url@ClassDef_0x16e393f1de4cae84
Predicted class:
Optional	1.0000
str	0.0000
int	0.0000
True class: Optional
Input saliency
       ids         score                                       name         type
0    69376  6.838461e-07               AnnAssign_0x16e393f1de9bf4f9    AnnAssign
1  2101673  6.102767e-07  bokeh.embed.bundle.ExtensionEmbed.cdn_url  class_field
2  1396154  1.824241e-07                          Constant_NoneType     Constant
3    27710  1.102966e-07        cdn_url@ClassDef_0x16e393f1de4cae84      mention
4  1582286  8.070634e-08                ClassDef_0x16e393f1de4cae84     ClassDef
5  2545114  5.333061e-08               AnnAssign_0x16e393f1de552f33    AnnAssign
6  1652746  2.194407e-08                  Module_0x16e393f1ddf954f7       Module
7   940594  3.660112e-09     server_url@ClassDef_0x16e393f1de4cae84      mention
8  115

  0%|          | 4/1864 [00:07<1:16:44,  2.48s/it]

Function:
 def hexbin(x: Any, y: Any, size: float, orientation: str = "pointytop", aspect_scale: float = 1) -> Any:
    ''' Perform an equal-weight binning of data points into hexagonal tiles.

    For more sophisticated use cases, e.g. weighted binning or scaling
    individual tiles proportional to some other quantity, consider using
    HoloViews.

    Args:
        x (array[float]) :
            A NumPy array of x-coordinates for binning

        y (array[float]) :
            A NumPy array of y-coordinates for binning

        size (float) :
            The size of the hexagonal tiling.

            The size is defined as the distance from the center of a hexagon
            to the top corner for "pointytop" orientation, or from the center
            to a side corner for "flattop" orientation.

        orientation (str, optional) :
            Whether the hex tile orientation should be "pointytop" or
            "flattop". (default: "pointytop")

        aspect_scale (float, opti

  0%|          | 5/1864 [00:08<58:36,  1.89s/it]  

Function:
 def linear_palette(palette: Palette, n: int) -> Palette:
    ''' Generate a new palette as a subset of a given palette.

    Given an input ``palette``, take ``n`` colors from it by dividing its
    length into ``n`` (approximately) evenly spaced indices.

    Args:

        palette (seq[str]) : a sequence of hex RGB color strings
        n (int) : the size of the output palette to generate

    Returns:
        seq[str] : a sequence of hex RGB color strings

    Raises:
        ``ValueError`` if ``n > len(palette)``

    '''
    if n > len(palette):
        raise ValueError("Requested %(r)s colors, function can only return colors up to the base palette's length (%(l)s)" % dict(r=n, l=len(palette)))
    return tuple( palette[int(math.floor(i))] for i in np.linspace(0, len(palette)-1, num=n) )
Investigating:  palette@FunctionDef_0x16e393f775f913a4
Predicted class:
Sequence	0.8297
Dict	0.1111
Palette	0.0572
True class: Palette
Input saliency
       ids     score               

  0%|          | 6/1864 [00:08<45:59,  1.49s/it]

Function:
 def nice_join(seq: Sequence[str], sep: str = ", ", conjuction: str = "or") -> str:
    ''' Join together sequences of strings into English-friendly phrases using
    the conjunction ``or`` when appropriate.

    Args:
        seq (seq[str]) : a sequence of strings to nicely join
        sep (str, optional) : a sequence delimiter to use (default: ", ")
        conjunction (str or None, optional) : a conjuction to use for the last
            two items, or None to reproduce basic join behaviour (default: "or")

    Returns:
        a joined string

    Examples:
        >>> nice_join(["a", "b", "c"])
        'a, b or c'

    '''
    seq = [str(x) for x in seq]

    if len(seq) <= 1 or conjuction is None:
        return sep.join(seq)
    else:
        return "%s %s %s" % (sep.join(seq[:-1]), conjuction, seq[-1])
Investigating:  seq@FunctionDef_0x16e393f61773c7da
Predicted class:
Sequence	0.9906
bytes	0.0093
Union	0.0001
True class: Sequence
Input saliency
       ids     score  

  0%|          | 7/1864 [00:09<38:40,  1.25s/it]

Function:
 def red(text: str) -> str:    return "%s%s%s" % (Fore.RED, text, Style.RESET_ALL)
Investigating:  text@FunctionDef_0x16e393f89887a81b
Predicted class:
int	0.7621
str	0.2378
TracebackType	0.0001
True class: str
Input saliency
       ids     score                                 name         type
0  1896862  0.146427             Tuple_0x16e393f898c5b477        Tuple
1  2575840  0.101647         arguments_0x16e393f8988b4e61    arguments
2  2414025  0.087682         Attribute_0x16e393f8984a407c    Attribute
3  2013294  0.081670         Attribute_0x16e393f898f935ef    Attribute
4  1360552  0.033268               arg_0x16e393f898cfeba2          arg
5   479867  0.027198                                  RED       #attr#
6  2374696  0.027137                            RESET_ALL       #attr#
7   386162  0.012943             BinOp_0x16e393f898a11fde        BinOp
8    59093  0.004072  text@FunctionDef_0x16e393f89887a81b      mention
9   282232  0.003445       FunctionDef_0x16e393f89887a

  0%|          | 8/1864 [00:09<31:27,  1.02s/it]

Function:
 def invoke(self, args: argparse.Namespace) -> None:
        '''

        '''
        argvs = { f : args.args for f in args.files}
        applications = build_single_handler_applications(args.files, argvs)

        if args.output is None:
            outputs: List[str] = []
        else:
            outputs = list(args.output)  # copy so we can pop from it

        if len(outputs) > len(applications):
            die("--output/-o was given too many times (%d times for %d applications)" %
                (len(outputs), len(applications)))

        for (route, app) in applications.items():
            doc = app.create_document()

            if len(outputs) > 0:
                filename = outputs.pop(0)
            else:
                filename = self.filename_from_route(route, self.extension)

            self.write_file(args, filename, doc)
Investigating:  args@FunctionDef_0x16e393f34a01a586
Predicted class:
Type	0.6661
_CodeWriter	0.3107
Rule	0.0170
True class: Namespace
I

  0%|          | 9/1864 [00:10<25:59,  1.19it/s]

Function:
 def report_server_init_errors(address: Optional[str] = None, port: Optional[int] = None, **kwargs: str) -> Iterator[None]:
    ''' A context manager to help print more informative error messages when a
    ``Server`` cannot be started due to a network problem.

    Args:
        address (str) : network address that the server will be listening on

        port (int) : network address that the server will be listening on

    Example:

        .. code-block:: python

            with report_server_init_errors(**server_kwargs):
                server = Server(applications, **server_kwargs)

        If there are any errors (e.g. port or address in already in use) then a
        critical error will be logged and the process will terminate with a
        call to ``sys.exit(1)``

    '''
    try:
        yield
    except EnvironmentError as e:
        if e.errno == errno.EADDRINUSE:
            log.critical("Cannot start Bokeh server, port %s is already in use", port)
        elif

  1%|          | 10/1864 [00:10<22:00,  1.40it/s]

Function:
 def set_single_plot_width_height(doc: Document, width: Optional[int], height: Optional[int]) -> None:
    if width is not None or height is not None:
        layout = doc.roots
        if len(layout) != 1 or not isinstance(layout[0], Plot):
        else:
            plot = layout[0]
            # TODO - below fails mypy check
            # unsure how to handle with typing. width is int base type and class property getter is typing.Int
            # plot.plot_width  = width if width is not None else plot.plot_width  # doesnt solve problem
            plot.plot_height = height or plot.plot_height
            plot.plot_width  = width or plot.plot_width
Investigating:  height@FunctionDef_0x16e393f3521f6615
Predicted class:
Any	0.5652
int	0.2723
Dict	0.1374
True class: Optional
Input saliency
       ids     score                          name       type
0  1330382  0.152810    Compare_0x16e393f352bc4abc    Compare
1  1353336  0.085995     BoolOp_0x16e393f3524317fb     BoolOp
2   

  1%|          | 10/1864 [00:11<34:29,  1.12s/it]

Function:
 def _version(modname: str, attr: str) -> Optional[Any]:
    mod = import_optional(modname)
    if mod:
        return getattr(mod, attr)
    else:  # explicit None return for mypy typing
        return None
Investigating:  modname@FunctionDef_0x16e393f5116bb8d6
Predicted class:
int	0.8102
str	0.1898
Optional	0.0000
True class: str
Input saliency
       ids     score                                            name       type
0  1412303  0.024165                         Call_0x16e393f511ce83ac       Call
1   828631  0.023561                          arg_0x16e393f511d07e01        arg
2    43748  0.020011  import_optional@FunctionDef_0x16e393f5116bb8d6    mention
3  2149820  0.016519         bokeh.util.dependencies.import_optional   function
4  1739571  0.015651                    arguments_0x16e393f511e303e9  arguments
5  1025498  0.011001                                        optional    subword
6   873758  0.011000                                         ▁import    subword
7


