# RoBERTa Classifier with target LM pretraining

In this notebook we will train a RoBERTa classifier for the proxy task using the pretrained target language model for initialization.  The language model is trained in 04_roberta_lm.ipynb.

This notebook is adapted from [this](https://towardsdatascience.com/fastai-with-transformers-bert-roberta-xlnet-xlm-distilbert-4f41ee18ecb2) blog post.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from fastai import *
from fastai.text import *
from transformers import RobertaForSequenceClassification, RobertaConfig
import eval_models
from train_utils import *

In [None]:
import fastai
import transformers
import tokenizers
print('fastai version :', fastai.__version__)
print('transformers version :', transformers.__version__)
print('tokenizers version :', tokenizers.__version__)

In [None]:
torch.cuda.set_device(1)

### Prep databunch

In [None]:
bs = 64
seed = 42
tok_model_dir = '/home/dyang/InstrumentID/tokenizer/roberta_tok/shift0'
max_seq_len = 256

In [None]:
cust_tok = CustomTokenizer(TransformersBaseTokenizer, tok_model_dir, max_seq_len)
transformer_base_tokenizer = TransformersBaseTokenizer(tok_model_dir, max_seq_len)
transformer_vocab =  TransformersVocab(tokenizer = transformer_base_tokenizer._pretrained_tokenizer)

In [None]:
pad_idx = transformer_base_tokenizer._pretrained_tokenizer.token_to_id('<pad>')

In [None]:
bpe_path = Path('/home/dyang/InstrumentID/train_data')
train_df = pd.read_csv(bpe_path/'train_df-frag64.char.csv')
valid_df = pd.read_csv(bpe_path/'valid_df-frag64.char.csv')
test_df = pd.read_csv(bpe_path/'test_df-frag64.char.csv')

In [None]:
data_clas = TextDataBunch.from_df(bpe_path, train_df, valid_df, tokenizer=cust_tok, vocab=transformer_vocab,
                                  include_bos=False, include_eos=False, pad_first=False, pad_idx=pad_idx, 
                                  bs=bs, num_workers=1)

In [None]:
train_df

### Train Classifier

In [None]:
model_class, config_class = RobertaForSequenceClassification, RobertaConfig
model_path = '/home/dyang/.fastai/data/bscore_lm/bpe_data/models/roberta_train-target_lm'

In [None]:
config = config_class.from_pretrained(model_path)
config.num_labels = data_clas.c

In [None]:
transformer_model = model_class.from_pretrained(model_path, config = config)
custom_transformer_model = RobertaModelWrapper(transformer_model, pad_idx)

In [None]:
# learner.destroy()
# torch.cuda.empty_cache()

In [None]:
learner = Learner(data_clas, custom_transformer_model, metrics=[accuracy, FBeta(average='macro', beta=1)])

In [None]:
list_layers = [learner.model.transformer.roberta.embeddings,
              learner.model.transformer.roberta.encoder.layer[0],
              learner.model.transformer.roberta.encoder.layer[1],
              learner.model.transformer.roberta.encoder.layer[2],
              learner.model.transformer.roberta.encoder.layer[3],
              learner.model.transformer.roberta.encoder.layer[4],
              learner.model.transformer.roberta.encoder.layer[5],
              learner.model.transformer.roberta.pooler]

In [None]:
learner.split(list_layers)
print(learner.layer_groups)

In [None]:
seed_all(seed)

In [None]:
learner.freeze_to(-1)

In [None]:
#learner.summary()

In [None]:
learner.lr_find()

In [None]:
learner.recorder.plot(suggestion=True)

In [None]:
lr = 3e-4

In [None]:
learner.fit_one_cycle(2, lr, moms=(0.8,0.7))

In [None]:
learner.freeze_to(-1)
learner.fit_one_cycle(2, 3e-6, moms=(0.8,0.7))

In [None]:
learner.freeze_to(-2)
learner.fit_one_cycle(2, slice(lr/(2.6**4),lr), moms=(0.8, 0.9))

In [None]:
learner.freeze_to(-3)
learner.fit_one_cycle(2, slice(lr/2/(2.6**4),lr/2), moms=(0.8, 0.9))

In [None]:
learner.save('roberta_unlabeled-frag64-shift3')
#learner.load('final_models/roberta_train-target_clas')

### Evaluate Classifier

Evaluate on the proxy task -- classifying fixed-length chunks of bootleg score features.

In [None]:
data_clas_test = TextDataBunch.from_df(bpe_path, train_df, test_df, tokenizer=cust_tok, vocab=transformer_vocab,
                                  include_bos=False, include_eos=False, pad_first=False, pad_idx=pad_idx, 
                                  bs=bs, num_workers=1)

In [None]:
learner.validate(data_clas_test.valid_dl)

Evaluate on the original task -- classifying pages of sheet music.  We can evaluate our models in two ways:
- applying the model to a variable length sequence
- applying the model to multiple fixed-length windows and averaging the predictions

First we evaluate the model on variable length inputs.  Report results with and without applying priors.

In [None]:
train_fullpage_df = pd.read_csv(bpe_path/'train_df.fullpage.char.csv')
valid_fullpage_df = pd.read_csv(bpe_path/'valid_df.fullpage.char.csv')
test_fullpage_df = pd.read_csv(bpe_path/'test_df.fullpage.char.csv')

In [None]:
data_clas_test = TextDataBunch.from_df(bpe_path, train_fullpage_df, valid_fullpage_df, test_fullpage_df,
                                       tokenizer=cust_tok, vocab=transformer_vocab, include_bos=False, 
                                       include_eos=False, pad_first=False, pad_idx=pad_idx, bs=bs, num_workers=1)

In [None]:
(acc, acc_with_prior), (f1, f1_with_prior), prob, gt = eval_models.calcAccuracy_fullpage1(learner, bpe_path, train_fullpage_df, valid_fullpage_df, test_fullpage_df, databunch=data_clas_test)
(acc, acc_with_prior), (f1, f1_with_prior)

In [None]:
prob.shape, gt.shape

Now we evaluate the model by considering multiple fixed-length windows and averaging the predictions.

In [None]:
test_ensemble_df = pd.read_csv(bpe_path/'test.ensemble64.char.csv')

In [None]:
data_clas_test = TextDataBunch.from_df(bpe_path, train_fullpage_df, valid_fullpage_df, test_ensemble_df,
                                       text_cols = 'text', label_cols = 'label', tokenizer=cust_tok, 
                                       vocab=transformer_vocab, include_bos=False, include_eos=False, 
                                       pad_first=False, pad_idx=pad_idx, bs=bs, num_workers=1)

In [None]:
(acc, acc_with_prior), (f1, f1_with_prior) = eval_models.calcAccuracy_fullpage1(learner, bpe_path, train_fullpage_df, valid_fullpage_df, test_ensemble_df, databunch=data_clas_test, ensembled=True)
(acc, acc_with_prior), (f1, f1_with_prior)

### Error Analysis

In [None]:
interp = ClassificationInterpretation.from_learner(learner)

In [None]:
interp.plot_confusion_matrix(figsize=(12,12))

In [None]:
a = learner.get_preds()

In [None]:
def convertLineToCharSeq(line):
    ints = [int(p) for p in line.split()]
    result = ' '.join([int2charseq(i) for i in ints])
    return result
def int2charseq(int64):
    chars = ''
    for i in range(8):
        numshift = i * 8
        charidx = (int64 >> numshift) & 255
        chars += chr(19968 + charidx) # 19968 ensures that all chars are chinese characters (not newline, space, etc)
    result = ''.join(chars)
    #print(int64,result)
    return result

In [None]:
for fname in os.listdir("/home/dyang/InstrumentID/bootleg_data-v1/labeled/cello/"):
    name = os.path.join("/home/dyang/InstrumentID/bootleg_data-v1/labeled/cello/",fname)    
    with open(name,'rb') as f:
        data = pickle.load(f)
        for idx,page in enumerate(data):
            pred_string = ""
            for count,i in enumerate(page):
                pred_string+=str(i)+ ' '
                if count == 63:
                    break
            i1 = convertLineToCharSeq(pred_string)
            pred_string.strip()
            print(str(learner.predict(i1)[0]),learner.predict(i1))
            if str(learner.predict(i1)[0]) == "cello":
                #print(fname,idx+1)
                #print(learner.predict(i1)[2])
                continue
            else:
                print(fname,idx+1)
                print(learner.predict(i1)[2])
                #print(fname,idx+1)
                continue

In [None]:
config.labels

In [None]:
df = learner.show_results(data_clas_test.valid_dl, 100)

In [None]:
df.head()

In [None]:
data_clas.show_batch()

In [None]:
train_df.iloc[0]['Fragment']

In [None]:
def convertLineToCharSeq(line):
    ints = [int(p) for p in line.split()]
    result = ' '.join([int2charseq(i) for i in ints])
    return result
def int2charseq(int64):
    chars = ''
    for i in range(8):
        numshift = i * 8
        charidx = (int64 >> numshift) & 255
        chars += chr(19968 + charidx) # 19968 ensures that all chars are chinese characters (not newline, space, etc)
    result = ''.join(chars)
    #print(int64,result)
    return result

In [None]:
def char_to_int(s):
    s= s.replace('<s>','')
    s = s.replace('<\s>','')
    s = s.replace(" ",'')
    data = s.split('</w>')
    bscore = []
    for i in data:
        num = 0
        mult = 1
        for c in i:
            tmp = ord(c)-19968
            tmp*=mult
            num+=tmp
            mult*=256
        bscore.append(num)
    return bscore

In [None]:
char_to_int()

In [None]:
char_to_int("一一亀一一一一一")

In [None]:
int2charseq(8388608)

In [None]:
import seaborn as sns
interp = ClassificationInterpretation.from_learner(learner)
mat = interp.confusion_matrix()
print(mat.shape)
midpoint = (np.amax(mat) - np.amin(mat)) // 2
sns.set(font_scale=1.6)
plt.figure(figsize=(8, 6))
# also run the code below for cmap = 'BuPu', 'OrRd', 'YlGnBu_r'
h = sns.heatmap(mat, 
            annot=True, 
            fmt="d", 
            cmap='YlGnBu', 
            center=midpoint, 
            vmin=100, 
            robust=True,
# uncomment to show cell borders - will have to run this code with borders/no borders because he wants to see both
#             linewidths=1, linecolor='black',
            square=False)
plt.yticks(rotation=0) 
h.set_xticklabels(['Cello','Clarinet','Flute','Guitar','Oboe','Trumpet','Viola','Violin'])
h.set_yticklabels(['Cello','Clarinet','Flute','Guitar','Oboe','Trumpet','Viola','Violin'])
h.set(xlabel="Predicted Class", ylabel="Actual Class")
plt.savefig('confusion_matrix.png', dpi=300)

In [None]:
import seaborn as sns
midpoint = (np.amax(mat) - np.amin(mat)) // 2
interp = ClassificationInterpretation.from_learner(learner)
mat = interp.confusion_matrix()
print(mat)
mat[0,:] = [271,132,183,25,158,87,210,134]
mat[1,:] =  [80,393,198,90,117,119,78,125]
mat[2,:] = [46,176,563,26,205,17,119,48]
mat[3,:] = [19,65,59,854,27,33,33,110]
mat[4,:] = [57,230,137,13,398,174,136,55]
mat[5,:] = [100,163,52,42,259,440,71,73]
mat[6,:] = [153,66,260,42,50,17,512,100]
mat[7,:] = [88,203,111,55,146,70,82,445]
print(mat.shape)
sns.set(font_scale=1)
plt.figure(figsize=(8, 6))
h = sns.heatmap(mat, 
            annot=True, 
            fmt="d", 
            cmap='YlGnBu', 
            center=midpoint, 
            vmin=100, 
            robust=True)
plt.yticks(rotation=0)
h.set_xticklabels(['Cello','Clarinet','Flute','Guitar','Oboe','Trumpet','Viola','Violin'])
h.set_yticklabels(['Cello','Clarinet','Flute','Guitar','Oboe','Trumpet','Viola','Violin'])
h.set(xlabel="Predicted Class", ylabel="Actual Class")
plt.savefig('confusion_matrix.png', dpi=300)

In [None]:
import seaborn as sns
midpoint = (np.amax(mat) - np.amin(mat)) // 2
interp = ClassificationInterpretation.from_learner(learner)
mat = interp.confusion_matrix()
sns.set(font_scale=1)
plt.figure(figsize=(8, 6))
h = sns.heatmap(mat, 
            annot=True, 
            fmt="d", 
            cmap='YlGnBu', 
            center=midpoint, 
            vmin=100, 
            robust=True)
plt.yticks(rotation=0)
h.set_xticklabels(['Cello','Clarinet','Flute','Guitar','Oboe','Trumpet','Viola','Violin'])
h.set_yticklabels(['Cello','Clarinet','Flute','Guitar','Oboe','Trumpet','Viola','Violin'])
h.set(xlabel="Predicted Class", ylabel="Actual Class")
plt.savefig('confusion_matrix.png', dpi=300)

In [None]:
import seaborn as sns
# interp = ClassificationInterpretation.from_learner(learner)
# mat = interp.confusion_matrix()
midpoint = (np.amax(mat) - np.amin(mat)) // 2
sns.set(font_scale=1.6)
plt.figure(figsize=(12, 10))
# also run the code below for cmap = 'BuPu', 'OrRd', 'YlGnBu_r'
h = sns.heatmap(mat, 
            annot=True, 
            fmt="d", 
            cmap='OrRd', 
            center=midpoint, 
            vmin=100, 
            robust=True,
# uncomment to show cell borders - will have to run this code with borders/no borders because he wants to see both
#             linewidths=1, linecolor='black',
            square=True)
plt.yticks(rotation=0) 
h.set_xticklabels(['Cello','Clarinet','Flute','Guitar','Oboe','Trumpet','Viola','Violin'])
h.set_yticklabels(['Cello','Clarinet','Flute','Guitar','Oboe','Trumpet','Viola','Violin'])
h.set(xlabel="Predicted Class", ylabel="Actual Class")

In [None]:
import seaborn as sns
interp = ClassificationInterpretation.from_learner(learner)
mat = interp.confusion_matrix()
midpoint = (np.amax(mat) - np.amin(mat)) // 2
sns.set(font_scale=1.6)
plt.figure(figsize=(12, 10))
# also run the code below for cmap = 'BuPu', 'OrRd', 'YlGnBu_r'
h = sns.heatmap(mat, 
            annot=True, 
            fmt="d", 
            cmap='YlGnBu_r', 
            center=midpoint, 
            vmin=100, 
            robust=True,
# uncomment to show cell borders - will have to run this code with borders/no borders because he wants to see both
#             linewidths=1, linecolor='black',
            square=True)
plt.yticks(rotation=0) 
h.set_xticklabels(['Cello','Clarinet','Flute','Guitar','Oboe','Trumpet','Viola','Violin'])
h.set_yticklabels(['Cello','Clarinet','Flute','Guitar','Oboe','Trumpet','Viola','Violin'])
h.set(xlabel="Predicted Class", ylabel="Actual Class")

In [None]:
plt.bar([0,1,2,3],[43.1,44.2,45.8,48.7])

In [None]:
plt.bar([0,1,2,3],[.471,.477,.486,.498])