# Sentiment Analysis on Go Emotions with Transformers

## Setup & Preprocessing

### Install Packages

In [1]:
# %pip install pandas numpy torch nltk datasets transformers scikit-learn

### NLTK Setup

In [2]:
import re
import string
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

# Ensure necessary resources are available
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')

# Initialize stopwords and lemmatizer
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\alber\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\alber\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\alber\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


### Preprocessing Functions

In [3]:
def preprocess_text(text):
    """Clean, normalize, tokenize, remove stopwords, and lemmatize text."""
    text = text.lower()  # Convert to lowercase
    text = re.sub(r'https?://\S+|www\.\S+', '', text)  # Remove URLs
    text = re.sub(r'\S*@\S*\s?', '', text)  # Remove emails
    text = text.translate(str.maketrans('', '', string.punctuation))  # Remove punctuation
    text = re.sub(r'\s+', ' ', text).strip()  # Remove extra spaces
    
    tokens = word_tokenize(text)  # Tokenization
    
    # Remove stopwords and lemmatize
    processed_tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words]
    
    return ' '.join(processed_tokens)

### Feature Extraction

In [4]:
def extract_features(text):
    """Extract additional features from text"""
    features = {}
    original_text = text
    
    features['text_length'] = len(text)
    features['word_count'] = len(text.split())
    features['has_question'] = 1 if '?' in original_text else 0
    features['has_exclamation'] = 1 if '!' in original_text else 0
    features['has_emoticon'] = 1 if any(emoticon in original_text.lower() for emoticon in [':)', ':(', ':D', ';)', ':P', 'XD']) else 0
    
    # All caps words ratio (potential indicator of intensity)
    words = text.split()
    if words:
        features['caps_ratio'] = sum(1 for word in words if word.isupper()) / len(words)
    else:
        features['caps_ratio'] = 0
    
    return features

## Data Preparation

### Load Dataset

In [5]:
import pandas as pd
from datasets import load_dataset

# Loading the dataset
data = load_dataset('go_emotions')
df_train = pd.DataFrame(data["train"])
df_val = pd.DataFrame(data["validation"])
df_test = pd.DataFrame(data["test"])

# Display dataset info
print(f"Training examples: {len(df_train)}")
print(f"Validation examples: {len(df_val)}")
print(f"Test examples: {len(df_test)}")

# Check the columns to understand what we're working with
print(f"Columns: {df_train.columns}")

# Get list of emotion labels
emotion_labels = data['train'].features['labels'].feature.names
num_labels = len(emotion_labels)
print(f"Number of emotion labels: {num_labels}")
print(f"Emotion labels: {emotion_labels}")

Training examples: 43410
Validation examples: 5426
Test examples: 5427
Columns: Index(['text', 'labels', 'id'], dtype='object')
Number of emotion labels: 28
Emotion labels: ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']


### Apply Preprocessing

In [6]:
# Add preprocessed text columns
print("Preprocessing text data...")
df_train['cleaned_text'] = df_train['text'].apply(preprocess_text)
df_val['cleaned_text'] = df_val['text'].apply(preprocess_text)
df_test['cleaned_text'] = df_test['text'].apply(preprocess_text)

# Extract features
print("Extracting text features...")
train_features = df_train['text'].apply(extract_features).apply(pd.Series)
val_features = df_val['text'].apply(extract_features).apply(pd.Series)
test_features = df_test['text'].apply(extract_features).apply(pd.Series)

# Add features to dataframes
for feature in train_features.columns:
    df_train[feature] = train_features[feature]
    df_val[feature] = val_features[feature]
    df_test[feature] = test_features[feature]

# Display sample
print("\nSample data with features:")
print(df_train[['text', 'cleaned_text', 'text_length', 'word_count', 'has_question', 'has_exclamation', 'caps_ratio']].head())

Preprocessing text data...
Extracting text features...

Sample data with features:
                                                text  \
0  My favourite food is anything I didn't have to...   
1  Now if he does off himself, everyone will thin...   
2                     WHY THE FUCK IS BAYLESS ISOING   
3                        To make her feel threatened   
4                             Dirty Southern Wankers   

                                        cleaned_text  text_length  word_count  \
0                 favourite food anything didnt cook         59.0        11.0   
1  everyone think he laugh screwing people instea...        112.0        20.0   
2                                fuck bayless isoing         30.0         6.0   
3                               make feel threatened         27.0         5.0   
4                              dirty southern wanker         22.0         3.0   

   has_question  has_exclamation  caps_ratio  
0           0.0              0.0    0.090909  

### Dataset Class

In [7]:
import torch

# Add cleaned text to the dataset
data["train"] = data["train"].add_column("cleaned_text", df_train["cleaned_text"].tolist())
data["validation"] = data["validation"].add_column("cleaned_text", df_val["cleaned_text"].tolist())
data["test"] = data["test"].add_column("cleaned_text", df_test["cleaned_text"].tolist())

# Create a custom dataset class that incorporates text features
class GoEmotionsDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels, features):
        self.encodings = encodings
        self.labels = labels
        self.features = features
    
    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        
        # Add text features
        for feature_name, feature_values in self.features.items():
            item[feature_name] = torch.tensor(feature_values[idx], dtype=torch.float)
        
        return item
    
    def __len__(self):
        return len(self.labels)

### Initialize Tokenizer

In [8]:
from transformers import AutoTokenizer

# Initialize tokenizer
model_name = "distilbert-base-uncased"  # A good balance of performance and speed
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Tokenize the datasets
def tokenize_function(examples):
    return tokenizer(
        examples["cleaned_text"] if "cleaned_text" in examples else examples["text"],
        padding="max_length",
        truncation=True,
        max_length=128
    )

