In [1]:
import torch
from torch.nn.utils.rnn import pack_padded_sequence
import torch.nn as nn
import sentencepiece as spm
from tqdm import tqdm

In [2]:
def get_code():
    code = '''from random import randint
from timeit import repeat
from random import randint

def run_sorting_algorithm(algorithm, array):
    setup_code = f"from __main__ import {algorithm}" \
        if algorithm != "sorted" else ""
    stmt = f"{algorithm}({array})"
    times = repeat(setup=setup_code, stmt=stmt, repeat=3, number=10)
    print(f"Algorithm: {algorithm}. Minimum execution time: {min(times)}")

def bubble_sort(array):
    n = len(array)
    for i in range(n):
        already_sorted = True
        for j in range(n - i - 1):
            if array[j] > array[j + 1]:
                array[j], array[j + 1] = array[j + 1], array[j]
                already_sorted = False
        if already_sorted:
            break
    return array

def insertion_sort(array):
    for i in range(1, len(array)):
        key_item = array[i]
        j = i - 1
        while j >= 0 and array[j] > key_item:
            array[j + 1] = array[j]
            j -= 1
        array[j + 1] = key_item
    return array

def merge(left, right):
    if len(left) == 0:
        return right
    if len(right) == 0:
        return left
    result = []
    index_left = index_right = 0
    while len(result) < len(left) + len(right):
        if left[index_left] <= right[index_right]:
            result.append(left[index_left])
            index_left += 1
        else:
            result.append(right[index_right])
            index_right += 1
        if index_right == len(right):
            result += left[index_left:]
            break
        if index_left == len(left):
            result += right[index_right:]
            break
    return result

def merge_sort(array):
    # If the input array contains fewer than two elements,
    # then return it as the result of the function
    if len(array) < 2:
        return array
    midpoint = len(array) // 2
    return merge(
        left=merge_sort(array[:midpoint]),
        right=merge_sort(array[midpoint:]))

def quicksort(array):
    if len(array) < 2:
        return array
    low, same, high = [], [], []
    pivot = array[randint(0, len(array) - 1)]
    for item in array:
        if item < pivot:
            low.append(item)
        elif item == pivot:
            same.append(item)
        elif item > pivot:
            high.append(item)
    return quicksort(low) + same + quicksort(high)

def insertion_sort(array, left=0, right=None):
    if right is None:
        right = len(array) - 1
    for i in range(left + 1, right + 1):
        key_item = array[i]
        j = i - 1
        while j >= left and array[j] > key_item:
            array[j + 1] = array[j]
            j -= 1
        array[j + 1] = key_item
    return array

def timsort(array):
    min_run = 32
    n = len(array)
    for i in range(0, n, min_run):
        insertion_sort(array, i, min((i + min_run - 1), n - 1))
    size = min_run
    while size < n:
        for start in range(0, n, size * 2):
            midpoint = start + size - 1
            end = min((start + size * 2 - 1), (n-1))
            merged_array = merge(
                left=array[start:midpoint + 1],
                right=array[midpoint + 1:end + 1])
            array[start:start + len(merged_array)] = merged_array
        size *= 2
    return array

ARRAY_LENGTH = 1000
if __name__ == "__main__":
    array = [randint(0, 1000) for i in range(ARRAY_LENGTH)]
    run_sorting_algorithm(algorithm="sorted", array=array)
    run_sorting_algorithm(algorithm="bubble_sort", array=array)
    run_sorting_algorithm(algorithm="insertion_sort", array=array)
    run_sorting_algorithm(algorithm="merge_sort", array=array)
    run_sorting_algorithm(algorithm="quicksort", array=array)
    run_sorting_algorithm(algorithm="insertion_sort", array=array)
    run_sorting_algorithm(algorithm="timsort", array=array)'''
    
    return code

In [3]:
def get_code():
    code = """def quicksort(array):
    if len(array) < 2:
        return array
    low, same, high = [], [], []
    pivot = array[randint(0, len(array) - 1)]
    for item in array:
        if item < pivot:
            low.append(item)
        elif item == pivot:
            same.append(item)
        elif item > pivot:
            high.append(item)
    return quicksort(low) + same + quicksort(high)"""
    return code

In [4]:
#gets all predictions from one window
def get_window_predictions(window, model):
    preds, h = model(torch.tensor([window]))
    preds = torch.flatten(torch.sigmoid(preds))
    preds = preds.detach().numpy()
    return preds

#get_window_predictions(test_x[0],LSTM_LM_net_trained)



