In [None]:
!pip install transformers datasets

In [None]:
import os
import time
import torch
import pandas as pd
import numpy as np
from datetime import datetime
from dataclasses import dataclass
from torch.utils.data import DataLoader
from scipy.special import expit as sigmoid
from sklearn.metrics import classification_report
from sklearn.multiclass import OneVsRestClassifier
from sklearn.ensemble import RandomForestClassifier
from datasets import Dataset, load_dataset, DatasetDict
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, TrainingArguments, Trainer

### Load dataset from huggingface

In [None]:
ds = load_dataset('bhujith10/multi_class_classification_dataset')

### Function to tokenize the dataset

In [None]:
def tokenize_dataset(tokenizer, config, dataset_name='bhujith10/multi_class_classification_dataset'):
  """
  Tokenizes and preprocesses a dataset

  Parameters
  ==========
  tokenizer (Tokenizer): The tokenizer used to preprocess the dataset
  config (object): Configuration object containing settings related to the tokenizer
  dataset_name (str): The name or path of the dataset to load

  Returns
  =======
  Tokenized and preprocessed dataset
  """

  def tokenize(batch):
    """
    Tokenizes a single batch of text data.

    Parameters
    ==========
    batch (dict): A dictionary containing a batch of text data.

    Returns
    =======
    dict: A dictionary with tokenized data.
    """
    return tokenizer(batch['text'], truncation=True, padding=True, max_length=config.max_length, return_tensors='pt')

  ds = load_dataset(dataset_name)

  ds_encoded = ds.map(tokenize, batched=True, batch_size=None)

  for split in ds_encoded:
    ds_encoded[split].set_format('torch')

  # Convert the labels into float datatype
  ds_encoded = ds_encoded.map(lambda x: {"labels_f": x["labels"].to(torch.float)},remove_columns=["labels"])
  ds_encoded = ds_encoded.rename_column("labels_f", "labels")

  return ds_encoded

### Function to load model and tokenizer

In [None]:
def load_model_and_tokenizer(config,
                             add_pad_token=False,
                             quantization=False,
                             peft=False,
                             load_model_for_sequence_classification=False):
  """
  Loads a model and its tokenizer based on the provided configuration.

  Parameters
  ==========
  config (object): Configuration object with model and tokenizer attributes.
  add_pad_token (bool): Whether to add a padding token.
  peft (bool): Whether to apply LORA or not.
  quantization (bool): Whether to apply quantization.
  load_model_for_sequence_classification (bool): Whether to load the model with classification head or not

  Returns
  =======
  Loaded model and tokenizer.
  """
  tokenizer = AutoTokenizer.from_pretrained(config.checkpoint)

  # Llama version 3 models already have a padding token
  # Hence we need not add a padding token
  if add_pad_token:
      if 'Llama' in tokenizer.name_or_path:
          tokenizer.pad_token = '<|finetune_right_pad_id|>'
      else:
          tokenizer.add_special_tokens({"pad_token":"<pad>"})

  # I faced some errors while right padding in Mistral models
  # Hence set padding_side as left for Mistral models alone
  if 'Mistral' in tokenizer.name_or_path:
      tokenizer.padding_side = "left"
  else:
      tokenizer.padding_side = "right"

  if load_model_for_sequence_classification:
      # Load the model with classification head
      # num_labels specifies the number of neurons in the output layer
      model =  AutoModelForSequenceClassification.from_pretrained(
          pretrained_model_name_or_path=config.checkpoint,
          quantization_config=quantization_config(config) if quantization else None,
          torch_dtype=torch.bfloat16 if config.bf16 else torch.float16,
          num_labels=config.num_labels,
          problem_type=config.problem_type
      )

  else:
      model = AutoModel.from_pretrained(checkpoint)

  if add_pad_token:
      model.config.pad_token_id = tokenizer.pad_token_id
      if 'Llama' not in tokenizer.name_or_path:
          model.resize_token_embeddings(len(tokenizer))
  if peft:
      peft_config = LoraConfig(
          task_type=TaskType.SEQ_CLS,
          r=config.lora_rank,
          lora_alpha=config.lora_alpha,
          lora_dropout=config.lora_dropout,
          bias=config.lora_bias,
          #target_modules=["q_proj", "k_proj"]
      )

      model = get_peft_model(model, peft_config)

  return model, tokenizer

### Function to extract hidden states from model

In [None]:
def extract_hidden_states(batch, **kwargs):
    """
    Extracts the hidden states from the model for the given batch of inputs.

    Parameters
    ==========
    batch (dict): A dictionary containing input tensors such as 'input_ids' and 'attention_mask'.
    **kwargs: Additional keyword arguments, including:
        - model : The pre-trained transformer model.
        - tokenizer : The tokenizer used for the model.

    Returns
    =======
    dict: A dictionary containing the hidden states (extracted from the first token (CLS) of the last layer).
    """
    # Extract the model and tokenizer from kwargs
    model = kwargs.get("model")
    tokenizer = kwargs.get("tokenizer")

    model.to("cuda")

    # Prepare input tensors for the model
    inputs = {k: v.to("cuda") for k, v in batch.items() if k in tokenizer.model_input_names}

    with torch.no_grad():
        # Extract the last hidden state from the model output
        last_hidden_state = model(**inputs).last_hidden_state

    # Return the CLS token's hidden state as a NumPy array
    return {"hidden_state": last_hidden_state[:, 0].cpu().numpy()}

