In [None]:
import numpy as np
import os
import pandas as pd
import random
import re
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch_xla
import torch_xla.core.xla_model as xm
import warnings
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification

warnings.filterwarnings('ignore')

BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 2e-5
MAX_LEN = 512
SEED = 42

In [None]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore
seed_everything(SEED)

In [None]:
def save_checkpoint(epoch, model, optimizer, filename):
    state = {
        'Epoch': epoch,
        'State_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    xm.save(state, filename)

In [None]:
device = xm.xla_device()
print(device)

In [None]:
train_df = pd.read_csv('./data/train.csv')

In [None]:
def remove_annotation(x):
    x = x+'\n'
    x = re.sub(r'\n.*\\\n','\n',x)
    x = re.sub(r'//.*\n','\n',x)
    x = re.sub(r'/\*.*\*/','',x)
    x = re.sub(r'#if 0.*#endif','',x)
    return x

In [None]:
def remove_links(x):
    x = re.sub(r'https*\S+', ' ', x) # remove links
    x = re.sub(r'http*\S+', ' ', x)
    return x

In [None]:
def standardize_sign(x):
    x = re.sub(r' +', ' ', x)
    x = re.sub(r' \+ | \+|\+ ','+',x)
    x = re.sub(r' - | -|- ','-',x)
    x = re.sub(r' \* | \*|\* ','*',x)
    x = re.sub(r' / | /|/ ','/',x)
    x = re.sub(r' % | %|% ','%',x)
    x = re.sub(r' = | =|= ','=',x)
    x = re.sub(r' > | >|> ','>',x)
    x = re.sub(r' < | <|< ','=',x)
    x = re.sub(r' !','!',x)
    x = re.sub(r' & | &|& ','&',x)
    x = re.sub(r' \| | \||\| ', '|', x)
    x = re.sub(r' : | :|: ',':',x)
    x = re.sub(r' \? | \?|\? ','?',x)
    return x

In [None]:
def remove_std(x):
    x = re.sub(r'std::','',x)
    x = re.sub(r'using namespace std;','',x)
    return x

In [None]:
def remove_include(x):
    x = re.sub(r'#include.*>','',x)
    return x

In [None]:
def text_clean(x):
    x = x.lower() # lowercase everything
    x = x.encode('ascii', 'ignore').decode()  # remove unicode characters
    x = remove_std(x)
    x = remove_links(x)
    x = remove_include(x)
    x = remove_annotation(x)
    x = standardize_sign(x)
    x = re.sub(r'\n', ' ', x)
    x = re.sub(r'\t', ' ', x)
    x = re.sub(r' +', ' ', x)
    return x

In [None]:
tokenizer = AutoTokenizer.from_pretrained("neulab/codebert-cpp")
tokenizer.truncation_side = 'left'
model = AutoModelForSequenceClassification.from_pretrained("neulab/codebert-cpp", num_labels=2)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, eps=1e-5)

In [None]:
class Datasets(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):
        code1, code2, similar = self.df.iloc[idx]
        code1 = text_clean(code1)
        code2 = text_clean(code2)
        return code1, code2, similar

In [None]:
train_dataset = Datasets(train_df)
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=False, num_workers=0)

In [None]:
def train(model, optimizer, num_epochs, criterion=nn.CrossEntropyLoss()):
    train_losses, train_accuracies = [], []
    t = time.strftime('%B_%d_%H_%M_%S')

    model.train()
    for epoch in range(1,num_epochs+1):
        print('======== Epoch {:} / {:} ========'.format(epoch, EPOCHS))
        running_loss = 0
        total_correct = 0
        for code1, code2, similar in tqdm(train_loader):
            optimizer.zero_grad()
            encoded_list = []
            attention_mask_list = []

            for c1, c2 in zip(code1, code2):
                tokenized = tokenizer(c1, c2, max_length=MAX_LEN, padding='max_length', truncation=True)
                encoded_list.append(tokenized['input_ids'])
                attention_mask_list.append(tokenized['attention_mask'])

            input_ids = torch.tensor(encoded_list)
            input_mask = torch.tensor(attention_mask_list)
            input_ids, input_mask, label = input_ids.to(device), input_mask.to(device), similar.to(device)
            labels = torch.tensor(label)
            outputs = model(input_ids, attention_mask=input_mask, labels=labels)
            logits = outputs['logits']
            loss = outputs['loss']
            logits = logits.detach().cpu()
            pred = np.argmax(F.softmax(logits.float()),axis=1)
            correct = pred.eq(labels.detach().cpu())
            running_loss += loss
            total_correct += correct.sum().item()

            loss.backward()
            xm.optimizer_step(optimizer, barrier=True)
        train_losses.append(float(running_loss)/len(train_loader))
        train_accuracies.append(total_correct/len(train_df))
        print('train_loss: ',train_losses[-1], ' train_accuracy: ',train_accuracies[-1])
        save_checkpoint(epoch, model, optimizer, './savepoint/model_{t}_epoch{e}.pt'.format(t=t, e=epoch))

In [None]:
train(model, optimizer, EPOCHS)