<a href="https://colab.research.google.com/github/EHDEV/xitext_model_trainer/blob/main/trainer_nb-onnx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Install requirements

In [22]:
!pip install transformers install onnxruntime onnxruntime-tools onnxconverter_common
 

Collecting onnxconverter_common
[?25l  Downloading https://files.pythonhosted.org/packages/fe/7a/7e30c643cd7d2ad87689188ef34ce93e657bd14da3605f87bcdbc19cd5b1/onnxconverter_common-1.7.0-py2.py3-none-any.whl (64kB)
[K     |████████████████████████████████| 71kB 5.7MB/s 
Installing collected packages: onnxconverter-common
Successfully installed onnxconverter-common-1.7.0


#### Download repo and set working directory
- Get the latest trainer code from github
- set working directory then execute code

In [3]:
import os
os.environ['TRAINER_HOME'] = '/content/drive/MyDrive/xitext/xitext_model_trainer'
trainer_home = os.environ.get('TRAINER_HOME')

In [4]:
%%shell

# export TRAINER_HOME=/content/drive/MyDrive/xitext/xitext_model_trainer
cd $TRAINER_HOME; git pull


Already up to date.




In [6]:
cd $trainer_home

/content/drive/MyDrive/xitext/xitext_model_trainer


#### Configurations

In [7]:
import configparser
import os
from pathlib import Path

config = configparser.ConfigParser(interpolation=configparser.ExtendedInterpolation())
config.read('config.ini')


['config.ini']

In [8]:
default_config = dict(config['DEFAULT'])
data_config = dict(config['DATA'])
model_config = dict(config['MODEL'])
onnx_config = dict(config['ONNX'])

In [9]:
data_config

{'company': 'xitext',
 'company_home': '/content/drive/MyDrive/xitext',
 'data_pickle_output_path': '/content/drive/MyDrive/xitext/news-topic-classifier/data/pickles',
 'delimiter': ',',
 'project_home': '/content/drive/MyDrive/xitext/news-topic-classifier',
 'project_name': 'news-topic-classifier',
 'source_dir': '/content/drive/MyDrive/xitext/news-topic-classifier/data',
 'target_col': 'topic',
 'text_col': 'text'}

#### Import classes and required functions

In [10]:
from file_config import FileConfig
from models import SequenceClassifierModel
from convert_optimize_onnx import TorchToONNX
from data_preprocess import TextClassifierData, _encode_text_into_tokens
from pathlib import Path
import torch

DEBUG:tensorflow:Falling back to TensorFlow client; we recommended you install the Cloud TPU client directly with pip install cloud-tpu-client.


#### Set data file path and other configurations of the file

In [11]:
data_config

{'company': 'xitext',
 'company_home': '/content/drive/MyDrive/xitext',
 'data_pickle_output_path': '/content/drive/MyDrive/xitext/news-topic-classifier/data/pickles',
 'delimiter': ',',
 'project_home': '/content/drive/MyDrive/xitext/news-topic-classifier',
 'project_name': 'news-topic-classifier',
 'source_dir': '/content/drive/MyDrive/xitext/news-topic-classifier/data',
 'target_col': 'topic',
 'text_col': 'text'}

In [12]:
fconfig = FileConfig(
    path_to_directory=Path(data_config['source_dir']), 
    target_column=data_config['target_col'],
    sequence_column=data_config['text_col'],
    delimiter=data_config['delimiter'],
    header_column=data_config.get('header'))


In [13]:
text_clas_data = TextClassifierData(fconfig)
train_data, val_data = text_clas_data.preprocess()

DEBUG:data-preprocessing.log:load data started
DEBUG:data-preprocessing.log:dataframe with shape (200854, 2) has been created
INFO:numexpr.utils:NumExpr defaulting to 2 threads.
DEBUG:data-preprocessing.log:sentence column cleaned
DEBUG:data-preprocessing.log:Clean_label_column complete
DEBUG:data-preprocessing.log:Underrepresented classes have been removed and data condensed
DEBUG:data-preprocessing.log:preparing data for training: train/val split and convert to tensor
DEBUG:data-preprocessing.log:train test split completed
DEBUG:data-preprocessing.log:Tokenizing train and valid data
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/vocab.txt HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 139943899388896 on /root/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f69309

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…

DEBUG:filelock:Attempting to release lock 139943899388896 on /root/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99.lock
INFO:filelock:Lock 139943899388896 released on /root/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443





DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/tokenizer.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 139943899388896 on /root/.cache/huggingface/transformers/534479488c54aeaf9c3406f647aa2ec13648c06771ffe269edabebd4c412da1d.7f2721073f19841be16f41b0a70b600ca6b880c8f3df6f3535cbc704371bdfa4.lock
INFO:filelock:Lock 139943899388896 acquired on /root/.cache/huggingface/transformers/534479488c54aeaf9c3406f647aa2ec13648c06771ffe269edabebd4c412da1d.7f2721073f19841be16f41b0a70b600ca6b880c8f3df6f3535cbc704371bdfa4.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /bert-base-uncased/resolve/main/tokenizer.json HTTP/1.1" 200 466062


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…

DEBUG:filelock:Attempting to release lock 139943899388896 on /root/.cache/huggingface/transformers/534479488c54aeaf9c3406f647aa2ec13648c06771ffe269edabebd4c412da1d.7f2721073f19841be16f41b0a70b600ca6b880c8f3df6f3535cbc704371bdfa4.lock
INFO:filelock:Lock 139943899388896 released on /root/.cache/huggingface/transformers/534479488c54aeaf9c3406f647aa2ec13648c06771ffe269edabebd4c412da1d.7f2721073f19841be16f41b0a70b600ca6b880c8f3df6f3535cbc704371bdfa4.lock





DEBUG:data-preprocessing.log:Encoding input sentences completed
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/vocab.txt HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/tokenizer.json HTTP/1.1" 200 0
DEBUG:data-preprocessing.log:Encoding input sentences completed
DEBUG:data-preprocessing.log:Wrapping tensors in dataloader completed
DEBUG:data-preprocessing.log:Wrapping tensors in dataloader completed
DEBUG:data-preprocessing.log:data preparation for training is complete. 2385 train and 597 validation examples
DEBUG:data-preprocessing.log:Data preprocessing complete


#### Write classes and their indices as a json in model directory

In [14]:
import json

model_output_dir = model_config['model_output_dir']

if not os.path.exists(model_output_dir):
    os.makedirs(model_output_dir)

with open(model_config['classes_file_path'], 'w') as cfile:
    cfile.write(json.dumps(
        { k: v for k,v in enumerate(text_clas_data.classes) }
    ))

In [15]:
pickle_path = Path(data_config['data_pickle_output_path'])
if not os.path.isdir(pickle_path):
  os.makedirs(pickle_path)

torch.save(train_data, pickle_path/'train_pickle.pth')
torch.save(val_data, pickle_path/'val_pickle.pth')

In [16]:
train_data = torch.load(pickle_path/'train_pickle.pth')
val_data = torch.load(pickle_path/'val_pickle.pth')

#### Train

In [17]:
model_config

{'classes_file_path': '/content/drive/MyDrive/xitext/news-topic-classifier/models/classes.json',
 'company': 'xitext',
 'company_home': '/content/drive/MyDrive/xitext',
 'epochs': '1',
 'eval_metric': 'accuracy',
 'model_filename': 'distilbert-topic-seq-classifier.bin',
 'model_group': 'distilbert',
 'model_output_dir': '/content/drive/MyDrive/xitext/news-topic-classifier/models',
 'optimizer': 'adam',
 'project_home': '/content/drive/MyDrive/xitext/news-topic-classifier',
 'project_name': 'news-topic-classifier',
 'scheduler': 'linear',
 'transformers_model_id': 'distilbert-base-uncased'}

In [18]:
seq_model = SequenceClassifierModel(
    project_home=model_config['project_home'],
    tr_model_id=model_config['transformers_model_id'],
    model_group=model_config['model_group'],
    optimizer=model_config['optimizer'],
    scheduler=model_config['scheduler'],
    eval_metric=model_config['eval_metric'],
    num_labels=text_clas_data.num_labels,
    epochs=int(model_config['epochs']),
    train_data=train_data,
    val_data=val_data,
    model_output_dir=model_config['model_output_dir'],
    model_output_filename=model_config['model_filename']
);

seq_model.train(save_model=True)

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /distilbert-base-uncased/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 139940955570528 on /root/.cache/huggingface/transformers/23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.d423bdf2f58dc8b77d5f5d18028d7ae4a72dcfd8f468e81fe979ada957a8c361.lock
INFO:filelock:Lock 139940955570528 acquired on /root/.cache/huggingface/transformers/23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.d423bdf2f58dc8b77d5f5d18028d7ae4a72dcfd8f468e81fe979ada957a8c361.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /distilbert-base-uncased/resolve/main/config.json HTTP/1.1" 200 442


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…

DEBUG:filelock:Attempting to release lock 139940955570528 on /root/.cache/huggingface/transformers/23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.d423bdf2f58dc8b77d5f5d18028d7ae4a72dcfd8f468e81fe979ada957a8c361.lock
INFO:filelock:Lock 139940955570528 released on /root/.cache/huggingface/transformers/23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.d423bdf2f58dc8b77d5f5d18028d7ae4a72dcfd8f468e81fe979ada957a8c361.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443





DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /distilbert-base-uncased/resolve/main/pytorch_model.bin HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 139941073892576 on /root/.cache/huggingface/transformers/9c169103d7e5a73936dd2b627e42851bec0831212b677c637033ee4bce9ab5ee.126183e36667471617ae2f0835fab707baa54b731f991507ebbb55ea85adb12a.lock
INFO:filelock:Lock 139941073892576 acquired on /root/.cache/huggingface/transformers/9c169103d7e5a73936dd2b627e42851bec0831212b677c637033ee4bce9ab5ee.126183e36667471617ae2f0835fab707baa54b731f991507ebbb55ea85adb12a.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): cdn-lfs.huggingface.co:443
DEBUG:urllib3.connectionpool:https://cdn-lfs.huggingface.co:443 "GET /distilbert-base-uncased/e60d71610916da4787c5513c81bc026d415708528295502fb3e1a6fe1485ea7c HTTP/1.1" 200 267967963


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…

DEBUG:filelock:Attempting to release lock 139941073892576 on /root/.cache/huggingface/transformers/9c169103d7e5a73936dd2b627e42851bec0831212b677c637033ee4bce9ab5ee.126183e36667471617ae2f0835fab707baa54b731f991507ebbb55ea85adb12a.lock
INFO:filelock:Lock 139941073892576 released on /root/.cache/huggingface/transformers/9c169103d7e5a73936dd2b627e42851bec0831212b677c637033ee4bce9ab5ee.126183e36667471617ae2f0835fab707baa54b731f991507ebbb55ea85adb12a.lock





Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

Training...
Batch    40 of 2,385. Elapsed: 0:00:12.
Batch    80 of 2,385. Elapsed: 0:00:24.
Batch   120 of 2,385. Elapsed: 0:00:37.
Batch   160 of 2,385. Elapsed: 0:00:49.
Batch   200 of 2,385. Elapsed: 0:01:02.
Batch   240 of 2,385. Elapsed: 0:01:15.
Batch   280 of 2,385. Elapsed: 0:01:28.
Batch   320 of 2,385. Elapsed: 0:01:42.
Batch   360 of 2,385. Elapsed: 0:01:55.
Batch   400 of 2,385. Elapsed: 0:02:08.
Batch   440 of 2,385. Elapsed: 0:02:22.
Batch   480 of 2,385. Elapsed: 0:02:35.
Batch   520 of 2,385. Elapsed: 0:02:48.
Batch   560 of 2,385. Elapsed: 0:03:01.
Batch   600 of 2,385. Elapsed: 0:03:15.
Batch   640 of 2,385. Elapsed: 0:03:28.
Batch   680 of 2,385. Elapsed: 0:03:41.
Batch   720 of 2,385. Elapsed: 0:03:55.
Batch   760 of 2,385. Elapsed: 0:04:08.
Batch   800 of 2,385. Elapsed: 0:04:21.
Batch   840 of 2,385. Elapsed: 0:04:35.
Batch   880 of 2,385. Elapsed: 0:04:48.
Batch   920 of 2,385. Elapsed: 0:05:01.
Batch   960 of 2,385. Elapsed: 0:05:14.
Batch 1,000 of 2,385. Elapse

DEBUG:model-training.log:Running Validation...


 Train Accuracy: 0.54
Average training loss: 1.7473923479735975
Training epoch took: 0:13:08



DEBUG:model-training.log:Training Complete
DEBUG:model-training.log:/content/drive/MyDrive/xitext/news-topic-classifier/models


 Accuracy: 0.6023328600929273
 Validation took: 0:01:06



DEBUG:model-training.log:model saved to /content/drive/MyDrive/xitext/news-topic-classifier/models/distilbert-topic-seq-classifier.bin


#### ONNX

In [19]:
onnx_config

{'company': 'xitext',
 'company_home': '/content/drive/MyDrive/xitext',
 'model_type': 'bert',
 'onnx_model_output_dir': '/content/drive/MyDrive/xitext/news-topic-classifier/models/onnx',
 'project_home': '/content/drive/MyDrive/xitext/news-topic-classifier',
 'project_name': 'news-topic-classifier',
 'runtimeprovider': 'CPUExecutionProvider',
 'torch_model_path': '/content/drive/MyDrive/xitext/news-topic-classifier/models/distilbert-topic-seq-classifier.bin'}

In [20]:
from pathlib import Path

torch_model_path = onnx_config['torch_model_path']
onnx_model_dir = Path('/content/drive/MyDrive/xitext/xitext_model_trainer/models/news-topic-classifier/onnx/')

tt2 = TorchToONNX(
    torch_model_path=Path(onnx_config['torch_model_path']),
    onnx_model_dir=Path(onnx_config['onnx_model_output_dir']),
    tokenizer=text_clas_data.tokenizer
)
tt2.model_type='bert'

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/vocab.txt HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/tokenizer.json HTTP/1.1" 200 0


In [None]:
tt2.convert_torch_to_onnx()
# tt2.model_type

In [None]:
from scipy.special import softmax
def make_predictions(model, encoded_sentence, attention_mask, token_type_id=None):
    
    model.eval()
    with torch.no_grad():
        
        preds = model(
                    encoded_sentence,
                    attention_mask=attention_mask
                ) 
            # labels are not passed here in validation
            # Get the "logits" output by the model. The "logits" are the output
            # values prior to applying an activation function like softmax
        
        logits = preds[0]
        probabilities = torch.nn.functional.softmax(logits, dim=1)        
        probabilities = probabilities.detach().cpu().numpy()
        # Move logits and labels to CPU

        np.set_printoptions(suppress=True)

    return probabilities[-1].round(4)


In [None]:
import numpy as np
sentence = [
	'''
  
  I need to eat there again. tastey meal

  ''']

encoded_tensor = _encode_text_into_tokens( sentence, text_clas_data.tokenizer)

# print(encoded_tensor.shape, attention_mask_tensor.shape)

seq_model.model.eval()
seq_model.model.to('cpu')
prediction_probabilities = make_predictions(
    model=seq_model.model,
    encoded_sentence=encoded_tensor['input_ids'], 
    attention_mask=encoded_tensor['attention_mask']
)

from collections import OrderedDict
top_topics = OrderedDict()
classes=text_clas_data.classes
for i in prediction_probabilities.argsort()[-10:][::-1]:
    top_topics[classes[i]] = prediction_probabilities[i]

top_topics

In [None]:
onnx_res = tt2.run_inference(sentence, text_clas_data.tokenizer)

from collections import OrderedDict
probs = {}
for i, x in enumerate(onnx_res[0]):
  probs[text_clas_data.classes[i]] = round(x, 4) 
  
sorted(probs.items(), key=lambda x: x[1], reverse=True)[:10]

#### End

In [None]:
text_clas_data.data_df.shape