In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
%cd /content/drive/MyDrive/logicalErrorFix-2/

/content/drive/MyDrive/logicalErrorFix-2


In [3]:
from __future__ import absolute_import
import os
import sys
import bleu
import pickle
import torch
import json
import random
import logging
import argparse
import numpy as np
import copy
from io import open
from itertools import cycle
import torch.nn as nn
from model import Seq2Seq
from tqdm import tqdm, trange
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset
from torch.utils.data.distributed import DistributedSampler
from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
                          RobertaConfig, RobertaModel, RobertaTokenizer)

from transformers import T5ForConditionalGeneration
MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer),
                 'codeT5' : (None, T5ForConditionalGeneration, RobertaTokenizer)}

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

class Example(object):
    """A single training/test example."""
    def __init__(self,
                 idx,
                 source,
                 target,
                 ):
        self.idx = idx
        self.source = source
        self.target = target

import pandas as pd
COLUMNS = ['Correct_code', 'Incorrect_code', 'Statement']

def read_examples(filename):
  """Read examples from filename for DeepFix style training Line stmt Line stmt Line stmt ..."""
  examples = []
  data = pd.read_csv(filename, sep='\t', header=[0]).drop(columns=COLUMNS[0])
  for idx, elem in data.iterrows():
    code = ' '.join(elem[COLUMNS[1]].split('||| '))[:-1].strip()
    stmt = elem[COLUMNS[2]].strip()

    examples.append(
      Example(
              idx = idx,
              source = code,
              target = stmt,
              )
    )
  return examples

class InputFeatures(object):
    """A single training/test features for a example."""
    def __init__(self,
                 example_id,
                 source_ids,
                 target_ids,
                 source_mask,
                 target_mask,

    ):
        self.example_id = example_id
        self.source_ids = source_ids
        self.target_ids = target_ids
        self.source_mask = source_mask
        self.target_mask = target_mask


In [4]:
train_examples = read_examples('./data/edit_distance/pair_code_edit_dist_train.txt')

In [5]:
print(train_examples[2].source)

1 #include <bits/stdc++.h> 2 using namespace std; 3 int dx[] = {0, 1, 0, -1, -1, 1, 1, -1}; 4 int dy[] = {1, 0, -1, 0, -1, -1, 1, 1}; 5 int kx[] = {-2, -2, -1, 1, 2, 2, 1, -1}; 6 int ky[] = {-1, 1, 2, 2, 1, -1, -2, -2}; 7 inline long long gcd(long long a, long long b) { 8 a = fabs(a); 9 b = fabs(b); 10 while (b) { 11 a = a % b; 12 swap(a, b); 13 } 14 return a; 15 } 16 inline long long bigmod(long long a, long long p, long long m) { 17 long long res = 1 % m, x = a % m; 18 while (p) { 19 if (p & 1) res = (res * x) % m; 20 x = (x * x) % m; 21 p >>= 1; 22 } 23 return res; 24 } 25 struct DATA { 26 long long a, b; 27 }; 28 vector<DATA> arr; 29 set<long long> s; 30 int main() { 31 long long n; 32 scanf("%lld", &n); 33 for (int i = (int)(1); i <= (int)(n); i++) { 34 long long x, y; 35 scanf("%lld %lld", &x, &y); 36 arr.push_back({x, y}); 37 } 38 if (n == 1) { 39 cout << arr[0].a << endl; 40 exit(0); 41 } 42 for (int i = 2; i * i <= arr[0].a; i++) { 43 if (arr[0].a % i == 0) { 44 while (arr[0].

