<a href="https://colab.research.google.com/github/EHDEV/xitext_model_trainer/blob/main/trainer_nb-onnx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Install requirements

In [1]:
!pip install transformers install onnxruntime onnxruntime-tools
 



#### Download repo and set working directory
- Get the latest trainer code from github
- set working directory then execute code

In [1]:
import os
os.environ['TRAINER_HOME'] = '/content/drive/MyDrive/xitext/xitext_model_trainer'
trainer_home = os.environ.get('TRAINER_HOME')

In [2]:
%%shell

# export TRAINER_HOME=/content/drive/MyDrive/xitext/xitext_model_trainer
cd $TRAINER_HOME; git pull


remote: Enumerating objects: 5, done.[K
remote: Counting objects:  20% (1/5)[Kremote: Counting objects:  40% (2/5)[Kremote: Counting objects:  60% (3/5)[Kremote: Counting objects:  80% (4/5)[Kremote: Counting objects: 100% (5/5)[Kremote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (1/1)[Kremote: Compressing objects: 100% (1/1), done.[K
remote: Total 3 (delta 2), reused 3 (delta 2), pack-reused 0[K
Unpacking objects:  33% (1/3)   Unpacking objects:  66% (2/3)   Unpacking objects: 100% (3/3)   Unpacking objects: 100% (3/3), done.
From https://github.com/EHDEV/xitext_model_trainer
   9d02a1c..ad3218e  main       -> origin/main
Updating 9d02a1c..ad3218e
Fast-forward
 models.py | 18 [32m+++++++++[m[31m---------[m
 1 file changed, 9 insertions(+), 9 deletions(-)




In [3]:
cd $trainer_home

/content/drive/MyDrive/xitext/xitext_model_trainer


#### Configurations

In [4]:
import configparser
import os
from pathlib import Path

config = configparser.ConfigParser(interpolation=configparser.ExtendedInterpolation())
config.read('config.ini')


['config.ini']

In [5]:
default_config = dict(config['DEFAULT'])
data_config = dict(config['DATA'])
model_config = dict(config['MODEL'])
onnx_config = dict(config['ONNX'])

#### Import classes and required functions

In [6]:
from file_config import FileConfig
from models import SequenceClassifierModel
from convert_optimize_onnx import TorchToONNX
from data_preprocess import TextClassifierData, _encode_text_into_tokens
from pathlib import Path
import torch

#### Set data file path and other configurations of the file

In [7]:
data_config

{'company': 'xitext',
 'company_home': '/content/drive/MyDrive/xitext',
 'data_pickle_output_path': '/content/drive/MyDrive/xitext/news-topic-classifier/data/pickles',
 'delimiter': ',',
 'project_home': '/content/drive/MyDrive/xitext/news-topic-classifier',
 'project_name': 'news-topic-classifier',
 'source_dir': '/content/drive/MyDrive/xitext/news-topic-classifier/data',
 'target_col': 'topic',
 'text_col': 'text'}

In [8]:
fconfig = FileConfig(
    path_to_directory=Path(data_config['source_dir']), 
    target_column=data_config['target_col'],
    sequence_column=data_config['text_col'],
    delimiter=data_config['delimiter'])


In [9]:
text_clas_data = TextClassifierData(fconfig)
train_data, val_data = text_clas_data.preprocess()

DEBUG:data-preprocessing.log:load data started
DEBUG:data-preprocessing.log:dataframe with shape (200853, 2) has been created
INFO:numexpr.utils:NumExpr defaulting to 2 threads.
DEBUG:data-preprocessing.log:sentence column cleaned
DEBUG:data-preprocessing.log:Clean_label_column complete
DEBUG:data-preprocessing.log:Underrepresented classes have been removed and data condensed
DEBUG:data-preprocessing.log:preparing data for training: train/val split and convert to tensor
DEBUG:data-preprocessing.log:train test split completed
DEBUG:data-preprocessing.log:Tokenizing train and valid data
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/vocab.txt HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/tokenizer.json HTTP/1.1" 200

#### Write classes and their indices as a json in model directory

In [10]:
import json
with open(model_config['classes_file_path'], 'w') as cfile:
    cfile.write(json.dumps(
        { k: v for k,v in enumerate(text_clas_data.classes) }
    ))

In [11]:
pickle_path = Path(data_config['data_pickle_output_path'])
if not os.path.isdir(pickle_path):
  os.makedirs(pickle_path)

torch.save(train_data, pickle_path/'train_pickle.pth')
torch.save(val_data, pickle_path/'val_pickle.pth')

In [12]:
train_data = torch.load(pickle_path/'train_pickle.pth')
val_data = torch.load(pickle_path/'val_pickle.pth')

#### Train

In [18]:
model_config

{'classes_file_path': '/content/drive/MyDrive/xitext/news-topic-classifier/models/classes.json',
 'company': 'xitext',
 'company_home': '/content/drive/MyDrive/xitext',
 'epochs': '1',
 'eval_metric': 'accuracy',
 'model_filename': 'distilbert-topic-seq-classifier.bin',
 'model_group': 'distilbert',
 'model_output_dir': '/content/drive/MyDrive/xitext/news-topic-classifier/models',
 'optimizer': 'adam',
 'project_home': '/content/drive/MyDrive/xitext/news-topic-classifier',
 'project_name': 'news-topic-classifier',
 'scheduler': 'linear',
 'transformers_model_id': 'distilbert-base-uncased'}

In [None]:
seq_model = SequenceClassifierModel(
    this_project_name=model_config['project_name'], 
    tr_model_id=model_config['transformers_model_id'],
    model_group=model_config['model_group'],
    optimizer=model_config['optimizer'],
    scheduler=model_config['scheduler'],
    eval_metric=model_config['eval_metric'],
    num_labels=text_clas_data.num_labels,
    epochs=int(model_config['epochs']),
    train_data=train_data,
    val_data=val_data,
    output_dir=model_config['model_output_dir']
);

seq_model.train(save_model=False)

#### ONNX

In [55]:
from pathlib import Path

torch_model_path = onnx_config['torch_model_path']
onnx_model_dir = Path('/content/drive/MyDrive/xitext/xitext_model_trainer/models/news-topic-classifier/onnx/')

tt2 = TorchToONNX(
    torch_model_path=torch_model_path,
    onnx_model_dir=onnx_model_dir,
    tokenizer=text_clas_data.tokenizer
)
tt2.model_type='bert'

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/vocab.txt HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/tokenizer.json HTTP/1.1" 200 0


In [None]:
tt2.convert_torch_to_onnx()
# tt2.model_type

In [58]:
from scipy.special import softmax
def make_predictions(model, encoded_sentence, attention_mask, token_type_id=None):
    
    model.eval()
    with torch.no_grad():
        
        preds = model(
                    encoded_sentence,
                    attention_mask=attention_mask
                ) 
            # labels are not passed here in validation
            # Get the "logits" output by the model. The "logits" are the output
            # values prior to applying an activation function like softmax
        
        logits = preds[0]
        probabilities = torch.nn.functional.softmax(logits, dim=1)        
        probabilities = probabilities.detach().cpu().numpy()
        # Move logits and labels to CPU

        np.set_printoptions(suppress=True)

    return probabilities[-1].round(4)


In [59]:
import numpy as np
sentence = [
	'The KKK used to run a youth group called the Klu Klux Kiddies. A sobering reminder of how evil shit like this starts at home.']

encoded_tensor = _encode_text_into_tokens( sentence, text_clas_data.tokenizer)

# print(encoded_tensor.shape, attention_mask_tensor.shape)

seq_model.model.eval()
seq_model.model.to('cpu')
prediction_probabilities = make_predictions(
    model=seq_model.model,
    encoded_sentence=encoded_tensor['input_ids'], 
    attention_mask=encoded_tensor['attention_mask']
)

from collections import OrderedDict
top_topics = OrderedDict()
classes=text_clas_data.classes
for i in prediction_probabilities.argsort()[-10:][::-1]:
    top_topics[classes[i]] = prediction_probabilities[i]

top_topics

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/vocab.txt HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/tokenizer.json HTTP/1.1" 200 0
DEBUG:data-preprocessing.log:Encoding input sentences completed


OrderedDict([('politics', 0.5881),
             ('black voices', 0.1225),
             ('queer voices', 0.0474),
             ('crime', 0.044),
             ('entertainment', 0.0324),
             ('comedy', 0.0297),
             ('religion', 0.0254),
             ('weird news', 0.0213),
             ('healthy living', 0.015),
             ('impact', 0.0109)])

In [None]:
onnx_res = tt2.run_inference(sentence, text_clas_data.tokenizer)

from collections import OrderedDict
probs = {}
for i, x in enumerate(onnx_res[0]):
  probs[text_clas_data.classes[i]] = round(x, 4) 

In [68]:
sorted(probs.items(), key=lambda x: x[1], reverse=True)

[('politics', 0.6278),
 ('black voices', 0.0898),
 ('queer voices', 0.0752),
 ('crime', 0.0715),
 ('entertainment', 0.0501),
 ('comedy', 0.0192),
 ('religion', 0.0159),
 ('weird news', 0.0086),
 ('business', 0.0062),
 ('impact', 0.005),
 ('latino voices', 0.0045),
 ('arts & culture', 0.0031),
 ('worldpost', 0.0029),
 ('women', 0.0026),
 ('healthy living', 0.0019),
 ('tech', 0.0017),
 ('media', 0.0016),
 ('good news', 0.0012),
 ('green', 0.001),
 ('taste', 0.001),
 ('college', 0.0009),
 ('food & drink', 0.0009),
 ('travel', 0.0009),
 ('arts', 0.0008),
 ('sports', 0.0008),
 ('wellness', 0.0008),
 ('world news', 0.0006),
 ('parenting', 0.0004),
 ('the worldpost', 0.0004),
 ('culture & arts', 0.0003),
 ('divorce', 0.0003),
 ('environment', 0.0003),
 ('fifty', 0.0003),
 ('parents', 0.0003),
 ('science', 0.0003),
 ('education', 0.0002),
 ('home & living', 0.0002),
 ('money', 0.0002),
 ('style', 1e-04),
 ('style & beauty', 1e-04),
 ('weddings', 1e-04)]

In [None]:
??text_clas_data.data_df['topic'].value_counts().plot(kind='bar', size=(10,12))