# Training Classifier
In this Notebook we will:
- [Load the data for each of the different classes](#data_collection)
- [Load the selected base model and hyperparams](#load_model)
- [Train the models](#train_model)
- [Evaluate the models](#eval)
- [Compare to baselines](#score)

In [38]:
# Imports
import pandas as pd
from setfit import SetFitModel, SetFitTrainer
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitTrainer
from datasets import Dataset
from setfit import sample_dataset
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
import joblib
import pickle
from sklearn.metrics import confusion_matrix
## Workaround for dashes in name
from importlib import import_module
nlbse_statistics = import_module('code-comment-classification.nlbse_statistics') 

tqdm.pandas()

<a id='data_collection'></a>
## Data collection
We first load the data. 
For each language we create a dataset for each of the seperate category.

In [39]:
langs = ['java', 'python', 'pharo']
lan_cats = []
datasets = {}
for lan in langs: # for each language
    df = pd.read_csv(f'./code-comment-classification/{lan}/input/{lan}.csv')
    df['combo'] = df[['comment_sentence', 'class']].agg(' | '.join, axis=1)
    df['label'] = df.instance_type
    cats = list(map(lambda x: lan + '_' + x, list(set(df.category))))
    lan_cats = lan_cats + cats
    for cat in list(set(df.category)): # for each category
        filtered =  df[df.category == cat]
        train_data = Dataset.from_pandas(filtered[filtered.partition == 0])
        test_data = Dataset.from_pandas(filtered[filtered.partition == 1])
        datasets[f'{lan}_{cat}'] = {'train_data': train_data, 'test_data' : test_data}

<a id='load_model'></a>

## Load model

In [3]:
hyperparameters = hyperparameters={'learning_rate': 1.7094555110821448e-05, 'num_epochs': 6, 
                                   'batch_size': 8, 'seed': 11, 'num_iterations': 10, 
                                   'max_iter': 241, 'solver': 'lbfgs'}
    
def model_init(params):
    params = params or {}
    max_iter = params.get("max_iter", 100)
    solver = params.get("solver", "liblinear")
    params = {
        "head_params": {
            "max_iter": max_iter,
            "solver": solver,
        }
    }
    return SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2", **params)

# Create a fresh trainer with hyperparams
def load_trainer(train_data, test_data):
    trainer = SetFitTrainer(
        train_dataset=train_data,
        eval_dataset=test_data,
        loss_class=CosineSimilarityLoss,
        model_init=model_init,
        column_mapping={"combo": "text", "label": "label"},
    )

    trainer.apply_hyperparameters(hyperparameters, final_model=True)
    
    return trainer

<a id='train_model'></a>


## Train Models
Train and save a model for each of the categories

This will take around 3h per category, so around 2 days in total.

In [None]:
# Train model for each cat
for lan_cat in lan_cats:
    print(f'training {lan_cat}')
    train_data = datasets[lan_cat]['train_data']
    test_data = datasets[lan_cat]['test_data']
    trainer = load_trainer(train_data, test_data)
    
    trainer.train()
    
    joblib.dump(trainer, f'./models/{lan_cat}_all-mpnet-base-v2.joblib')

<a id='eval'></a>

## Evaluation
Next we evaluate each of our trained models on the test set.

In [None]:
# Score each classifier and write scores to CSV
scores = []
for lan_cat in lan_cats:
    trainer = joblib.load(f'./models/{lan_cat}_all-mpnet-base-v2.joblib')
    test_data = datasets[lan_cat]['test_data']
    y_hat = trainer.model(test_data['combo'])
    y = test_data['label']
    _, fp, fn, tp = confusion_matrix(y_hat, y).ravel()
    wf1 = f1_score(y, y_hat, average='weighted')
    precision, recall, f1 = nlbse_statistics.get_precision_recall_f1(tp, fp, fn)
    scores.append({'lan_cat': lan_cat.lower(),'precision': precision,'recall': recall,'f1': f1,'wf1': wf1})

df = pd.DataFrame(scores)
df.sort_values('lan_cat').to_excel('scores.xlsx')

<a id='score'></a>

## Comparision with baseline
We compare the weighed f1 scores with the baseline

|                               | Baseline  |        |      |             | Ours      |        |      |      |          |
| ----------------------------- | --------- | ------ | ---- | ----------- | --------- | ------ | ---- | ---- | -------- |
|                               | precision | recall | f1   | weighted_f1 | precision | recall | f1   | wf1  | delta f1 |
| java_deprecation              | 0,00      | 0,00   | 0,00 | 0,92        | 0,78      | 0,95   | 0,86 | 0,98 | 0,86     |
| java_expand                   | 0,35      | 0,27   | 0,30 | 0,66        | 0,71      | 0,80   | 0,75 | 0,88 | 0,45     |
| java_ownership                | 1,00      | 0,68   | 0,81 | 0,98        | 1,00      | 1,00   | 1,00 | 1,00 | 0,19     |
| java_pointer                  | 0,67      | 0,24   | 0,35 | 0,84        | 0,71      | 0,82   | 0,76 | 0,93 | 0,40     |
| java_rational                 | 0,63      | 0,30   | 0,40 | 0,88        | 0,81      | 0,92   | 0,86 | 0,97 | 0,46     |
| java_summary                  | 0,38      | 0,29   | 0,33 | 0,78        | 0,85      | 0,76   | 0,80 | 0,93 | 0,48     |
| java_usage                    | 0,54      | 0,36   | 0,43 | 0,62        | 0,83      | 0,89   | 0,86 | 0,90 | 0,43     |
| pharo_classreferences         | 0,33      | 0,06   | 0,10 | 0,93        | 0,47      | 0,57   | 0,52 | 0,96 | 0,42     |
| pharo_collaborators           | 0,47      | 0,25   | 0,33 | 0,91        | 0,36      | 0,91   | 0,51 | 0,94 | 0,19     |
| pharo_example                 | 0,77      | 0,43   | 0,55 | 0,68        | 0,93      | 0,89   | 0,91 | 0,92 | 0,35     |
| pharo_intent                  | 0,58      | 0,33   | 0,42 | 0,87        | 0,87      | 0,89   | 0,88 | 0,97 | 0,45     |
| pharo_keyimplementationpoints | 0,18      | 0,10   | 0,13 | 0,79        | 0,69      | 0,79   | 0,73 | 0,93 | 0,60     |
| pharo_keymessages             | 0,31      | 0,16   | 0,21 | 0,76        | 0,79      | 0,91   | 0,85 | 0,95 | 0,64     |
| pharo_responsibilities        | 0,59      | 0,33   | 0,43 | 0,81        | 0,67      | 0,63   | 0,65 | 0,86 | 0,22     |
| python_developmentnotes       | 0,17      | 0,17   | 0,17 | 0,79        | 0,43      | 0,54   | 0,48 | 0,88 | 0,31     |
| python_expand                 | 0,26      | 0,20   | 0,22 | 0,72        | 0,52      | 0,56   | 0,54 | 0,82 | 0,31     |
| python_parameters             | 0,51      | 0,22   | 0,31 | 0,65        | 0,78      | 0,86   | 0,81 | 0,89 | 0,50     |
| python_summary                | 0,12      | 0,08   | 0,09 | 0,71        | 0,62      | 0,64   | 0,63 | 0,87 | 0,54     |
| python_usage                  | 0,47      | 0,18   | 0,26 | 0,63        | 0,69      | 0,77   | 0,73 | 0,84 | 0,46     |
| ----------------------------- | --------- | ------ | ---- | ----------- | --------- | ------ | ---- | ---- | -------- |
| average                       | 0,44      | 0,24   | 0,31 | 0,79        | 0,71      | 0,79   | 0,74 | 0,92 | 0,43     |


Finally we calculate our final score: 

\begin{align}
submission\_score(model) &= (avg. \space F_1) \times 0.75 + (\% \space of \space outperformed \space categories) \times 0.25
\end{align}

We take the average unweighed f1 (0.743594163), and the proportion of outperformed categories (all):

\begin{align}
submission\_score = 0.743594163 * 0.75 + 1 * 0.25 = 0,8076956223
\end{align}


<a id='hub'></a>

## Push to hub

Finally we push all of our models to the Hugging Face Hub to make them publically avaliable.

In [None]:
# Push to hub
token = 'hf_XXXXXXXXXXX'
repo = 'XXXXXXXXXXXXX'
for lan_cat in lan_cats:
    trainer = joblib.load(f'./models/{lan_cat}_all-mpnet-base-v2.joblib')
    name = lan_cat.lower().replace('_','-') + '-classifier'
    print(name)
    trainer.push_to_hub(f'{repo}/{name}', use_auth_token=token, private=True)

## Conclusion

In this Notebook we created our own classifiers which beat the baseline in every category. We uploaded the models to the Hugging Face Hub as well.

![display image](https://github.com/snipe/animated-gifs/blob/master/retro-computers/old-school.gif?raw=true)

Please join us in the [next and final Notebook](./3-Inference.ipynb) to see how we can load and use these models.