# Prepare datasets for training
train_dataset = data["train"].map(tokenize_function, batched=True)
val_dataset = data["validation"].map(tokenize_function, batched=True)
test_dataset = data["test"].map(tokenize_function, batched=True)

Map:   0%|          | 0/5427 [00:00<?, ? examples/s]

### Extract Features

In [9]:
# Extract numeric features from dataframes
train_numeric_features = {
    'text_length': df_train['text_length'].values,
    'word_count': df_train['word_count'].values,
    'has_question': df_train['has_question'].values,
    'has_exclamation': df_train['has_exclamation'].values,
    'has_emoticon': df_train['has_emoticon'].values,
    'caps_ratio': df_train['caps_ratio'].values
}

val_numeric_features = {
    'text_length': df_val['text_length'].values,
    'word_count': df_val['word_count'].values,
    'has_question': df_val['has_question'].values,
    'has_exclamation': df_val['has_exclamation'].values,
    'has_emoticon': df_val['has_emoticon'].values,
    'caps_ratio': df_val['caps_ratio'].values
}

test_numeric_features = {
    'text_length': df_test['text_length'].values,
    'word_count': df_test['word_count'].values,
    'has_question': df_test['has_question'].values,
    'has_exclamation': df_test['has_exclamation'].values,
    'has_emoticon': df_test['has_emoticon'].values,
    'caps_ratio': df_test['caps_ratio'].values
}


### Process Labels

In [10]:
import numpy as np

# Convert labels to the expected format for multi-label classification
def process_labels(examples):
    # Convert label indices to multi-hot encoding
    labels = np.zeros((len(examples["labels"]), num_labels))
    for i, label_list in enumerate(examples["labels"]):
        labels[i, label_list] = 1
    return {"labels": labels.tolist()}

train_dataset = train_dataset.map(process_labels, batched=True)
val_dataset = val_dataset.map(process_labels, batched=True)
test_dataset = test_dataset.map(process_labels, batched=True)

# Create custom datasets with features
# Get processed labels
train_labels = np.array(train_dataset["labels"])
val_labels = np.array(val_dataset["labels"])
test_labels = np.array(test_dataset["labels"])

# Remove labels from encodings
train_encodings = {k: v for k, v in train_dataset.to_dict().items() if k != "labels"}
val_encodings = {k: v for k, v in val_dataset.to_dict().items() if k != "labels"}
test_encodings = {k: v for k, v in test_dataset.to_dict().items() if k != "labels"}

# Create custom datasets
custom_train_dataset = GoEmotionsDataset(train_encodings, train_labels, train_numeric_features)
custom_val_dataset = GoEmotionsDataset(val_encodings, val_labels, val_numeric_features)
custom_test_dataset = GoEmotionsDataset(test_encodings, test_labels, test_numeric_features)

Map:   0%|          | 0/5427 [00:00<?, ? examples/s]

## Model Building

### Model Architecture

In [11]:
# Create a custom model that incorporates text features
from transformers import DistilBertPreTrainedModel, DistilBertModel
from torch import nn

class EmotionClassifier(DistilBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.distilbert = DistilBertModel(config)
        
        # Number of additional features
        self.num_features = 6  # text_length, word_count, etc.
        
        # Add a feature extractor network
        self.feature_extractor = nn.Sequential(
            nn.Linear(self.num_features, 32),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Combine BERT embeddings with our features
        self.classifier = nn.Sequential(
            nn.Linear(config.hidden_size + 32, config.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(config.hidden_size, config.num_labels)
        )
        
        # Initialize weights
        self.post_init()
        
    def forward(self, input_ids=None, attention_mask=None, head_mask=None, 
                inputs_embeds=None, labels=None, output_attentions=None, 
                output_hidden_states=None, return_dict=None, **kwargs):
        
        # Get text features from kwargs
        text_features = []
        for feature_name in ['text_length', 'word_count', 'has_question', 
                            'has_exclamation', 'has_emoticon', 'caps_ratio']:
            if feature_name in kwargs:
                text_features.append(kwargs[feature_name].unsqueeze(1))
        
        # Concatenate all features
        if text_features:
            text_features = torch.cat(text_features, dim=1)
        else:
            # If no features provided, use zeros as fallback
            text_features = torch.zeros(input_ids.shape[0], self.num_features, device=input_ids.device)
        
        # Get BERT embeddings
        outputs = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )
        
        # Get the CLS token embedding
        bert_output = outputs[0][:, 0]  # CLS token
        
        # Process the features
        feature_output = self.feature_extractor(text_features)
        
        # Concatenate BERT output with features
        combined = torch.cat([bert_output, feature_output], dim=1)
        
        # Final classification
        logits = self.classifier(combined)
        
        # loss = None
        # if labels is not None:
        #     loss_fct = nn.BCEWithLogitsLoss()
        #     loss = loss_fct(logits.view(-1, self.num_labels), 
        #                    labels.float().view(-1, self.num_labels))
        
        loss = None
        if labels is not None:
            # Use class weights if available
            if hasattr(self, 'pos_weights'):
                loss_fct = nn.BCEWithLogitsLoss(pos_weight=self.pos_weights)
            else:
                loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), 
                        labels.float().view(-1, self.num_labels))

        return (loss, logits) if loss is not None else logits

### Initialize Model

In [12]:
# Load pre-trained DistilBERT config and modify it
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
    model_name,
    num_labels=num_labels,
    problem_type="multi_label_classification"
)

# Initialize our custom model
model = EmotionClassifier.from_pretrained(
    model_name,
    config=config
)

# Check if GPU is available and move model to GPU
if torch.cuda.is_available():
    model = model.to('cuda')
    print("Model loaded on GPU")
else:
    print("Model loaded on CPU")


