# [Sentence-BERT](https://arxiv.org/pdf/1908.10084.pdf)

[Reference Code](https://www.pinecone.io/learn/series/nlp/train-sentence-transformers-softmax/)

In [1]:
import os
import math
import re
from   random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Set GPU device
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"

#os.environ['http_proxy']  = 'http://192.41.170.23:3128'
#os.environ['https_proxy'] = 'http://192.41.170.23:3128'
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'
device

'cpu'

## 1. Data

### Train, Test, Validation 

In [2]:
import datasets
snli = datasets.load_dataset('snli')
mnli = datasets.load_dataset('glue', 'mnli')
mnli['train'].features, snli['train'].features

({'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None),
  'idx': Value(dtype='int32', id=None)},
 {'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)})

In [3]:
# List of datasets to remove 'idx' column from
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [4]:
# Remove 'idx' column from each dataset
for column_names in mnli.column_names.keys():
    mnli[column_names] = mnli[column_names].remove_columns('idx')

In [5]:
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [6]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([-1,  0,  1,  2]))

In [7]:
# there are -1 values in the label feature, these are where no class could be decided so we remove
snli = snli.filter(
    lambda x: 0 if x['label'] == -1 else 1
)

mnli = mnli.filter(
    lambda x: 0 if x['label'] == -1 else 1
)

In [8]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([0, 1, 2]))

In [9]:
# Assuming you have your two DatasetDict objects named snli and mnli
from datasets import DatasetDict
# Merge the two DatasetDict objects
raw_dataset = DatasetDict({
    'train': datasets.concatenate_datasets([snli['train'], mnli['train']]).shuffle(seed=55).select(list(range(1000))),
    'test': datasets.concatenate_datasets([snli['test'], mnli['test_mismatched']]).shuffle(seed=55).select(list(range(100))),
    'validation': datasets.concatenate_datasets([snli['validation'], mnli['validation_mismatched']]).shuffle(seed=55).select(list(range(1000)))
})
#remove .select(list(range(1000))) in order to use full dataset
# Now, merged_dataset_dict contains the combined datasets from snli and mnli
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
})

## 2. Preprocessing

In [10]:
import torchtext

tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
vocab = torch.load('./model/vocab')

In [11]:
# max_seq_length = 1024
# def preprocess_function(examples):
#     padding = 'max_length'
#     # Tokenize the premise
#     premise_result = tokenizer(
#         examples['premise'], padding=padding, max_length=max_seq_length, truncation=True)
#     #num_rows, max_seq_length
#     # Tokenize the hypothesis
#     hypothesis_result = tokenizer(
#         examples['hypothesis'], padding=padding, max_length=max_seq_length, truncation=True)
#     #num_rows, max_seq_length
#     # Extract labels
#     labels = examples["label"]
#     #num_rows
#     return {
#         "premise_input_ids": premise_result["input_ids"],
#         "premise_attention_mask": premise_result["attention_mask"],
#         "hypothesis_input_ids": hypothesis_result["input_ids"],
#         "hypothesis_attention_mask": hypothesis_result["attention_mask"],
#         "labels" : labels
#     }

# tokenized_datasets = raw_dataset.map(
#     preprocess_function,
#     batched=True,
# )

# tokenized_datasets = tokenized_datasets.remove_columns(['premise','hypothesis','label'])
# tokenized_datasets.set_format("torch")

In [12]:
max_seq_length = 256