#NOTE, THIS ONLY GETS THE PREDICTED BREAK POINTS FROM A PREDICTION
#WITH THE NEWLINE TOKEN AT CENTER OF WINDOW
#if top = 0, return all break points, else return top number of break points
def get_predicted_break_points(code_windows, model, top=0, thresh=0.5):
    start = 0
    code  = []
    break_points = []
    mid_points_preds = []
    print(len(code_windows))
    for window_i in range(len(code_windows)):
        #get window, which has our tokens
        window = code_windows[window_i]
        window_predictions = get_window_predictions(window,model)
        #mid = math.ceil(len(window)/2)
        mid = int(len(window)/2) #actually we need to round down...
        mid_token = tokenizer.decode(int(window[mid]))
        mid_pred = window_predictions[mid]
        mid_points_preds.append(mid_pred)
        
        #only new lines
        if mid_token[-7:]=='NEWLINE' and mid_pred >= thresh:
            if top==0:
                print(mid_token, mid_pred)
                break_points.append(window_i)
            else:
                print('you forgot to fill this in ')
                break_points.append(window_i)
                
                
                
        code.append(mid_token)
        start+=1
    
    #print(max(mid_points_preds))
    return code, break_points

def get_top_n_preds(code_windows, model, top=3):
    start = 0
    code  = []
    break_points = []
    mid_points_preds = []
    print(len(code_windows))
    for window_i in tqdm(range(len(code_windows))):
        #get window, which has our tokens
        window = code_windows[window_i]
        window_predictions = get_window_predictions(window,model)
        #mid = math.ceil(len(window)/2)
        mid = int(len(window)/2) #actually we need to round down...
        mid_token = tokenizer.decode(int(window[mid]))
        mid_pred = window_predictions[mid]
        if mid_token[-7:]=='NEWLINE':
            mid_points_preds.append(mid_pred)
        else:
            mid_points_preds.append(0)
          
        code.append(mid_token)
        start+=1
        
    break_points = sorted(range(len(mid_points_preds)), key=lambda i: mid_points_preds[i])[-top:]
    
    #print(max(mid_points_preds))
    return code, break_points


    
#code_windows = segments['0']['x']
#code, breaks = get_predicted_break_points(code_windows,LSTM_LM_net_trained)
#print(breaks)

#from code segmentation file
def insert_comments(code, break_spots, comment='\n'+'*'*8+'\n',at_begining=True):
    #if there is a a comment at begining of snippet
    if at_begining:
        #adds a notation to add a 0
        #at beigning of break spots too
        break_spots.insert(0,0)
    
    #go through breaks backwards
    #so as not to mess up break 
    #spots as we would if we went forward
    for b in break_spots[::-1]:
        code.insert(b,comment)
    return code

def centered_sliding_window(token_list, window_diamiter,encode=False,PAD='unk'):
    windows = []
    for i in range(len(token_list)):
        
        #print(token_list)
        #input()
        
        window = []
        
        #if we have to pad the begining
        if i < window_diamiter:
            before_len = window_diamiter-i
            before = [PAD]*before_len+token_list[0:i]
        else:
            before = token_list[i-window_diamiter:i]
        
        #if we have to pad the end
        if i+window_diamiter>=len(token_list):
            after_len = (i+1+window_diamiter)-len(token_list)
            after = token_list[i+1:i+1+window_diamiter]+[PAD]*after_len

        else:
            after = token_list[i+1:i+1+window_diamiter]
        
        #put it togeather
        #print('------')
        #print('before:',before)
        #print('center:',token_list[i])
        #print('after:',after)
        window = before + [token_list[i]] + after
        #for encoding code if we want
        if encode:
            new_window = []
            #print(window)
            #input()
            for i in window:
                encoded = tokenizer.encode(i)
                if len(encoded)>1:       
                    x=encoded[1]
                    if type(x)==list:
                        new_window.append(x[0])
                    else:
                        new_window.append(x)
                elif len(encoded)==1:
                    if type(encoded)==list:
                        new_window.append(encoded[0])
                    else:
                        new_window.append(encoded)
                else:
                    #for some reason it finds the unicode stuff __
                    pass
                    #print(window)
                    #print(i)
                    #print(encoded)
                    #input()
            #print(window)
            #print(len(window))
            #print(len(tokenizer.decode(window)))
            #print(tokenizer.decode(window))
            #print(len(tokenizer.encode(window)))
            #window = tokenizer.encode(tokenizer.decode(window))
            window = new_window
        #print(window)
        #print(len(window))
        #input()

        #save windowz
        windows.append(window)
    
    return windows


###
# adapted from the PyTorch examples. for the full PyTorch LM example, see: 
# https://github.com/pytorch/examples/blob/master/word_language_model/model.py
###

