### Transformer Model

This notebook is designed to show the steps involved in fine-tuning transformer models on a multi-label dataset. To keep things simple and to ensure a fair comparison with the scikit-learn model that we've previously trained, both models will be directly fitted on the training set and evaluated on the test set. If you're interested in learning stratified splitting on multi-label datasets or hyperparameter search using transformers, please refer to the next notebook: reuters_hyperparameter_search_using_trainer_api notebook.


In [1]:
from collections import Counter
from itertools import chain
import re

import numpy as np
import torch
from datasets import load_dataset
from sklearn.metrics import accuracy_score, classification_report, f1_score, roc_auc_score
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    EvalPrediction,
    Trainer,
    TrainingArguments,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Text Preprocessing
def preprocess_text(text: str) -> str:
    """Remove numbers, newlines, and special characters from text."""
    text = re.sub(r'\d+', '', text)
    text = re.sub(r'\n', ' ', text)
    text = re.sub(r'[^\w\s]', '', text)
    return text

# Find Single Appearance Labels
def find_single_appearance_labels(y):
    """Find labels that appear only once in the dataset."""
    all_labels = list(chain.from_iterable(y))
    label_count = Counter(all_labels)
    single_appearance_labels = [label for label, count in label_count.items() if count == 1]
    return single_appearance_labels

# Remove Single Appearance Labels from Dataset
def remove_single_appearance_labels(dataset, single_appearance_labels):
    """Remove samples with single-appearance labels from both train and test sets."""
    for split in ['train', 'test']:
        dataset[split] = dataset[split].filter(lambda x: all(label not in single_appearance_labels for label in x['topics']))
    return dataset

In [3]:
def multi_label_metrics(predictions, labels, threshold=0.5):
    sigmoid = torch.nn.Sigmoid()
    
    probs = sigmoid(torch.Tensor(predictions))
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    
    y_true = labels
    
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

### Load dataset

Note that we are using `ModApte` split in this case.

In [4]:
# Load Dataset
dataset = load_dataset("reuters21578", "ModApte")

### Preprocess data

- Find out single appearance labels and remove them from train and test split
- Combine title and text together as `text` column
- Transform topics into multihot encoding as `labels` column
- Tokenize dataset

In [5]:
# Find and Remove Single Appearance Labels
print("Finding single appearance labels...")
y_train = [item['topics'] for item in dataset['train']]
single_appearance_labels = find_single_appearance_labels(y_train)
print(f"Single appearance labels: {single_appearance_labels}")

print("Removing samples with single-appearance labels...")
dataset = remove_single_appearance_labels(dataset, single_appearance_labels)

Finding single appearance labels...
Single appearance labels: ['lin-oil', 'rye', 'red-bean', 'groundnut-oil', 'citruspulp', 'rape-meal', 'corn-oil', 'peseta', 'cotton-oil', 'ringgit', 'castorseed', 'castor-oil', 'lit', 'rupiah', 'skr', 'nkr', 'dkr', 'sun-meal', 'lin-meal', 'cruzado']
Removing samples with single-appearance labels...


In [6]:
print("Combine title and text together")
dataset = dataset.map(
    lambda x: {"text": x["title"] + " " + x["text"]}
)

Combine title and text together


In [7]:
# Check number of unique labels 
unique_labels = set(chain.from_iterable(dataset['train']["topics"]))
print(f"We have {len(unique_labels)} unique labels:\n{unique_labels}")

# Transform topics into multi-hot encoding format
mlb = MultiLabelBinarizer()
mlb.fit(dataset['train']['topics'])
dataset = dataset.map(
    lambda x: {"labels": torch.from_numpy(mlb.transform(x["topics"])).float()}, batched=True)

labels = mlb.classes_
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
num_labels = len(id2label)

assert num_labels == len(unique_labels) 