def preprocess_function(examples):
    # Tokenize the premise
    tokenized_premise = [tokenizer(re.sub("[.,!?\\-]", '', sent.lower())) for sent in examples['premise']]
    premise_input_ids = [[vocab['[CLS]']] + [vocab[token] for token in tokens] + [vocab['[SEP]']] for tokens in tokenized_premise]
    premise_n_pad = [max_seq_length - len(tokens) for tokens in premise_input_ids]
    premise_attn_mask = [([1] * len(tokens)) + ([0] * n_pad) for tokens, n_pad in zip(premise_input_ids, premise_n_pad)]
    premise_input_ids = [tokens + ([0] * n_pad) for tokens, n_pad in zip(premise_input_ids, premise_n_pad)]
    #num_rows, max_seq_length

    # Tokenize the hypothesis
    tokenized_hypothesis = [tokenizer(re.sub("[.,!?\\-]", '', sent.lower())) for sent in examples['hypothesis']]
    hypothesis_input_ids = [[vocab['[CLS]']] + [vocab[token] for token in tokens] + [vocab['[SEP]']] for tokens in tokenized_hypothesis]
    hypothesis_n_pad = [max_seq_length - len(tokens) for tokens in hypothesis_input_ids]
    hypothesis_attn_mask = [([1] * len(tokens)) + ([0] * n_pad) for tokens, n_pad in zip(hypothesis_input_ids, hypothesis_n_pad)]
    hypothesis_input_ids = [tokens + ([0] * n_pad) for tokens, n_pad in zip(hypothesis_input_ids, hypothesis_n_pad)]
    #num_rows, max_seq_length

    # Extract labels
    labels = examples["label"]
    #num_rows
    return {
        "premise_input_ids": premise_input_ids,
        "premise_attention_mask": premise_attn_mask,
        "hypothesis_input_ids": hypothesis_input_ids,
        "hypothesis_attention_mask": hypothesis_attn_mask,
        "labels" : labels
    }

tokenized_datasets = raw_dataset.map(
    preprocess_function,
    batched=True,
)

tokenized_datasets = tokenized_datasets.remove_columns(['premise','hypothesis','label'])
tokenized_datasets.set_format("torch")

In [13]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 1000
    })
})

## 3. Data loader

In [14]:
from torch.utils.data import DataLoader

# initialize the dataloader
batch_size = 4
train_dataloader = DataLoader(
    tokenized_datasets['train'], 
    batch_size=batch_size, 
    shuffle=True
)
eval_dataloader = DataLoader(
    tokenized_datasets['validation'], 
    batch_size=batch_size
)
test_dataloader = DataLoader(
    tokenized_datasets['test'], 
    batch_size=batch_size
)

In [15]:
for batch in train_dataloader:
    print(batch['premise_input_ids'].shape)
    print(batch['premise_attention_mask'].shape)
    print(batch['hypothesis_input_ids'].shape)
    print(batch['hypothesis_attention_mask'].shape)
    print(batch['labels'].shape)
    break

torch.Size([4, 256])
torch.Size([4, 256])
torch.Size([4, 256])
torch.Size([4, 256])
torch.Size([4])


## 4. Model

In [16]:
from bert import *

# load the model and all its hyperparameters
save_path = './model/bert.pt'
params, state = torch.load(save_path)
model = BERT(**params, device=device).to(device)
model.load_state_dict(state)


<All keys matched successfully>

### Pooling
SBERT adds a pooling operation to the output of BERT / RoBERTa to derive a fixed sized sentence embedding

In [17]:
# define mean pooling function
def mean_pool(token_embeds, attention_mask):
    # reshape attention_mask to cover 768-dimension embeddings
    in_mask = attention_mask.unsqueeze(-1).expand(
        token_embeds.size()
    ).float()
    # perform mean-pooling but exclude padding tokens (specified by in_mask)
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(
        in_mask.sum(1), min=1e-9
    )
    return pool

## 5. Loss Function

## Classification Objective Function 
We concatenate the sentence embeddings $u$ and $v$ with the element-wise difference  $\lvert u - v \rvert $ and multiply the result with the trainable weight  $ W_t ∈  \mathbb{R}^{3n \times k}  $:

$ o = \text{softmax}\left(W^T \cdot \left(u, v, \lvert u - v \rvert\right)\right) $

where $n$ is the dimension of the sentence embeddings and k the number of labels. We optimize cross-entropy loss. This structure is depicted in Figure 1.

## Regression Objective Function. 
The cosine similarity between the two sentence embeddings $u$ and $v$ is computed (Figure 2). We use means quared-error loss as the objective function.

(Manhatten / Euclidean distance, semantically  similar sentences can be found.)

<img src="./figures/sbert-architecture.png" >

In [18]:
def configurations(u,v):
    # build the |u-v| tensor
    uv = torch.sub(u, v)   # batch_size,hidden_dim
    uv_abs = torch.abs(uv) # batch_size,hidden_dim
    
    # concatenate u, v, |u-v|
    x = torch.cat([u, v, uv_abs], dim=-1) # batch_size, 3*hidden_dim
    return x

def cosine_similarity(u, v):
    dot_product = np.dot(u, v)
    norm_u = np.linalg.norm(u)
    norm_v = np.linalg.norm(v)
    similarity = dot_product / (norm_u * norm_v)
    return similarity

