In [1]:
import os
import json
import torch
import random
import numpy as np
from tqdm import tqdm
import multiprocessing
from torch.utils.data import DataLoader, Dataset, RandomSampler
from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer
# from transformers import WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup

In [2]:
config = RobertaConfig.from_pretrained('microsoft/graphcodebert-base')
tokenizer = RobertaTokenizer.from_pretrained('microsoft/graphcodebert-base')

In [3]:
def read_data(filename):
    with open(filename) as f:
        text = []
        for line in f:
            text.append(line.strip())
        bar = tqdm(text, total=len(text))
        examples = []
        for x in bar:
            examples.append(eval(x))
    return examples

dataset = read_data('../py150_files/washed_python150k.txt')

100%|█████████████████████████████████| 137711/137711 [02:02<00:00, 1125.21it/s]


In [4]:
total_length = 512
graph_length = 200

epochs = 10
train_batch_size = 16
eval_batch_size = 16

seed = 978
random.seed(seed)
np.random.seed(seed)

gradient_accumulation_steps = 1
max_grad_norm = 1.0

learning_rate = 5e-5
max_steps = -1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

In [6]:
class InputFeatures(object):
    def __init__(self, code_ids, position_idx, edges, cross_edges):
        self.code_ids = code_ids
        self.position_idx = position_idx
        self.edges = edges
        self.cross_edges = cross_edges

def convert_example_to_feature(example):
    tokens = example['tokens']
    nodes = example['nodes']
    edges = example['edges']
    cross_edges = example['cross_edges']

    code_length = total_length - min(graph_length, len(nodes)) - 3
    tokens = tokens[: code_length] 
    tokens = [tokenizer.tokenize(tokens[0])] \
           + [tokenizer.tokenize('@ ' + x)[1 :] for x in tokens[1 :]]
    ori2cur_pos = {-1 : (0, 0)}
    for i in range(len(tokens)):
        ori2cur_pos[i] = (ori2cur_pos[i - 1][1], ori2cur_pos[i - 1][1] + len(tokens[i]))
    tokens=[y for x in tokens for y in x] 

    #truncating
    tokens = tokens[: code_length]
    nodes = nodes[: graph_length]
    edges = [(a, b) for (a, b) in edges if (a < len(nodes)) and (b < len(nodes))]
    cross_edges = [(ori2cur_pos[a], b) for (a, b) in cross_edges\
                   if (a in ori2cur_pos) and (ori2cur_pos[a][1] < len(tokens)) and (b < len(nodes))]

    #adding code tokens
    code_tokens = [tokenizer.cls_token] + tokens + [tokenizer.sep_token]
    code_ids = tokenizer.convert_tokens_to_ids(code_tokens)
    position_idx = [i + tokenizer.pad_token_id + 1 for i in range(len(code_tokens))]

    #adding graph nodes
    code_tokens += [x for x in nodes]
    code_ids += [tokenizer.unk_token_id] * len(nodes)
    position_idx += [0] * len(nodes)
    assert(len(code_ids) == len(position_idx))
    assert(len(code_ids) < total_length)

    #padding
    padding_length = total_length - len(code_ids)
    code_ids += [tokenizer.pad_token_id] * padding_length
    position_idx += [tokenizer.pad_token_id] * padding_length
    return InputFeatures(code_ids, position_idx, edges, cross_edges)

def convert_examples_to_features(examples):
    features = []
    pool = multiprocessing.Pool(processes = 24)
    for example in examples:
        features.append(pool.apply_async(convert_example_to_feature, (example, )))
    pool.close()
    pool.join()
    for i in range(len(features)):
        features[i] = features[i].get()
    return features

dataset2 = dataset
# random.shuffle(dataset2)
train_examples = dataset2[: int(len(dataset2) * 0.67)]
eval_examples = dataset2[int(len(dataset2) * 0.67) :]
train_features = convert_examples_to_features(train_examples)
eval_features = convert_examples_to_features(eval_examples)

In [7]:
# class InputFeatures(object):
#     def __init__(self, code_ids, position_idx, edges, cross_edges):
#         self.code_ids = code_ids
#         self.position_idx = position_idx
#         self.edges = edges
#         self.cross_edges = cross_edges

