# Decoder Based Transformer (GPT-2) Sampling
> Empirical netbook to sample bart for method2test benchmark.

In [1]:
from pathlib import Path
import csv
import seaborn as sns; sns.set_theme()
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import functools

pd.options.display.float_format = '{:.2f}'.format

In [2]:
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, load_from_disk
import torch



In [3]:
import warnings
from matplotlib import colors
import os

In [4]:
import sys
sys.path.insert(1, '/workspaces/code-rationales/sequential-rationales/huggingface')
from rationalization import rationalize_lm

In [5]:
def param_default():
    return {
        'dataset' : 'codeparrot/codeparrot-clean-valid',
        'dataset_disk_path': '/workspaces/code-rationales/semeru-datasets/codeparrot-clean-valid',
        'model_name': '/workspaces/code-rationales/data/codeparrot-small/checkpoints/checkpoint-29000', 
        'cache_dir': '/workspaces/code-rationales/datax/df_cache_dir', 
        'output_results': '/workspaces/code-rationales/data/sampling'
    }

In [6]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

## Data Loading and Testing

In [7]:
# Save dataset in nfs
#raw_datasets = load_dataset(param_default()['dataset'], cache_dir=param_default()['cache_dir'])
#raw_datasets.save_to_disk(param_default()['dataset_disk_path'])

In [8]:
# Reload with the `json` script
test_dataset = load_from_disk(param_default()['dataset_disk_path'])
test_dataset

DatasetDict({
    train: Dataset({
        features: ['repo_name', 'path', 'copies', 'size', 'content', 'license', 'hash', 'line_mean', 'line_max', 'alpha_frac', 'autogenerated'],
        num_rows: 61373
    })
})

In [9]:
test_dataset = test_dataset['train']
test_dataset[0]

