# GPT-2 Classifier with IMSLP LM pretraining & finetuning

In this notebook we will train a GPT-2 classifier for the proxy task using the IMSLP pretrained language model (trained on IMSLP, finetuned on target data) for initialization.  The language model is trained in 05_gpt2_lm.ipynb.

This notebook is adapted from [this](https://towardsdatascience.com/fastai-with-transformers-bert-roberta-xlnet-xlm-distilbert-4f41ee18ecb2) blog post.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from transformers import GPT2Model, GPT2DoubleHeadsModel, GPT2Config
import eval_models
from train_utils import *

In [None]:
import fastai
import transformers
import tokenizers
print('fastai version :', fastai.__version__)
print('transformers version :', transformers.__version__)
print('tokenizers version :', tokenizers.__version__)

In [None]:
torch.cuda.set_device(1)

### Prep databunch

In [None]:
bs = 64
seed = 42
tok_model_dir = '/home/tjtsai/.fastai/data/bscore_lm/bpe_data/tokenizer_imslp'
max_seq_len = 512

In [None]:
cust_tok = CustomTokenizer(TransformersBaseTokenizer, tok_model_dir, max_seq_len)
transformer_base_tokenizer = TransformersBaseTokenizer(tok_model_dir, max_seq_len)
transformer_vocab =  TransformersVocab(tokenizer = transformer_base_tokenizer._pretrained_tokenizer)

In [None]:
pad_idx = transformer_base_tokenizer._pretrained_tokenizer.token_to_id('<pad>')
cls_idx = transformer_base_tokenizer._pretrained_tokenizer.token_to_id('</s>')

In [None]:
bpe_path = Path('/home/tjtsai/.fastai/data/bscore_lm/bpe_data')
train_df = pd.read_csv(bpe_path/'train64.char.csv')
valid_df = pd.read_csv(bpe_path/'valid64.char.csv')
test_df = pd.read_csv(bpe_path/'test64.char.csv')

In [None]:
data_clas = TextDataBunch.from_df(bpe_path, train_df, valid_df, tokenizer=cust_tok, vocab=transformer_vocab,
                                  include_bos=False, include_eos=False, pad_first=False, pad_idx=pad_idx, 
                                  bs=bs, num_workers=1)

### Train Classifier

In [None]:
model_class, config_class = GPT2Model, GPT2Config

In [None]:
lang_model_path = '/home/tjtsai/.fastai/data/bscore_lm/bpe_data/models/gpt2_train-imslp_finetune-target_lm'
config = config_class.from_pretrained(lang_model_path)
config.num_labels = data_clas.c

In [None]:
transformer_model = model_class.from_pretrained(lang_model_path, config = config)
gpt2_clas = GPT2Classifier(transformer_model, config, pad_idx, cls_idx)

In [None]:
# learner.destroy()
# torch.cuda.empty_cache()

In [None]:
learner = Learner(data_clas, gpt2_clas, metrics=[accuracy, FBeta(average = 'macro', beta=1)])

In [None]:
list_layers = [learner.model.transformer.wte, 
               learner.model.transformer.wpe, 
               learner.model.transformer.h[0],
               learner.model.transformer.h[1],
               learner.model.transformer.h[2],
               learner.model.transformer.h[3],
               learner.model.transformer.h[4],
               learner.model.transformer.h[5],
               learner.model.transformer.ln_f]

In [None]:
learner.split(list_layers)
print(learner.layer_groups)

In [None]:
seed_all(seed)

In [None]:
learner.freeze_to(-1)

In [None]:
learner.summary()

In [None]:
learner.lr_find()

In [None]:
learner.recorder.plot(suggestion=True)

In [None]:
lr = 3e-4

In [None]:
learner.fit_one_cycle(4, lr, moms=(0.8,0.7))

In [None]:
learner.freeze_to(-2)
learner.fit_one_cycle(3, slice(lr/(2.6**4),lr), moms=(0.8, 0.9))

In [None]:
learner.freeze_to(-3)
learner.fit_one_cycle(1, slice(lr/10/(2.6**4),lr/10), moms=(0.8, 0.9))

In [None]:
learner.save('gpt2_train-imslp_finetune-target_clas')
#learner.load('gpt2_train-target_clas')

### Evaluate Classifier

Evaluate on the proxy task -- classifying fixed-length chunks of bootleg score features.

In [None]:
data_clas_test = TextDataBunch.from_df(bpe_path, train_df, test_df, tokenizer=cust_tok, vocab=transformer_vocab,
                                  include_bos=False, include_eos=False, pad_first=False, pad_idx=pad_idx, 
                                  bs=bs, num_workers=1)

In [None]:
learner.validate(data_clas_test.valid_dl)

Evaluate on the original task -- classifying pages of sheet music.  We can evaluate our models in two ways:
- applying the model to a variable length sequence
- applying the model to multiple fixed-length windows and averaging the predictions

First we evaluate the model on variable length inputs.  Report results with and without applying priors.

In [None]:
train_fullpage_df = pd.read_csv(bpe_path/'train.fullpage.char.csv')
valid_fullpage_df = pd.read_csv(bpe_path/'valid.fullpage.char.csv')
test_fullpage_df = pd.read_csv(bpe_path/'test.fullpage.char.csv')

In [None]:
data_clas_test = TextDataBunch.from_df(bpe_path, train_fullpage_df, valid_fullpage_df, test_fullpage_df,
                                       tokenizer=cust_tok, vocab=transformer_vocab, include_bos=False, 
                                       include_eos=False, pad_first=False, pad_idx=pad_idx, bs=bs, num_workers=1)

In [None]:
(acc, acc_with_prior), (f1, f1_with_prior) = eval_models.calcAccuracy_fullpage(learner, bpe_path, train_fullpage_df, valid_fullpage_df, test_fullpage_df, databunch=data_clas_test)
(acc, acc_with_prior), (f1, f1_with_prior)

Now we evaluate the model by considering multiple fixed-length windows and averaging the predictions.

In [None]:
test_ensemble_df = pd.read_csv(bpe_path/'test.ensemble64.char.csv')

In [None]:
data_clas_test = TextDataBunch.from_df(bpe_path, train_fullpage_df, valid_fullpage_df, test_ensemble_df,
                                       text_cols = 'text', label_cols = 'label', tokenizer=cust_tok, 
                                       vocab=transformer_vocab, include_bos=False, include_eos=False, 
                                       pad_first=False, pad_idx=pad_idx, bs=bs, num_workers=1)

In [None]:
(acc, acc_with_prior), (f1, f1_with_prior) = eval_models.calcAccuracy_fullpage(learner, bpe_path, train_fullpage_df, valid_fullpage_df, test_ensemble_df, databunch=data_clas_test, ensembled=True)
(acc, acc_with_prior), (f1, f1_with_prior)

### Error Analysis

In [None]:
interp = ClassificationInterpretation.from_learner(learner)

In [None]:
interp.plot_confusion_matrix(figsize=(12,12))