In [44]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [224]:
import os
from pathlib import Path
import torch
from torchtext import data
from neural_editor.seq2seq.train import load_data
from neural_editor.seq2seq.config import load_config
from neural_editor.seq2seq.train_utils import greedy_decode, remove_eos, lookup_words, calculate_accuracy
from neural_editor.seq2seq.datasets.dataset_utils import take_part_from_dataset
from neural_editor.seq2seq.train_utils import rebatch
import numpy as np

In [46]:
RESULTS_ROOT = '/home/mikhail/Documents/Development/embeddings-for-code-diffs-data/experiment_20'

In [47]:
CONFIG = load_config(False, Path(os.path.join(RESULTS_ROOT, 'config.pkl')))

In [48]:
train_dataset, val_dataset, test_dataset, diffs_field = load_data(verbose=True, config=CONFIG)

Data set sizes (number of sentence pairs):
train 8793
valid 1100
test 1098 

First training example:
src: public void METHOD_1 ( ) { TYPE_1 VAR_1 = new TYPE_1 ( ) ; byte [ ] VAR_2 = TYPE_2 . METHOD_2 ( ) ; VAR_1 . METHOD_3 ( VAR_2 , 0 , VAR_2 . length ) ; org.junit.Assert.assertEquals ( STRING_1 , VAR_1 . METHOD_4 ( ) . get ( STRING_2 ) ) ; }
trg: public void METHOD_1 ( ) { TYPE_1 VAR_1 = new TYPE_1 ( ) ; byte [ ] VAR_2 = TYPE_2 . METHOD_2 ( ) ; VAR_1 . METHOD_3 ( VAR_2 , 0 , VAR_2 . length ) ; assertEquals ( STRING_1 , VAR_1 . METHOD_4 ( ) . get ( STRING_2 ) ) ; }
diff_alignment: замена
diff_prev: org.junit.Assert.assertEquals
diff_updated: assertEquals 