{'repo_name': 'pansapiens/mytardis',
 'path': 'tardis/apps/mx_views/views.py',
 'copies': '3',
 'size': '2892',
 'content': 'from django.conf import settings\nfrom django.core.paginator import Paginator, InvalidPage, EmptyPage\nfrom django.http import HttpResponse\n\nfrom tardis.tardis_portal.auth import decorators as authz\nfrom tardis.tardis_portal.models import Dataset\nfrom tardis.tardis_portal.shortcuts import get_experiment_referer\nfrom tardis.tardis_portal.shortcuts import render_response_index\n\n\n@authz.dataset_access_required\ndef view_full_dataset(request, dataset_id):\n    """Displays a MX Dataset and associated information.\n\n    Shows a full (hundreds of images) dataset its metadata and a list\n    of associated files with the option to show metadata of each file\n    and ways to download those files.  With write permission this page\n    also allows uploading and metadata editing.\n\n    Settings for this view:\n    INSTALLED_APPS += ("tardis.apps.mx_views",)\n    DAT

In [10]:
test_dataset = test_dataset.shuffle(seed=42).select(range(20))

Loading cached shuffled indices for dataset at /workspaces/code-rationales/semeru-datasets/codeparrot-clean-valid/train/cache-4c39d1bfa09ae468.arrow


In [11]:
df_sampled_code = test_dataset.to_pandas()[['size','content']]
df_sampled_code

Unnamed: 0,size,content
0,1337,"""""""\n[2014-11-26] Challenge #190 [Intermediate..."
1,1277,import ConfigParser\nimport os\nimport sys\n# ...
2,2397,#!/usr/bin/env python\n'''\nCreated on Mars 20...
3,9094,from __future__ import division\nfrom itertool...
4,11421,#!/usr/bin/env python2\nimport re\nimport os\n...
5,1580,from mpl_toolkits.mplot3d import axes3d\nimpor...
6,25924,import asyncio\nimport fcntl\nimport logging\n...
7,8110,"""""""\nStochastic Gradient Descent.\n\n\nTODO: w..."
8,2688,#!/usr/bin/python2.4 -tt\n# Copyright 2010 Goo...
9,3922,import logging\r\n\r\nfrom pele.potentials imp...


## Model Loading and Testing

In [12]:
model = AutoModelForCausalLM.from_pretrained(
            param_default()['model_name'],
            cache_dir=param_default()['cache_dir'])

In [13]:
model.to(device)
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(32768, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )


In [14]:
model.device

device(type='cuda', index=0)

## Tokenizer Loading and Testing

In [15]:
df_sampled_code.head(2)

Unnamed: 0,size,content
0,1337,"""""""\n[2014-11-26] Challenge #190 [Intermediate..."
1,1277,import ConfigParser\nimport os\nimport sys\n# ...


In [16]:
tokenizer = AutoTokenizer.from_pretrained(param_default()['model_name'])

## Samples Encoding and Filtering

In [17]:
df_sampled_code['input_ids'] = df_sampled_code['content'].map(lambda code: tokenizer(code)['input_ids'])
df_sampled_code['size'] =  df_sampled_code['input_ids'].map(lambda ids: len(ids))
df_sampled_code['input_tokens'] = df_sampled_code['input_ids'].map(lambda ids: [tokenizer.decode(token_id, skip_special_tokens=False, clean_up_tokenization_spaces=False) for token_id in ids])

Token indices sequence length is longer than the specified maximum sequence length for this model (2776 > 1024). Running this sequence through the model will result in indexing errors


In [18]:
## JUST FOR TESTING REMOVE LATER
df_sampled_code['content'] = df_sampled_code['content'].map(lambda prompt: prompt[:100])

In [19]:
df_sampled_code

Unnamed: 0,size,content,input_ids,input_tokens
0,378,"""""""\n[2014-11-26] Challenge #190 [Intermediate...","[624, 199, 59, 7280, 13, 845, 13, 1479, 61, 27...","["""""", \n, [, 2014, -, 11, -, 26, ], Ch, allen..."
1,376,import ConfigParser\nimport os\nimport sys\n# ...,"[646, 14196, 199, 646, 747, 199, 646, 984, 199...","[import, ConfigParser, \n, import, os, \n, i..."
2,777,#!/usr/bin/env python\n'''\nCreated on Mars 20...,"[3381, 2647, 15, 1393, 15, 1813, 2366, 199, 23...","[#!/, usr, /, bin, /, env, python, \n, ''', \..."
3,2776,from __future__ import division\nfrom itertool...,"[504, 636, 2443, 363, 492, 4629, 199, 504, 797...","[from, __, future, __, import, division, \n..."
4,3019,#!/usr/bin/env python2\nimport re\nimport os\n...,"[3381, 2647, 15, 1393, 15, 1813, 2366, 18, 199...","[#!/, usr, /, bin, /, env, python, 2, \n, imp..."
5,518,from mpl_toolkits.mplot3d import axes3d\nimpor...,"[504, 22513, 63, 3557, 75, 1405, 14, 311, 1653...","[from, mpl, _, tool, k, its, ., mp, lot, 3, d..."
6,6572,import asyncio\nimport fcntl\nimport logging\n...,"[646, 16195, 199, 646, 12871, 199, 646, 2050, ...","[import, asyncio, \n, import, fcntl, \n, imp..."
7,2119,"""""""\nStochastic Gradient Descent.\n\n\nTODO: w...","[624, 199, 4759, 21913, 19603, 8055, 2946, 14,...","["""""", \n, Sto, chastic, Gradient, Des, cent,..."
8,906,#!/usr/bin/python2.4 -tt\n# Copyright 2010 Goo...,"[3381, 2647, 15, 1393, 15, 1548, 18, 14, 20, 4...","[#!/, usr, /, bin, /, python, 2, ., 4, -, tt,..."
9,1117,import logging\r\n\r\nfrom pele.potentials imp...,"[646, 2050, 2999, 199, 504, 4837, 274, 14, 246...","[import, logging, \r\n\r, \n, from, pe, le, ..."


## Model Sampling Generation

In [20]:
SAMPLES = 30 #<---- Hardocoded
MAX_GEN_TOK = 512

In [21]:
inputs = tokenizer(["def hello_world():"], return_tensors="pt")
inputs.to(device)
outputs = model.generate(**inputs, do_sample=True, max_length=MAX_GEN_TOK, top_k=0, num_return_sequences=2, pad_token_id=tokenizer.eos_token_id)
tokenizer.decode(outputs[0][len(inputs['input_ids'][0]):])


'\n    trabrid_client =2297528464b9620ba2150210114af5f13377\n        "live_url":"https://api.live.tv","headers":"application/json"," expires_in_milli=1077202,"states":"laUNCH"}\n       \n    new_episode = trabrid_client.get(visuals["episode.videoid.com"], timeout=3)\n\n    assert \'CREATED\' not in new_episode["content"]\n    assert \'Animation started\' not in new_episode["content"]\n    assert video was created last, existing ones should be deleted\n    assert len(new_episode["content"]["Walkers"]) == 0\n\n    # busy on to its busy mode\n    new_episode_id = trabrid_client.work_episode_id_2()\n    \n    assert \'Movie started\' not in new_episode_id["content"]\n    assert \'Videos found\' not in new_episode_id["content"]\n    assert "Watching" not in new_episode_id["content"]\n    \n    # sleep to ensure success\n    while new_episode_id["finished"] == True:\n        time.sleep(2)\n\n    new_episode_id = new_episode_id.pop()\n    \n\ndef current_false(playing):\n    trabrid_client =2

In [22]:
def df_sampled_generation(
        df_sampled_code, 
        model,
        number_samples = 1,
        max_gen_tok = 100
    ):
    dict_generated_code = {i: [] for i in range(number_samples)}
    for sample in df_sampled_code['content']:
        input = tokenizer([sample], return_tensors="pt")
        input.to(device)
        outputs = model.generate(**input, do_sample=True, max_length=max_gen_tok, top_k=0, num_return_sequences=number_samples, pad_token_id=tokenizer.eos_token_id)
        for index, output in enumerate(outputs):
            dict_generated_code[index].append(output[len(input['input_ids'][0]):].tolist())
    df_temp = pd.DataFrame().from_dict(data=dict_generated_code) # DataFrame from Generation
    df_temp = pd.concat([df_sampled_code.reset_index(), df_temp ], axis=1) #Index before concating
    return df_temp
        
    

In [23]:
#TODO limit the number of tokens generated
#WARNING TIME CONSUMING
df_generated_input = df_sampled_generation(
    df_sampled_code=df_sampled_code, 
    model=model, 
    number_samples=SAMPLES, 
    max_gen_tok=MAX_GEN_TOK)

In [24]:
df_generated_input.head(5)

Unnamed: 0,index,size,content,input_ids,input_tokens,0,1,2,3,4,...,20,21,22,23,24,25,26,27,28,29
0,0,378,"""""""\n[2014-11-26] Challenge #190 [Intermediate...","[624, 199, 59, 7280, 13, 845, 13, 1479, 61, 27...","["""""", \n, [, 2014, -, 11, -, 26, ], Ch, allen...","[63, 16389, 22128, 826, 4343, 647, 199, 199, 8...","[80, 18328, 15, 7270, 15, 11089, 1342, 2528, 1...","[13, 7618, 15, 10913, 5466, 4694, 13, 10233, 1...","[6815, 2626, 15, 5895, 15, 1602, 11135, 397, 1...","[15, 1917, 6223, 13, 460, 13, 4631, 13, 290, 1...",...,"[63, 20820, 63, 3148, 63, 17183, 63, 21713, 14...","[3349, 15, 68, 5257, 452, 15, 614, 3172, 3218,...","[667, 81, 15, 13455, 13, 3148, 8083, 199, 1014...","[10888, 15, 13911, 13, 12233, 3, 772, 13155, 1...","[16207, 63, 3148, 63, 17183, 63, 3148, 63, 171...","[63, 69, 6737, 15, 18084, 15, 2167, 5637, 199,...","[6815, 15, 24, 10197, 4599, 6883, 29, 14176, 1...","[19, 6419, 1085, 15, 13911, 63, 16389, 63, 87,...","[13, 20441, 15, 5095, 63, 1618, 63, 4653, 63, ...","[15335, 15, 1917, 63, 475, 63, 17381, 63, 1397..."
1,1,376,import ConfigParser\nimport os\nimport sys\n# ...,"[646, 14196, 199, 646, 747, 199, 646, 984, 199...","[import, ConfigParser, \n, import, os, \n, i...","[63, 1880, 83, 275, 469, 272, 327, 12939, 3851...","[275, 1052, 4537, 199, 1375, 63, 1130, 275, 10...","[63, 578, 63, 1419, 275, 359, 272, 1689, 1751,...","[275, 295, 14, 2014, 8, 82, 2, 4042, 7115, 338...","[63, 80, 2664, 694, 275, 295, 838, 371, 14, 15...",...,"[7073, 14, 525, 63, 1422, 6739, 515, 401, 2053...","[275, 14196, 14, 14196, 14, 7963, 15218, 342, ...","[275, 295, 14, 2014, 9097, 5169, 6742, 92, 124...","[63, 493, 275, 298, 3647, 83, 3149, 83, 2, 199...","[9653, 3821, 259, 275, 295, 14, 2014, 480, 143...","[63, 694, 275, 3286, 20481, 15, 6251, 15, 1431...","[63, 1135, 275, 295, 14, 2014, 8, 82, 31996, 3...","[63, 632, 275, 488, 258, 199, 29, 4806, 14, 77...","[275, 747, 14, 515, 14, 904, 8, 736, 14, 515, ...","[63, 1725, 275, 788, 888, 4009, 283, 264, 14, ..."
2,2,777,#!/usr/bin/env python\n'''\nCreated on Mars 20...,"[3381, 2647, 15, 1393, 15, 1813, 2366, 199, 23...","[#!/, usr, /, bin, /, env, python, \n, ''', \...","[2137, 63, 614, 63, 17, 275, 283, 3018, 7, 199...","[939, 275, 756, 199, 199, 4296, 63, 9490, 521,...","[1588, 63, 19737, 275, 283, 17, 7, 199, 711, 2...","[939, 275, 1052, 199, 199, 3, 8272, 26688, 307...","[311, 13656, 63, 3723, 275, 788, 564, 418, 199...",...,"[939, 275, 756, 199, 2308, 275, 22925, 63, 90,...","[6692, 534, 14382, 199, 1102, 26, 283, 5967, 2...","[275, 788, 14716, 297, 283, 569, 418, 199, 199...","[939, 275, 756, 199, 199, 318, 1678, 13652, 8,...","[939, 63, 1760, 604, 275, 298, 5159, 15, 3280,...","[939, 63, 278, 275, 283, 90, 6693, 7, 199, 414...","[7330, 63, 637, 275, 378, 14, 1994, 4219, 15, ...","[1164, 275, 298, 58, 4340, 1379, 2, 199, 199, ...","[939, 63, 19214, 63, 3236, 736, 461, 275, 2292...","[1588, 275, 283, 1662, 63, 1258, 63, 3236, 79,..."
3,3,2776,from __future__ import division\nfrom itertool...,"[504, 636, 2443, 363, 492, 4629, 199, 504, 797...","[from, __, future, __, import, division, \n...","[199, 504, 24706, 14, 1150, 83, 492, 377, 2327...","[199, 504, 24706, 14, 5819, 492, 1059, 63, 131...","[12, 12939, 80, 21782, 3007, 1944, 8014, 764, ...","[5932, 63, 15246, 12, 971, 272, 787, 63, 4790,...","[199, 199, 504, 1639, 14, 2190, 492, 2202, 199...",...,"[2821, 11729, 199, 504, 1115, 24629, 63, 3671,...","[12, 26611, 63, 505, 199, 504, 24706, 14, 1628...","[199, 199, 7721, 275, 2400, 308, 26, 7666, 8, ...","[12, 12939, 80, 21782, 199, 504, 24706, 14, 15...","[199, 504, 24706, 14, 5819, 492, 334, 15246, 1...","[2141, 421, 199, 318, 3560, 8, 16780, 63, 354,...","[12, 12939, 80, 21782, 199, 504, 24706, 14, 18...","[505, 199, 504, 24706, 14, 18446, 14, 1208, 49...","[857, 199, 504, 299, 439, 89, 1612, 14, 19762,...","[12, 12939, 80, 6261, 199, 199, 646, 3805, 520..."
4,4,3019,#!/usr/bin/env python2\nimport re\nimport os\n...,"[3381, 2647, 15, 1393, 15, 1813, 2366, 18, 199...","[#!/, usr, /, bin, /, env, python, 2, \n, imp...","[1122, 199, 504, 3031, 63, 2087, 492, 14735, 1...","[2087, 199, 646, 5316, 63, 5933, 199, 646, 588...","[199, 504, 8691, 492, 13491, 199, 199, 893, 26...","[3759, 199, 199, 2087, 275, 7534, 14, 10730, 8...","[199, 646, 14755, 199, 646, 2197, 199, 504, 91...",...,"[2087, 199, 646, 5377, 20, 199, 199, 504, 1680...","[2087, 199, 646, 26162, 14, 25032, 199, 646, 7...","[4883, 199, 504, 8691, 492, 13491, 199, 504, 8...","[4883, 199, 646, 2022, 4354, 773, 199, 646, 21...","[4883, 199, 646, 15416, 396, 199, 504, 15416, ...","[4883, 199, 646, 295, 2227, 697, 199, 646, 243...","[1122, 199, 646, 5145, 199, 199, 10484, 44, 13...","[2087, 199, 646, 12027, 1899, 199, 199, 14366,...","[4133, 199, 646, 12424, 1364, 199, 646, 24295,...","[199, 646, 5436, 63, 12457, 199, 646, 3805, 19..."


### Statistics and Checkpoint

In [25]:
np_len_method = [ (np.array([ len(gen_method) for gen_method in df_generated_input[j] ]).mean(),
                   np.array([ len(gen_method) for gen_method in df_generated_input[j] ]).std()  )
                    for j in range(30) ]

In [26]:
np_len_method

[(482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235),
 (482.25, 5.4394393093406235)]

In [27]:
#Checkpoint of Generation
def checkpoint_generation( df , name = 'output' ):
    df.to_csv(param_default()['output_results'] +  '/' + name + '.csv')
    pass

In [28]:
checkpoint_generation( df = df_generated_input, name='codeparrot-testing')

In [29]:
df_generated_input = pd.read_csv( param_default()['output_results'] + '/' +'codeparrot-testing.csv' , index_col=0)

In [30]:
df_generated_input.head()

Unnamed: 0,index,size,content,input_ids,input_tokens,0,1,2,3,4,...,20,21,22,23,24,25,26,27,28,29
0,0,378,"""""""\n[2014-11-26] Challenge #190 [Intermediate...","[624, 199, 59, 7280, 13, 845, 13, 1479, 61, 27...","['""""""', '\n', '[', '2014', '-', '11', '-', '26...","[63, 16389, 22128, 826, 4343, 647, 199, 199, 8...","[80, 18328, 15, 7270, 15, 11089, 1342, 2528, 1...","[13, 7618, 15, 10913, 5466, 4694, 13, 10233, 1...","[6815, 2626, 15, 5895, 15, 1602, 11135, 397, 1...","[15, 1917, 6223, 13, 460, 13, 4631, 13, 290, 1...",...,"[63, 20820, 63, 3148, 63, 17183, 63, 21713, 14...","[3349, 15, 68, 5257, 452, 15, 614, 3172, 3218,...","[667, 81, 15, 13455, 13, 3148, 8083, 199, 1014...","[10888, 15, 13911, 13, 12233, 3, 772, 13155, 1...","[16207, 63, 3148, 63, 17183, 63, 3148, 63, 171...","[63, 69, 6737, 15, 18084, 15, 2167, 5637, 199,...","[6815, 15, 24, 10197, 4599, 6883, 29, 14176, 1...","[19, 6419, 1085, 15, 13911, 63, 16389, 63, 87,...","[13, 20441, 15, 5095, 63, 1618, 63, 4653, 63, ...","[15335, 15, 1917, 63, 475, 63, 17381, 63, 1397..."
1,1,376,import ConfigParser\nimport os\nimport sys\n# ...,"[646, 14196, 199, 646, 747, 199, 646, 984, 199...","['import', ' ConfigParser', '\n', 'import', ' ...","[63, 1880, 83, 275, 469, 272, 327, 12939, 3851...","[275, 1052, 4537, 199, 1375, 63, 1130, 275, 10...","[63, 578, 63, 1419, 275, 359, 272, 1689, 1751,...","[275, 295, 14, 2014, 8, 82, 2, 4042, 7115, 338...","[63, 80, 2664, 694, 275, 295, 838, 371, 14, 15...",...,"[7073, 14, 525, 63, 1422, 6739, 515, 401, 2053...","[275, 14196, 14, 14196, 14, 7963, 15218, 342, ...","[275, 295, 14, 2014, 9097, 5169, 6742, 92, 124...","[63, 493, 275, 298, 3647, 83, 3149, 83, 2, 199...","[9653, 3821, 259, 275, 295, 14, 2014, 480, 143...","[63, 694, 275, 3286, 20481, 15, 6251, 15, 1431...","[63, 1135, 275, 295, 14, 2014, 8, 82, 31996, 3...","[63, 632, 275, 488, 258, 199, 29, 4806, 14, 77...","[275, 747, 14, 515, 14, 904, 8, 736, 14, 515, ...","[63, 1725, 275, 788, 888, 4009, 283, 264, 14, ..."
2,2,777,#!/usr/bin/env python\n'''\nCreated on Mars 20...,"[3381, 2647, 15, 1393, 15, 1813, 2366, 199, 23...","['#!/', 'usr', '/', 'bin', '/', 'env', ' pytho...","[2137, 63, 614, 63, 17, 275, 283, 3018, 7, 199...","[939, 275, 756, 199, 199, 4296, 63, 9490, 521,...","[1588, 63, 19737, 275, 283, 17, 7, 199, 711, 2...","[939, 275, 1052, 199, 199, 3, 8272, 26688, 307...","[311, 13656, 63, 3723, 275, 788, 564, 418, 199...",...,"[939, 275, 756, 199, 2308, 275, 22925, 63, 90,...","[6692, 534, 14382, 199, 1102, 26, 283, 5967, 2...","[275, 788, 14716, 297, 283, 569, 418, 199, 199...","[939, 275, 756, 199, 199, 318, 1678, 13652, 8,...","[939, 63, 1760, 604, 275, 298, 5159, 15, 3280,...","[939, 63, 278, 275, 283, 90, 6693, 7, 199, 414...","[7330, 63, 637, 275, 378, 14, 1994, 4219, 15, ...","[1164, 275, 298, 58, 4340, 1379, 2, 199, 199, ...","[939, 63, 19214, 63, 3236, 736, 461, 275, 2292...","[1588, 275, 283, 1662, 63, 1258, 63, 3236, 79,..."
3,3,2776,from __future__ import division\nfrom itertool...,"[504, 636, 2443, 363, 492, 4629, 199, 504, 797...","['from', ' __', 'future', '__', ' import', ' d...","[199, 504, 24706, 14, 1150, 83, 492, 377, 2327...","[199, 504, 24706, 14, 5819, 492, 1059, 63, 131...","[12, 12939, 80, 21782, 3007, 1944, 8014, 764, ...","[5932, 63, 15246, 12, 971, 272, 787, 63, 4790,...","[199, 199, 504, 1639, 14, 2190, 492, 2202, 199...",...,"[2821, 11729, 199, 504, 1115, 24629, 63, 3671,...","[12, 26611, 63, 505, 199, 504, 24706, 14, 1628...","[199, 199, 7721, 275, 2400, 308, 26, 7666, 8, ...","[12, 12939, 80, 21782, 199, 504, 24706, 14, 15...","[199, 504, 24706, 14, 5819, 492, 334, 15246, 1...","[2141, 421, 199, 318, 3560, 8, 16780, 63, 354,...","[12, 12939, 80, 21782, 199, 504, 24706, 14, 18...","[505, 199, 504, 24706, 14, 18446, 14, 1208, 49...","[857, 199, 504, 299, 439, 89, 1612, 14, 19762,...","[12, 12939, 80, 6261, 199, 199, 646, 3805, 520..."
4,4,3019,#!/usr/bin/env python2\nimport re\nimport os\n...,"[3381, 2647, 15, 1393, 15, 1813, 2366, 18, 199...","['#!/', 'usr', '/', 'bin', '/', 'env', ' pytho...","[1122, 199, 504, 3031, 63, 2087, 492, 14735, 1...","[2087, 199, 646, 5316, 63, 5933, 199, 646, 588...","[199, 504, 8691, 492, 13491, 199, 199, 893, 26...","[3759, 199, 199, 2087, 275, 7534, 14, 10730, 8...","[199, 646, 14755, 199, 646, 2197, 199, 504, 91...",...,"[2087, 199, 646, 5377, 20, 199, 199, 504, 1680...","[2087, 199, 646, 26162, 14, 25032, 199, 646, 7...","[4883, 199, 504, 8691, 492, 13491, 199, 504, 8...","[4883, 199, 646, 2022, 4354, 773, 199, 646, 21...","[4883, 199, 646, 15416, 396, 199, 504, 15416, ...","[4883, 199, 646, 295, 2227, 697, 199, 646, 243...","[1122, 199, 646, 5145, 199, 199, 10484, 44, 13...","[2087, 199, 646, 12027, 1899, 199, 199, 14366,...","[4133, 199, 646, 12424, 1364, 199, 646, 24295,...","[199, 646, 5436, 63, 12457, 199, 646, 3805, 19..."


In [38]:
#tst decoding
decoded_input = tokenizer.decode(eval(df_generated_input['input_ids'][1]))
decoded_output = tokenizer.decode(eval(df_generated_input['1'][1]))
print(decoded_input)
print('-'*100)
print(decoded_output)

import ConfigParser
import os
import sys
# Dealing with registering match and target modules 

match_modules_registered = {}
target_modules_registered = {}

def register_target_module(name, func):
	target_modules_registered[name.lower()] = func
	
def get_target_module_func(name):
	if name.lower() in target_modules_registered.keys():
		return target_modules_registered[name.lower()]
	else:
		return None # chain target

def register_match_module(name, func):
	match_modules_registered[name.lower()] = func
	
def get_match_module_func(name):
	return match_modules_registered[name.lower()]

def register_all():
	# Open the modules.conf file and register all modules
	configp = ConfigParser.ConfigParser()
	configp.read("modules.conf")
	match_modules = configp.get("Modules", "match").split(',')
	target_modules = configp.get("Modules", "target").split(',')

	for match_module in match_modules:
		sys.path.append('modules/match')
		module = __import__(match_module)
		match_func = getattr(module,'match

In [39]:
## MEMORy DEALLOCATION
torch.cuda.empty_cache()