# Interactive Predictions
This notebook showcases the preprocessing pipeline of the `CodeTransformer` as well as predicting the method name for an arbitrary code snippet in one of the 5 languages (Java, Python, JavaScript, Ruby and Go) that we explored in the paper.  
Once you downloaded the respective models and dataset files (we need the vocabularies and data configs for inference), and setup the paths in `env.py` you can load any model mentioned in the README and feed them with any code snippet to obtain a prediction for the method name.

In [None]:
%cd ..
%reload_ext autoreload
%autoreload 2

In [None]:
from code_transformer.preprocessing.datamanager.preprocessed import CTPreprocessedDataManager
from code_transformer.preprocessing.graph.binning import ExponentialBinning
from code_transformer.preprocessing.graph.distances import PersonalizedPageRank, ShortestPaths, \
    AncestorShortestPaths, SiblingShortestPaths, DistanceBinning
from code_transformer.preprocessing.graph.transform import DistancesTransformer
from code_transformer.preprocessing.nlp.vocab import VocabularyTransformer, CodeSummarizationVocabularyTransformer
from code_transformer.preprocessing.pipeline.stage1 import CTStage1Preprocessor
from code_transformer.preprocessing.pipeline.stage2 import CTStage2MultiLanguageSample
from code_transformer.utils.inference import get_model_manager, make_batch_from_sample, decode_predicted_tokens
from code_transformer.env import DATA_PATH_STAGE_2

%reload_ext autoreload
%autoreload 2

# 1. Load Model

## 1.1. Specify run ID
All our models are listed in the [README](../README.md) together with their corresponding `run_id` as well as the stored snapshot. 

In [None]:
model_type = 'code_transformer'  # code_transformer, great or xl_net
run_id = 'CT-6'  # Name of folder in which snapshots are stored
snapshot = 'latest'  # Use 'latest' for the last stored snapshot

In [None]:
model_manager = get_model_manager(model_type)

In [None]:
model_config = model_manager.load_config(run_id)

language = model_config['data_setup']['language']
print(f"Model was trained on: {language}")

## 1.2. Construct model

In [None]:
model = model_manager.load_model(run_id, snapshot, gpu=False)
model = model.eval()

# 2. Specify any code snippet
Code snippet has to be in the target language and the method name to be predicted should be marked with `f`.

In [None]:
code_snippet = """
"""
code_snippet_language = ''  # java, javascript, python, ruby, go

## 2.1. Examples from Paper

In [None]:
code_snippet = """public int f(Pair<LoggedJob, JobTraceReader> p1,Pair<LoggedJob, JobTraceReader> p2) {
    LoggedJob j1 = p1.first();
    LoggedJob j2 = p2.first();
    return(j1.getSubmitTime() < j2.getSubmitTime()) ? -1 : (j1.getSubmitTime() == j2.getSubmitTime()) ? 0 : 1;
}"""
code_snippet_language = 'java'

In [None]:
code_snippet = """public static MNTPROC f(int value) {
    if(value < 0 || value >= values().length) {
        return null;
    }
    return values()[value];
}"""
code_snippet_language = 'java'

In [None]:
code_snippet = """private Iterable<ListBlobItem> f(String aPrefix, boolean useFlatBlobListing, EnumSet<BlobListingDetails> listingDetails, BlobRequestOptions options, OperationContext opContext) throws StorageException, URISyntaxException {
    CloudBlobDirectoryWrapper directory = this.container.getDirectoryReference(aPrefix);
    return directory.listBlobs(null, useFlatBlobListing, listingDetails, options, opContext);
}"""
code_snippet_language = 'java'

In [None]:
code_snippet = """private static void f(EnumMap<FSEditLogOpCodes, Holder<Integer>> opCounts) {
    StringBuilder sb = newStringBuilder();
    sb.append("Summary of operations loaded from edit log:  ");
    Joiner.on("  ").withKeyValueSeparator("=").appendTo(sb, opCounts);
    FSImage.LOG.debug(sb.toString());
}"""
code_snippet_language = 'java'

In [None]:
code_snippet = """static String f(File f, String... cmd) throws IOException {
    String[] args = new String[cmd.length + 1];
    System.arraycopy(cmd, 0, args, 0, cmd.length);
    args[cmd.length] = f.getCanonicalPath();
    String output = Shell.execCommand(args);
    return output;
}"""
code_snippet_language = 'java'

In [None]:
code_snippet = """protected void f(Class<? extends SubView> cls) {
    indent(of(ENDTAG));
    sb.setLength(0);
    out.print(sb.append('[').append(cls.getName()).append(']').toString());
    out.println();
}"""
code_snippet_language = 'java'

In [None]:
code_snippet = """
function f() {
    var quotes = new Array();
    quotes[0] = "Action is the real measure of intelligence.";
    quotes[1] = "Baseball has the great advantage over cricket of being sooner ended.";
    quotes[2] = "Every goal, every action, every thought, every feeling one experiences, whether it be consciously or unconsciously known, is an attempt to increase one's level of peace of mind.";
    quotes[3] = "A good head and a good heart are always a formidable combination.";
    var rand = Math.floor(Math.random()*quotes.length);
    document.write(quotes[rand]);
}
"""
code_snippet_language = 'javascript'

# 3. Preprocess

## 3.1. Stage 1 (AST generation)

