In [1]:
import os
import numpy as np
import pandas as pd
import torch
from sklearn.datasets import fetch_20newsgroups
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
pd.options.display.max_columns = 30

# Get the Data

In [2]:
x_train, y_train = fetch_20newsgroups(subset='train', return_X_y=True)
x_valid, y_valid = fetch_20newsgroups(subset='test', return_X_y=True)

# Initialise Model

In [3]:
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=20)
training_args = TrainingArguments("news_classifier", num_train_epochs=2)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

# Tokenize Data

In [4]:
def tokenize(texts):
    return tokenizer(
        texts, padding='max_length', truncation=True, max_length=512, return_tensors='pt'
    )

In [5]:
x_train_tokenized = tokenize(x_train)
x_valid_tokenized = tokenize(x_valid)

# Prepare Data Loaders

In [6]:
class TextClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = TextClassificationDataset(x_train_tokenized, y_train)
valid_dataset = TextClassificationDataset(x_valid_tokenized, y_valid)

# Train

In [7]:
trainer = Trainer(
    model,
    training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
)
trainer.train()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


***** Running training *****
  Num examples = 11314
  Num Epochs = 2
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 2830
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  import sys


Step,Training Loss
500,1.3186
1000,0.6591
1500,0.4774
2000,0.2467
2500,0.2329


Saving model checkpoint to news_classifier/checkpoint-500
Configuration saved in news_classifier/checkpoint-500/config.json
Model weights saved in news_classifier/checkpoint-500/pytorch_model.bin
  import sys
Saving model checkpoint to news_classifier/checkpoint-1000
Configuration saved in news_classifier/checkpoint-1000/config.json
Model weights saved in news_classifier/checkpoint-1000/pytorch_model.bin
  import sys
Saving model checkpoint to news_classifier/checkpoint-1500
Configuration saved in news_classifier/checkpoint-1500/config.json
Model weights saved in news_classifier/checkpoint-1500/pytorch_model.bin
  import sys
Saving model checkpoint to news_classifier/checkpoint-2000
Configuration saved in news_classifier/checkpoint-2000/config.json
Model weights saved in news_classifier/checkpoint-2000/pytorch_model.bin
  import sys
Saving model checkpoint to news_classifier/checkpoint-2500
Configuration saved in news_classifier/checkpoint-2500/config.json
Model weights saved in news_c

TrainOutput(global_step=2830, training_loss=0.5438325106887009, metrics={'train_runtime': 1281.7669, 'train_samples_per_second': 17.654, 'train_steps_per_second': 2.208, 'total_flos': 5954639162621952.0, 'train_loss': 0.5438325106887009, 'epoch': 2.0})

# Inference

In [8]:
model_checkpoint = 'news_classifier/checkpoint-2500/'
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=20)
trainer = Trainer(
    model,
    training_args
)
predictions = trainer.predict(valid_dataset)

loading configuration file news_classifier/checkpoint-2500/config.json
Model config BertConfig {
  "_name_or_path": "news_classifier/checkpoint-2500/",
  "architectures": [
    "BertForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3",
    "4": "LABEL_4",
    "5": "LABEL_5",
    "6": "LABEL_6",
    "7": "LABEL_7",
    "8": "LABEL_8",
    "9": "LABEL_9",
    "10": "LABEL_10",
    "11": "LABEL_11",
    "12": "LABEL_12",
    "13": "LABEL_13",
    "14": "LABEL_14",
    "15": "LABEL_15",
    "16": "LABEL_16",
    "17": "LABEL_17",
    "18": "LABEL_18",
    "19": "LABEL_19"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_10": 10,
    "LABEL_11": 11,
    "LABEL_12": 

# Evaluation

In [9]:
clf_report = pd.DataFrame(confusion_matrix(y_valid, predictions.label_ids))
precision, recall, fscore, support = precision_recall_fscore_support(y_valid, predictions.label_ids)
clf_report['precision'] = precision
clf_report['recall'] = recall
clf_report['fscore'] = fscore
clf_report['support'] = support
clf_report

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,precision,recall,fscore,support
0,319,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1.0,1.0,1.0,319
1,0,389,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1.0,1.0,1.0,389
2,0,0,394,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1.0,1.0,1.0,394
3,0,0,0,392,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1.0,1.0,1.0,392
4,0,0,0,0,385,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1.0,1.0,1.0,385
5,0,0,0,0,0,395,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1.0,1.0,1.0,395
6,0,0,0,0,0,0,390,0,0,0,0,0,0,0,0,0,0,0,0,0,1.0,1.0,1.0,390
7,0,0,0,0,0,0,0,396,0,0,0,0,0,0,0,0,0,0,0,0,1.0,1.0,1.0,396
8,0,0,0,0,0,0,0,0,398,0,0,0,0,0,0,0,0,0,0,0,1.0,1.0,1.0,398
9,0,0,0,0,0,0,0,0,0,397,0,0,0,0,0,0,0,0,0,0,1.0,1.0,1.0,397
