# Training Classifier
In this Notebook we will:
- [Load the data for each of the different classes](#data_collection)
- [Load the selected base model and hyperparams](#load_model)
- [Train the models](#train_model)
- [Evaluate the models](#eval)
- [Compare to baselines](#score)

In [38]:
# Imports
import pandas as pd
from setfit import SetFitModel, SetFitTrainer
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitTrainer
from datasets import Dataset
from setfit import sample_dataset
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
import joblib
import pickle
from sklearn.metrics import confusion_matrix
## Workaround for dashes in name
from importlib import import_module
nlbse_statistics = import_module('code-comment-classification.nlbse_statistics') 

tqdm.pandas()

<a id='data_collection'></a>
## Data collection
We first load the data. 
For each language we create a dataset for each of the seperate category.

In [39]:
langs = ['java', 'python', 'pharo']
lan_cats = []
datasets = {}
for lan in langs: # for each language
    df = pd.read_csv(f'./code-comment-classification/{lan}/input/{lan}.csv')
    df['label'] = df.instance_type
    cats = list(map(lambda x: lan + '_' + x, list(set(df.category))))
    lan_cats = lan_cats + cats
    for cat in list(set(df.category)): # for each category
        filtered =  df[df.category == cat]
        train_data = Dataset.from_pandas(filtered[filtered.partition == 0])
        test_data = Dataset.from_pandas(filtered[filtered.partition == 1])
        datasets[f'{lan}_{cat}'] = {'train_data': train_data, 'test_data' : test_data}

<a id='load_model'></a>

## Load model

In [3]:
with open('best_run.pkl', 'rb') as file:
    best_run = pickle.load(file)
best_run.hyperparameters['batch_size'] = 8 # Reduce batch size
    
def model_init(params):
    params = params or {}
    max_iter = params.get("max_iter", 100)
    solver = params.get("solver", "liblinear")
    params = {
        "head_params": {
            "max_iter": max_iter,
            "solver": solver,
        }
    }
    return SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2", **params)

# Create a fresh trainer with hyperparams
def load_trainer(train_data, test_data):
    trainer = SetFitTrainer(
        train_dataset=train_data,
        eval_dataset=test_data,
        loss_class=CosineSimilarityLoss,
        model_init=model_init,
        column_mapping={"comment_sentence": "text", "label": "label"},
    )

    trainer.apply_hyperparameters(best_run.hyperparameters, final_model=True)
    
    return trainer

<a id='train_model'></a>


## Train Models
Train and save a model for each of the categories

This will take around 3h per category, so around 2 days in total.

In [None]:
# Train model for each cat
for lan_cat in lan_cats:
    print(f'training {lan_cat}')
    train_data = datasets[lan_cat]['train_data']
    test_data = datasets[lan_cat]['test_data']
    trainer = load_trainer(train_data, test_data)
    
    trainer.train()
    
#     joblib.dump(trainer, f'./models/{lan_cat}_all-mpnet-base-v2.joblib')

training java_deprecation


model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
Applying column mapping to training dataset
***** Running training *****
  Num examples = 38620
  Num epochs = 6
  Total optimization steps = 28968
  Total train batch size = 8


Epoch:   0%|          | 0/6 [00:00<?, ?it/s]

Iteration:   0%|          | 0/28968 [00:00<?, ?it/s]

Iteration:   0%|          | 0/28968 [00:00<?, ?it/s]

<a id='eval'></a>

## Evaluation
Next we evaluate each of our trained models on the test set.

In [24]:
# Score each classifier and write scores to CSV
scores = []
for lan_cat in lan_cats:
    trainer = joblib.load(f'./models/{lan_cat}_all-mpnet-base-v2.joblib')
    test_data = datasets[lan_cat]['test_data']
    y_hat = trainer.model(test_data['comment_sentence'])
    y = test_data['label']
    _, fp, fn, tp = confusion_matrix(y_hat, y).ravel()
    wf1 = f1_score(y, y_hat, average='weighted')
    precision, recall, f1 = nlbse_statistics.get_precision_recall_f1(tp, fp, fn)
#     print(f'{lan_cat} precision: {precision}, recall: {recall}, f1 {f1} weighted f1: {wf1}')
    scores.append({'lan_cat': lan_cat.lower(),'precision': precision,'recall': recall,'f1': f1,'wf1': wf1})

df = pd.DataFrame(scores)
# df.sort_values('lan_cat').to_excel('scores.xlsx')

java_Pointer precision: 0.6, recall: 0.8333333333333334, f1 0.6976744186046512 weighted f1: 0.9147401095108914
java_rational precision: 0.543859649122807, recall: 0.8157894736842105, f1 0.6526315789473685 weighted f1: 0.9263439459630449
java_Ownership precision: 1.0, recall: 1.0, f1 1.0 weighted f1: 1.0
java_Expand precision: 0.6456692913385826, recall: 0.7522935779816514, f1 0.6949152542372881 weighted f1: 0.8483209159519988
java_deprecation precision: 0.7037037037037037, recall: 0.8636363636363636, f1 0.7755102040816326 weighted f1: 0.9763213659957574
java_summary precision: 0.7471264367816092, recall: 0.7738095238095238, f1 0.760233918128655 weighted f1: 0.9157476952136655
java_usage precision: 0.6956521739130435, recall: 0.8421052631578947, f1 0.761904761904762 weighted f1: 0.8320261720389173
python_Parameters precision: 0.6956521739130435, recall: 0.7832167832167832, f1 0.7368421052631579 weighted f1: 0.8428864585033696
python_Expand precision: 0.4803921568627451, recall: 0.628205

<a id='score'></a>

## Comparision with baseline
We compare the weighed f1 scores with the baseline

| Language_Category             | Baseline    |             |             |            | Ours      |          |          |            |
|-------------------------------|-------------|-------------|-------------|------------|-----------|----------|----------|------------|
|                               | Precision   | Recall      | F1          | Weighed F1 | Precision | Recall   | F1       | Weighed F1 |
| java_deprecation              | 0           | 0           | 0           | 0,916601   | 0,703704  | 0,863636 | 0,77551  | 0,976321   |
| java_expand                   | 0,350515464 | 0,267716535 | 0,303571429 | 0,664627   | 0,645669  | 0,752294 | 0,694915 | 0,848321   |
| java_ownership                | 1           | 0,68        | 0,80952381  | 0,982152   | 1         | 1        | 1        | 1          |
| java_pointer                  | 0,666666667 | 0,24        | 0,352941176 | 0,836971   | 0,6       | 0,833333 | 0,697674 | 0,91474    |
| java_rational                 | 0,62962963  | 0,298245614 | 0,404761905 | 0,880968   | 0,54386   | 0,815789 | 0,652632 | 0,926344   |
| java_summary                  | 0,384615385 | 0,287356322 | 0,328947368 | 0,779538   | 0,747126  | 0,77381  | 0,760234 | 0,915748   |
| java_usage                    | 0,540983607 | 0,358695652 | 0,431372549 | 0,623095   | 0,695652  | 0,842105 | 0,761905 | 0,832026   |
| pharo_classreferences         | 0,333333333 | 0,058823529 | 0,1         | 0,932441   | 0,294118  | 0,555556 | 0,384615 | 0,948548   |
| pharo_collaborators           | 0,466666667 | 0,25        | 0,325581395 | 0,907787   | 0,25      | 0,777778 | 0,378378 | 0,920377   |
| pharo_example                 | 0,76744186  | 0,434210526 | 0,554621849 | 0,682497   | 0,894737  | 0,866242 | 0,880259 | 0,896557   |
| pharo_intent                  | 0,576923077 | 0,333333333 | 0,422535211 | 0,871128   | 0,866667  | 0,906977 | 0,886364 | 0,971636   |
| pharo_keyimplementationpoints | 0,178571429 | 0,104166667 | 0,131578947 | 0,79483    | 0,291667  | 0,736842 | 0,41791  | 0,870274   |
| pharo_keymessages             | 0,3125      | 0,158730159 | 0,210526316 | 0,761551   | 0,460317  | 0,852941 | 0,597938 | 0,877329   |
| pharo_responsibilities        | 0,58974359  | 0,333333333 | 0,425925926 | 0,807558   | 0,637681  | 0,666667 | 0,651852 | 0,867963   |
| python_developmentnotes       | 0,171875    | 0,169230769 | 0,170542636 | 0,791947   | 0,246154  | 0,484848 | 0,326531 | 0,853401   |
| python_expand                 | 0,263157895 | 0,196078431 | 0,224719101 | 0,717097   | 0,480392  | 0,628205 | 0,544444 | 0,832729   |
| python_parameters             | 0,514285714 | 0,223602484 | 0,311688312 | 0,64994    | 0,695652  | 0,783217 | 0,736842 | 0,842886   |
| python_summary                | 0,122807018 | 0,075268817 | 0,093333333 | 0,710185   | 0,569892  | 0,688312 | 0,623529 | 0,871283   |
| python_usage                  | 0,46875     | 0,18404908  | 0,264317181 | 0,626358   | 0,576687  | 0,77686  | 0,661972 | 0,805782   |
| average                       | 0,438866649 | 0,244886382 | 0,30876255  | 0,78617216 | 0,589472  | 0,768706 | 0,654395 | 0,893277   |


Finally we calculate our final score: 

\begin{align}
submission\_score(model) &= (avg. \space F_1) \times 0.75 + (\% \space of \space outperformed \space categories) \times 0.25
\end{align}

We take the average unweighed f1 (0.654395), and the proportion of outperformed categories (all):

\begin{align}
submission\_score = 0.654395 * 0.75 + 1 * 0.25 = 0.74079625
\end{align}


<a id='hub'></a>

## Push to hub

Finally we push all of our models to the Hugging Face Hub to make them publically avaliable.

In [36]:
# Push to hub
token = 'hf_XXXXXXXXXXX'
repo = 'XXXXXXXXXXXXX'
for lan_cat in lan_cats:
    trainer = joblib.load(f'./models/{lan_cat}_all-mpnet-base-v2.joblib')
    name = lan_cat.lower().replace('_','-') + '-classifier'
    print(name)
    trainer.push_to_hub(f'{repo}/{name}', use_auth_token=token, private=True)

java-ownership-classifier


Cloning https://huggingface.co/aalkaswan/java-ownership-classifier into local empty directory.


Upload file pytorch_model.bin:   0%|          | 32.0k/418M [00:00<?, ?B/s]

Upload file model_head.pkl: 100%|##########| 6.83k/6.83k [00:00<?, ?B/s]

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/aalkaswan/java-ownership-classifier
   52311db..01cf829  main -> main



python-usage-classifier


Cloning https://huggingface.co/aalkaswan/python-usage-classifier into local empty directory.


Upload file pytorch_model.bin:   0%|          | 32.0k/418M [00:00<?, ?B/s]

Upload file model_head.pkl: 100%|##########| 6.83k/6.83k [00:00<?, ?B/s]

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/aalkaswan/python-usage-classifier
   2f03675..b273747  main -> main



pharo-example-classifier


Cloning https://huggingface.co/aalkaswan/pharo-example-classifier into local empty directory.


Upload file pytorch_model.bin:   0%|          | 32.0k/418M [00:00<?, ?B/s]

Upload file model_head.pkl: 100%|##########| 6.83k/6.83k [00:00<?, ?B/s]

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/aalkaswan/pharo-example-classifier
   56c11b2..aa13b96  main -> main



pharo-responsibilities-classifier


Cloning https://huggingface.co/aalkaswan/pharo-responsibilities-classifier into local empty directory.


Upload file pytorch_model.bin:   0%|          | 32.0k/418M [00:00<?, ?B/s]

Upload file model_head.pkl: 100%|##########| 6.83k/6.83k [00:00<?, ?B/s]

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/aalkaswan/pharo-responsibilities-classifier
   c8fa046..78696a3  main -> main



## Conclusion

In this Notebook we created our own classifiers which beat the baseline in every category. We uploaded the models to the Hugging Face Hub as well.

![display image](https://github.com/snipe/animated-gifs/blob/master/retro-computers/old-school.gif?raw=true)

Please join us in the [next and final Notebook](./3-Inference.ipynb) to see how we can load and use these models.