In [1]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.datasets import fetch_20newsgroups
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
pd.options.display.max_columns = 30

# Get the Data

In [2]:
x_train, y_train = fetch_20newsgroups(subset='train', return_X_y=True)
x_valid, y_valid = fetch_20newsgroups(subset='test', return_X_y=True)

# Initialise Model

In [3]:
checkpoint = "distilroberta-base"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = TFAutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=20)

Downloading:   0%|          | 0.00/480 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/465M [00:00<?, ?B/s]

2022-07-02 05:32:56.631340: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-07-02 05:32:56.632922: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-07-02 05:32:56.634015: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-07-02 05:32:56.635237: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags

# Tokenize Data

In [4]:
def tokenize(texts):
    return tokenizer(
        texts, padding='max_length', truncation=True, max_length=512, return_tensors='np'
    )

In [5]:
x_train_tokenized = tokenize(x_train)
x_valid_tokenized = tokenize(x_valid)

# Prepare Data Loaders

In [6]:
train_data = tf.data.Dataset.from_tensor_slices((dict(x_train_tokenized), y_train)).batch(8)
valid_data = tf.data.Dataset.from_tensor_slices((dict(x_valid_tokenized), y_valid)).batch(8)

# Train

In [7]:
model.compile(
    optimizer=tf.optimizers.Adam(2e-5), 
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
    metrics='accuracy'
)
model.fit(train_data, validation_data=valid_data, epochs=5)
model.save_pretrained('news-classifier')

Epoch 1/5


2022-07-02 05:33:52.266460: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


# Inference

In [8]:
model = TFAutoModelForSequenceClassification.from_pretrained('news-classifier')
logits = model.predict(valid_data, verbose=1).logits
preds_proba = tf.nn.softmax(logits).numpy()
preds = preds_proba.argmax(axis=1)

All model checkpoint layers were used when initializing TFRobertaForSequenceClassification.

All the layers of TFRobertaForSequenceClassification were initialized from the model checkpoint at news-classifier.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFRobertaForSequenceClassification for predictions without further training.




# Evaluation

In [9]:
clf_report = pd.DataFrame(confusion_matrix(y_valid, preds))
precision, recall, fscore, support = precision_recall_fscore_support(y_valid, preds)
clf_report['precision'] = precision
clf_report['recall'] = recall
clf_report['fscore'] = fscore
clf_report['support'] = support
clf_report

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,precision,recall,fscore,support
0,215,0,0,0,0,0,1,0,1,0,0,1,2,2,4,9,1,11,0,72,0.808271,0.673981,0.735043,319
1,1,310,8,21,11,13,1,1,1,0,0,15,2,2,0,2,0,1,0,0,0.824468,0.796915,0.810458,389
2,0,18,279,68,16,6,0,0,0,0,0,3,2,2,0,0,0,0,0,0,0.853211,0.708122,0.773925,394
3,0,3,10,344,24,3,3,0,1,0,0,0,4,0,0,0,0,0,0,0,0.640596,0.877551,0.740581,392
4,0,1,4,35,334,0,2,0,1,0,0,0,6,2,0,0,0,0,0,0,0.712154,0.867532,0.782201,385
5,0,22,19,2,5,344,1,0,1,0,0,0,1,0,0,0,0,0,0,0,0.932249,0.870886,0.900524,395
6,0,1,1,22,29,0,309,17,5,0,0,0,5,0,1,0,0,0,0,0,0.930723,0.792308,0.855956,390
7,0,1,0,1,4,0,3,360,11,0,0,1,10,0,1,0,0,0,4,0,0.8867,0.909091,0.897756,396
8,2,0,0,0,4,0,2,17,360,0,0,0,5,2,1,0,1,0,4,0,0.8933,0.904523,0.898876,398
9,3,0,1,1,0,1,4,2,5,359,8,0,1,2,0,0,3,0,5,2,0.975543,0.904282,0.938562,397
