In [None]:
"""
Directory Structure and Contents
tools/ Contains the core implementation of the seq2edit GEC (Grammar Error Correction) model.
data/ This folder stores the data and models required for training and prediction.
exp/: Saves the weights and log files generated during model training.
modelinput/: Stores the training set, test set, and prediction results.
codebert/: Contains the pre-trained CodeBERT model.
"""

In [None]:
#Requirements
pip install allennlp==1.3.0
pip install ltp==4.1.3.post1
pip install OpenCC==1.1.2
pip install pyarrow
pip install python-Levenshtein==0.12.1
pip install numpy==1.21.1
pip install transformers

In [None]:
# Define the model type
modeltype = 'seq'  # Options: 'seq', 'span', 'word'

# Specify the model file path
MODEL_FILE = './model/mode_' + modeltype + '.th'

# Specify the label file path
LABEL_FILE = './data/labels/labels_' + modeltype + '.txt'

# Move the model file to the specified directory
!cp $MODEL_FILE ./data/exp/Best_code_gec.th

# Move the label file to the specified directory
!cp $LABEL_FILE ./data/labels/labels.txt

In [None]:
#predict model
!python predict.py --model_path ./data/exp/Best_code_gec.th \
                  --weights_name ./data/codebert \
                  --vocab_path ./data/labels \
                  --input_file ./data/modelinput/test.src.char \
                  --output_file ./data/modelinput/test.pre.char 

In [None]:
# Calculate the accuracy of the prediction results
with open("./data/modelinput/test.pre.char", 'r', encoding='utf-8') as f:
    predictionlines = f.readlines()  # Read the prediction results

with open("./data/modelinput/test.tgt.char", 'r', encoding='utf-8') as f:
    targetlines = f.readlines()  # Read the target (ground truth) results

# Print the number of prediction lines and target lines
print(len(predictionlines))
print(len(targetlines))

# Initialize the correct count
correct = 0

# Compare each prediction with the corresponding target
for index in range(len(targetlines)):
    if targetlines[index].strip() == predictionlines[index].strip():
        correct += 1  # Increment the correct count if they match

# Calculate and print the accuracy
accuracy = correct / len(targetlines)
print(f'Accuracy: {accuracy:.8f}')


In [None]:
# Train the model
# --train_set Training dataset
# --dev_set Validation dataset
# --model_dir Model output directory
# --model_name Model filename
# --vocab_path Vocabulary file path
# --batch_size Batch size for training
# --n_epoch Number of training epochs
# --lr Learning rate
# --accumulation_size Number of operations before updating gradients
# --weights_name CodeBERT encoder weights location
# --pretrain_folder Pre-trained model location
# --pretrain Pre-trained model name
# --seed Random seed
!python train.py --train_set ./data/modelinput/train.label \
                --dev_set ./data/modelinput/valid.label \
                --model_dir ./data/exp \
                --model_name Gec_Model.th \
                --vocab_path ./data/labels \
                --batch_size 16\
                --n_epoch 150\
                --lr 1e-5\
                --accumulation_size 4\
                --weights_name ./data/codebert\
                --pretrain_folder ./data/exp \
                --pretrain "Best_code_gec"\
                --seed 1