We have 95 unique labels:
{'veg-oil', 'gold', 'platinum', 'ipi', 'acq', 'carcass', 'wool', 'coconut-oil', 'linseed', 'copper', 'soy-meal', 'jet', 'dlr', 'copra-cake', 'hog', 'rand', 'strategic-metal', 'can', 'tea', 'sorghum', 'livestock', 'barley', 'lumber', 'earn', 'wheat', 'trade', 'soy-oil', 'cocoa', 'inventories', 'income', 'rubber', 'tin', 'iron-steel', 'ship', 'rapeseed', 'wpi', 'sun-oil', 'pet-chem', 'palmkernel', 'nat-gas', 'gnp', 'l-cattle', 'propane', 'rice', 'lead', 'alum', 'instal-debt', 'saudriyal', 'cpu', 'jobs', 'meal-feed', 'oilseed', 'dmk', 'plywood', 'zinc', 'retail', 'dfl', 'cpi', 'crude', 'pork-belly', 'gas', 'money-fx', 'corn', 'tapioca', 'palladium', 'lei', 'cornglutenfeed', 'sunseed', 'potato', 'silver', 'sugar', 'grain', 'groundnut', 'naphtha', 'orange', 'soybean', 'coconut', 'stg', 'cotton', 'yen', 'rape-oil', 'palm-oil', 'oat', 'reserves', 'housing', 'interest', 'coffee', 'fuel', 'austdlr', 'money-supply', 'heat', 'fishmeal', 'bop', 'nickel', 'nzdlr'}


Map: 100%|██████████| 9588/9588 [00:00<00:00, 60341.82 examples/s]


In [8]:
# sanity check:
for idx, label in id2label.items():
    if idx>=10:
        break
    
    print(f"{idx}: {label}")

0: acq
1: alum
2: austdlr
3: barley
4: bop
5: can
6: carcass
7: cocoa
8: coconut
9: coconut-oil


In [9]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

# Tokenize and remove unwanted columns
def tokenize_function(example):
    return tokenizer(example["text"], padding="max_length", truncation=True, max_length=512)

columns = dataset["train"].column_names
columns.remove("text")
columns.remove("labels")
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=columns)

Map: 100%|██████████| 9588/9588 [00:02<00:00, 3850.19 examples/s]


In [10]:
tokenized_dataset 

DatasetDict({
    test: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask'],
        num_rows: 3292
    })
    train: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask'],
        num_rows: 9588
    })
    unused: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask'],
        num_rows: 722
    })
})

In [11]:
example = tokenized_dataset['train'][0]
print(example.keys())

dict_keys(['text', 'labels', 'input_ids', 'attention_mask'])


In [12]:
tokenizer.decode(example['input_ids'])

'[CLS] BAHIA COCOA REVIEW Showers continued throughout the week in the Bahia cocoa zone, alleviating the drought since early January and improving prospects for the coming temporao, although normal humidity levels have not been restored, Comissaria Smith said in its weekly review. The dry period means the temporao will be late this year. Arrivals for the week ended February 22 were 155, 221 bags of 60 kilos making a cumulative total for the season of 5. 93 mln against 5. 81 at the same stage last year. Again it seems that cocoa delivered earlier on consignment was included in the arrivals figures. Comissaria Smith said there is still some doubt as to how much old crop cocoa is still available as harvesting has practically come to an end. With total Bahia crop estimates around 6. 4 mln bags and sales standing at almost 6. 2 mln there are a few hundred thousand bags still in the hands of farmers, middlemen, exporters and processors. There are doubts as to how much of this cocoa would be 

In [13]:
print(example['labels'])

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]


In [14]:
[id2label[idx] for idx, label in enumerate(example['labels']) if label == 1.0]

['cocoa']

In [15]:
tokenized_dataset.set_format("torch")

In [16]:
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-cased", 
    num_labels=num_labels, 
    problem_type="multi_label_classification",
    id2label=id2label,
    label2id=label2id
    )

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
# test forward pass
tokenized_dataset['train'][0]['labels'].type()

'torch.FloatTensor'

In [19]:
tokenized_dataset['train']['input_ids'][0]

