In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

# Task description

To train a variable misuse detection model one needs to implement an NLP labeling model.

For example, for a funciton containing misuse
```
def _eq(l1, l2):\n    return (set(l1) == set(l1))
```
the misuse character span is (44, 46). To do this with NLP methods, code is tokenized, and labels for tokens are generated
```
[def, _, eq, (, l, 1, ",", l, 2, ):, \n, \t, return, (, set, (, l1, ), ==, set, (, l1, ), ), ]
[O  , O, O , O, O, O,  O , O, O, 0 , O , O ,    O  , O, O  , O, O , O, O , O  , O, M , O, O, O
```
The goal is to train an NLP model that predicts those labels correctly. In this project, BILUO labeling scheme is used.

# Goal

The goal of this project
1. Verify dataset, make sure that encoded batches are correct (misuse spans are correct). You can sample dataset and make sure that the number of errors is less than a certain threshold.
2. Train variable misuse detection model (with finetuning and without)
3. Verify [scoring function](https://github.com/VitalyRomanov/method-embedding/blob/e995477db13a13875cca54c37d4d29f63b0c8e93/SourceCodeTools/nlp/entity/type_prediction.py#L71)
4. Conduct a series of experiments to identify performance
5. Analyze errors

# Why using this example?

Basic functionality, necessary for train an NLP labeler is
1. Loading data (implemented in this example)
2. Tokenization, preparing labels (implemented in [`PythonBatcher.prepare_sent`](https://github.com/VitalyRomanov/method-embedding/blob/e995477db13a13875cca54c37d4d29f63b0c8e93/SourceCodeTools/nlp/batchers/PythonBatcher.py#L123))
3. Data encoding for using with ML models (implemented in [`PythonBatcher.create_batches_with_mask`](https://github.com/VitalyRomanov/method-embedding/blob/e995477db13a13875cca54c37d4d29f63b0c8e93/SourceCodeTools/nlp/batchers/PythonBatcher.py#L206))
4. Batching (implemented in [`PythonBatcher.format_batch`](https://github.com/VitalyRomanov/method-embedding/blob/e995477db13a13875cca54c37d4d29f63b0c8e93/SourceCodeTools/nlp/batchers/PythonBatcher.py#L256))
5. Model training (partially implemented in [`CodeBertModelTrainer2.train_model`](https://github.com/VitalyRomanov/method-embedding/blob/e995477db13a13875cca54c37d4d29f63b0c8e93/SourceCodeTools/nlp/codebert/codebert_train.py#L148) and extended here)
6. Tensorboard tracking (implemented in `CodeBertModelTrainer2`)

# Install libraries

1. See [installation steps](https://github.com/VitalyRomanov/method-embedding#installing-python-libraries).

2. Install transformers
```bash
pip install transformers
```

In [2]:
import os
from os.path import join
from argparse import Namespace
from SourceCodeTools.nlp.codebert.codebert_train import CodeBertModelTrainer2, CodebertHybridModel, batch_to_torch
from SourceCodeTools.nlp.entity.type_prediction import scorer
from transformers import RobertaTokenizer, RobertaModel

import json
import torch
from time import time
from copy import copy
from datetime import datetime
import hashlib
import tempfile
from collections import defaultdict
from pathlib import Path
from math import ceil
from typing import Dict, Optional, List, Union
from SourceCodeTools.nlp.tokenizers import _inject_tokenizer
# from sqlitedict import SqliteDict

from tqdm import tqdm
import numpy as np

import diskcache as dc


from SourceCodeTools.nlp.entity import fix_incorrect_tags
from SourceCodeTools.code.annotator_utils import adjust_offsets, biluo_tags_from_offsets

# Definitions

## Tokenizer

Current code works with many tokenizers. The most comparible format for storing labels is to store them as character spans. Character spans for labels are mapped to tokens with Spacy's `biluo_tags_from_offsets`. For this reason, we need to have instruments to make tokenizers compatible with Spacy format.

In [3]:
class AdapterDoc:
    """
    A simple wrapper for tokens that also stores additional data such as character span adjustment and 
    tokens compatible with `biluo_tags_from_offsets`
    """
    def __init__(self, tokens):
        self.tokens = tokens
        self.adjustment_amount = 0
        self.tokens_for_biluo_alignment = None

    def __iter__(self):
        return iter(self.tokens)

    def __repr__(self):
        return "".join(self.tokens)
    
    def __len__(self):
        return len(self.tokens)


class CodebertAdapter:
    """
    This tokenizer returns tokens in a format that can be used with `biluo_tags_from_offsets`
    """
    def __init__(self):
        from transformers import RobertaTokenizer
        import spacy

        # create primary tokenizer
        self.tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
        # create secondary tokenizer, need this to fix token alignment errors
        self.regex_tok = create_tokenizer("regex")
        # need to have a blank spacy model for compatibility
        self.nlp = spacy.blank("en")

    def primary_tokenization(self, text):
        return self.tokenizer.tokenize(text)

    def secondary_tokenization(self, tokens):
        # secondary tokenizer performs subtokenization 
        # example:
        # "(arg1" -> "(", "arg1"
        new_tokens = []
        for token in tokens:
            new_tokens.extend(self.regex_tok(token))
        return new_tokens

    def __call__(self, text):
        """
        Tokenization function. Example:
            original string: 'a + b'
            codebert tokenized: '<s>', 'a', 'Ġ+', 'Ġb', '</s>'
        """
        from spacy.tokens import Doc
        tokens = self.primary_tokenization(text)
        tokens = self.secondary_tokenization(tokens)
        doc = Doc(self.nlp.vocab, tokens, spaces=[False] * len(tokens))

        backup_tokens = doc
        fixed_spaces = [False]
        fixed_words = ["<s>"]  # add additional tokens for codebert to avoid adding them later.

        for ind, t in enumerate(doc):
            if len(t.text) > 1:
                fixed_words.append(t.text.strip("Ġ"))
            else:
                fixed_words.append(t.text)
            if ind != 0:
                fixed_spaces.append(t.text.startswith("Ġ") and len(t.text) > 1)
        fixed_spaces.append(False)
        fixed_spaces.append(False)
        fixed_words.append("</s>")

        assert len(fixed_spaces) == len(fixed_words)

        doc = Doc(self.nlp.vocab, fixed_words, fixed_spaces)

        assert len(doc) - 2 == len(backup_tokens)
        assert len(doc.text) - 7 == len(backup_tokens.text)

        final_doc = AdapterDoc(["<s>"] + [t.text for t in backup_tokens] + ["</s>"])
        final_doc.adjustment_amount = -3
        final_doc.tokens_for_biluo_alignment = doc

        return final_doc
    
    
def create_tokenizer(type, bpe_path=None, regex=None):
    if type == "spacy":
        import spacy
        print("Creating spacy tokenizer")
        return _inject_tokenizer(spacy.blank("en"))
    elif type == "codebert":
        from transformers import RobertaTokenizer
        import spacy
        from spacy.tokens import Doc
        print("Creating CodeBERT tokenizer")
        adapter = CodebertAdapter()

        def tokenize(text):
            return adapter(text)

        return tokenize
    else:
        raise Exception("Supported tokenizer types: spacy, regex, bpe")

## Batcher

In [4]:
from typing import Dict, Optional, List, Union
from SourceCodeTools.nlp import create_tokenizer, tag_map_from_sentences, TagMap, token_hasher, try_int, ValueEncoder


class SampleEntry(object):
    def __init__(self, id, text, labels=None, category=None, **kwargs):
        self._storage = dict()
        self._storage["id"] = id
        self._storage["text"] = text
        self._storage["labels"] = labels
        self._storage["category"] = category
        self._storage.update(kwargs)

    def __getattr__(self, item):
        storage = object.__getattribute__(self, "_storage")
        if item in storage:
            return storage[item]
        return super().__getattribute__(item)

    def __repr__(self):
        return repr(self._storage)

    def __getitem__(self, item):
        return self._storage[item]

    def keys(self):
        return list(self._storage.keys())


class MapperSpec:
    def __init__(self, field, target_field, encoder, dtype=np.int32, preproc_fn=None):
        self.field = field
        self.target_field = target_field
        self.encoder = encoder
        self.preproc_fn = preproc_fn
        self.dtype = dtype

        
class SimplePythonBatcher:
    def __init__(
            self, data, batch_size: int, seq_len: int,
            wordmap: Dict[str, int], *, tagmap: Optional[TagMap] = None,
            class_weights=False, element_hash_size=1000, sort_by_length=True, tokenizer="spacy", no_localization=False,
            cache_dir: Optional[Union[str, Path]] = None, **kwargs
    ):

        self._batch_size = batch_size
        self._batch_count = None
        self._max_seq_len = seq_len
        self._tokenizer = tokenizer
        self._class_weights = None
        self._no_localization = no_localization
        self._nlp = create_tokenizer(tokenizer)
        self._cache_dir = Path(cache_dir) if cache_dir is not None else cache_dir
        self._valid_sentences = 0
        self._filtered_sentences = 0
        self._wordmap = wordmap
        self.tagmap = tagmap
        self._sort_by_length = sort_by_length
        self._data = data
        self._length = None

        self._create_cache()
        self._prepare_data(data)
        self._create_mappers(**kwargs)

    def _get_version_code(self):
        defining_parameters = json.dumps({
            "tokenizer": self._tokenizer, "max_seq_len": self._max_seq_len
        })
        return self._compute_text_id(defining_parameters)

    @staticmethod
    def _compute_text_id(text):
        return int(hashlib.md5(text.encode('utf-8')).hexdigest(), 16) % 1152921504606846976

    def _check_cache_dir(self):
        if not hasattr(self, "_cache_dir") or self._cache_dir is None:
            raise Exception("Cache directory location has not been specified yet")

    def _get_cache_location_name(self, cache_name):
        self._check_cache_dir()
        return str(self._cache_dir.joinpath(cache_name))

    @property
    def _data_cache_path(self):
        return self._get_cache_location_name("DataCache")

    # @property
    # def _sent_cache_path(self):
    #     return self._get_cache_location_name("SentCache")

    @property
    def _batch_cache_path(self):
        return self._get_cache_location_name("BatchCache")

    @property
    def _tagmap_path(self):
        return self._cache_dir.joinpath("tagmap.json")

    def _create_cache(self):
        if self._cache_dir is None:
            self._tmp_dir = tempfile.TemporaryDirectory()
            self._cache_dir = Path(self._tmp_dir.name)

        self._cache_dir = self._cache_dir.joinpath(f"PythonBatcher{self._get_version_code()}")
        self._cache_dir.mkdir(parents=True, exist_ok=True)
        print(f"Created cache directory at location {self._cache_dir}")

        # self._data_cache = SqliteDict(self._data_cache_path) # dc.Cache(self._data_cache_path)
        # # self._sent_cache = dc.Cache(self._sent_cache_path)
        # self._batch_cache = SqliteDict(self._batch_cache_path) #dc.Cache(self._batch_cache_path)
        
    def _parse_entry(self, text, annotations):
        id_ = self._compute_text_id(text)
        extra = copy(annotations)
        labels = extra.pop("entities")
        extra.update(self._prepare_tokenized_sent((text, annotations)))
        return SampleEntry(id=id_, text=text, labels=labels, **extra)

    def _prepare_data(self, data):
        pass
#         self._sent_lenghts = {}

#         for text, annotations in tqdm(data, desc="Preprocess functions"):
#             id_ = self._compute_text_id(text)
#             if id_ not in self._data_cache:
#                 extra = copy(annotations)
#                 labels = extra.pop("entities")
#                 extra.update(self._prepare_tokenized_sent((text, annotations)))
#                 entry = SampleEntry(id=id_, text=text, labels=labels, **extra)
#                 self._data_cache[id_] = entry
#             else:
#                 entry = self._data_cache[id_]
#             self._sent_lenghts[id_] = len(entry.tokens)
        
#     def _iterate_record_ids(self):
#         return self._data_cache.iterkeys()
    
#     def _get_record_with_id(self, id):
#         if id not in self._data_cache:
#             raise KeyError("Record with such id is not found")
#         return self._data_cache[id]

#     def _iterate_sorted_by_length(self, limit_max_length=False):
#         for id_, length in sorted(self._sent_lenghts.items(), key=lambda x: x[1]):
#             if limit_max_length and length >= self._max_seq_len:
#                 continue
#             yield self._get_record_with_id(id_)

    def _iterate_records(self, limit_max_length=False, shuffle=False):
        for data, annotations in self._data:
            entry = self._parse_entry(data, annotations)
            if limit_max_length and len(entry.tokens) >= self._max_seq_len:
                continue
            yield entry
        # for id_ in self._sent_lenghts.keys():
        #     if limit_max_length and self._sent_lenghts[id_] >= self._max_seq_len:
        #         continue
        #     yield self._get_record_with_id(id_)

    def _create_mappers(self, **kwargs):
        self._mappers = []
        self._create_wordmap_encoder()
        # self._create_tagmap_encoder()

    def _create_tagmap_encoder(self):
        if self.tagmap is None:
            if self._tagmap_path.is_file():
                tagmap = TagMap.load(self._tagmap_path)
            else:
                def iterate_tags():
                    for record in self._iterate_records():
                        for label in record.tags:
                            yield label

                tagmap = tag_map_from_sentences(iterate_tags())
                tagmap.set_default(tagmap._value_to_code["O"])
                tagmap.save(self._tagmap_path)

            self.tagmap = tagmap

        self._mappers.append(
            MapperSpec(field="tags", target_field="tags", encoder=self.tagmap)
        )
        # self.tagmap = tagmap
        # self.tagpad = self.tagmap["O"]

    def _create_wordmap_encoder(self):
        wordmap_enc = ValueEncoder(value_to_code=self._wordmap)
        wordmap_enc.set_default(len(self._wordmap))
        self._mappers.append(
            MapperSpec(field="tokens", target_field="tok_ids", encoder=wordmap_enc)
        )

    @property
    def num_classes(self):
        return 2  # binary classification

    def _prepare_tokenized_sent(self, sent):
        text, annotations = sent

        doc = self._nlp(text)
        label = 1.0 if len(annotations['entities']) > 0 else 0.

        tokens = doc
        try:
            tokens = [t.text for t in tokens]
        except:
            pass

        output = {
            "tokens": tokens,
            "label": label
        }

        return output

    # @lru_cache(maxsize=200000)
    def _encode_for_batch(self, record):

        # if record.id in self._batch_cache:
            # return self._batch_cache[record.id]

        def encode(seq, encoder, pad, preproc_fn=None):
            if preproc_fn is None:
                def preproc_fn(x):
                    return x
            blank = np.ones((self._max_seq_len,), dtype=np.int32) * pad
            encoded = np.array([encoder[preproc_fn(w)] for w in seq], dtype=np.int32)
            blank[0:min(encoded.size, self._max_seq_len)] = encoded[0:min(encoded.size, self._max_seq_len)]
            return blank

        output = {}

        for mapper in self._mappers:
            output[mapper.target_field] = encode(
                seq=record[mapper.field], encoder=mapper.encoder, pad=mapper.encoder.default,
                preproc_fn=mapper.preproc_fn
            ).astype(mapper.dtype)

        tokens = record.tokens
        num_tokens = len(tokens)

        # assert len(s) == len(t)

        # output["no_loc_mask"] = np.array([tag != self.tagmap.default for tag in output["tags"]]).astype(np.bool)
        output["lens"] = num_tokens if num_tokens < self._max_seq_len else self._max_seq_len
        output["label"] = record["label"]

        # self._batch_cache[record.id] = output

        return output

    def format_batch(self, batch):
        fbatch = defaultdict(list)

        for sent in batch:
            for key, val in sent.items():
                fbatch[key].append(val)

        max_len = max(fbatch["lens"])

        return {
            key: np.stack(val)[:,:max_len] if key != "lens" and key != "replacements" and key != "tokens" and  key != "label"
            else (np.array(val, dtype=np.int32) if key == "lens" or key == "label" else np.array(val)) for key, val in fbatch.items()}

    def generate_batches(self):
        batch = []
        # if self._sort_by_length:
            # records = self._iterate_sorted_by_length(limit_max_length=True)
        # else:
        records = self._iterate_records(limit_max_length=True, shuffle=False)

        batch_count = 0

        for sent in records:
            batch.append(self._encode_for_batch(sent))
            if len(batch) >= self._batch_size:
                yield self.format_batch(batch)
                batch = []
                batch_count += 1
        if len(batch) > 0:
            yield self.format_batch(batch)
            batch_count += 1
        # yield self.format_batch(batch)

        if self._batch_count is None:
            self._batch_count = batch_count

    def __iter__(self):
        return self.generate_batches()

    def __len__(self):
        if self._batch_count is None:
            if self._length is None:
                self._length = 10000
                # for text, annotations in tqdm(self._data, desc="Counting entries"):
                #     entry = self._parse_entry(text, annotations)
                #     if len(entry.tokens) < self._max_seq_len:
                #         self._length += 1

            total_valid = self._length
            # total_valid = sum(1 for id_, length in self._sent_lenghts.items() if length < self._max_seq_len)
            return int(ceil(total_valid / self._batch_size))
        else:
            return self._batch_count

## Reading data

In [5]:
def read_data(dataset_path, partition):
    """
    Read data storead as JSON records.
    """
    assert partition in {"train", "val", "test"}
    data_path = join(dataset_path, f"var_misuse_seq_{partition}.json")
    
    # data = []
    for line in open(data_path, "r"):
        entry = json.loads(line)
        
        text = entry.pop("text")
        yield (text, entry)
        # data.append((text, entry))
    # return data
    
def get_num_lines(dataset_path, partition):
    assert partition in {"train", "val", "test"}
    data_path = join(dataset_path, f"var_misuse_seq_{partition}.json")
    
    return sum(1 for line in open(data_path))

    
class DataIterator:
    def __init__(self, data_path, partition_name):
        assert partition_name in {"train", "val", "test"}
        
        self._data_path = data_path
        self._partition_name = partition_name
        
        self._num_entries = get_num_lines(self._data_path, self._partition_name)
        
    def __iter__(self):
        return read_data(self._data_path, self._partition_name)
    
    def __len__(self):
        return self._num_entries

## Model

In [6]:
import torch.nn as nn

class CodebertHybridModelFnClf(nn.Module):
    def __init__(
            self, codebert_model, graph_emb, padding_idx, num_classes, dense_hidden=100, dropout=0.1, bert_emb_size=768,
            no_graph=False
    ):
        super(CodebertHybridModelFnClf, self).__init__()

        self.codebert_model = codebert_model
        self.use_graph = not no_graph

        num_emb = padding_idx + 1  # padding id is usually not a real embedding

        if self.use_graph:
            graph_emb_dim = graph_emb.shape[1]
            self.graph_emb = nn.Embedding(num_embeddings=num_emb, embedding_dim=graph_emb_dim, padding_idx=padding_idx)

            import numpy as np
            pretrained_embeddings = torch.from_numpy(np.concatenate([graph_emb, np.zeros((1, graph_emb_dim))], axis=0)).float()
            new_param = torch.nn.Parameter(pretrained_embeddings)
            assert self.graph_emb.weight.shape == new_param.shape
            self.graph_emb.weight = new_param
            self.graph_emb.weight.requires_grad = False
        else:
            graph_emb_dim = 0

        self.fc1 = nn.Linear(
            bert_emb_size + (graph_emb_dim if self.use_graph else 0),
            dense_hidden
        )
        self.drop = nn.Dropout(dropout)
        self.fc2 = nn.Linear(dense_hidden, num_classes)

        self.loss_f = nn.CrossEntropyLoss(reduction="mean")

    def forward(self, token_ids, graph_ids, mask, finetune=False):
        if finetune:
            x = self.codebert_model(input_ids=token_ids, attention_mask=mask).pooler_output
        else:
            with torch.no_grad():
                x = self.codebert_model(input_ids=token_ids, attention_mask=mask).pooler_output

#         if self.use_graph:
#             graph_emb = self.graph_emb(graph_ids)
#             x = torch.cat([x, graph_emb], dim=-1)

        x = torch.relu(self.fc1(x))
        x = self.drop(x)
        x = self.fc2(x)

        return x

    def loss(self, logits, labels, mask, class_weights=None, extra_mask=None):
        # if extra_mask is not None:
        #     mask = torch.logical_and(mask, extra_mask)
        # print("logits", logits.shape)
        # print("labels", labels.shape)
        # print("mask", mask.shape)
        # logits = logits[mask, :]
        # labels = labels[mask]
        loss = self.loss_f(logits, labels)
        # if class_weights is None:
        #     loss = tf.reduce_mean(tf.boolean_mask(losses, seq_mask))
        # else:
        #     loss = tf.reduce_mean(tf.boolean_mask(losses * class_weights, seq_mask))

        return loss

    def score(self, logits, labels, mask, scorer=None, extra_mask=None):
        # if extra_mask is not None:
            # mask = torch.logical_and(mask, extra_mask)
        true_labels = labels
        estimated_labels = logits.argmax(-1)

        acc = scorer(to_numpy(estimated_labels), to_numpy(true_labels))

        return {"Accuracy": acc}

## Training procedure

In [7]:
def to_numpy(tensor):
    return tensor.cpu().detach().numpy()


def batch_to_torch(batch, device):
    key_types = {
        'tok_ids': torch.LongTensor,
        'tags': torch.LongTensor,
        'hide_mask': torch.BoolTensor,
        'no_loc_mask': torch.BoolTensor,
        'lens': torch.LongTensor,
        'graph_ids': torch.LongTensor,
        'label': torch.LongTensor
    }
    for key, tf in key_types.items():
        if key in batch:
            batch[key] = tf(batch[key]).to(device)
            

def get_length_mask(target, lens):
    mask = torch.arange(target.size(1)).to(target.device)[None, :] < lens[:, None]
    return mask


def train_step_finetune(model, optimizer, token_ids, prefix, suffix, graph_ids, labels, lengths,
                   extra_mask=None, class_weights=None, scorer=None, finetune=False, vocab_mapping=None):
    token_ids[token_ids == len(vocab_mapping)] = vocab_mapping["<unk>"]
    seq_mask = get_length_mask(token_ids, lengths)
    logits = model(token_ids, graph_ids, mask=seq_mask, finetune=finetune)
    loss = model.loss(logits, labels, mask=seq_mask, class_weights=class_weights, extra_mask=extra_mask)
    scores = model.score(logits, labels, mask=seq_mask, scorer=scorer, extra_mask=extra_mask)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    scores["Loss"] = loss.cpu().item()

    return scores


def test_step(
        model, token_ids, prefix, suffix, graph_ids, labels, lengths, extra_mask=None, class_weights=None, scorer=None,
        vocab_mapping=None
):
    with torch.no_grad():
        token_ids[token_ids == len(vocab_mapping)] = vocab_mapping["<unk>"]
        seq_mask = get_length_mask(token_ids, lengths)
        logits = model(token_ids, graph_ids, mask=seq_mask)
        loss = model.loss(logits, labels, mask=seq_mask, class_weights=class_weights, extra_mask=extra_mask)
        scores = model.score(logits, labels, mask=seq_mask, scorer=scorer, extra_mask=extra_mask)

    scores["Loss"] = loss.cpu().item()

    return scores


class VariableMisuseDetector(CodeBertModelTrainer2):
    def set_batcher_class(self):
        self.batcher = SimplePythonBatcher
    
    def get_trial_dir(self):
        """
        Define folder name format for storing checkpoints.
        """
        return os.path.join(self.output_dir, "codebert_var_mususe_fn_clf" + str(datetime.now())).replace(":", "-").replace(" ", "_")
    
    def train(
            self, model, train_batches, test_batches, epochs, report_every=10, scorer=None, learning_rate=0.01,
            learning_rate_decay=1., finetune=False, summary_writer=None, save_ckpt_fn=None, no_localization=False
    ):
        # all training options are specified [here](https://github.com/VitalyRomanov/method-embedding/blob/e995477db13a13875cca54c37d4d29f63b0c8e93/SourceCodeTools/nlp/entity/type_prediction.py#L256)

        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=learning_rate_decay)  # there is no learning rate decay by default

        # metric history is stored here
        train_losses = []
        test_losses = []
        # train_f1s = []
        # test_f1s = []
        

        num_train_batches = len(train_batches)
        num_test_batches = len(test_batches)

        best_acc = 0.

        for e in range(epochs):
            # losses = []
            # ps = []
            # rs = []
            # f1s = []
            train_scores = []
            test_scores = []
            train_batch_size = []
            test_batch_size = []
            train_scores_for_averaging = defaultdict(list)
            test_scores_for_averaging = defaultdict(list)

            start = time()
            model.train()

            for ind, batch in enumerate(tqdm(train_batches)):
                batch_to_torch(batch, self.device)  # inspect the content of `batch`

                scores = train_step_finetune(
                    model=model, optimizer=optimizer, token_ids=batch['tok_ids'],
                    prefix=None, suffix=None, graph_ids=None,  # keep this None
                    labels=batch['label'], lengths=batch['lens'],
                    extra_mask=None,  # Keep this None
                    scorer=scorer,
                    finetune=finetune and e / epochs > 0.2,  # finetuning starts after 20% of training is complete
                    vocab_mapping=self.vocab_mapping
                )
                # losses.append(loss)
                # ps.append(p)
                # rs.append(r)
                # f1s.append(f1)
                # train_batch_size.append(len(batch['label']))

                scores["batch_size"] = len(batch['label'])
                for score, value in scores.items():
                    self.summary_writer.add_scalar(
                        f"{score}/Train", value, global_step=e * num_train_batches + ind
                    )
                    train_scores_for_averaging[score].append(value)
                train_scores.append(scores)

                # self.summary_writer.add_scalar("Loss/Train", loss, global_step=e * num_train_batches + ind)
                # self.summary_writer.add_scalar("Acc/Train", p, global_step=e * num_train_batches + ind)
                # self.summary_writer.add_scalar("Recall/Train", r, global_step=e * num_train_batches + ind)
                # self.summary_writer.add_scalar("F1/Train", f1, global_step=e * num_train_batches + ind)

            # test_alosses = []
            # test_aps = []
            # test_ars = []
            # test_af1s = []

            model.eval()

            for ind, batch in enumerate(test_batches):
                batch_to_torch(batch, self.device)
                
                scores = test_step(
                    model=model, token_ids=batch['tok_ids'],
                    prefix=None, suffix=None, graph_ids=None,  # keep this None
                    labels=batch['label'], lengths=batch['lens'],
                    extra_mask=None,  # keep this None
                    scorer=scorer, vocab_mapping=self.vocab_mapping
                )

                scores["batch_size"] = len(batch['label'])
                for score, value in scores.items():
                    self.summary_writer.add_scalar(
                        f"{score}/Test", value, global_step=e * num_test_batches + ind
                    )
                    test_scores_for_averaging[score].append(value)
                test_scores.append(scores)
                # self.summary_writer.add_scalar("Loss/Test", test_loss, global_step=e * num_test_batches + ind)
                # self.summary_writer.add_scalar("Acc/Test", test_p, global_step=e * num_test_batches + ind)
                # self.summary_writer.add_scalar("Recall/Test", test_r, global_step=e * num_test_batches + ind)
                # self.summary_writer.add_scalar("F1/Test", test_f1, global_step=e * num_test_batches + ind)
                # test_alosses.append(test_loss)
                # test_aps.append(test_p)
                # test_ars.append(test_r)
                # test_af1s.append(test_f1)
                # test_batch_size.append(len(batch['label']))

            epoch_time = time() - start

            # train_losses.append(float(sum(losses) / len(losses)))
            # train_f1s.append(float(sum(f1s) / len(f1s)))
            # test_losses.append(float(sum(test_alosses) / len(test_alosses)))
            # test_f1s.append(float(sum(test_af1s) / len(test_af1s)))

            def epoch_wide_acc(scores):
                return sum(a * bs for a, bs in zip(scores["Accuracy"], scores["batch_size"])) / sum(scores["batch_size"])

            train_acc = epoch_wide_acc(train_scores_for_averaging)
            test_acc = epoch_wide_acc(test_scores_for_averaging)
            
            # train_acc = sum(a * bs for a, bs in zip(ps, train_batch_size)) / sum(train_batch_size)
            # test_acc = sum(a * bs for a, bs in zip(test_aps, test_batch_size)) / sum(test_batch_size)


            print(f"Epoch: {e}, {epoch_time: .2f} s", end=" ")
            for score, value in train_scores_for_averaging.items():
                if score == "batch_size":
                    continue
                avg_value = sum(value) / len(value)
                print(f"Train {score}: {avg_value: .4f}", end=" ")
                self.summary_writer.add_scalar(
                    f"Average {score}/Train", avg_value, global_step=e
                )
            for score, value in test_scores_for_averaging.items():
                if score == "batch_size":
                    continue
                avg_value = sum(value) / len(value)
                print(f"Test {score}: {avg_value: .4f}", end=" ")
                self.summary_writer.add_scalar(
                    f"Average {score}/Test", avg_value, global_step=e
                )

            print(f"Epoch Train Acc: {train_acc: .4f}, Epoch Test Acc: {test_acc: .4f}")

            # print(
            #     f"Epoch: {e}, {epoch_time: .2f} s, Train Loss: {train_losses[-1]: .4f}, Train P: {sum(ps) / len(ps): .4f}, Train R: {sum(rs) / len(rs): .4f}, Train F1: {sum(f1s) / len(f1s): .4f}, "
            #     f"Test loss: {test_losses[-1]: .4f}, Test P: {sum(test_aps) / len(test_aps): .4f}, Test R: {sum(test_ars) / len(test_ars): .4f}, Test F1: {test_f1s[-1]: .4f}")

            if save_ckpt_fn is not None and test_acc > best_acc:
                save_ckpt_fn()
                best_acc = test_acc

            scheduler.step(epoch=e)

        return train_scores_for_averaging, test_scores_for_averaging
    
    def train_model(self, cache_dir=None):
        
        model_params = copy(self.model_params)

        print(f"\n\n{model_params}")
        lr = model_params.pop("learning_rate")
        lr_decay = model_params.pop("learning_rate_decay")
        suffix_prefix_buckets = model_params.pop("suffix_prefix_buckets")  # used for another model, ignore

        print("Creating dataloaders")
        train_batcher, test_batcher = self.get_dataloaders(word_emb=None, graph_emb=None, suffix_prefix_buckets=suffix_prefix_buckets, cache_dir=cache_dir)

        print("Loading pretrained model")
        codebert_model = RobertaModel.from_pretrained("microsoft/codebert-base")
        
        print("Creating model")
        # definition of CodebertHybridModel is at https://github.com/VitalyRomanov/method-embedding/blob/e995477db13a13875cca54c37d4d29f63b0c8e93/SourceCodeTools/nlp/codebert/codebert_train.py#L21
        model = CodebertHybridModelFnClf(
            codebert_model, graph_emb=None, padding_idx=0, num_classes=train_batcher.num_classes,
            no_graph=self.no_graph
        )
        
        if self.use_cuda:
            model.cuda()

        trial_dir = self.get_trial_dir()  # create directory for saving checkpoints
        os.mkdir(trial_dir)
        self.create_summary_writer(trial_dir)
        
        # train_batcher.tagmap.save(os.path.join(trial_dir, "tagmap.json"))
        # pickle.dump(train_batcher.tagmap, open(os.path.join(trial_dir, "tag_types.pkl"), "wb"))

        def save_ckpt_fn():
            checkpoint_path = os.path.join(trial_dir, "checkpoint")
            torch.save(model, open(checkpoint_path, 'wb'))

        print("Begin training")
        train_scores, test_scores = self.train(
            model=model, train_batches=train_batcher, test_batches=test_batcher,
            epochs=self.epochs, learning_rate=lr,
            scorer=lambda pred, true: (pred == true).sum() / len(pred),  # need to verify scoring function
            learning_rate_decay=lr_decay, finetune=self.finetune, save_ckpt_fn=save_ckpt_fn,
            no_localization=self.no_localization
        )

        metadata = {
            "train_scores": train_scores,
            "test_scores": test_scores,
            "learning_rate": lr,
            "learning_rate_decay": lr_decay,
            "epochs": self.epochs,
            "suffix_prefix_buckets": suffix_prefix_buckets,
            "seq_len": self.seq_len,
            "batch_size": self.batch_size,
            "no_localization": self.no_localization,
        }

        # print("Maximum accuracy:", max(test_scores["Accuracy"]))

        metadata.update(model_params)

        with open(os.path.join(trial_dir, "params.json"), "w") as metadata_sink:
            metadata_sink.write(json.dumps(metadata, indent=4))

# Execution

All training options are specified [here](https://github.com/VitalyRomanov/method-embedding/blob/e995477db13a13875cca54c37d4d29f63b0c8e93/SourceCodeTools/nlp/entity/type_prediction.py#L256)
Option names are added to `args` below.

In [8]:
dataset_path = "variable_misuse_graph_2_percent_balanced/with_ast"

args = Namespace()
args.__dict__.update({
    "learning_rate": 1e-6,           #
    "max_seq_len": 512,              # default for BERT
    "random_seed": 42,               #
    "epochs": 30,                     #
    "gpu": 0,                       # set this to GPU id to use gpu
    "batch_size": 8,                 # higher value increases memory consumption
    "finetune": True,  # set this flag to enable finetuning
    "no_localization": False,        # whether to solve variable misuse with, or without localization
    
    # do not change items below
    "no_graph": True,                # used for another model
    "model_output": dataset_path,    # where to store checkpoints
    "graph_emb_path": None,          # used for another model
    "word_emb_path": None,           # used for another model
    "trials": 1,                     # setting > 1 repeats training, used to accumulate statisitcs
})

In [9]:
train_data = DataIterator(dataset_path, "train")
test_data = DataIterator(dataset_path, "val")

In [10]:
# test_data[0]  # ignore `replacements`

In [11]:
# test_data[100]

In [12]:
trainer = VariableMisuseDetector(
    train_data, test_data, params={"learning_rate": args.learning_rate, "learning_rate_decay": 0.99, "suffix_prefix_buckets": 1},
    graph_emb_path=args.graph_emb_path, word_emb_path=args.word_emb_path,
    output_dir=args.model_output, epochs=args.epochs, batch_size=args.batch_size, gpu_id=args.gpu,
    finetune=args.finetune, trials=args.trials, seq_len=args.max_seq_len, no_localization=args.no_localization,
    no_graph=args.no_graph
)

In [None]:
trainer.train_model(cache_dir=Path(dataset_path).joinpath("__cache__"))

## Models to test

- CodeBert ([Huggingface](https://huggingface.co/microsoft/codebert-base))
- CodeGPT-2 ([Huggingface](https://huggingface.co/microsoft/CodeGPT-small-py))
- GraphCodeBert ([Huggingface](https://huggingface.co/microsoft/graphcodebert-base), [GitHub](https://github.com/microsoft/CodeBERT/tree/master/GraphCodeBERT/refinement))



## Visualizing results

Check example in the script

`SourceCodeTools/nlp/entity/utils/visualize_dataset.py`