class LSTM_LM(nn.Module):
    """Model feeds pre-trained embeddings through a series of biLSTM
       layers, followed by a linear vocabulary decoder."""
    
    def __init__(self, in_dim, hidden_dim, lstm_layers, word_vectors, 
                 dropout=0.05, bidirectional = True):
        super(LSTM_LM, self).__init__()

        self.vocab_size = word_vectors.shape[0]
        self.hidden_dim = hidden_dim
        self.lstm_layers = lstm_layers

        # blank embed layer starting from GloVe pre-trained vectors
        self._embed = nn.Embedding.from_pretrained(word_vectors, freeze=False)        
        self._drop = nn.Dropout(dropout)

        self._lstm = nn.LSTM(in_dim, hidden_dim, num_layers = lstm_layers, dropout = dropout,
                             bidirectional = bidirectional, batch_first=True)
        self._ReLU = nn.ReLU()
        self._pred = nn.Linear((2 if bidirectional else 1)*hidden_dim, 
                               #self.vocab_size)
                               1) #only 1 or zeros here 

    def forward(self, x):
        e = self._drop(self._embed(x))
        z, h = self._lstm(e)
        z_drop = self._drop(z)
        s = self._pred(self._ReLU(z_drop))
        #s = s.view(-1, self.vocab_size)
        s = s.squeeze()
        return s, h


    def init_hidden(self, batch_size):
        weight = next(self.parameters())
        return weight.new_zeros(self.lstm_layers, batch_size, self.hidden_dim)



In [5]:
########
## use saved model
#######

torch.manual_seed(691)

#vocab size from sentence peice
vocab_size = 10000 #same as sentence peice
vocab_dim = 50  # the size of our pre-trained word vectors

# randomly initialize our word vectors!
vocab_dim = 256
word_vectors = torch.randn(vocab_size, vocab_dim)
word_vectors.shape, word_vectors

#set up model
hidden_dim = 200
lstm_layers = 2
LSTM_LM_net_trained = LSTM_LM(word_vectors.shape[1], hidden_dim,lstm_layers, word_vectors)

#[TODO]: fix so it works
#https://pytorch.org/tutorials/beginner/saving_loading_models.html
#https://stackoverflow.com/questions/61242966/pytorch-attributeerror-function-object-has-no-attribute-copy

name = 'medium_py'

class Tokenizer:
    def __init__(self, filepath=name+'_tokenizer.model'):
        self.sp = spm.SentencePieceProcessor(model_file=filepath)

    def encode(self, text, t=int):
        return self.sp.encode(text, out_type=t)

    def decode(self, pieces):
        return self.sp.decode(pieces)

    @staticmethod
    def train(input_file='data/raw_sents.txt', model_prefix='sp_model', vocab_size=30522):
        spm.SentencePieceTrainer.train(input=input_file, model_prefix=model_prefix, vocab_size=vocab_size,
                                        #input_sentence_size=2 ** 16, shuffle_input_sentence=True)
                                        input_sentence_size=number_of_lines, shuffle_input_sentence=True)

#load weights into model
LSTM_LM_net_trained.load_state_dict(torch.load('./data/'+name+'biLSTM_LM.pt'))
LSTM_LM_net_trained.eval()

#instantiate tokenizer model
tokenizer = Tokenizer(name+'_tokenizer.model')

In [7]:
for i in range(1,2):
    print('-----------------------------------------------------')
    code=get_code()
    
    #space
    new_code = code.replace(' ',' SPACE')
    #newline
    new_code = new_code.replace('\n',' NEWLINE')
    #tab
    new_code = new_code.replace('\t',' TAB')

    tokens = tokenizer.encode(new_code,t=str)


    wd=20 #window diameter
    X_windows = centered_sliding_window(tokens,wd,encode=True)
    #code, breaks = get_predicted_break_points(X_windows,LSTM_LM_net_trained,top=0,thresh=.5)
    code, breaks = get_top_n_preds(X_windows,LSTM_LM_net_trained,top=25)
    if len(breaks)>0:
        print(i,len(breaks))
        #'''
        #from code segmentation file
        comments_added = insert_comments(code,sorted(breaks))
        comments_added_decoded = tokenizer.decode(comments_added)
        comments_added_token_string = ''.join(comments_added_decoded)
        comments_added_token_string = comments_added_token_string.replace('SPACE',' ')
        comments_added_token_string = comments_added_token_string.replace('NEWLINE','\n')
        comments_added_token_string = comments_added_token_string.replace('TAB','\t')
        print(comments_added_token_string)
        #'''

  4%|▍         | 9/214 [00:00<00:02, 89.55it/s]

-----------------------------------------------------
214


100%|██████████| 214/214 [00:01<00:00, 115.36it/s]

1 25

********
def quicksort(array):
********

    if len(array) < 2:2:
********

        return array
********

    low, same, high = [], [], []
********

    pivot = array[randint(0, len(array) - 1)]
********

    for item in array:
********

        if item < pivot:
********

            low.append(item)
********

        elif item == pivot:
********

            same.append(item)
********

        elif item > pivot:
********

            high.append(item)
********

    return quicksort(low
********
)
********
 
********
+
********
 
********
same
********
 
********
+
********
 
********
quick
********
sort
********
(
********
high
********
)





In [8]:
print(sorted(breaks))

[7, 22, 33, 53, 75, 86, 102, 121, 137, 156, 172, 191, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213]