Most common words:
         )     194026
         (     193976
         .      92332
         ;      82530
   паддинг      54790
         ,      54180
         {      42036
         }      41868
     VAR_1      41750
  удаление      41184 

First 10 words:
00 <unk>
01 <pad>
02 <s>
03 </s>
04 )
05 (
06 .
07 ;
08 паддинг
09 , 

Special

In [49]:
PAD_INDEX = diffs_field.vocab.stoi[CONFIG['PAD_TOKEN']]
SOS_INDEX = diffs_field.vocab.stoi[CONFIG['SOS_TOKEN']]
EOS_INDEX = diffs_field.vocab.stoi[CONFIG['EOS_TOKEN']]

In [50]:
CONFIG._CONFIG['DEVICE'] = torch.device('cpu')

In [51]:
MODEL = torch.load(os.path.join(RESULTS_ROOT, 'model_best_on_validation.pt'), map_location=DEVICE)

In [52]:
MODEL

EncoderDecoder(
  (encoder): Encoder(
    (rnn): LSTM(128, 128, num_layers=2, batch_first=True, dropout=0.2, bidirectional=True)
  )
  (decoder): Decoder(
    (generator): Generator(
      (projection): Linear(in_features=256, out_features=758, bias=False)
    )
    (embedding): Embedding(758, 128)
    (attention): BahdanauAttention(
      (key_layer): Linear(in_features=256, out_features=256, bias=False)
      (query_layer): Linear(in_features=256, out_features=256, bias=False)
      (energy_layer): Linear(in_features=256, out_features=1, bias=False)
    )
    (rnn): LSTM(416, 256, num_layers=2, batch_first=True, dropout=0.2)
    (bridge): Linear(in_features=288, out_features=256, bias=True)
    (dropout_layer): Dropout(p=0.2, inplace=False)
    (pre_output_layer): Linear(in_features=640, out_features=256, bias=False)
  )
  (edit_encoder): EditEncoder(
    (rnn): LSTM(384, 16, num_layers=2, batch_first=True, dropout=0.2, bidirectional=True)
  )
  (embed): Embedding(758, 128)
  (genera

In [188]:
TEST_DATASET_300 = take_part_from_dataset(test_dataset, n=300)
TEST_DATASET_10 = take_part_from_dataset(test_dataset, n=10)
TEST_DATASET_3 = take_part_from_dataset(test_dataset, n=3)
TEST_DATASET_2 = take_part_from_dataset(test_dataset, n=2)

In [208]:
def get_accuracy_with_batch(dataset, batch_size):
    iterator = data.Iterator(dataset, batch_size=batch_size, train=False,
                                      sort_within_batch=True,
                                      sort=False,
                                      sort_key=lambda x: (len(x.src), len(x.trg)),
                                      repeat=False,
                                      device=CONFIG['DEVICE'])
    return calculate_accuracy((rebatch(PAD_INDEX, t, CONFIG) for t in iterator),
                                  MODEL,
                                  CONFIG['TOKENS_CODE_CHUNK_MAX_LEN'],
                                  diffs_field.vocab, CONFIG)

def show_calculate_accruacy_diff(dataset):
    batch_accuracy = get_accuracy_with_batch(dataset, batch_size=len(dataset))
    single_accuracy = get_accuracy_with_batch(dataset, batch_size=1)
    print(f'Batch accuracy: {batch_accuracy}')
    print(f'Single accuracy: {single_accuracy}')
    return

In [210]:
show_calculate_accruacy_diff(TEST_DATASET_300)

Batch accuracy: 0.07666666666666666
Single accuracy: 0.08333333333333333


In [179]:
def get_batch_decoded(dataset, batch_size):
    iterator = data.Iterator(dataset, batch_size=batch_size, train=False,
                              sort_within_batch=True,
                              sort=False,
                              sort_key=lambda x: (len(x.src), len(x.trg)), repeat=False,
                              device=DEVICE)
    decoded = []
    for batch in iterator:
        batch = rebatch(PAD_INDEX, batch, CONFIG)
        decoded += greedy_decode(MODEL, batch, CONFIG['TOKENS_CODE_CHUNK_MAX_LEN'], SOS_INDEX, EOS_INDEX)
    return decoded

In [182]:
def get_diff_from_decoded(dataset):
    batch_decoded = np.array(get_batch_decoded(dataset, len(dataset)))
    single_decoded = get_batch_decoded(dataset, 1)
    sort_keys = [(len(x.src), len(x.trg)) for x in dataset]
    sorted_indices = list(reversed([i[0] for i in sorted(enumerate(sort_keys), key=lambda x:x[1])]))
    correct_order = [0 for _ in range(len(sorted_indices))]
    for i, value in enumerate(sorted_indices):
        correct_order[value] = i
    batch_decoded_correct_order = list(batch_decoded[correct_order])
    
    incorrect = []
    for i, values in enumerate(zip(batch_decoded_correct_order, single_decoded)):
        #print(values)
        batch_decoded_value, single_decoded_value = values
        if len(batch_decoded_value) != len(single_decoded_value) or \
           (batch_decoded_value != single_decoded_value).any():
            incorrect.append((i, batch_decoded_value, single_decoded_value))
    return incorrect

In [197]:
def print_diff_from_decoded(dataset, until=None):
    diff = get_diff_from_decoded(dataset)
    if len(diff) == 0:
        print('NO DIFF')
        return
    for take_id in range(len(diff[:until])):
        print(f'First i = {diff[take_id][0]}')
        batch_version = " ".join(lookup_words(diff[take_id][1], diffs_field.vocab))
        single_version = " ".join(lookup_words(diff[take_id][2], diffs_field.vocab))
        target_version = " ".join(dataset[diff[take_id][0]].trg)
        print(f'Batch  : {batch_version}')
        print(f'Single : {single_version}')
        print(f'Target : {target_version}')
        if batch_version == target_version:
            print(f'Batch correct')
        elif single_version == target_version:
            print(f'Single correct')
        else:
            print(f'None is correct')

In [198]:
print_diff_from_decoded(TEST_DATASET_2)

NO DIFF


In [204]:
print_diff_from_decoded(TEST_DATASET_3)

First i = 2
Batch  : public java.lang.String METHOD_1 ( TYPE_1 locale ) { if ( VAR_1 . METHOD_2 ( ) ) { return STRING_1 ; } java.lang.String result = TYPE_2 . METHOD_3 ( METHOD_4 ( ) , locale . METHOD_4 ( ) ) ; if ( result == null ) { result = TYPE_2 . METHOD_3 ( METHOD_4 ( ) , TYPE_1 . METHOD_5 ( ) . METHOD_4 ( ) ) ; } return result ; }
Single : public java.lang.String METHOD_1 ( TYPE_1 locale ) { if ( VAR_1 . METHOD_2 ( ) ) { return STRING_1 ; } java.lang.String result = TYPE_2 . METHOD_3 ( METHOD_4 ( ) , locale . METHOD_4 ( ) ) ; if ( result == null ) { result = TYPE_2 . METHOD_3 ( METHOD_4 ( ) . METHOD_4 ( ) ) ; } return result ; }
Target : public java.lang.String METHOD_1 ( TYPE_1 locale ) { if ( VAR_1 . METHOD_2 ( ) ) { return STRING_1 ; } java.lang.String result = TYPE_2 . METHOD_1 ( this , locale ) ; if ( result == null ) { result = TYPE_2 . METHOD_1 ( this , TYPE_1 . METHOD_5 ( ) ) ; } return result ; }
None is correct


In [216]:
DATASET_TO_REPRODUCE_BUG = data.Dataset(test_dataset[1:3], test_dataset.fields)

In [217]:
print_diff_from_decoded(DATASET_TO_REPRODUCE_BUG)

First i = 1
Batch  : public java.lang.String METHOD_1 ( TYPE_1 locale ) { if ( VAR_1 . METHOD_2 ( ) ) { return STRING_1 ; } java.lang.String result = TYPE_2 . METHOD_3 ( METHOD_4 ( ) , locale . METHOD_4 ( ) ) ; if ( result == null ) { result = TYPE_2 . METHOD_3 ( METHOD_4 ( ) , TYPE_1 . METHOD_5 ( ) . METHOD_4 ( ) ) ; } return result ; }
Single : public java.lang.String METHOD_1 ( TYPE_1 locale ) { if ( VAR_1 . METHOD_2 ( ) ) { return STRING_1 ; } java.lang.String result = TYPE_2 . METHOD_3 ( METHOD_4 ( ) , locale . METHOD_4 ( ) ) ; if ( result == null ) { result = TYPE_2 . METHOD_3 ( METHOD_4 ( ) . METHOD_4 ( ) ) ; } return result ; }
Target : public java.lang.String METHOD_1 ( TYPE_1 locale ) { if ( VAR_1 . METHOD_2 ( ) ) { return STRING_1 ; } java.lang.String result = TYPE_2 . METHOD_1 ( this , locale ) ; if ( result == null ) { result = TYPE_2 . METHOD_1 ( this , TYPE_1 . METHOD_5 ( ) ) ; } return result ; }
None is correct


In [303]:
"BUGGY VERSION"
def greedy_decode(model, batch,
                  max_len: int,
                  sos_index: int, eos_index: int):
    """
    Greedily decode a sentence.
    :return: [DecodedSeqLenCutWithEos]
    """
    # TODO: create beam search
    # [B, SrcSeqLen], [B, 1, SrcSeqLen], [B]
    src, src_mask, src_lengths = batch.src, batch.src_mask, batch.src_lengths
    with torch.no_grad():
        edit_final, encoder_output, encoder_final = model.encode(batch)
        prev_y = torch.ones(batch.nseqs, 1).fill_(sos_index).type_as(src)  # [B, 1]
        trg_mask = torch.ones_like(prev_y)  # [B, 1]

    output = torch.zeros((batch.nseqs, max_len))
    states = None

    for i in range(max_len):
        with torch.no_grad():
            # pre_output: [B, TrgSeqLen, DecoderH]
            out, states, pre_output = model.decode(edit_final, encoder_output, encoder_final,
                                                   src_mask, prev_y, trg_mask, states)

            # we predict from the pre-output layer, which is
            # a combination of Decoder state, prev emb, and context
            prob = model.generator(pre_output[:, -1])  # [B, V]

        _, next_words = torch.max(prob, dim=1)
        output[:, i] = next_words
        prev_y[:, 0] = next_words

    output = output.cpu().long().numpy()
    return remove_eos(output, eos_index)

In [304]:
print_diff_from_decoded(DATASET_TO_REPRODUCE_BUG)

NO DIFF


In [287]:
# Looks like we fixed a bug. I forgot to permute results back after sorting for pack padded sequence in EditEncoder.

In [288]:
show_calculate_accruacy_diff(TEST_DATASET_300)

Batch accuracy: 0.08333333333333333
Single accuracy: 0.08333333333333333


In [67]:
BATCH_DECODED_2 = get_batch_decoded(TEST_DATASET_2, batch_size=2)

In [68]:
SINGLE_DECODED_2 = get_batch_decoded(TEST_DATASET_2, batch_size=1)

  """Entry point for launching an IPython kernel.


False

TypeError: unhashable type: 'numpy.ndarray'