## Finetune Flair

In [1]:
import os
from pathlib import Path
from sklearn.model_selection import train_test_split
import pandas as pd

import torch

from flair.data import Corpus, Sentence
from flair.datasets import ClassificationCorpus
from flair.embeddings import (
    WordEmbeddings,
    FlairEmbeddings,
    DocumentRNNEmbeddings,
    TransformerWordEmbeddings,
    StackedEmbeddings,
)
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from flair.training_utils import EvaluationMetric
from flair.visual.training_curves import Plotter

import tqdm

import flair
flair.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {flair.device}")

from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
import torch
print("PyTorch Version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA Version:", torch.version.cuda)

PyTorch Version: 2.5.1+cu124
CUDA Available: True
CUDA Version: 12.4


In [3]:
# Ensure the output directory exists
flair_data_folder = "data/flair_data"
os.makedirs(flair_data_folder, exist_ok=True)

# Define the folder where the data is located
corpus_folder = Path(flair_data_folder)

# Create the corpus
corpus = ClassificationCorpus(
    corpus_folder,
    train_file='train.txt',
    dev_file='dev.txt',
    test_file='test.txt',
    label_type='sentiment'
)

# Print statistics
print(f"Number of training sentences: {len(corpus.train)}")
print(f"Number of validation sentences: {len(corpus.dev)}")
print(f"Number of test sentences: {len(corpus.test)}")

2024-12-03 18:21:18,430 Reading data from data\flair_data
2024-12-03 18:21:18,430 Train: data\flair_data\train.txt
2024-12-03 18:21:18,430 Dev: data\flair_data\dev.txt
2024-12-03 18:21:18,430 Test: data\flair_data\test.txt
2024-12-03 18:21:18,956 Initialized corpus data\flair_data (label type name is 'sentiment')
Number of training sentences: 22500
Number of validation sentences: 2500
Number of test sentences: 25000


In [4]:
label_dict = corpus.make_label_dictionary(label_type='sentiment')
print(label_dict)

2024-12-03 18:21:18,960 Computing label dictionary. Progress:


0it [00:00, ?it/s]
22500it [00:42, 527.47it/s]

2024-12-03 18:22:01,633 Dictionary created for label 'sentiment' with 2 values: POS (seen 11298 times), NEG (seen 11202 times)
Dictionary with 2 tags: POS, NEG





In [5]:
# Step 4: Set Up Stacked Embeddings
flair_forward_embedding = FlairEmbeddings('news-forward')
flair_backward_embedding = FlairEmbeddings('news-backward')
transformer_word_embeddings = TransformerWordEmbeddings('distilbert-base-uncased')

# List of embeddings
embeddings = [
    flair_forward_embedding,
    flair_backward_embedding,
    transformer_word_embeddings,
]

# Create document embeddings from word embeddings
document_embeddings = DocumentRNNEmbeddings(
    embeddings=embeddings,
    hidden_size=256,
    reproject_words=True,
    reproject_words_dimension=256,
)

In [6]:
classifier = TextClassifier(
    document_embeddings,
    label_dictionary=label_dict,
    label_type='sentiment'
).to(flair.device)

In [7]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
import tqdm
from collections import Counter