<img src="./figures/sbert-ablation.png" width="350" height="300">

In [19]:
classifier_head = torch.nn.Linear(768*3, 3).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
optimizer_classifier = torch.optim.Adam(classifier_head.parameters(), lr=2e-5)

criterion = nn.CrossEntropyLoss()

In [20]:
from transformers import get_linear_schedule_with_warmup

# and setup a warmup for the first ~10% steps
total_steps = int(len(raw_dataset) / batch_size)
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
		optimizer, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler.step()

scheduler_classifier = get_linear_schedule_with_warmup(
		optimizer_classifier, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler_classifier.step()



In [21]:
model

BERT(
  (embedding): Embedding(
    (tok_embed): Embedding(6944, 768)
    (pos_embed): Embedding(256, 768)
    (seg_embed): Embedding(2, 768)
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (layers): ModuleList(
    (0-5): 6 x EncoderLayer(
      (enc_self_attn): MultiHeadAttention(
        (W_Q): Linear(in_features=768, out_features=512, bias=True)
        (W_K): Linear(in_features=768, out_features=512, bias=True)
        (W_V): Linear(in_features=768, out_features=512, bias=True)
      )
      (pos_ffn): PoswiseFeedForwardNet(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
  )
  (fc): Linear(in_features=768, out_features=768, bias=True)
  (activ): Tanh()
  (linear): Linear(in_features=768, out_features=768, bias=True)
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
  (decode

## 6. Training

In [None]:
from tqdm.auto import tqdm

num_epoch = 2
# 1 epoch should be enough, increase if wanted
for epoch in range(num_epoch):
    model.train()  
    classifier_head.train()
    # initialize the dataloader loop with tqdm (tqdm == progress bar)
    for step, batch in enumerate(tqdm(train_dataloader, leave=True)):
        # zero all gradients on each new step
        optimizer.zero_grad()
        optimizer_classifier.zero_grad()
        
        # prepare batches and more all to the active device
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)
        segment_ids = torch.zeros(batch_size, max_seq_length, dtype=torch.int32).to(device)  # each input contains only one sentence hence we define them all as sentence '0'
        label = batch['labels'].to(device)
        
        # extract token embeddings from BERT at last_hidden_state
        u_last_hidden_state = model.get_last_hidden_state(inputs_ids_a, segment_ids)  
        v_last_hidden_state = model.get_last_hidden_state(inputs_ids_b, segment_ids)  

        # u_last_hidden_state = u.last_hidden_state # all token embeddings A = batch_size, seq_len, hidden_dim
        # v_last_hidden_state = v.last_hidden_state # all token embeddings B = batch_size, seq_len, hidden_dim

         # get the mean pooled vectors
        u_mean_pool = mean_pool(u_last_hidden_state, attention_a) # batch_size, hidden_dim
        v_mean_pool = mean_pool(v_last_hidden_state, attention_b) # batch_size, hidden_dim
        
        # build the |u-v| tensor
        uv = torch.sub(u_mean_pool, v_mean_pool)   # batch_size,hidden_dim
        uv_abs = torch.abs(uv) # batch_size,hidden_dim
        
        # concatenate u, v, |u-v|
        x = torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1) # batch_size, 3*hidden_dim
        
        # process concatenated tensor through classifier_head
        x = classifier_head(x) #batch_size, classifer
        
        # calculate the 'softmax-loss' between predicted and true label
        loss = criterion(x, label)
        
        # using loss, calculate gradients and then optimizerize
        loss.backward()
        optimizer.step()
        optimizer_classifier.step()

        scheduler.step() # update learning rate scheduler
        scheduler_classifier.step()
        
    print(f'Epoch: {epoch + 1} | loss = {loss.item():.6f}')

  0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 1 | loss = 1.763875


  0%|          | 0/250 [00:00<?, ?it/s]

Epoch: 2 | loss = 3.219887


In [None]:
# save the classifier head
head_path = './model/classifier-head-custom-bert.pt'
torch.save(classifier_head, head_path)

# save model
model_path = './model/s-bert.pt'
torch.save([model.params, model.state_dict()], model_path)


In [None]:
model.eval()
classifier_head.eval()
total_similarity = 0
with torch.no_grad():
    for step, batch in enumerate(eval_dataloader):
        # prepare batches and more all to the active device
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)
        segment_ids = torch.zeros(batch_size, max_seq_length, dtype=torch.int32).to(device)
        label = batch['labels'].to(device)
        
        # extract token embeddings from BERT at last_hidden_state
        u = model.get_last_hidden_state(inputs_ids_a, segment_ids)  # all token embeddings A = batch_size, seq_len, hidden_dim
        v = model.get_last_hidden_state(inputs_ids_b, segment_ids)  # all token embeddings B = batch_size, seq_len, hidden_dim

        # get the mean pooled vectors
        u_mean_pool = mean_pool(u, attention_a).detach().cpu().numpy().reshape(-1) # batch_size, hidden_dim
        v_mean_pool = mean_pool(v, attention_b).detach().cpu().numpy().reshape(-1) # batch_size, hidden_dim

        similarity_score = cosine_similarity(u_mean_pool, v_mean_pool)
        total_similarity += similarity_score
    
average_similarity = total_similarity / len(eval_dataloader)
print(f"Average Cosine Similarity: {average_similarity:.4f}")

Average Cosine Similarity: 0.9912


## 7. Inference

In [22]:
model_path = './model/s-bert.pt'
params, state = torch.load(model_path)
model = BERT(**params, device=device).to(device)
model.load_state_dict(state)

<All keys matched successfully>

In [23]:
def get_inputs(sentence, tokenizer, vocab, max_seq_length):
    tokens = tokenizer(re.sub("[.,!?\\-]", '', sentence.lower()))
    input_ids = [vocab['[CLS]']] + [vocab[token] for token in tokens] + [vocab['[SEP]']]
    n_pad = max_seq_length - len(input_ids)
    attention_mask = ([1] * len(input_ids)) + ([0] * n_pad)
    input_ids = input_ids + ([0] * n_pad)

    return {'input_ids': torch.LongTensor(input_ids).reshape(1, -1),
            'attention_mask': torch.LongTensor(attention_mask).reshape(1, -1)}

In [24]:
import torch
from sklearn.metrics.pairwise import cosine_similarity

def calculate_similarity(model, tokenizer, vocab, sentence_a, sentence_b, device):
    # Tokenize and convert sentences to input IDs and attention masks
    inputs_a = get_inputs(sentence_a, tokenizer, vocab, max_seq_length)
    inputs_b = get_inputs(sentence_b, tokenizer, vocab, max_seq_length)
    

    # Move input IDs and attention masks to the active device
    inputs_ids_a = inputs_a['input_ids'].to(device)
    attention_a = inputs_a['attention_mask'].to(device)
    inputs_ids_b = inputs_b['input_ids'].to(device)
    attention_b = inputs_b['attention_mask'].to(device)
    segment_ids = torch.zeros(1, max_seq_length, dtype=torch.int32).to(device)

    # Extract token embeddings from BERT
    u = model.get_last_hidden_state(inputs_ids_a, segment_ids)
    v = model.get_last_hidden_state(inputs_ids_b, segment_ids)

    u = mean_pool(u, attention_a).detach().cpu().numpy().reshape(-1)  # batch_size, hidden_dim
    v = mean_pool(v, attention_b).detach().cpu().numpy().reshape(-1)  # batch_size, hidden_dim

    # Calculate cosine similarity
    similarity_score = cosine_similarity(u.reshape(1, -1), v.reshape(1, -1))[0, 0]

    return similarity_score
# Example usage:
sentence_a = 'Your contribution helped make it possible for us to provide our students with a quality education.'
sentence_b = "Your contributions were of no help with our students' education."
similarity = calculate_similarity(model, tokenizer,vocab, sentence_a, sentence_b, device)
print(f"Cosine Similarity: {similarity:.4f}")

Cosine Similarity: 0.9953


## 8. Evaluation and Analysis

In [25]:
def predict_label(model, inputs_ids_a, inputs_ids_b, attention_a, attention_b, segment_ids):
    # Extract token embeddings from BERT
    u_last_hidden_state = model.get_last_hidden_state(inputs_ids_a, segment_ids)  
    v_last_hidden_state = model.get_last_hidden_state(inputs_ids_b, segment_ids)  

    # Get the mean pooled vectors
    u_mean_pool = mean_pool(u_last_hidden_state, attention_a) # batch_size, hidden_dim
    v_mean_pool = mean_pool(v_last_hidden_state, attention_b) # batch_size, hidden_dim

    # Calculate cosine similarity
    similarity_scores = torch.cosine_similarity(u_mean_pool, v_mean_pool, dim=1)

    predictions = []
    for score in similarity_scores:
        if score >= 0.5:
            predictions.append(0)
        elif score > -0.5: # score in range [-1, 0.499..]
            predictions.append(1)
        else:
            predictions.append(2)
    predicted_classes = torch.tensor(predictions)

    return predicted_classes.view(-1, 1)

### 1. Performance Evaluation

In [26]:
# 1. Performance Evaluation
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report

def evaluate_model(model, dataloader):
    model.eval()
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for batch in dataloader:
            inputs_ids_a = batch['premise_input_ids'].to(device)
            inputs_ids_b = batch['hypothesis_input_ids'].to(device)
            attention_a = batch['premise_attention_mask'].to(device)
            attention_b = batch['hypothesis_attention_mask'].to(device)
            segment_ids = torch.zeros(batch_size, max_seq_length, dtype=torch.int32).to(device)
            labels = batch['labels'].to(device)

            # Calculate similarity scores
            similarity_scores = predict_label(model, inputs_ids_a, inputs_ids_b, attention_a, attention_b, segment_ids)

            # Predict labels based on similarity scores
            predictions = torch.argmax(similarity_scores, dim=1).cpu().numpy()

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions)

    print(classification_report(all_labels, all_predictions))

In [27]:
# Evaluate on the test dataset
test_dataloader = DataLoader(
    tokenized_datasets['test'],
    batch_size=batch_size
)
evaluate_model(model, test_dataloader)

              precision    recall  f1-score   support

           0       0.33      1.00      0.50        33
           1       0.00      0.00      0.00        38
           2       0.00      0.00      0.00        29

    accuracy                           0.33       100
   macro avg       0.11      0.33      0.17       100
weighted avg       0.11      0.33      0.16       100



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [28]:
#!pip install sentence-transformers

### 2. Comparison to Other Pre-trained Models

In [29]:
# 2. Comparison to Other Pre-trained Models
from sentence_transformers import SentenceTransformer
other_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

  return self.fget.__get__(instance, owner)()


In [30]:
def calculate_similarity_other(model, sentence_a, sentence_b):
    embeddings = model.encode([sentence_a, sentence_b])
    return cosine_similarity(embeddings[0].reshape(1, -1), embeddings[1].reshape(1, -1))[0, 0]

In [31]:
# Try the Models on Different Sentences Pairs
sentence_a = 'I love spending time outdoors, exploring nature and hiking.'
sentence_b = 'I prefer staying indoors, reading books and watching movies.'

similarity = calculate_similarity(model, tokenizer, vocab, sentence_a, sentence_b, device)
other_similarity = calculate_similarity_other(other_model, sentence_a, sentence_b)

print(f"Cosine Similarity my model: {similarity:.4f}")
print(f"Cosine Similarity model from hugging face: {other_similarity:.4f}")

Cosine Similarity my model: 0.9787
Cosine Similarity model from hugging face: 0.5919


### 3.Analyze the impact of hyperparameter choices on the model’s performance.

The hyperparameters chosen for the model, such as the learning rate, batch size, and number of epochs, can significantly impact the performance of the model. For example, a higher learning rate might lead to faster convergence but could also cause the model to overshoot the optimal solution. Similarly, a larger batch size might speed up training but could lead to less accurate gradients and poorer generalization. Therefore, it is crucial to carefully tune these hyperparameters to achieve the best performance for the specific task and dataset.

### 4. Discuss any limitations and improvements or modifications.

During the implementation of the model, several limitations and challenges were encountered. One of the main challenges was handling the imbalance in the dataset, especially regarding the label distribution. Additionally, processing large datasets and tuning the hyperparameters effectively required significant computational resources and time. Another limitation was the reliance on pre-trained models, which may not always capture the specific nuances of the dataset or task.

For the improvement, several improvements or modifications could be considered. One approach is to explore different pre-trained models or fine-tuning strategies to improve the model's performance on the specific task. Additionally, data augmentation techniques could be used to address the imbalance in the dataset and improve the model's ability to generalize. Furthermore, optimizing the hyperparameters more effectively, such as through automated hyperparameter tuning techniques, could lead to better performance and faster convergence.