In [2]:
text_train_file = "/share/shubh/data/yelp_polarity/train_yelp.csv"
text_test_file = "/share/shubh/data/yelp_polarity/test_yelp.csv"

#### Training a text classifier using Universal Deep Transformers is just a few lines of code.  

In [4]:
from thirdai import bolt

text_model = bolt.UniversalDeepTransformer(
    data_types={
        "text": bolt.types.text(),
        "category": bolt.types.categorical(n_unique_classes=2)
    },
    target="category"
)


input_1 (Input): dim=200000
input_1 -> fc_1 (FullyConnected): dim=512, sparsity=1, act_func=ReLU
fc_1 -> fc_2 (FullyConnected): dim=2, sparsity=1, act_func=Softmax



### We will be training a model that has more than 100M parameters on an M1 under 10mins

In [5]:
train_config = (bolt.TrainConfig(epochs=5, learning_rate=0.01)
                    .with_metrics(["categorical_accuracy"]))

text_model.train(text_train_file, train_config)

test_config = (bolt.EvalConfig()
                   .with_metrics(["categorical_accuracy"]))

text_model.evaluate(text_test_file, test_config)

Loading vectors from './yelp_polarity/train_yelp.csv'
Loaded 560000 vectors from './yelp_polarity/train_yelp.csv' in 10 seconds.
train epoch 0:


train | epoch 0 | updates 274 | {categorical_accuracy: 0.913309} | batches 274 | time 222s | complete

train epoch 1:


train | epoch 1 | updates 548 | {categorical_accuracy: 0.961759} | batches 274 | time 231s | complete

train epoch 2:


train | epoch 2 | updates 822 | {categorical_accuracy: 0.974805} | batches 274 | time 210s | complete

train epoch 3:


train | epoch 3 | updates 1096 | {categorical_accuracy: 0.982463} | batches 274 | time 194s | complete

train epoch 4:


array([[8.90052259e-01, 1.09947614e-01],
       [9.95818436e-01, 4.18142136e-03],
       [1.21686628e-06, 9.99998689e-01],
       ...,
       [9.99999762e-01, 1.74599634e-07],
       [9.99999881e-01, 5.80643589e-10],
       [9.99999881e-01, 5.16755847e-08]], dtype=float32)


train | epoch 4 | updates 1370 | {categorical_accuracy: 0.989852} | batches 274 | time 214s | complete

Loading vectors from './yelp_polarity/test_yelp.csv'
Loaded 38000 vectors from './yelp_polarity/test_yelp.csv' in 0 seconds.
test:


predict | epoch 5 | updates 1370 | {categorical_accuracy: 0.919342} | batches 19 | time 8175ms



In [6]:
text_model.save("yelp_udt.bolt")

In [31]:
inference_sample={"text":"Everyone said that Nobu is bad but nothing could be further away from the truth"}

#### Machine learning models sometimes feel like a black box where it becomes very hard to gauge into how the model make decisions about the input data points. Our explainability module lifts the curtain over predictions and offers you deeper insights into how the decision process of the model

In [33]:
predicted_class=text_model.predict(inference_sample)
print(predicted_class)

explanations=text_model.explain(inference_sample)
print(explanations[0])

[0.998505   0.00149496]
column_name: "text" | keyword: "nothing" | percentage_significance: 16.2851


We can see that the word "nothing" has the highest significance.

Not only can we get insights the predictions of the model, but also how to change the datapoint to get the desired output. By specifying a target, we can find out what columns should be changed to get the desired output

In [37]:
inference_sample={"text":"Nobu is an underwhelming restaurant"}
explanations=text_model.explain(inference_sample, target_class="1")
print(explanations[0])

column_name: "text" | keyword: "underwhelming" | percentage_significance: -55.5133
