# BERT Classifier as Feature Extractor

In this feature-based approach, we are using the embeddings from a pretrained transormer to train logistic regression model in scikit-learn

## Imports

In [93]:
import os
import numpy as np
from functools import partial

In [94]:
import torch
from datasets import load_dataset, Features, Value, ClassLabel
from transformers import AutoTokenizer, AutoModel
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score

## Parameters

In [95]:
SEED = 42

In [96]:
F1_AVERAGING = "macro"

In [97]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda:0


In [98]:
MODEL_NAME = "distilbert-base-multilingual-cased"

## Paths

In [99]:
relative_path = os.path.join("../../../", "data")

In [100]:
sentiment_analysis_data_path = os.path.join(relative_path, "3_sentiment_analysis")

## Functions

In [101]:
def get_output_embeddings(model, batch):
    inputs = {key: tensor.to(DEVICE) for key, tensor in batch.items() if key != "label"}
    with torch.inference_mode():
        output = model(**inputs).last_hidden_state[:, 0]
    return {"features": output.cpu().numpy()}

In [102]:
def tokenize_text(tokenizer, batch):
    return tokenizer(batch["text"], truncation=True, padding=True)

## Loading data

In [104]:
features = Features(
    {
        "label": ClassLabel(
            num_classes=3,
            names=["negative", "neutral", "positive"],
        ),
        "text": Value(dtype="string"),
    }
)

review_dataset = load_dataset(
    "parquet",
    data_files={
        "train": os.path.join(sentiment_analysis_data_path, "train.parquet"),
        "validation": os.path.join(sentiment_analysis_data_path, "validation.parquet"),
        "test": os.path.join(sentiment_analysis_data_path, "test.parquet"),
    },
    features=features,
)

print(review_dataset)

Using custom data configuration default-ad59138ff5c15790


Downloading and preparing dataset parquet/default to /home/extremesarova/.cache/huggingface/datasets/parquet/default-ad59138ff5c15790/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

Dataset parquet downloaded and prepared to /home/extremesarova/.cache/huggingface/datasets/parquet/default-ad59138ff5c15790/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 175556
    })
    validation: Dataset({
        features: ['label', 'text'],
        num_rows: 15491
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 15490
    })
})


## Tokenization

In [110]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print("Tokenizer input max length:", tokenizer.model_max_length)
print("Tokenizer vocabulary size:", tokenizer.vocab_size)

Tokenizer input max length: 512
Tokenizer vocabulary size: 119547


In [111]:
review_dataset_tokenized = review_dataset.map(
    partial(tokenize_text, tokenizer), batched=True, batch_size=None
)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [112]:
del review_dataset

## Using DistilBERT as a Feature Extractor

In [53]:
model = AutoModel.from_pretrained(MODEL_NAME)
model.to(DEVICE)

Downloading:   0%|          | 0.00/517M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-multilingual-cased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [113]:
review_dataset_tokenized.set_format(
    "torch", columns=["input_ids", "attention_mask", "label"]
)

In [55]:
test_batch = {
    "attention_mask": review_dataset_tokenized["train"][:3]["attention_mask"].to(
        DEVICE
    ),
    "input_ids": review_dataset_tokenized["train"][:3]["input_ids"].to(DEVICE),
}

with torch.inference_mode():
    test_output = model(**test_batch)

test_output.last_hidden_state.shape

torch.Size([3, 512, 768])

In [56]:
cls_token_output = test_output.last_hidden_state[:, 0]
cls_token_output.shape

torch.Size([3, 768])

In [69]:
review_dataset_features = review_dataset_tokenized.map(
    partial(get_output_embeddings, model), batched=True, batch_size=128
)

  0%|          | 0/1372 [00:00<?, ?ba/s]

  0%|          | 0/122 [00:00<?, ?ba/s]

  0%|          | 0/122 [00:00<?, ?ba/s]

In [106]:
X_train = np.array(review_dataset_features["train"]["features"])
y_train = np.array(review_dataset_features["train"]["label"])

X_val = np.array(review_dataset_features["validation"]["features"])
y_val = np.array(review_dataset_features["validation"]["label"])

X_test = np.array(review_dataset_features["test"]["features"])
y_test = np.array(review_dataset_features["test"]["label"])

## Train Model on Embeddings (Extracted Features)

In [72]:
log_reg = LogisticRegression(
    random_state=SEED, class_weight="balanced", solver="saga", max_iter=1_000
)
log_reg.fit(X_train, y_train)



LogisticRegression(class_weight='balanced', max_iter=1000, random_state=42,
                   solver='saga')

In [117]:
pred_labels = log_reg.predict(X_val)

f1_macro = f1_score(y_val, pred_labels, average=F1_AVERAGING)
print(f"F1 score with macro-averaging is {f1_macro.round(5)}")

F1 score with macro-averaging is 0.47945