Some weights of EmotionClassifier were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.0.bias', 'classifier.0.weight', 'classifier.3.bias', 'classifier.3.weight', 'feature_extractor.0.bias', 'feature_extractor.0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded on GPU


In [13]:
# Compute class weights based on label distribution
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

# Reshape labels for weight calculation (train_labels is multi-hot encoded)
flat_labels = np.where(train_labels == 1)[1]  # Get indices of positive labels

# Compute class weights
class_weights = compute_class_weight('balanced', 
                                    classes=np.unique(flat_labels), 
                                    y=flat_labels)
                                    
# Convert to tensor and move to appropriate device
pos_weights = torch.tensor(class_weights, dtype=torch.float)
if torch.cuda.is_available():
    pos_weights = pos_weights.to('cuda')

# Pass these weights to the model
model.pos_weights = pos_weights

### Training Setup

In [14]:
import torch
from sklearn.metrics import (accuracy_score, f1_score, precision_score,
                             recall_score)
from transformers import TrainingArguments


# Define metrics for evaluation
def compute_metrics(pred):
    predictions = (pred.predictions > 0).astype(int)
    labels = pred.label_ids
    
    # Compute metrics
    accuracy = accuracy_score(labels.flatten(), predictions.flatten())
    f1 = f1_score(labels, predictions, average='weighted')
    precision = precision_score(labels, predictions, average='weighted', zero_division=0)
    recall = recall_score(labels, predictions, average='weighted', zero_division=0)
    
    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# # Set up training arguments
# training_args = TrainingArguments(
#     output_dir="./results",
#     learning_rate=5e-5,
#     per_device_train_batch_size=16,
#     per_device_eval_batch_size=16,
#     num_train_epochs=3,
#     weight_decay=0.01,
#     evaluation_strategy="epoch",
#     save_strategy="epoch",
#     load_best_model_at_end=True,
#     push_to_hub=False,
#     fp16=torch.cuda.is_available(),  # Enable mixed precision if GPU available
# )


# Set up training arguments with improvements
training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,  # Slightly lower learning rate
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,  # Increase training epochs
    weight_decay=0.01,
    evaluation_strategy="steps",  # Evaluate more frequently
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    load_best_model_at_end=True,
    metric_for_best_model="f1",  # Optimize for F1 score
    greater_is_better=True,
    push_to_hub=False,
    fp16=torch.cuda.is_available(),
    logging_steps=100,
    warmup_steps=500,  # Add warmup steps
    lr_scheduler_type="linear",  # Use linear scheduler
    gradient_accumulation_steps=2,  # Effectively doubles batch size
)



## Training & Evaluation

### Train Model

In [15]:
from transformers import Trainer

# Initialize trainer with our custom datasets
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=custom_train_dataset,
    eval_dataset=custom_val_dataset,
    compute_metrics=compute_metrics,
)