# def convert_examples_to_features(examples):
#     features = []
#     for example in tqdm(examples, total = len(examples)):
#         tokens = example['tokens']
#         nodes = example['nodes']
#         edges = example['edges']
#         cross_edges = example['cross_edges']
        
#         code_length = total_length - min(graph_length, len(nodes)) - 3
#         tokens = tokens[: code_length] 
#         tokens = [tokenizer.tokenize(tokens[0])] \
#                + [tokenizer.tokenize('@ ' + x)[1 :] for x in tokens[1 :]]
#         ori2cur_pos = {-1 : (0, 0)}
#         for i in range(len(tokens)):
#             ori2cur_pos[i] = (ori2cur_pos[i - 1][1], ori2cur_pos[i - 1][1] + len(tokens[i]))
#         tokens=[y for x in tokens for y in x] 
        
#         #truncating
#         tokens = tokens[: code_length]
#         nodes = nodes[: graph_length]
#         edges = [(a, b) for (a, b) in edges if (a < len(nodes)) and (b < len(nodes))]
#         cross_edges = [(ori2cur_pos[a], b) for (a, b) in cross_edges\
#                        if (a in ori2cur_pos) and (ori2cur_pos[a][1] < len(tokens)) and (b < len(nodes))]
        
#         #adding code tokens
#         code_tokens = [tokenizer.cls_token] + tokens + [tokenizer.sep_token]
#         code_ids = tokenizer.convert_tokens_to_ids(code_tokens)
#         position_idx = [i + tokenizer.pad_token_id + 1 for i in range(len(code_tokens))]
        
#         #adding graph nodes
#         code_tokens += [x for x in nodes]
#         code_ids += [tokenizer.unk_token_id] * len(nodes)
#         position_idx += [0] * len(nodes)
#         assert(len(code_ids) == len(position_idx))
#         assert(len(code_ids) < total_length)
        
#         #padding
#         padding_length = total_length - len(code_ids)
#         code_ids += [tokenizer.pad_token_id] * padding_length
#         position_idx += [tokenizer.pad_token_id] * padding_length
#         features.append(InputFeatures(code_ids, position_idx, edges, cross_edges))
#     return features

# dataset2 = dataset
# # random.shuffle(dataset2)
# train_examples = dataset2[: int(len(dataset) * 0.67)]
# eval_examples = dataset2[int(len(dataset) * 0.67) :]
# train_features = convert_examples_to_features(train_examples)
# eval_features = convert_examples_to_features(eval_examples)

In [8]:
class TextDataset(Dataset):
    def __init__(self, examples):
        self.examples = examples
        
    def __len__(self):
        return len(self.examples)
        
    def __getitem__(self, item):
        attn_mask = np.zeros((total_length, total_length), dtype = np.bool)
        node_index = sum([i > 1 for i in self.examples[item].position_idx])
        max_length = sum([i != 1 for i in self.examples[item].position_idx])
        
        attn_mask[: node_index, : node_index] = True
        for i, x in enumerate(self.examples[item].code_ids):
            if x in [tokenizer.cls_token_id, tokenizer.sep_token_id]:
                attn_mask[i, 0 : max_length] = True # [cls/sep, all]
                attn_mask[0 : max_length, i] = True # test [all, cls/sep]
        attn_mask[1 : node_index - 1, node_index] = True # cross edge (token, graph ROOT)
        attn_mask[node_index, 1 : node_index - 1] = True # cross edge (graph ROOT, token)
        for ((a, b), c) in self.examples[item].cross_edges:
            attn_mask[a + 1 : b + 1, node_index + c] = True # cross edge (token, graph node)
            attn_mask[node_index + c, a + 1 : b + 1] = True # cross edge (token, graph node)
        for (a, b) in self.examples[item].edges:
            attn_mask[node_index + a, node_index + b] = True # edge (source, target)
