In [None]:
# 在colab上取消注释这一cell
# from google.colab import drive
# drive.mount('/content/drive')
# %cd /content/drive/MyDrive/CodeAnnotation
# !pip install transformers

## Import

In [None]:
from __future__ import absolute_import
import os

import torch
import torch.nn as nn
from Trainer import Trainer
from transformers import RobertaModel, T5ForConditionalGeneration
from transformers import (AdamW, get_linear_schedule_with_warmup)

from utils.common import read_data, write_to_file, save_checkpoint
from utils.DataProcess import Processor
from Config import config
from Model import Seq2Seq

## Initialization

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config.device = device

# make dir if output_dir not exist
if not os.path.exists(config.output_dir):
    os.makedirs(config.output_dir)

In [None]:
print(config.model_type, config.model_path)

## Data Read and Process

In [None]:
processor = Processor(config)

train_data = read_data(config.train_data_path)
train_loader = processor(train_data, config.train_params, 'train')

val_data = read_data(config.val_data_path)
val_loader = processor(val_data, config.val_params, 'eval')

## Modeling

In [None]:
tokenizer = processor.tokenizer

if config.model_type.lower() == 'codet5':   
    model = T5ForConditionalGeneration.from_pretrained(config.model_path)
else:
    # Encoder
    encoder = RobertaModel.from_pretrained(config.model_path)
    model_config = encoder.config
    # Decoder
    decoder_layer = nn.TransformerDecoderLayer(d_model=model_config.hidden_size, nhead=model_config.num_attention_heads)
    decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
    # Seq2Seq
    model = Seq2Seq(encoder=encoder, decoder=decoder, config=model_config,
                beam_size=config.beam_size, max_length=config.max_target_length,
                sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)
                
if config.load_model_path is not None:
    print("reload model from {}".format(config.load_model_path))
    model.load_state_dict(torch.load(config.load_model_path))
    
model.to(device)
print('model built')

## Initialize Trainer

In [None]:
epoch = config.epoch

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        'weight_decay': config.weight_decay},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 
        'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate, eps=config.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=config.warmup_steps,
    num_training_steps=(epoch * len(train_loader))
)

trainer = Trainer(config, processor, model, optimizer, scheduler, device)

## Train and Valid

In [None]:
all_train_loss_list, all_train_loss_list_detail = [], []
val_ppl_list, val_bleu_list = [], []
best_ppl = 9999
best_bleu = 0

for e in range(1, epoch + 1):
    print('-' * 20 + ' ' + 'Epoch ' + str(e) + ' ' + '-' * 20)
    train_loss, train_loss_list = trainer.train(train_loader)
    all_train_loss_list.append(train_loss)
    all_train_loss_list_detail.extend(train_loss_list)
    print('train loss: {}'.format(train_loss))

    val_ppl, val_bleu = trainer.valid(val_loader)
    val_ppl_list.append(val_ppl)
    val_bleu_list.append(val_bleu)
    print('val ppl: {}, val bleu: {}'.format(val_ppl, val_bleu))

    # save last checkpoint
    save_checkpoint(model, config.output_dir, 'checkpoint-last')

    # save best ppl checkpoint
    if val_ppl < best_ppl:
        best_ppl = val_ppl
        save_checkpoint(model, config.output_dir, 'checkpoint-best-ppl')
        print('update best ppl: {} and model saved'.format(val_ppl))
    
    # save best bleu checkpoint
    if val_bleu > best_bleu:
        best_bleu = val_bleu
        save_checkpoint(model, config.output_dir, 'checkpoint-best-bleu')
        print('update best bleu: {} and model saved'.format(best_bleu))
        
    print()

## Plot

In [None]:
import matplotlib.pyplot as plt

plt.plot(all_train_loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.xticks(range(epoch), range(1, epoch + 1))
plt.show()

plt.plot(all_train_loss_list_detail)
plt.ylabel('loss detail')
plt.xlabel('epoch')
plt.xticks(range(0, len(all_train_loss_list_detail), len(all_train_loss_list_detail) // epoch), range(epoch))
plt.show()

plt.plot(val_ppl_list)
plt.ylabel('valid ppl')
plt.xlabel('epoch')
plt.xticks(range(epoch), range(1, epoch + 1))
plt.show()

plt.plot(val_bleu_list)
plt.ylabel('valid bleu')
plt.xlabel('epoch')
plt.xticks(range(epoch), range(1, epoch + 1))
plt.show()

## Predict

In [None]:
test_examples = read_data(config.test_data_path)
test_loader = processor(test_examples, config.test_params, 'test')
pred_ids = trainer.predict(test_loader)
pred_texts = processor.decode(pred_ids)
write_to_file(config.output_dir, pred_texts, 'test.output')