Dataset Link : https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_20newsgroups.html

Pretrained Model : https://huggingface.co/transformers/v3.3.1/model_doc/distilbert.html

In [None]:
!pip install ktrain

Collecting ktrain
  Downloading ktrain-0.29.3.tar.gz (25.3 MB)
[K     |████████████████████████████████| 25.3 MB 1.5 MB/s 
[?25hCollecting scikit-learn==0.24.2
  Downloading scikit_learn-0.24.2-cp37-cp37m-manylinux2010_x86_64.whl (22.3 MB)
[K     |████████████████████████████████| 22.3 MB 1.2 MB/s 
Collecting langdetect
  Downloading langdetect-1.0.9.tar.gz (981 kB)
[K     |████████████████████████████████| 981 kB 43.6 MB/s 
Collecting cchardet
  Downloading cchardet-2.1.7-cp37-cp37m-manylinux2010_x86_64.whl (263 kB)
[K     |████████████████████████████████| 263 kB 45.0 MB/s 
Collecting syntok==1.3.3
  Downloading syntok-1.3.3-py3-none-any.whl (22 kB)
Collecting seqeval==0.0.19
  Downloading seqeval-0.0.19.tar.gz (30 kB)
Collecting transformers==4.10.3
  Downloading transformers-4.10.3-py3-none-any.whl (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 35.9 MB/s 
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux201

In [None]:
%reload_ext autoreload
%autoreload
%matplotlib inline
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0";

In [None]:
import ktrain
from ktrain import text
from sklearn.datasets import fetch_20newsgroups

In [None]:
categories = ['alt.atheism','soc.religion.christian','comp.graphics','sci.med','rec.sport.baseball']


In [None]:
train = fetch_20newsgroups(
    subset = 'train',
    categories = categories,
    shuffle = True,
    random_state = 40
)    

In [None]:
test = fetch_20newsgroups(
    subset = 'test',
    categories = categories,
    shuffle = True,
    random_state = 40
)   

In [None]:
test.keys()

dict_keys(['data', 'filenames', 'target_names', 'target', 'DESCR'])

In [None]:
test.data

['From: banschbach@vms.ocom.okstate.edu\nSubject: Re: Candida(yeast) Bloom, Fact or Fiction\nLines: 68\nNntp-Posting-Host: vms.ocom.okstate.edu\nOrganization: OSU College of Osteopathic Medicine\n\nIn article <1rjifg$bgm@hsdndev.harvard.edu>, rind@enterprise.bih.harvard.edu (David Rind) writes:\n> In article <1993Apr26.174538.1@vms.ocom.okstate.edu>\n>  banschbach@vms.ocom.okstate.edu writes:\n>>oxygen(just like it does in the vagina).  As much stuff as there is in the \n>>lay press about L. acidophilus and vaginal yeast infections, I\'m really \n>>amazed that someone has not done a clinical trial yet to check it out.\n> \n> I\'ve mentioned this study a couple of times now: Ingestion of yogurt\n> containing Lactobacillus acidophilus as prophylaxis for candidal\n> vaginitis, Annals of Internal Medicine, 3/1/92 116(5):353-7.  Do you\n> have a problem with the study because they used yogurt rather than\n> capsules of lactobacillus (even though it had positive results)?\n> \n> The study wa

In [None]:
test.target_names

['alt.atheism',
 'comp.graphics',
 'rec.sport.baseball',
 'sci.med',
 'soc.religion.christian']

In [None]:
X_train = train.data
y_train = train.target

X_test = test.data
y_test = test.target

In [None]:
len(X_train),len(X_test)

(2854, 1899)

In [None]:
#Building ML Model with Transformer
Model_name = 'distilbert-base-uncased'

In [None]:
trans = text.Transformer(Model_name,maxlen = 512,class_names = categories)

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

In [None]:
train_data = trans.preprocess_train(X_train,y_train)
test_data = trans.preprocess_test(X_test,y_test)

preprocessing train...
language: en
train sequence lengths:
	mean : 291
	95percentile : 820
	99percentile : 1757


Is Multi-Label? False
preprocessing test...
language: en
test sequence lengths:
	mean : 323
	95percentile : 894
	99percentile : 2394


In [None]:
model = trans.get_classifier()



Downloading:   0%|          | 0.00/363M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

In [None]:
learner = ktrain.get_learner(model,train_data=train_data,val_data = test_data,batch_size = 16)

In [None]:
learner.fit_onecycle(1e-4,1)



begin training using onecycle policy with max lr of 0.0001...


<keras.callbacks.History at 0x7f06a5a72610>

In [None]:
learner.validate(class_names = categories)

                        precision    recall  f1-score   support

           alt.atheism       0.88      0.73      0.80       319
soc.religion.christian       0.95      0.94      0.95       389
         comp.graphics       0.95      0.99      0.97       397
               sci.med       0.97      0.93      0.95       396
    rec.sport.baseball       0.84      0.96      0.90       398

              accuracy                           0.92      1899
             macro avg       0.92      0.91      0.91      1899
          weighted avg       0.92      0.92      0.92      1899



array([[233,   1,  10,   8,  67],
       [ 12, 367,   4,   4,   2],
       [  0,   4, 393,   0,   0],
       [ 10,  11,   5, 368,   2],
       [ 11,   4,   1,   0, 382]])

In [None]:
learner.view_top_losses(n = 5,preproc=trans)

----------
id:276 | loss:5.59 | true:rec.sport.baseball | pred:comp.graphics)

----------
id:865 | loss:5.29 | true:comp.graphics | pred:soc.religion.christian)

----------
id:978 | loss:5.2 | true:sci.med | pred:alt.atheism)

----------
id:1616 | loss:5.19 | true:soc.religion.christian | pred:rec.sport.baseball)

----------
id:229 | loss:5.16 | true:alt.atheism | pred:comp.graphics)



This is our worst classified records

In [None]:
#Predict on new data
predictor = ktrain.get_predictor(learner.model,preproc=trans)

In [None]:
x = "My Friend is suffering from Cancer"

In [None]:
predictor.predict(x)

'sci.med'

Hurray we got the correct prediction

In [None]:
#Saving the model
predictor.save('distilber-model')