# Train the model
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malbert-negoro89[0m ([33malbert-negoro89-arizona-state-university[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/13570 [00:00<?, ?it/s]

{'loss': 0.6703, 'grad_norm': 1.4818263053894043, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.07}
{'loss': 0.4467, 'grad_norm': 1.4410593509674072, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.15}
{'loss': 0.2633, 'grad_norm': 0.895915687084198, 'learning_rate': 1.2e-05, 'epoch': 0.22}
{'loss': 0.1856, 'grad_norm': 0.6379570364952087, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.29}
{'loss': 0.1727, 'grad_norm': 0.451177716255188, 'learning_rate': 2e-05, 'epoch': 0.37}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.176142156124115, 'eval_accuracy': 0.9580064240956242, 'eval_f1': 0.0, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_runtime': 3.6661, 'eval_samples_per_second': 1480.056, 'eval_steps_per_second': 92.742, 'epoch': 0.37}
{'loss': 0.1736, 'grad_norm': 1.2803378105163574, 'learning_rate': 1.984697781178271e-05, 'epoch': 0.44}
{'loss': 0.1816, 'grad_norm': 0.5326987504959106, 'learning_rate': 1.969395562356542e-05, 'epoch': 0.52}
{'loss': 0.1839, 'grad_norm': 0.5067772269248962, 'learning_rate': 1.9540933435348126e-05, 'epoch': 0.59}
{'loss': 0.1719, 'grad_norm': 0.4756167531013489, 'learning_rate': 1.9387911247130835e-05, 'epoch': 0.66}
{'loss': 0.1571, 'grad_norm': 0.4957443177700043, 'learning_rate': 1.9234889058913545e-05, 'epoch': 0.74}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.1535516381263733, 'eval_accuracy': 0.9581709757253436, 'eval_f1': 0.0077014865001516835, 'eval_precision': 0.006860519583468454, 'eval_recall': 0.00877742946708464, 'eval_runtime': 3.5667, 'eval_samples_per_second': 1521.305, 'eval_steps_per_second': 95.327, 'epoch': 0.74}
{'loss': 0.1516, 'grad_norm': 0.5781682729721069, 'learning_rate': 1.9081866870696254e-05, 'epoch': 0.81}
{'loss': 0.1484, 'grad_norm': 0.5314978957176208, 'learning_rate': 1.8928844682478963e-05, 'epoch': 0.88}
{'loss': 0.1395, 'grad_norm': 1.2173354625701904, 'learning_rate': 1.877582249426167e-05, 'epoch': 0.96}
{'loss': 0.1332, 'grad_norm': 0.559838056564331, 'learning_rate': 1.862280030604438e-05, 'epoch': 1.03}
{'loss': 0.1286, 'grad_norm': 1.4240052700042725, 'learning_rate': 1.8469778117827088e-05, 'epoch': 1.11}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.1207302138209343, 'eval_accuracy': 0.9625151387499342, 'eval_f1': 0.17005875448726687, 'eval_precision': 0.2504385464155222, 'eval_recall': 0.1487460815047022, 'eval_runtime': 3.5621, 'eval_samples_per_second': 1523.268, 'eval_steps_per_second': 95.45, 'epoch': 1.11}
{'loss': 0.1232, 'grad_norm': 1.436185359954834, 'learning_rate': 1.8316755929609794e-05, 'epoch': 1.18}
{'loss': 0.1242, 'grad_norm': 0.5424305200576782, 'learning_rate': 1.8163733741392503e-05, 'epoch': 1.25}
{'loss': 0.114, 'grad_norm': 0.6921884417533875, 'learning_rate': 1.8010711553175212e-05, 'epoch': 1.33}
{'loss': 0.1152, 'grad_norm': 0.5870735049247742, 'learning_rate': 1.7857689364957918e-05, 'epoch': 1.4}
{'loss': 0.1194, 'grad_norm': 0.4806309640407562, 'learning_rate': 1.7704667176740627e-05, 'epoch': 1.47}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.11288198828697205, 'eval_accuracy': 0.9631667632036228, 'eval_f1': 0.2262408586707799, 'eval_precision': 0.31782319477191967, 'eval_recall': 0.19655172413793104, 'eval_runtime': 3.1602, 'eval_samples_per_second': 1716.959, 'eval_steps_per_second': 107.587, 'epoch': 1.47}
{'loss': 0.1162, 'grad_norm': 0.7124719619750977, 'learning_rate': 1.7551644988523337e-05, 'epoch': 1.55}
{'loss': 0.111, 'grad_norm': 0.7519200444221497, 'learning_rate': 1.7398622800306046e-05, 'epoch': 1.62}
{'loss': 0.112, 'grad_norm': 0.7556813955307007, 'learning_rate': 1.7245600612088755e-05, 'epoch': 1.69}
{'loss': 0.1104, 'grad_norm': 0.7455730438232422, 'learning_rate': 1.709257842387146e-05, 'epoch': 1.77}
{'loss': 0.1071, 'grad_norm': 0.7523863315582275, 'learning_rate': 1.693955623565417e-05, 'epoch': 1.84}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.10747367143630981, 'eval_accuracy': 0.9635880153757043, 'eval_f1': 0.2592986780553055, 'eval_precision': 0.37565926803576194, 'eval_recall': 0.22884012539184953, 'eval_runtime': 3.0531, 'eval_samples_per_second': 1777.231, 'eval_steps_per_second': 111.364, 'epoch': 1.84}
{'loss': 0.107, 'grad_norm': 0.6564183235168457, 'learning_rate': 1.678653404743688e-05, 'epoch': 1.92}
{'loss': 0.1146, 'grad_norm': 0.5673279166221619, 'learning_rate': 1.663351185921959e-05, 'epoch': 1.99}
{'loss': 0.0978, 'grad_norm': 3.1949009895324707, 'learning_rate': 1.64804896710023e-05, 'epoch': 2.06}
{'loss': 0.1011, 'grad_norm': 0.6810345649719238, 'learning_rate': 1.6327467482785004e-05, 'epoch': 2.14}
{'loss': 0.1025, 'grad_norm': 0.7549296021461487, 'learning_rate': 1.6174445294567714e-05, 'epoch': 2.21}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.10755860060453415, 'eval_accuracy': 0.9638381338528776, 'eval_f1': 0.27454460043395656, 'eval_precision': 0.37438415621530563, 'eval_recall': 0.2421630094043887, 'eval_runtime': 3.3345, 'eval_samples_per_second': 1627.244, 'eval_steps_per_second': 101.965, 'epoch': 2.21}
{'loss': 0.1039, 'grad_norm': 0.8730422854423523, 'learning_rate': 1.6021423106350423e-05, 'epoch': 2.28}
{'loss': 0.0993, 'grad_norm': 0.5700458884239197, 'learning_rate': 1.5868400918133132e-05, 'epoch': 2.36}
{'loss': 0.0989, 'grad_norm': 3.8379528522491455, 'learning_rate': 1.571537872991584e-05, 'epoch': 2.43}
{'loss': 0.1002, 'grad_norm': 0.7672767639160156, 'learning_rate': 1.5562356541698547e-05, 'epoch': 2.51}
{'loss': 0.1025, 'grad_norm': 1.6742726564407349, 'learning_rate': 1.5409334353481257e-05, 'epoch': 2.58}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.10461688041687012, 'eval_accuracy': 0.9631338528776789, 'eval_f1': 0.28210003942312384, 'eval_precision': 0.458256100918088, 'eval_recall': 0.2584639498432602, 'eval_runtime': 3.0527, 'eval_samples_per_second': 1777.457, 'eval_steps_per_second': 111.378, 'epoch': 2.58}
{'loss': 0.0968, 'grad_norm': 0.8232335448265076, 'learning_rate': 1.5256312165263964e-05, 'epoch': 2.65}
{'loss': 0.1011, 'grad_norm': 2.3566648960113525, 'learning_rate': 1.5103289977046673e-05, 'epoch': 2.73}
{'loss': 0.0956, 'grad_norm': 0.832960844039917, 'learning_rate': 1.4950267788829383e-05, 'epoch': 2.8}
{'loss': 0.0914, 'grad_norm': 0.8221614360809326, 'learning_rate': 1.4797245600612089e-05, 'epoch': 2.87}
{'loss': 0.0947, 'grad_norm': 0.5743135213851929, 'learning_rate': 1.4644223412394798e-05, 'epoch': 2.95}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.10491683334112167, 'eval_accuracy': 0.9640553420041073, 'eval_f1': 0.2902804350891198, 'eval_precision': 0.46553575430258715, 'eval_recall': 0.2573667711598746, 'eval_runtime': 3.1469, 'eval_samples_per_second': 1724.216, 'eval_steps_per_second': 108.042, 'epoch': 2.95}
{'loss': 0.0946, 'grad_norm': 0.8363548517227173, 'learning_rate': 1.449273144605968e-05, 'epoch': 3.02}
{'loss': 0.0867, 'grad_norm': 1.017137050628662, 'learning_rate': 1.433970925784239e-05, 'epoch': 3.1}
{'loss': 0.088, 'grad_norm': 1.0113731622695923, 'learning_rate': 1.4186687069625096e-05, 'epoch': 3.17}
{'loss': 0.0886, 'grad_norm': 2.0993387699127197, 'learning_rate': 1.4033664881407805e-05, 'epoch': 3.24}
{'loss': 0.086, 'grad_norm': 0.6556255221366882, 'learning_rate': 1.3880642693190514e-05, 'epoch': 3.32}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.10529347509145737, 'eval_accuracy': 0.9640092675477858, 'eval_f1': 0.30512515208298435, 'eval_precision': 0.5168086342442652, 'eval_recall': 0.2736677115987461, 'eval_runtime': 3.2615, 'eval_samples_per_second': 1663.646, 'eval_steps_per_second': 104.246, 'epoch': 3.32}
{'loss': 0.0884, 'grad_norm': 0.6325908303260803, 'learning_rate': 1.3727620504973222e-05, 'epoch': 3.39}
{'loss': 0.0857, 'grad_norm': 0.8645972609519958, 'learning_rate': 1.3576128538638102e-05, 'epoch': 3.46}
{'loss': 0.0895, 'grad_norm': 1.9085578918457031, 'learning_rate': 1.3423106350420812e-05, 'epoch': 3.54}
{'loss': 0.0864, 'grad_norm': 1.0664838552474976, 'learning_rate': 1.3270084162203521e-05, 'epoch': 3.61}
{'loss': 0.0897, 'grad_norm': 0.9331812858581543, 'learning_rate': 1.3117061973986229e-05, 'epoch': 3.68}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.10429470986127853, 'eval_accuracy': 0.9636275077668369, 'eval_f1': 0.3085846409611024, 'eval_precision': 0.7449155957468633, 'eval_recall': 0.2818181818181818, 'eval_runtime': 3.4146, 'eval_samples_per_second': 1589.077, 'eval_steps_per_second': 99.574, 'epoch': 3.68}
{'loss': 0.0884, 'grad_norm': 0.6646069288253784, 'learning_rate': 1.2964039785768938e-05, 'epoch': 3.76}
{'loss': 0.0902, 'grad_norm': 1.5386028289794922, 'learning_rate': 1.2811017597551647e-05, 'epoch': 3.83}
{'loss': 0.0864, 'grad_norm': 0.8018712401390076, 'learning_rate': 1.2657995409334353e-05, 'epoch': 3.91}
{'loss': 0.0889, 'grad_norm': 1.051065444946289, 'learning_rate': 1.2504973221117062e-05, 'epoch': 3.98}
{'loss': 0.0789, 'grad_norm': 0.7616689205169678, 'learning_rate': 1.2351951032899772e-05, 'epoch': 4.05}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.1061759889125824, 'eval_accuracy': 0.9636933284187247, 'eval_f1': 0.3172998005003649, 'eval_precision': 0.754852035703854, 'eval_recall': 0.2874608150470219, 'eval_runtime': 3.1828, 'eval_samples_per_second': 1704.801, 'eval_steps_per_second': 106.825, 'epoch': 4.05}
{'loss': 0.0782, 'grad_norm': 0.7168601155281067, 'learning_rate': 1.2198928844682481e-05, 'epoch': 4.13}
{'loss': 0.0766, 'grad_norm': 0.8667905926704407, 'learning_rate': 1.204590665646519e-05, 'epoch': 4.2}
{'loss': 0.08, 'grad_norm': 1.0344047546386719, 'learning_rate': 1.1892884468247896e-05, 'epoch': 4.27}
{'loss': 0.0775, 'grad_norm': 0.7376607060432434, 'learning_rate': 1.1739862280030605e-05, 'epoch': 4.35}
{'loss': 0.0787, 'grad_norm': 0.7732806205749512, 'learning_rate': 1.1586840091813315e-05, 'epoch': 4.42}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.10806291550397873, 'eval_accuracy': 0.9637328208098573, 'eval_f1': 0.32210581508302616, 'eval_precision': 0.7414075524171891, 'eval_recall': 0.29169278996865206, 'eval_runtime': 3.069, 'eval_samples_per_second': 1768.022, 'eval_steps_per_second': 110.787, 'epoch': 4.42}
{'loss': 0.0775, 'grad_norm': 3.6777687072753906, 'learning_rate': 1.1433817903596022e-05, 'epoch': 4.5}
{'loss': 0.0795, 'grad_norm': 0.731162428855896, 'learning_rate': 1.1280795715378732e-05, 'epoch': 4.57}
{'loss': 0.0782, 'grad_norm': 1.0120034217834473, 'learning_rate': 1.112777352716144e-05, 'epoch': 4.64}
{'loss': 0.0805, 'grad_norm': 2.688056707382202, 'learning_rate': 1.0974751338944147e-05, 'epoch': 4.72}
{'loss': 0.0752, 'grad_norm': 0.7908505797386169, 'learning_rate': 1.0821729150726856e-05, 'epoch': 4.79}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.11027836799621582, 'eval_accuracy': 0.9642330577642041, 'eval_f1': 0.31604541860635393, 'eval_precision': 0.744119464744957, 'eval_recall': 0.2753918495297806, 'eval_runtime': 3.0016, 'eval_samples_per_second': 1807.696, 'eval_steps_per_second': 113.272, 'epoch': 4.79}
{'loss': 0.0774, 'grad_norm': 1.2842286825180054, 'learning_rate': 1.0668706962509565e-05, 'epoch': 4.86}
{'loss': 0.0797, 'grad_norm': 1.2963978052139282, 'learning_rate': 1.0515684774292275e-05, 'epoch': 4.94}
{'loss': 0.0768, 'grad_norm': 0.8828572034835815, 'learning_rate': 1.036266258607498e-05, 'epoch': 5.01}
{'loss': 0.0731, 'grad_norm': 0.8276371359825134, 'learning_rate': 1.020964039785769e-05, 'epoch': 5.08}
{'loss': 0.0709, 'grad_norm': 0.9744724631309509, 'learning_rate': 1.0056618209640399e-05, 'epoch': 5.16}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.1106201782822609, 'eval_accuracy': 0.9637394028750461, 'eval_f1': 0.32445732912591935, 'eval_precision': 0.7199697956590252, 'eval_recall': 0.285423197492163, 'eval_runtime': 3.0601, 'eval_samples_per_second': 1773.165, 'eval_steps_per_second': 111.109, 'epoch': 5.16}
{'loss': 0.0678, 'grad_norm': 0.8326922655105591, 'learning_rate': 9.903596021423107e-06, 'epoch': 5.23}
{'loss': 0.069, 'grad_norm': 0.8306703567504883, 'learning_rate': 9.750573833205816e-06, 'epoch': 5.31}
{'loss': 0.0712, 'grad_norm': 0.7732207179069519, 'learning_rate': 9.597551644988524e-06, 'epoch': 5.38}
{'loss': 0.0696, 'grad_norm': 1.0576950311660767, 'learning_rate': 9.444529456771233e-06, 'epoch': 5.45}
{'loss': 0.0703, 'grad_norm': 0.6791108846664429, 'learning_rate': 9.29150726855394e-06, 'epoch': 5.53}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.11293172836303711, 'eval_accuracy': 0.9633247327681533, 'eval_f1': 0.3259080257516103, 'eval_precision': 0.6605392249239089, 'eval_recall': 0.29169278996865206, 'eval_runtime': 3.023, 'eval_samples_per_second': 1794.886, 'eval_steps_per_second': 112.47, 'epoch': 5.53}
{'loss': 0.0719, 'grad_norm': 0.9457306861877441, 'learning_rate': 9.140015302218823e-06, 'epoch': 5.6}
{'loss': 0.0691, 'grad_norm': 0.7998746633529663, 'learning_rate': 8.98699311400153e-06, 'epoch': 5.67}
{'loss': 0.0681, 'grad_norm': 0.8884339928627014, 'learning_rate': 8.83397092578424e-06, 'epoch': 5.75}
{'loss': 0.0708, 'grad_norm': 0.9116162061691284, 'learning_rate': 8.680948737566947e-06, 'epoch': 5.82}
{'loss': 0.0667, 'grad_norm': 0.9349080920219421, 'learning_rate': 8.527926549349657e-06, 'epoch': 5.9}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.11200952529907227, 'eval_accuracy': 0.9632523300510768, 'eval_f1': 0.3296824235123006, 'eval_precision': 0.6768366796868893, 'eval_recall': 0.29968652037617555, 'eval_runtime': 3.1276, 'eval_samples_per_second': 1734.897, 'eval_steps_per_second': 108.711, 'epoch': 5.9}
{'loss': 0.0716, 'grad_norm': 1.0005426406860352, 'learning_rate': 8.376434583014537e-06, 'epoch': 5.97}
{'loss': 0.0627, 'grad_norm': 0.9553356766700745, 'learning_rate': 8.223412394797247e-06, 'epoch': 6.04}
{'loss': 0.0621, 'grad_norm': 0.8453904390335083, 'learning_rate': 8.070390206579954e-06, 'epoch': 6.12}
{'loss': 0.0621, 'grad_norm': 1.1616870164871216, 'learning_rate': 7.917368018362664e-06, 'epoch': 6.19}
{'loss': 0.0635, 'grad_norm': 0.8313819766044617, 'learning_rate': 7.764345830145373e-06, 'epoch': 6.26}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.11636032164096832, 'eval_accuracy': 0.9636735822231584, 'eval_f1': 0.34634791856164254, 'eval_precision': 0.6691899065839136, 'eval_recall': 0.30595611285266455, 'eval_runtime': 3.0579, 'eval_samples_per_second': 1774.431, 'eval_steps_per_second': 111.188, 'epoch': 6.26}
{'loss': 0.0637, 'grad_norm': 0.8224002718925476, 'learning_rate': 7.61132364192808e-06, 'epoch': 6.34}
{'loss': 0.0652, 'grad_norm': 1.4366868734359741, 'learning_rate': 7.458301453710789e-06, 'epoch': 6.41}
{'loss': 0.0635, 'grad_norm': 1.3260701894760132, 'learning_rate': 7.3052792654934965e-06, 'epoch': 6.48}
{'loss': 0.0629, 'grad_norm': 0.8972442746162415, 'learning_rate': 7.152257077276206e-06, 'epoch': 6.56}
{'loss': 0.0627, 'grad_norm': 0.9457471370697021, 'learning_rate': 6.999234889058914e-06, 'epoch': 6.63}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.11768282204866409, 'eval_accuracy': 0.963917118635143, 'eval_f1': 0.35009827047246556, 'eval_precision': 0.6754975441322976, 'eval_recall': 0.30689655172413793, 'eval_runtime': 3.0966, 'eval_samples_per_second': 1752.231, 'eval_steps_per_second': 109.797, 'epoch': 6.63}
{'loss': 0.0635, 'grad_norm': 0.9568567276000977, 'learning_rate': 6.846212700841623e-06, 'epoch': 6.71}
{'loss': 0.0652, 'grad_norm': 1.254683017730713, 'learning_rate': 6.693190512624331e-06, 'epoch': 6.78}
{'loss': 0.0634, 'grad_norm': 1.0171293020248413, 'learning_rate': 6.540168324407039e-06, 'epoch': 6.85}
{'loss': 0.0664, 'grad_norm': 0.6855462193489075, 'learning_rate': 6.387146136189748e-06, 'epoch': 6.93}
{'loss': 0.0599, 'grad_norm': 1.1485294103622437, 'learning_rate': 6.234123947972457e-06, 'epoch': 7.0}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.11977747082710266, 'eval_accuracy': 0.9635748512453267, 'eval_f1': 0.3492267822393212, 'eval_precision': 0.6692085544106906, 'eval_recall': 0.30626959247648905, 'eval_runtime': 3.0858, 'eval_samples_per_second': 1758.357, 'eval_steps_per_second': 110.181, 'epoch': 7.0}
{'loss': 0.0582, 'grad_norm': 0.8502678275108337, 'learning_rate': 6.081101759755165e-06, 'epoch': 7.07}
{'loss': 0.0585, 'grad_norm': 3.2378292083740234, 'learning_rate': 5.928079571537873e-06, 'epoch': 7.15}
{'loss': 0.0583, 'grad_norm': 0.898062527179718, 'learning_rate': 5.775057383320582e-06, 'epoch': 7.22}
{'loss': 0.0581, 'grad_norm': 1.072053074836731, 'learning_rate': 5.62203519510329e-06, 'epoch': 7.3}
{'loss': 0.0599, 'grad_norm': 4.647213935852051, 'learning_rate': 5.4690130068859995e-06, 'epoch': 7.37}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.1193893775343895, 'eval_accuracy': 0.9629232267916381, 'eval_f1': 0.3578844579199506, 'eval_precision': 0.6520839224277236, 'eval_recall': 0.31880877742946706, 'eval_runtime': 3.1643, 'eval_samples_per_second': 1714.734, 'eval_steps_per_second': 107.447, 'epoch': 7.37}
{'loss': 0.0577, 'grad_norm': 0.8279275894165039, 'learning_rate': 5.315990818668707e-06, 'epoch': 7.44}
{'loss': 0.0568, 'grad_norm': 0.8180505037307739, 'learning_rate': 5.162968630451416e-06, 'epoch': 7.52}
{'loss': 0.0593, 'grad_norm': 1.1571886539459229, 'learning_rate': 5.009946442234125e-06, 'epoch': 7.59}
{'loss': 0.0575, 'grad_norm': 1.0878573656082153, 'learning_rate': 4.856924254016832e-06, 'epoch': 7.66}
{'loss': 0.0587, 'grad_norm': 0.843271791934967, 'learning_rate': 4.703902065799542e-06, 'epoch': 7.74}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.12098520249128342, 'eval_accuracy': 0.9631535990732453, 'eval_f1': 0.3626325683616767, 'eval_precision': 0.6573126757543235, 'eval_recall': 0.32147335423197493, 'eval_runtime': 3.2187, 'eval_samples_per_second': 1685.8, 'eval_steps_per_second': 105.634, 'epoch': 7.74}
{'loss': 0.0562, 'grad_norm': 1.0744048357009888, 'learning_rate': 4.55087987758225e-06, 'epoch': 7.81}
{'loss': 0.057, 'grad_norm': 1.2183935642242432, 'learning_rate': 4.3978576893649585e-06, 'epoch': 7.89}
{'loss': 0.0578, 'grad_norm': 1.472259283065796, 'learning_rate': 4.244835501147667e-06, 'epoch': 7.96}
{'loss': 0.0569, 'grad_norm': 0.8545963168144226, 'learning_rate': 4.091813312930375e-06, 'epoch': 8.03}
{'loss': 0.0525, 'grad_norm': 0.6764536499977112, 'learning_rate': 3.938791124713084e-06, 'epoch': 8.11}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.1238512396812439, 'eval_accuracy': 0.9632589121162656, 'eval_f1': 0.3673380775657221, 'eval_precision': 0.6462127204600816, 'eval_recall': 0.3221003134796238, 'eval_runtime': 3.0772, 'eval_samples_per_second': 1763.315, 'eval_steps_per_second': 110.492, 'epoch': 8.11}
{'loss': 0.0555, 'grad_norm': 1.091623306274414, 'learning_rate': 3.7857689364957923e-06, 'epoch': 8.18}
{'loss': 0.0528, 'grad_norm': 0.9153516888618469, 'learning_rate': 3.6327467482785008e-06, 'epoch': 8.25}
{'loss': 0.0556, 'grad_norm': 1.3724726438522339, 'learning_rate': 3.4797245600612088e-06, 'epoch': 8.33}
{'loss': 0.0535, 'grad_norm': 1.564528465270996, 'learning_rate': 3.326702371843918e-06, 'epoch': 8.4}
{'loss': 0.0528, 'grad_norm': 1.2574080228805542, 'learning_rate': 3.173680183626626e-06, 'epoch': 8.47}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.12537036836147308, 'eval_accuracy': 0.9637854773313674, 'eval_f1': 0.3737334651470462, 'eval_precision': 0.6528838405882988, 'eval_recall': 0.3236677115987461, 'eval_runtime': 3.1325, 'eval_samples_per_second': 1732.166, 'eval_steps_per_second': 108.54, 'epoch': 8.47}
{'loss': 0.0536, 'grad_norm': 1.4430060386657715, 'learning_rate': 3.0206579954093345e-06, 'epoch': 8.55}
{'loss': 0.0546, 'grad_norm': 1.3935166597366333, 'learning_rate': 2.867635807192043e-06, 'epoch': 8.62}
{'loss': 0.0547, 'grad_norm': 0.840077817440033, 'learning_rate': 2.7146136189747514e-06, 'epoch': 8.7}
{'loss': 0.0562, 'grad_norm': 1.0477360486984253, 'learning_rate': 2.5615914307574603e-06, 'epoch': 8.77}
{'loss': 0.054, 'grad_norm': 1.077167272567749, 'learning_rate': 2.4085692425401687e-06, 'epoch': 8.84}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.12626788020133972, 'eval_accuracy': 0.9638315517876889, 'eval_f1': 0.37442596660949495, 'eval_precision': 0.6509328972240173, 'eval_recall': 0.32241379310344825, 'eval_runtime': 3.363, 'eval_samples_per_second': 1613.453, 'eval_steps_per_second': 101.101, 'epoch': 8.84}
{'loss': 0.0557, 'grad_norm': 0.918236494064331, 'learning_rate': 2.2555470543228767e-06, 'epoch': 8.92}
{'loss': 0.0556, 'grad_norm': 0.7294603586196899, 'learning_rate': 2.1025248661055856e-06, 'epoch': 8.99}
{'loss': 0.0514, 'grad_norm': 0.9611464738845825, 'learning_rate': 1.949502677888294e-06, 'epoch': 9.06}
{'loss': 0.0505, 'grad_norm': 0.9171906113624573, 'learning_rate': 1.7964804896710025e-06, 'epoch': 9.14}
{'loss': 0.0523, 'grad_norm': 0.6333191394805908, 'learning_rate': 1.643458301453711e-06, 'epoch': 9.21}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.12523885071277618, 'eval_accuracy': 0.963245747985888, 'eval_f1': 0.37286978639494384, 'eval_precision': 0.6453297486434203, 'eval_recall': 0.3274294670846395, 'eval_runtime': 2.9989, 'eval_samples_per_second': 1809.33, 'eval_steps_per_second': 113.375, 'epoch': 9.21}
{'loss': 0.0529, 'grad_norm': 1.330005407333374, 'learning_rate': 1.4904361132364196e-06, 'epoch': 9.29}
{'loss': 0.051, 'grad_norm': 1.129470705986023, 'learning_rate': 1.3374139250191278e-06, 'epoch': 9.36}
{'loss': 0.0525, 'grad_norm': 1.4215322732925415, 'learning_rate': 1.1843917368018365e-06, 'epoch': 9.43}
{'loss': 0.053, 'grad_norm': 1.2234958410263062, 'learning_rate': 1.031369548584545e-06, 'epoch': 9.51}
{'loss': 0.051, 'grad_norm': 1.0394752025604248, 'learning_rate': 8.783473603672533e-07, 'epoch': 9.58}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.12615908682346344, 'eval_accuracy': 0.9632589121162656, 'eval_f1': 0.37310288232502486, 'eval_precision': 0.640266833217612, 'eval_recall': 0.325705329153605, 'eval_runtime': 3.2342, 'eval_samples_per_second': 1677.71, 'eval_steps_per_second': 105.127, 'epoch': 9.58}
{'loss': 0.0516, 'grad_norm': 1.1355153322219849, 'learning_rate': 7.253251721499618e-07, 'epoch': 9.65}
{'loss': 0.0508, 'grad_norm': 2.1274962425231934, 'learning_rate': 5.723029839326702e-07, 'epoch': 9.73}
{'loss': 0.0527, 'grad_norm': 1.015350341796875, 'learning_rate': 4.192807957153788e-07, 'epoch': 9.8}
{'loss': 0.0528, 'grad_norm': 1.2331136465072632, 'learning_rate': 2.662586074980873e-07, 'epoch': 9.87}
{'loss': 0.0512, 'grad_norm': 1.1303362846374512, 'learning_rate': 1.1323641928079573e-07, 'epoch': 9.95}


  0%|          | 0/340 [00:00<?, ?it/s]

