<a href="https://colab.research.google.com/github/SRAVANIDHARA/FARO_assignment/blob/master/20newsgroups_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [22]:
# install ktrain on Google Colab
!pip3 install ktrain



In [0]:
# import ktrain and the ktrain.text modules
import ktrain
from ktrain import text

In [24]:
ktrain.__version__

'0.14.1'

# Multiclass Text Classification Using BERT and Keras
In this example, we will use ***ktrain*** ([a lightweight wrapper around Keras](https://github.com/amaiya/ktrain)) to build a model using the dataset employed in the **scikit-learn** tutorial: [Working with Text Data](https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html).  As in the tutorial, we will sample 4 newsgroups to create a relatively small multiclass text classification dataset.  The objective is to accurately classify each document into one of these four newsgroup topic categories.  This will provide us an opportunity to see **BERT** in action on a relatively smaller training set.  Let's fetch the [20newsgroups dataset ](http://qwone.com/~jason/20Newsgroups/) using scikit-learn.

In [25]:
# fetch the dataset using scikit-learn
categories = ['comp.graphics', 'comp.os.ms-windows.misc',
             'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale',
             'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey',
             'talk.politics.misc', 'talk.politics.guns', 'talk.politics.mideast', 'sci.crypt',
             'sci.electronics', 'sci.med', 'sci.space', 'talk.religion.misc', 'alt.atheism', 'soc.religion.christian']
from sklearn.datasets import fetch_20newsgroups
train_b = fetch_20newsgroups(subset='train',
   categories=categories, shuffle=True, random_state=42)
test_b = fetch_20newsgroups(subset='test',
   categories=categories, shuffle=True, random_state=42)

print('size of training set: %s' % (len(train_b['data'])))
print('size of validation set: %s' % (len(test_b['data'])))
print('classes: %s' % (train_b.target_names))

x_train = train_b.data
y_train = train_b.target
x_test = test_b.data
y_test = test_b.target

size of training set: 11314
size of validation set: 7532
classes: ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']


## STEP 1:  Load and Preprocess the Data
Preprocess the data using the `texts_from_array function` (since the data resides in an array).
If your documents are stored in folders or a CSV file you can use the `texts_from_folder` or `texts_from_csv` functions, respectively.

In [26]:
(x_train,  y_train), (x_test, y_test), preproc = text.texts_from_array(x_train=x_train, y_train=y_train,
                                                                       x_test=x_test, y_test=y_test,
                                                                       class_names=train_b.target_names,
                                                                       preprocess_mode='bert',
                                                                       maxlen=350, 
                                                                       max_features=35000)

task: text classification
preprocessing train...
language: en


preprocessing test...
language: en


## STEP 2:  Load the BERT Model and Instantiate a Learner object

In [27]:
# you can disregard the deprecation warnings arising from using Keras 2.2.4 with TensorFlow 1.14.
model = text.text_classifier('bert', train_data=(x_train, y_train), preproc=preproc)
learner = ktrain.get_learner(model, train_data=(x_train, y_train), batch_size=6)

Is Multi-Label? False
maxlen is 350
done.


## STEP 3: Train the Model

We train using one of the three learning rates recommended in the BERT paper: *5e-5*, *3e-5*, or *2e-5*.
Alternatively, the ktrain Learning Rate Finder can be used to find a good learning rate by invoking `learner.lr_find()` and `learner.lr_plot()`, prior to training.
The `learner.fit_onecycle` method employs a [1cycle learning rate policy](https://arxiv.org/pdf/1803.09820.pdf).



In [28]:
learner.fit_onecycle(2e-5, 4)



begin training using onecycle policy with max lr of 2e-05...
Train on 11314 samples
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<tensorflow.python.keras.callbacks.History at 0x7f2ce1501550>

We can use the `learner.validate` method to test our model against the validation set.
As we can see, BERT achieves a **96%** accuracy, which is quite a bit higher than the 91% accuracy achieved by SVM in the [scikit-learn tutorial](https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html).

In [29]:
learner.validate(val_data=(x_test, y_test), class_names=train_b.target_names)

                          precision    recall  f1-score   support

             alt.atheism       0.80      0.78      0.79       319
           comp.graphics       0.84      0.81      0.82       389
 comp.os.ms-windows.misc       0.83      0.84      0.83       394
comp.sys.ibm.pc.hardware       0.78      0.82      0.80       392
   comp.sys.mac.hardware       0.83      0.87      0.85       385
          comp.windows.x       0.90      0.92      0.91       395
            misc.forsale       0.91      0.92      0.92       390
               rec.autos       0.94      0.90      0.92       396
         rec.motorcycles       0.92      0.92      0.92       398
      rec.sport.baseball       0.97      0.94      0.96       397
        rec.sport.hockey       0.96      0.98      0.97       399
               sci.crypt       0.93      0.94      0.94       396
         sci.electronics       0.85      0.82      0.84       393
                 sci.med       0.94      0.93      0.94       396
         

array([[248,   1,   0,   0,   0,   0,   0,   0,   1,   0,   0,   0,   0,
          2,   7,  14,   0,   0,  10,  36],
       [  0, 315,  14,   7,   9,  18,   1,   0,   0,   0,   1,  14,   4,
          0,   4,   2,   0,   0,   0,   0],
       [  0,  18, 331,  21,   7,  10,   0,   0,   1,   0,   0,   0,   1,
          1,   0,   0,   0,   0,   4,   0],
       [  0,   9,  26, 320,  17,   2,   7,   0,   2,   0,   0,   2,   6,
          0,   0,   1,   0,   0,   0,   0],
       [  0,   2,   3,  31, 334,   1,   7,   0,   0,   0,   0,   1,   6,
          0,   0,   0,   0,   0,   0,   0],
       [  0,  10,  18,   0,   1, 362,   2,   0,   0,   0,   0,   0,   1,
          1,   0,   0,   0,   0,   0,   0],
       [  0,   2,   3,   4,  10,   1, 358,   0,   3,   0,   0,   1,   6,
          0,   1,   1,   0,   0,   0,   0],
       [  0,   0,   3,   0,   1,   0,   8, 355,  12,   0,   0,   0,  12,
          0,   1,   0,   0,   0,   3,   1],
       [  1,   1,   0,   0,   0,   0,   2,  12, 368,   0,   1,  

## How to Use Our Trained BERT Model

We can call the `learner.get_predictor` method to obtain a Predictor object capable of making predictions on new raw data.

In [0]:
predictor = ktrain.get_predictor(learner.model, preproc)

In [31]:
predictor.get_classes()

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

In [32]:
predictor.predict(test_b.data[0:1])

['rec.autos']

In [33]:
# we can visually verify that our prediction of 'sci.med' for this document is correct
print(test_b.data[0])

From: v064mb9k@ubvmsd.cc.buffalo.edu (NEIL B. GANDLER)
Subject: Need info on 88-89 Bonneville
Organization: University at Buffalo
Lines: 10
News-Software: VAX/VMS VNEWS 1.41
Nntp-Posting-Host: ubvmsd.cc.buffalo.edu


 I am a little confused on all of the models of the 88-89 bonnevilles.
I have heard of the LE SE LSE SSE SSEI. Could someone tell me the
differences are far as features or performance. I am also curious to
know what the book value is for prefereably the 89 model. And how much
less than book value can you usually get them for. In other words how
much are they in demand this time of year. I have heard that the mid-spring
early summer is the best time to buy.

			Neil Gandler



In [34]:
# we predicted the correct label
print(test_b.target_names[test_b.target[0]])

rec.autos


The `predictor.save` and `ktrain.load_predictor` methods can be used to save the Predictor object to disk and reload it at a later time to make predictions on new data.

In [0]:
# let's save the predictor for later use
predictor.save('/tmp/my_predictor')

In [0]:
# reload the predictor
reloaded_predictor = ktrain.load_predictor('/tmp/my_predictor')

In [37]:
# make a prediction on the same document to verify it still works
reloaded_predictor.predict(test_b.data[0:1])

['rec.autos']