In [10]:
from sklearn.metrics import precision_score, accuracy_score, recall_score, f1_score
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
import re
import numpy as np
import pandas as pd
import random

seed = 2610
random.seed(seed)
np.random.seed(seed)

def compute_metrics(eval_pred):
    y_pred, y_true = np.argmax(eval_pred.predictions, -1), eval_pred.label_ids
    return {'accuracy': accuracy_score(y_true, y_pred),
            'precision': precision_score(y_true, y_pred),
            'recall': recall_score(y_true, y_pred),
            'f1': f1_score(y_true, y_pred)}

def data_cleaning(data, max_char=15000):
    comment_regex = r'(//[^\n]*|\/\*[\s\S]*?\*\/)'
    newline_regex = '\n{1,}'
    whitespace_regex = '\s{2,}'
    def replace(inp, pat, rep):
        return re.sub(pat, rep, inp)

    data['truncated_code'] = (data['code'].apply(replace, args=(comment_regex, ''))
                                        .apply(replace, args=(newline_regex, ' '))
                                        .apply(replace, args=(whitespace_regex, ' '))
                            )
    data = data.sort_values(by='truncated_code', key=lambda x: x.str.len())
    # remove all data points that have more than 15000 characters
    length_check = np.array([len(x) for x in data['truncated_code']]) > max_char
    data = data[~length_check]
    return data
def to_huggingface_dataset(data_train, data_test, data_valid):
    dts = DatasetDict()
    dts['train'] = Dataset.from_pandas(data_train)
    dts['test'] = Dataset.from_pandas(pd.concat([data_test, data_valid]))
    dts['valid'] = Dataset.from_pandas(pd.concat([data_test, data_valid]))
    dts.set_format('torch')
    dts.rename_column('label', 'labels')
    dts = dts.remove_columns(['code', 'truncated_code', '__index_level_0__'])
    
    return dts

def train_test_valid_split(data, train_size=0.8, test_size=0.1, valid_size=0.1):
    X_train, X_test_valid, y_train, y_test_valid = train_test_split(data.loc[:, data.columns != 'label'],
                                                                data['label'],
                                                                train_size=train_size,
                                                                stratify=data['label']
                                                               )
    test_size /= (test_size+valid_size)
    X_test, X_valid, y_test, y_valid = train_test_split(X_test_valid.loc[:, X_test_valid.columns != 'label'],
                                                        y_test_valid,
                                                        test_size=test_size,
                                                        stratify=y_test_valid)
    data_train = X_train
    data_train['label'] = y_train
    data_test = X_test
    data_test['label'] = y_test
    data_valid = X_valid
    data_valid['label'] = y_valid
    print(data_train)
    dts = to_huggingface_dataset(data_train, data_test, data_valid)
    return dts

def data_preprocessing(file_path="full_data.csv",
                       model_ckpt='neulab/codebert-c'):
    tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
    data = pd.read_csv(file_path)
    data = data_cleaning(data)
    dts = train_test_valid_split(data)
    
    return dts

In [11]:
out = data_preprocessing()

                                                    code  \
9182   static size_t curl_size_cb(void *ptr, size_t s...   
12554  static void *iothread_run(void *opaque)\n\n{\n...   
5908   static void v9fs_readlink(void *opaque)\n\n{\n...   
7554   static int qxl_init_secondary(PCIDevice *dev)\...   
9493   static int nprobe(AVFormatContext *s, uint8_t ...   
...                                                  ...   
3203   static av_cold void alloc_temp(HYuvContext *s)...   
8208   static int commit_direntries(BDRVVVFATState* s...   
21988  void avfilter_free(AVFilterContext *filter)\n\...   
23635  static int old_codec47(SANMVideoContext *ctx, ...   
1961   static void test_i440fx_defaults(gconstpointer...   

                                          truncated_code  label  
9182   static size_t curl_size_cb(void *ptr, size_t s...      1  
12554  static void *iothread_run(void *opaque) { IOTh...      1  
5908   static void v9fs_readlink(void *opaque) { V9fs...      0  
7554   static i

In [3]:
out

DatasetDict({
    train: Dataset({
        features: ['label'],
        num_rows: 21760
    })
    test: Dataset({
        features: ['label'],
        num_rows: 5440
    })
    valid: Dataset({
        features: ['label'],
        num_rows: 5440
    })
})