tensor([  101, 12465,  3048,  9984, 18732, 15678,  1592,   155,  2036, 23314,
         2036,  2924,  3237,  1468,  1598,  2032,  1103,  1989,  1107,  1103,
        18757, 10652,  1884, 20535,  4834,   117,  1155,  6348, 25148,  1103,
        16076,  1290,  1346,  1356,  1105,  9248, 19743,  1111,  1103,  1909,
        16655,  1611,  1186,   117,  1780,  2999, 20641,  3001,  1138,  1136,
         1151,  5219,   117,  3291, 15394,  9724,  1465,  2159,  1163,  1107,
         1157,  5392,  3189,   119,  1109,  3712,  1669,  2086,  1103, 16655,
         1611,  1186,  1209,  1129,  1523,  1142,  1214,   119,   138, 14791,
         7501,  1116,  1111,  1103,  1989,  2207,  1428,  1659,  1127, 14691,
          117, 21319,  8483,  1104,  2539,   180, 24755,  1116,  1543,   170,
        27574,  1703,  1111,  1103,  1265,  1104,   126,   119,  5429,   182,
        21615,  1222,   126,   119,  5615,  1120,  1103,  1269,  2016,  1314,
         1214,   119,  5630,  1122,  3093,  1115,  1884, 20535, 

In [20]:
outputs = model(input_ids=tokenized_dataset['train']['input_ids'][0].unsqueeze(0), labels=tokenized_dataset['train'][0]['labels'].unsqueeze(0))
outputs

SequenceClassifierOutput(loss=tensor(0.6990, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), logits=tensor([[-0.0081, -0.0012, -0.1407, -0.0754,  0.0880, -0.0083, -0.1369,  0.0325,
          0.1409, -0.0466, -0.1064, -0.0094,  0.0234,  0.0217, -0.2389,  0.0915,
          0.0148,  0.0769, -0.0678, -0.0427, -0.0168,  0.0680,  0.0972,  0.0718,
          0.1725,  0.0523,  0.1037, -0.0177,  0.0143,  0.0510, -0.0428, -0.1005,
         -0.1053, -0.0996, -0.0407,  0.0608, -0.1763,  0.0490,  0.1373,  0.1419,
          0.0867, -0.0576, -0.1596,  0.0084, -0.0396, -0.1620,  0.0884,  0.0631,
         -0.0276,  0.0112, -0.1778, -0.0582, -0.0645, -0.0952, -0.0962, -0.2435,
          0.1074,  0.0811, -0.1186,  0.0730,  0.1431,  0.1783, -0.0022, -0.0878,
          0.0458, -0.1022,  0.0841,  0.0173,  0.1825,  0.1341, -0.1104,  0.0081,
         -0.0668,  0.0023,  0.3571,  0.0953, -0.0483,  0.1300,  0.0370,  0.0782,
         -0.0569,  0.1126, -0.1065,  0.2615,  0.0463,  0.0785, -0.0053,  0.0632,
       

### Model Training

In [23]:
args = TrainingArguments(
    f"distilbert-finetuned-reuters21578-multilabel",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    logging_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=20,
    save_total_limit=2,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="eval_f1",
    push_to_hub=True,
    greater_is_better=True,
)

In [24]:
from transformers import EarlyStoppingCallback

early_stopping = EarlyStoppingCallback(early_stopping_patience=3)

trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping],
)

In [25]:
trainer.train()

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,0.1801,0.043942,0.389627,0.621005,0.356622
2,0.0345,0.028706,0.628855,0.731796,0.595383
3,0.0243,0.021945,0.672073,0.757869,0.608445
4,0.0178,0.017743,0.750487,0.812821,0.690765
5,0.014,0.015141,0.790499,0.837643,0.727825
6,0.0115,0.01354,0.813237,0.858942,0.755468
7,0.0096,0.012398,0.829122,0.872717,0.772479
8,0.0082,0.012401,0.833483,0.875701,0.782199
9,0.0071,0.011853,0.839233,0.884694,0.788275
10,0.0064,0.012274,0.833851,0.88103,0.782807


TrainOutput(global_step=6000, training_loss=0.017809471408526102, metrics={'train_runtime': 3867.6489, 'train_samples_per_second': 49.581, 'train_steps_per_second': 1.551, 'total_flos': 2.54440780812288e+16, 'train_loss': 0.017809471408526102, 'epoch': 20.0})

In [26]:
trainer.evaluate()

{'eval_loss': 0.010994679294526577,
 'eval_f1': 0.8628858578607322,
 'eval_roc_auc': 0.906310303884654,
 'eval_accuracy': 0.8195625759416768,
 'eval_runtime': 26.409,
 'eval_samples_per_second': 124.654,
 'eval_steps_per_second': 3.9,
 'epoch': 20.0}

### Push Model to HF hub

In [27]:
trainer.push_to_hub()

'https://huggingface.co/lxyuan/distilbert-finetuned-reuters21578-multilabel/tree/main/'

In [47]:
tokenizer.push_to_hub("distilbert-finetuned-reuters21578-multilabel")

CommitInfo(commit_url='https://huggingface.co/lxyuan/distilbert-finetuned-reuters21578-multilabel/commit/d1f274a43ed66a57b8317f8b785b064425b2414b', commit_message='Upload tokenizer', commit_description='', oid='d1f274a43ed66a57b8317f8b785b064425b2414b', pr_url=None, pr_revision=None, pr_num=None)

### Load pushed model from HF hub for inferencing

In [28]:
from transformers import pipeline

pipe = pipeline("text-classification", model="lxyuan/distilbert-finetuned-reuters21578-multilabel", return_all_scores=True)

Downloading (…)lve/main/config.json: 100%|██████████| 4.30k/4.30k [00:00<00:00, 1.85MB/s]
Downloading pytorch_model.bin: 100%|██████████| 263M/263M [00:16<00:00, 16.3MB/s] 
Downloading (…)okenizer_config.json: 100%|██████████| 321/321 [00:00<00:00, 152kB/s]
Downloading (…)solve/main/vocab.txt: 100%|██████████| 213k/213k [00:00<00:00, 1.99MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 669k/669k [00:00<00:00, 8.61MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 125/125 [00:00<00:00, 55.9kB/s]


In [40]:
example = dataset["test"]["text"][2]
target_topics = dataset["test"]["topics"][2]

example, target_topics

("JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWARDS The Ministry of International Trade and\nIndustry (MITI) will revise its long-term energy supply/demand\noutlook by August to meet a forecast downtrend in Japanese\nenergy demand, ministry officials said.\n    MITI is expected to lower the projection for primary energy\nsupplies in the year 2000 to 550 mln kilolitres (kl) from 600\nmln, they said.\n    The decision follows the emergence of structural changes in\nJapanese industry following the rise in the value of the yen\nand a decline in domestic electric power demand.\n    MITI is planning to work out a revised energy supply/demand\noutlook through deliberations of committee meetings of the\nAgency of Natural Resources and Energy, the officials said.\n    They said MITI will also review the breakdown of energy\nsupply sources, including oil, nuclear, coal and natural gas.\n    Nuclear energy provided the bulk of Japan's electric power\nin the fiscal year ended March 31, supplying a

In [58]:
fn_kwargs={"padding": "max_length", "truncation": True, "max_length": 512}

output = pipe(example, function_to_apply="sigmoid", **fn_kwargs)

In [59]:
for item in output[0]:
    if item["score"]>=0.5:
        print(item["label"], item["score"])

crude 0.7355073690414429
nat-gas 0.8600426316261292


In [60]:
output

[[{'label': 'acq', 'score': 0.0020191946532577276},
  {'label': 'alum', 'score': 0.006228976417332888},
  {'label': 'austdlr', 'score': 0.0015945922350510955},
  {'label': 'barley', 'score': 0.0026575943920761347},
  {'label': 'bop', 'score': 0.04313690587878227},
  {'label': 'can', 'score': 0.0008254407439380884},
  {'label': 'carcass', 'score': 0.0011239182204008102},
  {'label': 'cocoa', 'score': 0.001983838388696313},
  {'label': 'coconut', 'score': 0.0005582965677604079},
  {'label': 'coconut-oil', 'score': 0.0015241571236401796},
  {'label': 'coffee', 'score': 0.0027940908912569284},
  {'label': 'copper', 'score': 0.015326935797929764},
  {'label': 'copra-cake', 'score': 0.0008283186471089721},
  {'label': 'corn', 'score': 0.0037029897794127464},
  {'label': 'cornglutenfeed', 'score': 0.0005428714212030172},
  {'label': 'cotton', 'score': 0.002034077188000083},
  {'label': 'cpi', 'score': 0.06535973399877548},
  {'label': 'cpu', 'score': 0.0021881815046072006},
  {'label': 'crude

### Generate classification report using pipeline class

In [83]:
X_test = dataset["test"]["text"]
y_test = tokenized_dataset["test"]["labels"]

In [84]:
pipe = pipeline("text-classification", model="lxyuan/distilbert-finetuned-reuters21578-multilabel", return_all_scores=True, device=0)

In [85]:
y_pred = pipe(X_test, function_to_apply="sigmoid", **fn_kwargs)

In [88]:
# Extract scores using list comprehension
scores = [[prediction["score"] for prediction in sample] for sample in y_pred]

# Convert list of scores to a tensor
y_pred_float = torch.tensor(scores, dtype=torch.float32)

In [89]:
print(y_pred_float.shape)

torch.Size([3292, 95])


In [104]:
from sklearn.metrics import classification_report

threshold=0.5
report = classification_report(y_test, torch.ge(y_pred_float, threshold), target_names=labels)
print("Classification Report:\n", report)

Classification Report:
                  precision    recall  f1-score   support

            acq       0.97      0.93      0.95       719
           alum       1.00      0.70      0.82        23
        austdlr       0.00      0.00      0.00         0
         barley       1.00      0.50      0.67        12
            bop       0.79      0.50      0.61        30
            can       0.00      0.00      0.00         0
        carcass       0.67      0.67      0.67        18
          cocoa       1.00      1.00      1.00        18
        coconut       0.00      0.00      0.00         2
    coconut-oil       0.00      0.00      0.00         2
         coffee       0.86      0.89      0.87        27
         copper       1.00      0.78      0.88        18
     copra-cake       0.00      0.00      0.00         1
           corn       0.84      0.87      0.86        55
 cornglutenfeed       0.00      0.00      0.00         0
         cotton       0.92      0.67      0.77        18
      

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


---
Insight:


**Precision vs Recall Trade-off**

The model exhibits a high level of precision for many topics, such as acq (0.97), alum (1.00), barley (1.00), and cocoa (1.00). However, recall for some of these same categories tends to be lower. For example, alum has a recall of 0.70, and barley a recall of 0.50. This can be likened to the previous baseline model, where high precision was a design objective to avoid false positives, which are considered more problematic than false negatives in client-facing applications. While the high precision scores ensure fewer false positives, the cost is lower recall, meaning the model may miss some articles that should have been classified under certain labels.

**Discrepancy in Micro and Macro Averages**

The model shows a micro-averaged F1-score of 0.86 and a macro-averaged F1-score of 0.33. The large discrepancy indicates similar behavior to the baseline model: the model performs exceptionally well on common labels but struggles with minority classes. The gap is even more pronounced in this model compared to the baseline, which could be a significant issue if the application requires comprehensive coverage across a variety of topics.

**Labels with Zero Support**

In this classification report, there are several labels like austdlr, can, cornglutenfeed, and others with zero support. These are classes that didn't appear in the test set. This issue also occurred in the baseline model. The zero-support issue indicates a lack of diversity in the test data for these labels and could mean the model is untested for these minority classes, a point of concern for real-world deployment where these classes may appear.

**Topic-Specific Observations**

For some labels, there's a good balance between precision and recall (acq, coffee, corn). However, for others, the model seems to struggle. For example, gnp and interest have decent precision (0.79 and 0.89, respectively) but lower recall (0.66 and 0.67, respectively). This suggests that while the model is confident in its predictions for these labels, it is missing out on several true positives, a behavior similar to the baseline model.
Conclusion

In summary, while this model seems to have higher precision values than the baseline for several labels, it faces similar challenges: low recall, poor performance on minority classes, and untested categories due to zero support. These are critical points to consider for future model iterations, especially if comprehensive and balanced performance across all topics is a key objective.