In [None]:
preprocessor = CTStage1Preprocessor(code_snippet_language, allow_empty_methods=True)
stage1_sample = preprocessor.process([("f", "", code_snippet)], 0)

## 3.2. Stage 2 (Distance matrices)
We have to mimic the preprocessing to match exactly what the model has been trained on. To this end, we make use of the respective dataset config that was stored during preprocessing. 

In [None]:
# Load the config of the respective dataset that this model was trained on
model_language = model_config['data_setup']['language']
data_manager = CTPreprocessedDataManager(DATA_PATH_STAGE_2, model_language, partition='train', shuffle=True)
data_config = data_manager.load_config()

# Extract how distances should be computed from the dataset config
distances_config = data_config['distances']
PPR_ALPHA = distances_config['ppr_alpha']
PPR_USE_LOG = distances_config['ppr_use_log']
PPR_THRESHOLD = distances_config['ppr_threshold']

SP_THRESHOLD = distances_config['sp_threshold']

ANCESTOR_SP_FORWARD = distances_config['ancestor_sp_forward']
ANCESTOR_SP_BACKWARD = distances_config['ancestor_sp_backward']
ANCESTOR_SP_NEGATIVE_REVERSE_DISTS = distances_config['ancestor_sp_negative_reverse_dists']
ANCESTOR_SP_THRESHOLD = distances_config['ancestor_sp_threshold']

SIBLING_SP_FORWARD = distances_config['sibling_sp_forward']
SIBLING_SP_BACKWARD = distances_config['sibling_sp_backward']
SIBLING_SP_NEGATIVE_REVERSE_DISTS = distances_config['sibling_sp_negative_reverse_dists']
SIBLING_SP_THRESHOLD = distances_config['sibling_sp_threshold']

# Extract how distances should be binned from the dataset config
binning_config = data_config['binning']
EXPONENTIAL_BINNING_GROWTH_FACTOR = binning_config['exponential_binning_growth_factor']
N_FIXED_BINS = binning_config['n_fixed_bins']
NUM_BINS = binning_config['num_bins']

preprocessing_config = data_config['preprocessing']
REMOVE_PUNCTUATION = preprocessing_config['remove_punctuation']

# Put together all the implementations of the different distance metrics
distance_metrics = [
    PersonalizedPageRank(threshold=PPR_THRESHOLD, log=PPR_USE_LOG, alpha=PPR_ALPHA),
    ShortestPaths(threshold=SP_THRESHOLD),
    AncestorShortestPaths(forward=ANCESTOR_SP_FORWARD, backward=ANCESTOR_SP_BACKWARD,
                          negative_reverse_dists=ANCESTOR_SP_NEGATIVE_REVERSE_DISTS,
                          threshold=ANCESTOR_SP_THRESHOLD),
    SiblingShortestPaths(forward=SIBLING_SP_FORWARD, backward=SIBLING_SP_BACKWARD,
                         negative_reverse_dists=SIBLING_SP_NEGATIVE_REVERSE_DISTS,
                         threshold=SIBLING_SP_THRESHOLD)]

db = DistanceBinning(NUM_BINS, N_FIXED_BINS, ExponentialBinning(EXPONENTIAL_BINNING_GROWTH_FACTOR))

distances_transformer = DistancesTransformer(distance_metrics, db)
vocabs = data_manager.load_vocabularies()
if len(vocabs) == 4:
    vocabulary_transformer = CodeSummarizationVocabularyTransformer(*vocabs)
else:
    vocabulary_transformer = VocabularyTransformer(*vocabs)

In [None]:
# Now, take the result of stage1 preprocessing and feed it through the vocabulary and distances transformer to obtain a stage2 sample

stage2_sample = stage1_sample[0]
if REMOVE_PUNCTUATION:
    stage2_sample.remove_punctuation()
stage2_sample = vocabulary_transformer(stage2_sample)
stage2_sample = distances_transformer(stage2_sample)

if ',' in model_language:
    # In the multi-lingual setting, we have to furthermore bake the code snippet language into the sample
    stage2_sample = CTStage2MultiLanguageSample(stage2_sample.tokens, stage2_sample.graph_sample, stage2_sample.token_mapping,
                                                stage2_sample.stripped_code_snippet, stage2_sample.func_name,
                                                stage2_sample.docstring,
                                                code_snippet_language,
                                                encoded_func_name=stage2_sample.encoded_func_name if hasattr(stage2_sample, 'encoded_func_name') else None)

## 3.3. Prepare sample to feed into model

In [None]:
batch = make_batch_from_sample(stage2_sample, model_config, model_type)

# 4. Prediction from model

In [None]:
output = model.forward_batch(batch)

In [None]:
k = 3
predictions = output.logits \
    .topk(k, axis=-1)\
    .indices\
    .squeeze()\
    .T

In [None]:
print('Predicted method names:')
for i, prediction in enumerate(predictions):
    predicted_method_name = decode_predicted_tokens(prediction, batch, data_manager)
    print(f"  ({i + 1}) ", ' '.join(predicted_method_name))

## 4.1. Code Snippet embedding
In order to obtain a meaningful embedding of the provided AST/Source code pair, one can use the Query Stream Embedding of the masked method name token in the final encoder layer.

In [None]:
encoder_output = model.lm_encoder.forward_batch(batch, need_all_embeddings=True)
query_stream_embedding = encoder_output.all_emb[-1][1]  # [1, B, D]