def evaluate_model(classifier, test_dataset):
    """
    Evaluate a Flair classifier on a given test dataset, with verification and debugging steps.

    Args:
        classifier (TextClassifier): The trained Flair classifier.
        test_dataset (Dataset): The test dataset.

    Returns:
        dict: Dictionary containing evaluation metrics.
    """
    true_labels = []
    predicted_labels = []

    # Label mapping
    label_mapping = {"NEG": 0, "POS": 1}

    # Iterate over the test dataset with tqdm progress bar
    for sentence in tqdm.tqdm(test_dataset, desc="Evaluating", leave=True):
        sentence.to(flair.device)

        # Get true label
        true_label = sentence.get_label("sentiment").value
        true_labels.append(true_label)

        # Get predicted label
        classifier.predict(sentence)
        predicted_label = sentence.labels[0].value
        predicted_labels.append(predicted_label)

    # Verify label consistency
    print("True Labels Sample:", true_labels[:5])
    print("Predicted Labels Sample:", predicted_labels[:5])
    print("True Label Distribution:", Counter(true_labels))
    print("Predicted Label Distribution:", Counter(predicted_labels))

    # Map labels to numeric values for sklearn
    try:
        true_labels_mapped = [label_mapping[label] for label in true_labels]
        predicted_labels_mapped = [label_mapping[label] for label in predicted_labels]
    except KeyError as e:
        print(f"Label mapping error: {e}. Ensure all labels are in {label_mapping}.")
        return {}

    # Verify mapped labels
    print("Mapped True Labels Sample:", true_labels_mapped[:5])
    print("Mapped Predicted Labels Sample:", predicted_labels_mapped[:5])

    # Calculate metrics
    accuracy = accuracy_score(true_labels_mapped, predicted_labels_mapped)
    precision = precision_score(true_labels_mapped, predicted_labels_mapped, pos_label=1, zero_division=0)
    recall = recall_score(true_labels_mapped, predicted_labels_mapped, pos_label=1, zero_division=0)
    f1 = f1_score(true_labels_mapped, predicted_labels_mapped, pos_label=1, zero_division=0)

    # Full classification report
    classification_rep = classification_report(
        true_labels_mapped,
        predicted_labels_mapped,
        target_names=["NEG", "POS"]  # Match target names with label_mapping
    )

    # Print metrics
    print(f"\nAccuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print("\nClassification Report:")
    print(classification_rep)

    # Return metrics as a dictionary
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1,
        "classification_report": classification_rep,
    }

In [8]:
from torch.optim import AdamW
import logging


# Test on small dataset
small_data_folder = "data/flair_data"
small_corpus = ClassificationCorpus(small_data_folder, 
                                    train_file="train_small.txt",  
                                    dev_file='dev_small.txt', 
                                    test_file='test_small.txt',
                                    label_type="sentiment")


# Set the logging level to INFO
logging.basicConfig(level=logging.INFO)

# Initialize the trainer
trainer = ModelTrainer(classifier, corpus)

#Fine-tune the model
trainer.fine_tune(
    base_path='flair_model',             # Directory to save the model and logs
    learning_rate=5e-5,                  # Learning rate for fine-tuning
    mini_batch_size=16,                   # Smaller batch size for transformers
    max_epochs=10,                        # Number of epochs
    embeddings_storage_mode='gpu',     
    optimizer=AdamW,                     # Optimizer suited for transformers
    save_final_model=True,               # Save the final model
    save_model_each_k_epochs=1,          # Save model checkpoint every epoch
    create_file_logs=True,               # Save logs to a file
    create_loss_file=True,               # Save loss values to a file
    use_final_model_for_eval=False
)