#             attn_mask[node_index + b, node_index + a] = True # test

        input_ids = []
        labels = []
        for x in self.examples[item].code_ids:
            if (x in [tokenizer.cls_token_id, tokenizer.sep_token_id,
                      tokenizer.unk_token_id, tokenizer.pad_token_id]):
                input_ids.append(x)
                labels.append(-100)
            elif (random.randint(0, 99) < 15):
                input_ids.append(tokenizer.mask_token_id)
                labels.append(x)
            else:
                input_ids.append(x)
                labels.append(-100)

        return (torch.tensor(input_ids),
                torch.tensor(self.examples[item].position_idx),
                torch.tensor(attn_mask),
                torch.tensor(labels))

train_data = TextDataset(train_features)
eval_data = TextDataset(eval_features)

In [9]:
train_sampler = RandomSampler(train_data)
eval_sampler = RandomSampler(eval_data)
train_dataloader = DataLoader(train_data, sampler = train_sampler, drop_last = True,
                              batch_size = train_batch_size, num_workers = 4)
eval_dataloader = DataLoader(eval_data, sampler = eval_sampler, shuffle = False, drop_last = False,
                             batch_size = eval_batch_size, num_workers = 4)

In [10]:
model = RobertaForMaskedLM.from_pretrained('microsoft/graphcodebert-base', config = config)
model.to(device)

if n_gpu > 1:
    model = torch.nn.DataParallel(model)

# no_decay = ['bias', 'LayerNorm.weight']
# optimizer_grouped_parameters = [
#     {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
#      'weight_decay': weight_decay},
#     {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
# ]
# optimizer = AdamW(optimizer_grouped_parameters, lr = learning_rate, eps = adam_epsilon)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warmup_steps,
#                                             num_training_steps = max_steps)

optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

avg_acc = 0
best_acc = 0

In [11]:
for epoch_id in range(epochs): 
    train_num = 0
    train_loss = 0
    avg_loss = 0
    bar = tqdm(train_dataloader, total = len(train_dataloader))
    bar.set_description("{}: loss {} acc {} best {}".\
                        format(epoch_id, round(avg_loss, 2), round(avg_acc * 100, 2), round(best_acc * 100, 2)))

    for step, batch in enumerate(bar):
        (input_ids, position_ids, attention_mask, labels) = [x.to(device) for x in batch]
        output = model(input_ids = input_ids,
                       position_ids = position_ids,
                       attention_mask = attention_mask,
                       labels = labels)
        loss = output.loss

        if n_gpu > 1:
            loss = loss.mean()
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        train_num += 1
        train_loss += loss.item()
        avg_loss = train_loss / train_num
        bar.set_description("{}: loss {} acc {} best {}".\
                            format(epoch_id, round(avg_loss, 2), round(avg_acc * 100, 2), round(best_acc * 100, 2)))

        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
#             scheduler.step()

    if ((epoch_id + 1) % 1 == 0):
        bar = tqdm(eval_dataloader, total = len(eval_dataloader))
        total = 0
        correct = 0
        for batch in bar:
            (input_ids, position_ids, attention_mask, labels) = [x.to(device) for x in batch]
            with torch.no_grad():
                output = model(input_ids = input_ids,
                               position_ids = position_ids,
                               attention_mask = attention_mask)
            _, predicted = torch.max(output.logits, 2)
            predicted = predicted.view(1, -1).squeeze()
            labels = labels.view(1, -1).squeeze()
            total += (labels != -100).sum().item()
            correct += (predicted == labels).sum().item()
        avg_acc = correct / total
        best_acc = max(best_acc, avg_acc)
        print(avg_acc)

0: loss 0.43 acc 0 best 0: 100%|████████████| 5766/5766 [36:19<00:00,  2.65it/s]
100%|███████████████████████████████████████| 2841/2841 [06:05<00:00,  7.78it/s]


0.9135488409191111


1: loss 0.39 acc 91.35 best 91.35: 100%|████| 5766/5766 [36:42<00:00,  2.62it/s]
100%|███████████████████████████████████████| 2841/2841 [06:07<00:00,  7.72it/s]


0.9138852969566722


2: loss 0.36 acc 91.39 best 91.39:  29%|█▏  | 1654/5766 [10:40<26:31,  2.58it/s]


KeyboardInterrupt: 