{'eval_loss': 0.1265140324831009, 'eval_accuracy': 0.9634695382023064, 'eval_f1': 0.3760206364743733, 'eval_precision': 0.6413025386459763, 'eval_recall': 0.32601880877742945, 'eval_runtime': 3.0425, 'eval_samples_per_second': 1783.38, 'eval_steps_per_second': 111.749, 'epoch': 9.95}
{'train_runtime': 1015.1144, 'train_samples_per_second': 427.637, 'train_steps_per_second': 13.368, 'train_loss': 0.09082458361862508, 'epoch': 10.0}


TrainOutput(global_step=13570, training_loss=0.09082458361862508, metrics={'train_runtime': 1015.1144, 'train_samples_per_second': 427.637, 'train_steps_per_second': 13.368, 'total_flos': 1.43909582570496e+16, 'train_loss': 0.09082458361862508, 'epoch': 10.0})

### Evaluate Model

In [16]:
# Evaluate on test set
test_results = trainer.evaluate(custom_test_dataset)
print(f"Test results: {test_results}")

  0%|          | 0/340 [00:00<?, ?it/s]

Test results: {'eval_loss': 0.1243186742067337, 'eval_accuracy': 0.9636868567216826, 'eval_f1': 0.37498611678788357, 'eval_precision': 0.6347143582701688, 'eval_recall': 0.3220097961763312, 'eval_runtime': 3.3762, 'eval_samples_per_second': 1607.427, 'eval_steps_per_second': 100.705, 'epoch': 10.0}