In [6]:
def process_code_examples(input_examples):
    def parse_and_continue_on_mismatch(code_str):
        code_dict = {}
        expected_line_number = 1
        parts = code_str.split()
        current_line = None
        for part in parts:
            if part.isdigit():
                line_number = int(part)
                if line_number == expected_line_number:
                    current_line = line_number
                    code_dict[current_line] = ""
                    expected_line_number += 1
                else:
                    code_dict[current_line] += part + " "
            else:
                if current_line is not None:
                    code_dict[current_line] += part + " "
        for line in code_dict:
            code_dict[line] = code_dict[line].strip()
        return code_dict

    def remove_line_numbers(source_dict):
        ret = []
        for _, v in source_dict.items():
            ret.append(v)
        return ' '.join(ret)

    def parse_single_line_to_dict(line_str):
        parts = line_str.split(' ', 1)
        line_number = parts[0]
        code = parts[1] if len(parts) > 1 else ""
        return {int(line_number): code}

    def update_values_from_dict(dict1, dict2):
        for key in dict1:
            if key in dict2:
                dict1[key] = dict2[key]
        return dict1

    source_code_dict = {}
    for i, example in enumerate(input_examples):
        source_code_dict[i] = parse_and_continue_on_mismatch(example.source)

    target_dict = {}
    for i, example in enumerate(input_examples):
        target_dict[i] = parse_single_line_to_dict(example.target)

    target_code_dict = {i: update_values_from_dict(copy.deepcopy(source_code_dict[i]), target_dict[i]) for i in range(len(target_dict))}

    input_code = {}
    for i in range(len(source_code_dict)):
        input_code[i] = {
            "source": remove_line_numbers(source_code_dict[i]),
            "target": remove_line_numbers(target_code_dict[i])
        }

    return input_code


In [8]:
import argparse

# 명령줄 인자를 직접 코드 내에서 설정
args = argparse.Namespace(
    model_type='codeT5',
    model_name_or_path='Salesforce/codet5-base',
    output_dir='model/cpp',
    load_model_path='path/to/trained/model',
    train_filename='./data/edit_distance/pair_code_edit_dist_train.txt',
    dev_filename='./data/edit_distance/pair_code_edit_dist_valid.txt',
    test_filename='./data/edit_distance/pair_code_edit_dist_test.txt',
    config_name='',
    tokenizer_name='',
    max_source_length=64,
    max_target_length=32,
    do_train=True,
    do_eval=True,
    do_test=True,
    do_lower_case=True,
    no_cuda=True,
    train_batch_size=8,
    eval_batch_size=8,
    gradient_accumulation_steps=1,
    learning_rate=5e-5,
    beam_size=10,
    weight_decay=0.0,
    adam_epsilon=1e-8,
    max_grad_norm=1.0,
    num_train_epochs=3.0,
    max_steps=-1,
    eval_steps=-1,
    train_steps=-1,
    warmup_steps=0,
    local_rank=-1,
    seed=42,
)

# 예제 사용
print(args.model_type)  # 출력: roberta
print(args.do_train)    # 출력: True


codeT5
True


In [9]:
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
if config_class is not None:
    config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,do_lower_case=args.do_lower_case)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [10]:
def parse_and_continue_on_mismatch(code_str):
    code_dict = {}
    expected_line_number = 1
    parts = code_str.split()
    current_line = None
    for part in parts:
        if part.isdigit():
            line_number = int(part)
            if line_number == expected_line_number:
                current_line = line_number
                code_dict[current_line] = ""
                expected_line_number += 1
            else:
                code_dict[current_line] += part + " "
        else:
            if current_line is not None:
                code_dict[current_line] += part + " "
    for line in code_dict:
        code_dict[line] = code_dict[line].strip()
    return code_dict

In [37]:
code_dict = parse_and_continue_on_mismatch(train_examples[1].source)

In [38]:
import json
dictionary_print = json.dumps(code_dict, indent=4)
print(dictionary_print)