2024-12-03 18:22:03,826 Reading data from data\flair_data
2024-12-03 18:22:03,827 Train: data\flair_data\train_small.txt
2024-12-03 18:22:03,827 Dev: data\flair_data\dev_small.txt
2024-12-03 18:22:03,827 Test: data\flair_data\test_small.txt
2024-12-03 18:22:03,831 Initialized corpus data/flair_data (label type name is 'sentiment')
2024-12-03 18:22:03,833 ----------------------------------------------------------------------------------------------------
2024-12-03 18:22:03,834 Model: "TextClassifier(
  (embeddings): DocumentRNNEmbeddings(
    (embeddings): StackedEmbeddings(
      (list_embedding_0): FlairEmbeddings(
        (lm): LanguageModel(
          (drop): Dropout(p=0.05, inplace=False)
          (encoder): Embedding(300, 100)
          (rnn): LSTM(100, 2048)
        )
      )
      (list_embedding_1): FlairEmbeddings(
        (lm): LanguageModel(
          (drop): Dropout(p=0.05, inplace=False)
          (encoder): Embedding(300, 100)
          (rnn): LSTM(100, 2048)
        )


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp and flair.device.type != "cpu")


2024-12-03 18:40:44,374 epoch 1 - iter 140/1407 - loss 0.70971192 - time (sec): 1120.53 - samples/sec: 2.00 - lr: 0.000005 - momentum: 0.000000
2024-12-03 18:59:58,059 epoch 1 - iter 280/1407 - loss 0.67125986 - time (sec): 2274.22 - samples/sec: 1.97 - lr: 0.000010 - momentum: 0.000000
2024-12-03 19:24:28,877 epoch 1 - iter 420/1407 - loss 0.56417752 - time (sec): 3745.03 - samples/sec: 1.79 - lr: 0.000015 - momentum: 0.000000
2024-12-03 19:47:53,884 epoch 1 - iter 560/1407 - loss 0.50285825 - time (sec): 5150.04 - samples/sec: 1.74 - lr: 0.000020 - momentum: 0.000000
2024-12-03 20:10:01,058 epoch 1 - iter 700/1407 - loss 0.46016501 - time (sec): 6477.22 - samples/sec: 1.73 - lr: 0.000025 - momentum: 0.000000
2024-12-03 20:31:58,632 epoch 1 - iter 840/1407 - loss 0.43516031 - time (sec): 7794.79 - samples/sec: 1.72 - lr: 0.000030 - momentum: 0.000000
2024-12-03 20:50:19,360 epoch 1 - iter 980/1407 - loss 0.41024027 - time (sec): 8895.52 - samples/sec: 1.76 - lr: 0.000035 - momentum: 0

100%|██████████| 157/157 [32:51<00:00, 12.56s/it]

2024-12-03 22:38:48,419 DEV : loss 0.33744463324546814 - f1-score (micro avg)  0.9004





2024-12-03 22:38:53,718 saving best model
2024-12-03 22:38:54,220 ----------------------------------------------------------------------------------------------------
2024-12-03 23:18:19,202 epoch 2 - iter 140/1407 - loss 0.22787478 - time (sec): 2364.98 - samples/sec: 0.95 - lr: 0.000049 - momentum: 0.000000
2024-12-04 00:00:46,566 epoch 2 - iter 280/1407 - loss 0.23892554 - time (sec): 4912.34 - samples/sec: 0.91 - lr: 0.000049 - momentum: 0.000000
2024-12-04 00:39:02,959 epoch 2 - iter 420/1407 - loss 0.23037537 - time (sec): 7208.74 - samples/sec: 0.93 - lr: 0.000048 - momentum: 0.000000
2024-12-04 01:17:47,070 epoch 2 - iter 560/1407 - loss 0.23122157 - time (sec): 9532.85 - samples/sec: 0.94 - lr: 0.000048 - momentum: 0.000000
2024-12-04 01:55:51,751 epoch 2 - iter 700/1407 - loss 0.22478482 - time (sec): 11817.53 - samples/sec: 0.95 - lr: 0.000047 - momentum: 0.000000
2024-12-04 02:34:20,982 epoch 2 - iter 840/1407 - loss 0.22478380 - time (sec): 14126.76 - samples/sec: 0.95 - l

100%|██████████| 157/157 [31:20<00:00, 11.98s/it]

2024-12-04 05:38:32,980 DEV : loss 0.3276882469654083 - f1-score (micro avg)  0.898





2024-12-04 05:38:38,004 ----------------------------------------------------------------------------------------------------
2024-12-04 06:16:55,606 epoch 3 - iter 140/1407 - loss 0.13563384 - time (sec): 2297.60 - samples/sec: 0.97 - lr: 0.000044 - momentum: 0.000000
2024-12-04 06:54:15,681 epoch 3 - iter 280/1407 - loss 0.12973217 - time (sec): 4537.68 - samples/sec: 0.99 - lr: 0.000043 - momentum: 0.000000
2024-12-04 07:32:25,388 epoch 3 - iter 420/1407 - loss 0.13003129 - time (sec): 6827.38 - samples/sec: 0.98 - lr: 0.000043 - momentum: 0.000000
2024-12-04 08:11:09,769 epoch 3 - iter 560/1407 - loss 0.13230037 - time (sec): 9151.76 - samples/sec: 0.98 - lr: 0.000042 - momentum: 0.000000
2024-12-04 08:48:24,438 epoch 3 - iter 700/1407 - loss 0.13160721 - time (sec): 11386.43 - samples/sec: 0.98 - lr: 0.000042 - momentum: 0.000000
2024-12-04 09:27:19,313 epoch 3 - iter 840/1407 - loss 0.13585726 - time (sec): 13721.31 - samples/sec: 0.98 - lr: 0.000041 - momentum: 0.000000
2024-12-0

100%|██████████| 157/157 [31:10<00:00, 11.91s/it]

2024-12-04 12:31:32,399 DEV : loss 0.38121363520622253 - f1-score (micro avg)  0.9108





2024-12-04 12:31:37,663 saving best model
2024-12-04 12:31:38,148 ----------------------------------------------------------------------------------------------------
2024-12-04 13:09:24,014 epoch 4 - iter 140/1407 - loss 0.04894524 - time (sec): 2265.86 - samples/sec: 0.99 - lr: 0.000038 - momentum: 0.000000
2024-12-04 13:47:56,385 epoch 4 - iter 280/1407 - loss 0.07164522 - time (sec): 4578.24 - samples/sec: 0.98 - lr: 0.000038 - momentum: 0.000000
2024-12-04 14:24:18,448 epoch 4 - iter 420/1407 - loss 0.06864954 - time (sec): 6760.30 - samples/sec: 0.99 - lr: 0.000037 - momentum: 0.000000
2024-12-04 15:00:58,254 epoch 4 - iter 560/1407 - loss 0.06719500 - time (sec): 8960.11 - samples/sec: 1.00 - lr: 0.000037 - momentum: 0.000000
2024-12-04 15:37:56,492 epoch 4 - iter 700/1407 - loss 0.07707795 - time (sec): 11178.34 - samples/sec: 1.00 - lr: 0.000036 - momentum: 0.000000
2024-12-04 16:15:11,274 epoch 4 - iter 840/1407 - loss 0.07700706 - time (sec): 13413.13 - samples/sec: 1.00 - l

100%|██████████| 157/157 [40:22<00:00, 15.43s/it]

2024-12-04 19:32:19,614 DEV : loss 0.38885292410850525 - f1-score (micro avg)  0.9188





2024-12-04 19:32:24,763 saving best model
2024-12-04 19:32:25,221 ----------------------------------------------------------------------------------------------------
2024-12-04 20:09:56,737 epoch 5 - iter 140/1407 - loss 0.03200944 - time (sec): 2251.51 - samples/sec: 0.99 - lr: 0.000033 - momentum: 0.000000
2024-12-04 20:45:38,339 epoch 5 - iter 280/1407 - loss 0.03304739 - time (sec): 4393.12 - samples/sec: 1.02 - lr: 0.000032 - momentum: 0.000000
2024-12-04 21:24:58,768 epoch 5 - iter 420/1407 - loss 0.03928664 - time (sec): 6753.55 - samples/sec: 1.00 - lr: 0.000032 - momentum: 0.000000
2024-12-04 22:02:40,559 epoch 5 - iter 560/1407 - loss 0.03937346 - time (sec): 9015.34 - samples/sec: 0.99 - lr: 0.000031 - momentum: 0.000000
2024-12-04 22:42:42,431 epoch 5 - iter 700/1407 - loss 0.03751128 - time (sec): 11417.21 - samples/sec: 0.98 - lr: 0.000031 - momentum: 0.000000
2024-12-04 23:21:33,758 epoch 5 - iter 840/1407 - loss 0.03679222 - time (sec): 13748.54 - samples/sec: 0.98 - l

100%|██████████| 157/157 [42:59<00:00, 16.43s/it]

2024-12-05 02:39:48,939 DEV : loss 0.5793805718421936 - f1-score (micro avg)  0.9212





2024-12-05 02:39:53,879 saving best model
2024-12-05 02:39:54,237 ----------------------------------------------------------------------------------------------------
2024-12-05 03:18:13,642 epoch 6 - iter 140/1407 - loss 0.01637261 - time (sec): 2299.40 - samples/sec: 0.97 - lr: 0.000027 - momentum: 0.000000
2024-12-05 03:54:51,096 epoch 6 - iter 280/1407 - loss 0.02102738 - time (sec): 4496.86 - samples/sec: 1.00 - lr: 0.000027 - momentum: 0.000000
2024-12-05 04:32:12,583 epoch 6 - iter 420/1407 - loss 0.02083186 - time (sec): 6738.34 - samples/sec: 1.00 - lr: 0.000026 - momentum: 0.000000
2024-12-05 05:12:01,800 epoch 6 - iter 560/1407 - loss 0.02621299 - time (sec): 9127.56 - samples/sec: 0.98 - lr: 0.000026 - momentum: 0.000000
2024-12-05 05:49:36,067 epoch 6 - iter 700/1407 - loss 0.02547396 - time (sec): 11381.83 - samples/sec: 0.98 - lr: 0.000025 - momentum: 0.000000
2024-12-05 06:26:54,281 epoch 6 - iter 840/1407 - loss 0.02490445 - time (sec): 13620.04 - samples/sec: 0.99 - l

100%|██████████| 157/157 [30:46<00:00, 11.76s/it]

2024-12-05 09:27:26,003 DEV : loss 0.5343671441078186 - f1-score (micro avg)  0.9204





2024-12-05 09:27:30,956 ----------------------------------------------------------------------------------------------------
2024-12-05 10:04:23,623 epoch 7 - iter 140/1407 - loss 0.00798728 - time (sec): 2212.67 - samples/sec: 1.01 - lr: 0.000022 - momentum: 0.000000
2024-12-05 10:45:13,219 epoch 7 - iter 280/1407 - loss 0.01166894 - time (sec): 4662.26 - samples/sec: 0.96 - lr: 0.000021 - momentum: 0.000000
2024-12-05 11:25:32,882 epoch 7 - iter 420/1407 - loss 0.01188168 - time (sec): 7081.93 - samples/sec: 0.95 - lr: 0.000021 - momentum: 0.000000
2024-12-05 12:03:54,519 epoch 7 - iter 560/1407 - loss 0.01574838 - time (sec): 9383.56 - samples/sec: 0.95 - lr: 0.000020 - momentum: 0.000000
2024-12-05 12:41:44,609 epoch 7 - iter 700/1407 - loss 0.01550765 - time (sec): 11653.65 - samples/sec: 0.96 - lr: 0.000019 - momentum: 0.000000
2024-12-05 13:18:54,804 epoch 7 - iter 840/1407 - loss 0.01515138 - time (sec): 13883.85 - samples/sec: 0.97 - lr: 0.000019 - momentum: 0.000000
2024-12-0

100%|██████████| 157/157 [32:42<00:00, 12.50s/it]

2024-12-05 16:24:40,210 DEV : loss 0.6759727001190186 - f1-score (micro avg)  0.928





2024-12-05 16:24:45,144 saving best model
2024-12-05 16:24:45,477 ----------------------------------------------------------------------------------------------------
2024-12-05 17:02:50,626 epoch 8 - iter 140/1407 - loss 0.01056803 - time (sec): 2285.15 - samples/sec: 0.98 - lr: 0.000016 - momentum: 0.000000
2024-12-05 17:39:25,083 epoch 8 - iter 280/1407 - loss 0.01011907 - time (sec): 4479.61 - samples/sec: 1.00 - lr: 0.000016 - momentum: 0.000000
2024-12-05 18:16:31,577 epoch 8 - iter 420/1407 - loss 0.00857263 - time (sec): 6706.10 - samples/sec: 1.00 - lr: 0.000015 - momentum: 0.000000
2024-12-05 18:54:38,109 epoch 8 - iter 560/1407 - loss 0.00732613 - time (sec): 8992.63 - samples/sec: 1.00 - lr: 0.000014 - momentum: 0.000000
2024-12-05 19:32:45,190 epoch 8 - iter 700/1407 - loss 0.00690599 - time (sec): 11279.71 - samples/sec: 0.99 - lr: 0.000014 - momentum: 0.000000
2024-12-05 20:10:20,674 epoch 8 - iter 840/1407 - loss 0.00694750 - time (sec): 13535.20 - samples/sec: 0.99 - l

100%|██████████| 157/157 [34:18<00:00, 13.11s/it]

2024-12-05 23:21:05,849 DEV : loss 0.6266456842422485 - f1-score (micro avg)  0.928





2024-12-05 23:21:10,845 ----------------------------------------------------------------------------------------------------
2024-12-05 23:56:09,919 epoch 9 - iter 140/1407 - loss 0.00649136 - time (sec): 2099.07 - samples/sec: 1.07 - lr: 0.000011 - momentum: 0.000000
2024-12-06 00:34:52,437 epoch 9 - iter 280/1407 - loss 0.00593049 - time (sec): 4421.59 - samples/sec: 1.01 - lr: 0.000010 - momentum: 0.000000
2024-12-06 01:13:15,165 epoch 9 - iter 420/1407 - loss 0.00410749 - time (sec): 6724.32 - samples/sec: 1.00 - lr: 0.000009 - momentum: 0.000000
2024-12-06 01:50:44,860 epoch 9 - iter 560/1407 - loss 0.00425486 - time (sec): 8974.01 - samples/sec: 1.00 - lr: 0.000009 - momentum: 0.000000
2024-12-06 02:29:25,597 epoch 9 - iter 700/1407 - loss 0.00448478 - time (sec): 11294.75 - samples/sec: 0.99 - lr: 0.000008 - momentum: 0.000000
2024-12-06 03:07:45,766 epoch 9 - iter 840/1407 - loss 0.00592652 - time (sec): 13594.92 - samples/sec: 0.99 - lr: 0.000008 - momentum: 0.000000
2024-12-0

100%|██████████| 157/157 [29:34<00:00, 11.30s/it]

2024-12-06 06:08:46,795 DEV : loss 0.7458382844924927 - f1-score (micro avg)  0.9288





2024-12-06 06:08:51,775 saving best model
2024-12-06 06:08:52,138 ----------------------------------------------------------------------------------------------------
2024-12-06 06:48:11,293 epoch 10 - iter 140/1407 - loss 0.00001014 - time (sec): 2359.16 - samples/sec: 0.95 - lr: 0.000005 - momentum: 0.000000
2024-12-06 07:25:09,065 epoch 10 - iter 280/1407 - loss 0.00102093 - time (sec): 4576.93 - samples/sec: 0.98 - lr: 0.000004 - momentum: 0.000000
2024-12-06 08:03:36,095 epoch 10 - iter 420/1407 - loss 0.00164725 - time (sec): 6883.96 - samples/sec: 0.98 - lr: 0.000004 - momentum: 0.000000
2024-12-06 08:41:08,240 epoch 10 - iter 560/1407 - loss 0.00123949 - time (sec): 9136.10 - samples/sec: 0.98 - lr: 0.000003 - momentum: 0.000000
2024-12-06 09:20:01,200 epoch 10 - iter 700/1407 - loss 0.00185058 - time (sec): 11469.06 - samples/sec: 0.98 - lr: 0.000003 - momentum: 0.000000
2024-12-06 09:58:41,615 epoch 10 - iter 840/1407 - loss 0.00155128 - time (sec): 13789.48 - samples/sec: 0.

100%|██████████| 157/157 [30:07<00:00, 11.51s/it]

2024-12-06 13:01:59,773 DEV : loss 0.7704070210456848 - f1-score (micro avg)  0.9284





2024-12-06 13:02:05,071 ----------------------------------------------------------------------------------------------------
2024-12-06 13:02:05,072 Loading model from best epoch ...


100%|██████████| 1563/1563 [4:18:48<00:00,  9.94s/it]  


2024-12-06 17:20:55,301 
Results:
- F-score (micro) 0.929
- F-score (macro) 0.929
- Accuracy 0.929

By class:
              precision    recall  f1-score   support

         POS     0.9208    0.9387    0.9297     12500
         NEG     0.9375    0.9193    0.9283     12500

    accuracy                         0.9290     25000
   macro avg     0.9292    0.9290    0.9290     25000
weighted avg     0.9292    0.9290    0.9290     25000

2024-12-06 17:20:55,302 ----------------------------------------------------------------------------------------------------


{'test_score': 0.929}

In [9]:
# Path to the saved model
saved_model_path = "flair_model/best-model.pt"  # Update this if the path or filename is different

# Load the trained model
classifier = TextClassifier.load(saved_model_path).to(flair.device)

# Evaluate the model
results = evaluate_model(classifier, corpus.test)


Evaluating: 100%|██████████| 25000/25000 [3:56:31<00:00,  1.76it/s]  

True Labels Sample: ['NEG', 'NEG', 'NEG', 'NEG', 'NEG']
Predicted Labels Sample: ['POS', 'NEG', 'NEG', 'NEG', 'NEG']
True Label Distribution: Counter({'NEG': 12500, 'POS': 12500})
Predicted Label Distribution: Counter({'POS': 12743, 'NEG': 12257})
Mapped True Labels Sample: [0, 0, 0, 0, 0]
Mapped Predicted Labels Sample: [1, 0, 0, 0, 0]

Accuracy: 0.9290
Precision: 0.9208
Recall: 0.9387
F1 Score: 0.9297

Classification Report:
              precision    recall  f1-score   support

         NEG       0.94      0.92      0.93     12500
         POS       0.92      0.94      0.93     12500

    accuracy                           0.93     25000
   macro avg       0.93      0.93      0.93     25000
weighted avg       0.93      0.93      0.93     25000