### Function to calculate micro and macro F1 scores

In [None]:
def calculate_f1_score(y_true, y_pred):
  """
  Calculates micro and macro F1-scores given the predicted and actual labels

  Parameters
  ==========
  y_true (numpy array): Actual labels
  y_pred (numpy array): Predicted labels

  Returns
  =======
  dict: A dictionary containing micro f1 and macro f1 scores.
  """
  # Generate a classification report to compute detailed metrics
  clf_dict = classification_report(
      y_true,
      y_pred,
      zero_division=0,
      output_dict=True
  )

  return {
      "micro f1": clf_dict["micro avg"]["f1-score"],
      "macro f1": clf_dict["macro avg"]["f1-score"]
  }

### Function to build classifier using hidden states

In [None]:
def build_classifier_using_hidden_states(ds, model, tokenizer):
    """
    Builds and trains a multi-label classifier using hidden states extracted from a transformer model.

    Parameters
    ==========
    ds (DatasetDict): A Hugging Face DatasetDict containing train and test splits.
    model: The pre-trained transformer model used to extract hidden states.
    tokenizer: The tokenizer used for the model.

    Returns
    =======
    dict: F1 score results containing micro f1 and macro f1 scores.
    float: Time taken (in seconds) to predict on the test set.
    """
    # Extract hidden states for the training set
    train_ds = ds['train'].map(
        extract_hidden_states,
        batched=True,
        batch_size=4,
        fn_kwargs={"model": model, "tokenizer": tokenizer}
    )

    # Prepare training data (features and labels)
    x_train = np.array(train_ds["hidden_state"])
    y_train = np.array(train_ds["labels"])

    random_forest_clf = RandomForestClassifier(n_estimators=500)

    # Use One-vs-Rest strategy for multi-label classification
    multi_class_clf = OneVsRestClassifier(random_forest_clf)
    multi_class_clf.fit(x_train, y_train)

    start_time = time.time()

    # Extract hidden states for the test set
    test_ds = ds['test'].map(
        extract_hidden_states,
        batched=True,
        batch_size=4,
        fn_kwargs={"model": model, "tokenizer": tokenizer}
    )

    x_test = np.array(test_ds["hidden_state"])
    y_test = np.array(test_ds["labels"])

    y_pred = multi_class_clf.predict(x_test)

    # Calculate F1 scores for the test set predictions
    f1_score_results = calculate_f1_score(y_test, y_pred)

    end_time = time.time()

    time_for_test_set_prediction = end_time - start_time

    return f1_score_results, time_for_test_set_prediction

### Function to update model related info in config

In [None]:
def update_model_related_settings(checkpoint, config):
  config.checkpoint = checkpoint
  model_name = checkpoint.split('/')[-1]
  currtime = datetime.now().strftime("%Y_%m_%d_%H_%M")
  config.model_name = f"{model_name}_{currtime}"
  config.local_save_path = config.model_name
  return config

## Build classifier that uses hidden states from bert model

In [None]:
@dataclass
class Config:
  checkpoint:str = "microsoft/deberta-v3-base"
  max_length:int = 512
  num_labels:int = 6
  problem_type:str = "multi_label_classification"
  lora_rank:int = 8
  lora_alpha:int = 32
  lora_dropout:float = 0.1
  lora_bias:str = "none"
  device:str = "cuda" if torch.cuda.is_available() else "cpu"
  repo_user_id:str = "bhujith10"
  model_name:str = ""
  local_save_path:str = ""
  bf16:bool = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
  fp16:bool = torch.cuda.is_available() and not torch.cuda.is_bf16_supported()
  load_in_4bit:bool = True,
  bnb_4bit_quant_type:bool = "nf4",
  bnb_4bit_compute_dtype:bool = "float16",
  bnb_4bit_use_double_quant:bool = False,
  num_train_epochs:int = 2,
  batch_size:int = 8,
  gradient_accumulation_steps:int = 2,
  gradient_checkpointing:bool = True

config = Config()

In [None]:
checkpoint = 'google-bert/bert-large-uncased'
config = update_model_related_settings(checkpoint, config)

# Load the model without any classification head
model, tokenizer = load_model_and_tokenizer(config=config,
                                            add_pad_token=False,
                                            peft=False,
                                            load_model_for_sequence_classification=False
                                            )

ds = tokenize_dataset(tokenizer,
                      config)

# Build the classifier using BERT model hidden states
f1_score_results, time_taken = build_classifier_using_hidden_states(ds, model, tokenizer)