{
    "1": "#include <bits/stdc++.h>",
    "2": "#pragma comment(linker, \"/stack:200000000\")",
    "3": "#pragma GCC optimize(\"Ofast\")",
    "4": "#pragma GCC target(\"sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native\")",
    "5": "using namespace std;",
    "6": "inline long long rint() {",
    "7": "long long x = 0, f = 1;",
    "8": "char c = getchar();",
    "9": "while (c < '0' || c > '9') {",
    "10": "if (c == '-') f = -1;",
    "11": "c = getchar();",
    "12": "}",
    "13": "while (c >= '0' && c <= '9') {",
    "14": "x = x * 10 + c - '0';",
    "15": "c = getchar();",
    "16": "}",
    "17": "return x * f;",
    "18": "}",
    "19": "long long gcd(long long a, long long b) { return (b) ? gcd(b, a % b) : a; }",
    "20": "long long lcm(long long a, long long b) { return a / gcd(a, b) * b; }",
    "21": "long long pow(long long a, long long b, long long q) {",
    "22": "long long ans = 1;",
    "23": "while (b) {",
    "24": "if (b & 1) ans = ans * a % q;",
    "

In [13]:
code_dict_list = []
for i in range(len(train_examples)):
    code_dict_list.append(parse_and_continue_on_mismatch(train_examples[i].source))

In [14]:
print(code_dict_list[0])

{1: '#include <bits/stdc++.h>', 2: 'using namespace std;', 3: 'int ara[500008];', 4: 'set<int> s;', 5: 'set<int>::iterator it;', 6: 'int n;', 7: 'void check(int num) {', 8: 'for (int i = 1; i < n; i++) {', 9: 'if (!(ara[i] % num == 0 || ara[n + i] % num == 0)) {', 10: 'return;', 11: '}', 12: '}', 13: 'cout << num << endl;', 14: 'exit(0);', 15: '}', 16: 'int primefactor(int num) {', 17: 'int i;', 18: 'for (i = 2; i <= num / i; i++) {', 19: 'bool flag = 0;', 20: 'while (num % i == 0) {', 21: 'num /= i;', 22: 'if (!flag) {', 23: 'check(i);', 24: '}', 25: 'flag = 1;', 26: '}', 27: '}', 28: 'if (num > i) {', 29: 'check(num);', 30: '}', 31: '}', 32: 'int main() {', 33: 'scanf("%d", &n);', 34: 'for (int i = 0; i < n; i++) {', 35: 'scanf("%d%d", &ara[i], &ara[n + i]);', 36: '}', 37: 'primefactor(ara[0]);', 38: 'primefactor(ara[n]);', 39: 'cout << "-1" << endl;', 40: 'return 0;', 41: '}'}


In [15]:
def parse_single_line_to_dict(line_str):
    parts = line_str.split(' ', 1)
    line_number = parts[0]
    code = parts[1] if len(parts) > 1 else ""

    return {int(line_number): code}

target_dict_list = []
for i in range(len(train_examples)):
    target_dict_list.append(parse_single_line_to_dict(train_examples[i].target))

In [16]:
import copy
target_code_dict_list = copy.deepcopy(code_dict_list)

In [17]:
def update_values_from_dict(dict1, dict2):
    for key in dict1:
        if key in dict2:
            dict1[key] = dict2[key]
    return dict1

In [18]:
for i in range(len(target_code_dict_list)):
    target_code_dict_list[i] = update_values_from_dict(target_code_dict_list[i], target_dict_list[i])

In [19]:
print(code_dict_list[58])
print(target_code_dict_list[58])
print(target_dict_list[58])
print(train_examples[58 ].target)

{1: '#include <bits/stdc++.h>', 2: 'using namespace std;', 3: 'int main() {', 4: 'string a, b;', 5: 'getline(cin, a);', 6: 'getline(cin, b);', 7: 'sort(a.begin(), a.end());', 8: 'int z = 0;', 9: 'int idx = 0;', 10: 'for (int i = 0; i < a.size(); i++) {', 11: "if (a[i] == '0')", 12: 'z++;', 13: 'else {', 14: 'idx = i;', 15: 'break;', 16: '}', 17: '}', 18: 'string res;', 19: 'if (idx != 0) {', 20: 'res.push_back(a[idx]);', 21: 'idx++;', 22: '}', 23: 'for (int i = 0; i < z; i++) {', 24: "res.push_back('0');", 25: '}', 26: 'for (int i = idx; i < a.size(); i++) res.push_back(a[i]);', 27: 'if (res == b)', 28: 'cout << "OK" << endl;', 29: 'else', 30: 'cout << "WRONG_ANSWER" << endl;', 31: 'return 0;', 32: '}'}
{1: '#include <bits/stdc++.h>', 2: 'using namespace std;', 3: 'int main() {', 4: 'string a, b;', 5: 'getline(cin, a);', 6: 'getline(cin, b);', 7: 'sort(a.begin(), a.end());', 8: 'int z = 0;', 9: 'int idx = 0;', 10: 'for (int i = 0; i < a.size(); i++) {', 11: "if (a[i] == '0')", 12: 'z++

In [20]:
for i in range(10):
    print(train_examples[i].target)

28 if (num > 1) {
43 if (a[0] != 1) fac.push_back(a[0]);
49 if (arr[0].a > 1) s.insert(arr[0].a);
28 if (num >= i) {
38 printf("%lf\n", ans);
64 dfs(0);
64 dfs(0);
51 if (type[i] == 0)
38 sum[u] += (1 - p[v]);
62 sol += (dp[n - 1][curr][1] + dp[n - 1][curr][2]) % MOD;


In [21]:
target_line_numbers = []
for i in range(len(train_examples)):
    line_number = int(train_examples[i].target.split(' ', 1)[0])
    target_line_numbers.append(line_number)

In [22]:
print(len(target_line_numbers))

56051


In [23]:
print(target_line_numbers[0:10])

[28, 43, 49, 28, 38, 64, 64, 51, 38, 62]


In [24]:
def find_main_function_bounds(index, code_dict):
    start_line = None
    end_line = None
    brace_count = 0

    for line_number, code_line in code_dict.items():
        if 'main' in code_line and start_line is None:
            start_line = int(line_number)

        if start_line is not None:
            brace_count += code_line.count('{') - code_line.count('}')
            if brace_count == 0 and start_line is not None:
                end_line = int(line_number)
                break

    if start_line is None or end_line is None:
        print(f'{index} line error - start line {start_line}, end_line {end_line}, brace_count {brace_count}')
        return [-1, -1]

    return [start_line, end_line]

In [25]:
main_loc_list = []
count = 0
Nan_list = []
for index, code_dict in enumerate(code_dict_list):
    temp = find_main_function_bounds(index, code_dict)
    if temp[0] == -1:
        count += 1
        Nan_list.append(index)
    main_loc_list.append(temp)
print(f'오류 code 개수 : {count}')

4239 line error - start line 3, end_line None, brace_count 1
6332 line error - start line 17, end_line None, brace_count 1
7353 line error - start line 5, end_line None, brace_count 2
9277 line error - start line None, end_line None, brace_count 0
11504 line error - start line 4, end_line None, brace_count 1
12243 line error - start line 3, end_line None, brace_count 2
12645 line error - start line 6, end_line None, brace_count 1
12654 line error - start line 4, end_line None, brace_count 1
13882 line error - start line 4, end_line None, brace_count 1
14309 line error - start line 4, end_line None, brace_count 1
14792 line error - start line 2, end_line None, brace_count -1
14802 line error - start line 3, end_line None, brace_count 1
15250 line error - start line 32, end_line None, brace_count 1
15667 line error - start line 8, end_line None, brace_count 1
16519 line error - start line 3, end_line None, brace_count 1
17514 line error - start line 33, end_line None, brace_count 1
18841

In [26]:
print(Nan_list)

[4239, 6332, 7353, 9277, 11504, 12243, 12645, 12654, 13882, 14309, 14792, 14802, 15250, 15667, 16519, 17514, 18841, 18852, 19781, 20579, 21358, 21832, 21836, 21868, 24223, 25172, 25513, 25514, 25590, 25596, 25695, 25975, 27185, 27499, 27501, 27502, 27586, 27674, 27682, 27684, 27691, 28112, 28118, 28491, 29348, 30276, 30745, 31143, 31146, 31617, 32401, 34706, 35099, 35117, 36380, 38336, 39710, 40603, 40821, 40856, 42571, 43479, 44944, 46842, 47626, 49187, 49624, 50030, 51718, 51746, 52135, 53045, 54367, 54368, 54548, 54710, 55153, 55155]


In [27]:
Nan_list_modified = []
main_loc_list_modified = []
count = 0
Nan_list = []
for index, code_dict in enumerate(target_code_dict_list):
    temp = find_main_function_bounds(index, code_dict)
    if temp[0] == -1:
        count += 1
        Nan_list_modified.append(index)
    main_loc_list_modified.append(temp)
print(f'오류 code 개수 : {count}')

58 line error - start line 3, end_line None, brace_count 2
78 line error - start line 3, end_line None, brace_count 1
155 line error - start line 8, end_line None, brace_count 1
250 line error - start line 3, end_line None, brace_count 1
252 line error - start line 17, end_line None, brace_count 1
314 line error - start line 3, end_line None, brace_count 2
448 line error - start line 6, end_line None, brace_count 1
873 line error - start line 3, end_line None, brace_count 1
2443 line error - start line 3, end_line None, brace_count 2
2465 line error - start line None, end_line None, brace_count 0
2603 line error - start line 4, end_line None, brace_count 1
2609 line error - start line 27, end_line None, brace_count 1
2611 line error - start line 8, end_line None, brace_count 1
2620 line error - start line 3, end_line None, brace_count 1
2746 line error - start line 24, end_line None, brace_count 1
2973 line error - start line 3, end_line None, brace_count 1
3034 line error - start line

In [28]:
function_error_count = 0
main_error_count = 0
for i in range(len(target_line_numbers)):
    if main_loc_list[i][0] == main_loc_list[i][1] == -1:
        continue
    elif (main_loc_list[i][0] <= target_line_numbers[i]) & (target_line_numbers[i] <= main_loc_list[i][1]):
        main_error_count += 1
    else:
        function_error_count += 1
print(f'main error = {main_error_count}\nfunction error = {function_error_count}')
print(f'total sum = {main_error_count + function_error_count}')

main error = 39057
function error = 16916
total sum = 55973


In [29]:
function_error_count_modified = 0
main_error_count_modified = 0
for i in range(len(target_line_numbers)):
    if main_loc_list_modified[i][0] == main_loc_list_modified[i][1] == -1:
        continue
    elif (main_loc_list_modified[i][0] <= target_line_numbers[i]) & (target_line_numbers[i] <= main_loc_list_modified[i][1]):
        main_error_count_modified += 1
    else:
        function_error_count_modified += 1
print(f'main error = {main_error_count_modified}\nfunction error = {function_error_count_modified}')
print(f'total sum = {main_error_count_modified + function_error_count_modified}')

main error = 37813
function error = 16895
total sum = 54708


In [30]:
import os

cpp_file_path = '/content/drive/MyDrive/logicalErrorFix-2/model/cpp'

cpp_gold_path = os.path.join(cpp_file_path, 'cpp_gold')
cpp_output_path = os.path.join(cpp_file_path, 'cpp_output')
cpp_source_path = os.path.join(cpp_file_path, 'cpp_source')

In [31]:
from tqdm import tqdm

def cpp_file_create(file_path, dictionary_list, str):
  for index, dictionary in tqdm(enumerate(dictionary_list)):
    with open(os.path.join(file_path, f'{str}_{index}.cpp'), 'w') as f:
      for key, value in dictionary.items():
        f.write(value + '\n')
    if index == 1000:
      break

In [32]:
print(target_code_dict_list[0])

{1: '#include <bits/stdc++.h>', 2: 'using namespace std;', 3: 'int ara[500008];', 4: 'set<int> s;', 5: 'set<int>::iterator it;', 6: 'int n;', 7: 'void check(int num) {', 8: 'for (int i = 1; i < n; i++) {', 9: 'if (!(ara[i] % num == 0 || ara[n + i] % num == 0)) {', 10: 'return;', 11: '}', 12: '}', 13: 'cout << num << endl;', 14: 'exit(0);', 15: '}', 16: 'int primefactor(int num) {', 17: 'int i;', 18: 'for (i = 2; i <= num / i; i++) {', 19: 'bool flag = 0;', 20: 'while (num % i == 0) {', 21: 'num /= i;', 22: 'if (!flag) {', 23: 'check(i);', 24: '}', 25: 'flag = 1;', 26: '}', 27: '}', 28: 'if (num > 1) {', 29: 'check(num);', 30: '}', 31: '}', 32: 'int main() {', 33: 'scanf("%d", &n);', 34: 'for (int i = 0; i < n; i++) {', 35: 'scanf("%d%d", &ara[i], &ara[n + i]);', 36: '}', 37: 'primefactor(ara[0]);', 38: 'primefactor(ara[n]);', 39: 'cout << "-1" << endl;', 40: 'return 0;', 41: '}'}


In [33]:
print(cpp_source_path)

/content/drive/MyDrive/logicalErrorFix-2/model/cpp/cpp_source


In [34]:
cpp_file_create(cpp_source_path, code_dict_list, 'source')

1000it [00:10, 98.09it/s]


In [35]:
print(cpp_gold_path)

/content/drive/MyDrive/logicalErrorFix-2/model/cpp/cpp_gold


In [36]:
cpp_file_create(cpp_gold_path, target_code_dict_list, 'gold')

1000it [00:10, 99.52it/s]
