In [1]:
import os
import pickle
import itertools

from tqdm import tqdm
import numpy as np
import torch 
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split

In [4]:
def permute_lines(input, line_permutation_order, tokenized_line_inds):
    input_len = len(input)
    permute_indices = torch.zeros(input_len).to(input.device).detach()
    
    curr_ind = 0
    for new_line_num in line_permutation_order: 
        line_beg, line_end = tokenized_line_inds[new_line_num], tokenized_line_inds[new_line_num+1]
        line_len = line_end - line_beg
        permute_indices[curr_ind:curr_ind+line_len] = torch.arange(line_beg, line_end)
        curr_ind += line_len
        
    permuted_input = torch.index_select(input, 0, permute_indices.to(torch.long))
    return permuted_input

In [5]:
def filter_tokenize(data_path='/home/albertjan/equitune/data/data.pkl', num_permutations=4, max_length=1024, tokenizer="deepseek-ai/deepseek-coder-1.3b-base"):
    '''
    data_path: Path to a .pkl of a list of (code, permutation_orders, label) where
    code is a list of strings for each line of code,
    permutations is a list of tuples of invariant permutations, and
    label is a string

    num_permutations: Filter dataset for samples with at least this many permutations

    max_length: Filter dataset for samples whose tokenized length is at most this 

    tokenizer: Tokenizer to use 
    '''
    with open(data_path, 'rb')  as f:
        dataset = pickle.load(f)

    dataset = [sample for sample in dataset if len(sample['permutations']) >= num_permutations]

    tokenizer = AutoTokenizer.from_pretrained(tokenizer)
    res = []
    for sample in tqdm(dataset):
        code, line_permutation_orders, label = sample['code'], sample['permutations'], sample['label']
        tokenized_loc = [ 
            tokenizer(line_text, return_tensors="pt", add_special_tokens=(line_num==0))['input_ids'][0]
            for line_num, line_text in enumerate(code)
        ]
        tokenized_loc_len = [len(loc) for loc in tokenized_loc]
        tokenized_line_inds = np.array([0] + list(itertools.accumulate(tokenized_loc_len))) # ind of beginning of every line, post-tokenization 

        input_orig = torch.cat(tokenized_loc) # unpermuted code input
        input_len = len(input_orig)
        if len(input_orig) > max_length:
            continue 

        num_permutations = len(line_permutation_orders)
        # each row contains a permutation of the original code input, we feed this data tensor directly into the model
        data = torch.zeros((num_permutations, input_len))
        data[0, :] = input_orig
        for i in range(1, num_permutations):
            # permute the input, fill in next row of data
            data[i, :] = permute_lines(
                input=input_orig,
                line_permutation_order=line_permutation_orders[i],
                tokenized_line_inds=tokenized_line_inds
            )
        label = tokenizer(label, return_tensors="pt", add_special_tokens=False)['input_ids'][0][0].unsqueeze(dim=0)
        metadata = {
            'line_permutation_orders': line_permutation_orders,
            'tokenized_line_inds': tokenized_line_inds,
        }
        res.append({
            'input_ids': data.to(dtype=torch.long),
            'label': label, 
            'metadata': metadata
        })

    return res
    

In [6]:
filtered = filter_tokenize(max_length=900)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
100%|██████████| 26556/26556 [01:54<00:00, 232.43it/s]


In [18]:
STORAGE_PATH = '/home/albertjan/equitune/data/'

In [15]:
dataset_pkl = pickle.dumps(filtered)
with open(os.path.join(STORAGE_PATH, 'dataset_4perms.pkl'), 'wb') as f:
    f.write(dataset_pkl)