### Importing require libraries

In [30]:
# Import necessary libraries
import torch
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

In [31]:
# Folder containing Excel files
folder_path = r'data\labeldata.parquet'

data = pd.read_parquet(folder_path)
data.head()

Unnamed: 0,id,subject,emailtext,label
0,738a2c78-e9b7-4886-8a1b-7af4c08f3906,RE: Yates Residence - Mechanical Contractor Me...,"Zac Stevenson, PEPRINCIPAL | BUILDING MEPC . ....",0
1,23e4febf-364f-464a-9a3a-80f9d89d0e6b,Re: 6771 - VeLa - Level 4 Slab Penetration Plans,You tooRespectfullyMichael BentleyGeneral Fore...,0
2,f0789ea4-0a5f-46c3-8adb-083766b63740,RE: FW: 13027A - 20240603 - Prl - Columbia + A...,You can update the two 12 ducts and adjust on ...,0
3,09474b0b-bc7f-4a1f-ab95-0ed2c497468a,RE: West Zephyrhills Elementary - Pasco County...,"You can start with Area A. Proceeded by B,C,D ...",0
4,b7438359-ce6c-42f7-bcfd-5690e875dcda,RE: MLW_0001_CRS_24410_ Connacht Stadium - Exi...,You can hold off on doing this for now.,0


In [32]:
df = data[['emailtext', 'label']].copy()
df = df.rename({'emailtext': 'text'}, axis=1)
df.head()

Unnamed: 0,text,label
0,"Zac Stevenson, PEPRINCIPAL | BUILDING MEPC . ....",0
1,You tooRespectfullyMichael BentleyGeneral Fore...,0
2,You can update the two 12 ducts and adjust on ...,0
3,"You can start with Area A. Proceeded by B,C,D ...",0
4,You can hold off on doing this for now.,0


In [33]:
df.label.value_counts()

label
0    5798
1      55
Name: count, dtype: int64

In [34]:
# 2. Split data into training and validation sets
train_df, val_df = train_test_split(
    df,
    test_size=0.1,
    stratify=df['label'],  # Ensures the split is stratified
    random_state=42
)

In [35]:
# Reset indices
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

In [36]:
# 3. Calculate class weights
# Assign class weights as variables in the code
labels = train_df['label'].values  # Assuming labels are 0 and 1

In [37]:
# Compute class weights using sklearn's compute_class_weight function
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(labels),
    y=labels
)

# # Convert class weights to a tensor
# class_weights = torch.tensor(class_weights, dtype=torch.float)
class_weights = class_weights.tolist()

# Print class weights for verification
print(f"Class Weights: {class_weights}")

Class Weights: [0.5046952855500192, 53.744897959183675]


In [38]:
# 4. Define the custom dataset
class EmailDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts.iloc[idx])
        label = self.labels.iloc[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),  # Flatten to remove extra dimension
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [39]:
# 5. Initialize tokenizer and model configuration
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
max_length = 128
num_labels = len(np.unique(labels))  # Should be 2 for labels 0 and 1

In [40]:
# Load the configuration and set class weights as a variable in the config
config = BertConfig.from_pretrained('bert-base-uncased', num_labels=num_labels)
config.class_weights = class_weights  # Assign class_weights to the config

In [41]:
from transformers.modeling_outputs import SequenceClassifierOutput

class WeightedBertForSequenceClassification(BertForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.class_weights = torch.tensor(config.class_weights, dtype=torch.float)

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        # Get the outputs from the base model without computing the default loss
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=None,  # Avoid default loss computation
            **kwargs
        )
        logits = outputs.logits

        # Compute custom loss with class weights
        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss(weight=self.class_weights.to(logits.device))
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        # Return outputs as a SequenceClassifierOutput
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [42]:
# 7. Instantiate the model using the custom class and config
model = WeightedBertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    config=config
)

Some weights of WeightedBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [43]:
# 8. Create training and validation datasets
train_dataset = EmailDataset(
    texts=train_df['text'],
    labels=train_df['label'],
    tokenizer=tokenizer,
    max_length=max_length
)

val_dataset = EmailDataset(
    texts=val_df['text'],
    labels=val_df['label'],
    tokenizer=tokenizer,
    max_length=max_length
)

In [44]:

# 9. Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',           # Output directory
    num_train_epochs=2,               # Adjust number of epochs as needed
    per_device_train_batch_size=5,    # Batch size for training
    per_device_eval_batch_size=5,     # Batch size for evaluation
    evaluation_strategy='epoch',      # Evaluate every epoch
    save_strategy='epoch',            # Save model every epoch
    logging_dir='./logs',             # Directory for logs
    logging_steps=10,
    load_best_model_at_end=True,      # Load the best model at the end of training
    metric_for_best_model='f1',       # Use F1 score to select best model
    greater_is_better=True,           # Higher F1 score is better
)



In [45]:
# 10. Define the compute_metrics function
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average='binary', zero_division=1
    )
    acc = accuracy_score(labels, preds)

    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [46]:

# 11. Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

In [47]:
# 12. Train the model
trainer.train()

 20%|██        | 659/3295 [2:45:23<11:01:35, 15.06s/it]
                                                 
  0%|          | 10/2108 [00:14<50:04,  1.43s/it]  

{'loss': 0.2604, 'grad_norm': 0.9774286150932312, 'learning_rate': 4.976280834914611e-05, 'epoch': 0.01}


                                                   
  1%|          | 20/2108 [00:32<1:06:17,  1.91s/it]

{'loss': 0.0123, 'grad_norm': 0.09739227592945099, 'learning_rate': 4.952561669829222e-05, 'epoch': 0.02}


                                                   
  1%|▏         | 30/2108 [00:52<1:09:43,  2.01s/it]

{'loss': 0.6563, 'grad_norm': 0.07446282356977463, 'learning_rate': 4.928842504743833e-05, 'epoch': 0.03}


                                                   
  2%|▏         | 40/2108 [01:12<1:07:14,  1.95s/it]

{'loss': 0.0018, 'grad_norm': 0.03175177425146103, 'learning_rate': 4.9051233396584444e-05, 'epoch': 0.04}


                                                   
  2%|▏         | 50/2108 [01:28<52:45,  1.54s/it]  

{'loss': 0.0009, 'grad_norm': 0.02368052490055561, 'learning_rate': 4.8814041745730554e-05, 'epoch': 0.05}


                                                 
  3%|▎         | 60/2108 [01:43<50:46,  1.49s/it]  

{'loss': 0.7395, 'grad_norm': 0.032671008259058, 'learning_rate': 4.8576850094876664e-05, 'epoch': 0.06}


                                                 
  3%|▎         | 70/2108 [01:58<49:57,  1.47s/it]  

{'loss': 0.0011, 'grad_norm': 0.026606006547808647, 'learning_rate': 4.8339658444022774e-05, 'epoch': 0.07}


                                                 
  4%|▍         | 80/2108 [02:12<49:19,  1.46s/it]  

{'loss': 1.816, 'grad_norm': 29.88258934020996, 'learning_rate': 4.8102466793168885e-05, 'epoch': 0.08}


                                                 
  4%|▍         | 90/2108 [02:27<49:16,  1.47s/it]  

{'loss': 0.0081, 'grad_norm': 0.13835358619689941, 'learning_rate': 4.786527514231499e-05, 'epoch': 0.09}


                                                  
  5%|▍         | 100/2108 [02:42<50:43,  1.52s/it] 

{'loss': 0.5175, 'grad_norm': 0.09101822227239609, 'learning_rate': 4.76280834914611e-05, 'epoch': 0.09}


                                                  
  5%|▌         | 110/2108 [02:57<48:37,  1.46s/it] 

{'loss': 0.0019, 'grad_norm': 0.04516149312257767, 'learning_rate': 4.7390891840607216e-05, 'epoch': 0.1}


                                                  
  6%|▌         | 120/2108 [03:11<48:08,  1.45s/it] 

{'loss': 0.6615, 'grad_norm': 0.04062328860163689, 'learning_rate': 4.7153700189753326e-05, 'epoch': 0.11}


                                                  
  6%|▌         | 130/2108 [03:26<48:05,  1.46s/it] 

{'loss': 0.0018, 'grad_norm': 0.04701231047511101, 'learning_rate': 4.6916508538899436e-05, 'epoch': 0.12}


                                                  
  7%|▋         | 140/2108 [03:41<48:17,  1.47s/it] 

{'loss': 0.6921, 'grad_norm': 0.0632542222738266, 'learning_rate': 4.6679316888045547e-05, 'epoch': 0.13}


                                                    
  7%|▋         | 150/2108 [04:00<1:03:20,  1.94s/it]

{'loss': 0.002, 'grad_norm': 0.04176236689090729, 'learning_rate': 4.644212523719165e-05, 'epoch': 0.14}


                                                    
  8%|▊         | 160/2108 [04:19<1:00:31,  1.86s/it]

{'loss': 1.2375, 'grad_norm': 0.11850539594888687, 'learning_rate': 4.620493358633776e-05, 'epoch': 0.15}


                                                    
  8%|▊         | 170/2108 [04:37<57:59,  1.80s/it] 

{'loss': 0.0036, 'grad_norm': 0.08815610408782959, 'learning_rate': 4.596774193548387e-05, 'epoch': 0.16}


                                                  
  9%|▊         | 180/2108 [04:55<58:12,  1.81s/it] 

{'loss': 0.0019, 'grad_norm': 0.03522692620754242, 'learning_rate': 4.573055028462998e-05, 'epoch': 0.17}


                                                  
  9%|▉         | 190/2108 [05:13<57:36,  1.80s/it] 

{'loss': 0.0009, 'grad_norm': 0.020895402878522873, 'learning_rate': 4.54933586337761e-05, 'epoch': 0.18}


                                                    
  9%|▉         | 200/2108 [05:33<1:06:57,  2.11s/it]

{'loss': 0.0006, 'grad_norm': 0.015118488110601902, 'learning_rate': 4.52561669829222e-05, 'epoch': 0.19}


                                                    
 10%|▉         | 210/2108 [05:53<1:03:52,  2.02s/it]

{'loss': 1.4978, 'grad_norm': 28.019216537475586, 'learning_rate': 4.501897533206831e-05, 'epoch': 0.2}


                                                    
 10%|█         | 220/2108 [06:15<1:05:15,  2.07s/it]

{'loss': 0.629, 'grad_norm': 0.21084845066070557, 'learning_rate': 4.478178368121442e-05, 'epoch': 0.21}


                                                    
 11%|█         | 230/2108 [06:34<1:01:36,  1.97s/it]

{'loss': 0.0038, 'grad_norm': 0.06313268095254898, 'learning_rate': 4.454459203036053e-05, 'epoch': 0.22}


                                                    
 11%|█▏        | 240/2108 [06:55<1:03:14,  2.03s/it]

{'loss': 0.0018, 'grad_norm': 0.04615599289536476, 'learning_rate': 4.430740037950664e-05, 'epoch': 0.23}


                                                    
 12%|█▏        | 250/2108 [07:15<1:04:19,  2.08s/it]

{'loss': 0.6403, 'grad_norm': 0.04691358655691147, 'learning_rate': 4.407020872865275e-05, 'epoch': 0.24}


                                                    
 12%|█▏        | 260/2108 [07:35<58:36,  1.90s/it] 

{'loss': 0.5913, 'grad_norm': 0.06017689406871796, 'learning_rate': 4.383301707779886e-05, 'epoch': 0.25}


                                                    
 13%|█▎        | 270/2108 [07:55<1:03:25,  2.07s/it]

{'loss': 0.0021, 'grad_norm': 0.05681421235203743, 'learning_rate': 4.3595825426944974e-05, 'epoch': 0.26}


                                                    
 13%|█▎        | 280/2108 [08:09<42:59,  1.41s/it] 

{'loss': 0.0017, 'grad_norm': 0.03668248653411865, 'learning_rate': 4.3358633776091084e-05, 'epoch': 0.27}


                                                  
 14%|█▍        | 290/2108 [08:27<58:20,  1.93s/it] 

{'loss': 0.6822, 'grad_norm': 0.026802456006407738, 'learning_rate': 4.3121442125237194e-05, 'epoch': 0.28}


                                                    
 14%|█▍        | 300/2108 [08:43<44:25,  1.47s/it] 

{'loss': 0.6373, 'grad_norm': 0.053860314190387726, 'learning_rate': 4.2884250474383305e-05, 'epoch': 0.28}


                                                  
 15%|█▍        | 310/2108 [08:57<42:10,  1.41s/it] 

{'loss': 0.6136, 'grad_norm': 0.10745895653963089, 'learning_rate': 4.2647058823529415e-05, 'epoch': 0.29}


                                                  
 15%|█▌        | 320/2108 [09:11<41:59,  1.41s/it] 

{'loss': 0.6056, 'grad_norm': 0.0830380767583847, 'learning_rate': 4.240986717267552e-05, 'epoch': 0.3}


                                                  
 16%|█▌        | 330/2108 [09:26<42:16,  1.43s/it] 

{'loss': 0.0029, 'grad_norm': 0.07536138594150543, 'learning_rate': 4.2172675521821635e-05, 'epoch': 0.31}


                                                  
 16%|█▌        | 340/2108 [09:40<42:21,  1.44s/it] 

{'loss': 0.0021, 'grad_norm': 0.0490848533809185, 'learning_rate': 4.1935483870967746e-05, 'epoch': 0.32}


                                                  
 17%|█▋        | 350/2108 [09:54<42:48,  1.46s/it] 

{'loss': 0.0013, 'grad_norm': 0.03763434663414955, 'learning_rate': 4.1698292220113856e-05, 'epoch': 0.33}


                                                  
 17%|█▋        | 360/2108 [10:09<41:52,  1.44s/it] 

{'loss': 0.704, 'grad_norm': 0.024762464687228203, 'learning_rate': 4.1461100569259966e-05, 'epoch': 0.34}


                                                  
 18%|█▊        | 370/2108 [10:23<41:55,  1.45s/it] 

{'loss': 0.0011, 'grad_norm': 0.03857451677322388, 'learning_rate': 4.1223908918406077e-05, 'epoch': 0.35}


                                                  
 18%|█▊        | 380/2108 [10:38<41:24,  1.44s/it] 

{'loss': 0.6962, 'grad_norm': 30.099998474121094, 'learning_rate': 4.098671726755218e-05, 'epoch': 0.36}


                                                  
 19%|█▊        | 390/2108 [10:52<41:07,  1.44s/it] 

{'loss': 0.0013, 'grad_norm': 0.04352223500609398, 'learning_rate': 4.074952561669829e-05, 'epoch': 0.37}


                                                  
 19%|█▉        | 400/2108 [11:07<40:33,  1.43s/it] 

{'loss': 0.0012, 'grad_norm': 0.0395660474896431, 'learning_rate': 4.051233396584441e-05, 'epoch': 0.38}


                                                  
 19%|█▉        | 410/2108 [11:21<41:14,  1.46s/it] 

{'loss': 0.6751, 'grad_norm': 0.03928961977362633, 'learning_rate': 4.027514231499052e-05, 'epoch': 0.39}


                                                  
 20%|█▉        | 420/2108 [11:36<40:37,  1.44s/it] 

{'loss': 1.2239, 'grad_norm': 0.08843176811933517, 'learning_rate': 4.003795066413663e-05, 'epoch': 0.4}


                                                  
 20%|██        | 430/2108 [11:51<42:28,  1.52s/it] 

{'loss': 0.0037, 'grad_norm': 0.08480316400527954, 'learning_rate': 3.980075901328273e-05, 'epoch': 0.41}


                                                  
 21%|██        | 440/2108 [12:06<40:10,  1.45s/it] 

{'loss': 0.5572, 'grad_norm': 0.09386710077524185, 'learning_rate': 3.956356736242884e-05, 'epoch': 0.42}


                                                  
 21%|██▏       | 450/2108 [12:20<39:53,  1.44s/it] 

{'loss': 0.5857, 'grad_norm': 0.061521269381046295, 'learning_rate': 3.932637571157495e-05, 'epoch': 0.43}


                                                  
 22%|██▏       | 460/2108 [12:34<38:54,  1.42s/it] 

{'loss': 1.1141, 'grad_norm': 0.16526469588279724, 'learning_rate': 3.908918406072106e-05, 'epoch': 0.44}


                                                  
 22%|██▏       | 470/2108 [12:49<41:51,  1.53s/it] 

{'loss': 0.0049, 'grad_norm': 0.1027889996767044, 'learning_rate': 3.885199240986718e-05, 'epoch': 0.45}


                                                  
 23%|██▎       | 480/2108 [13:04<39:23,  1.45s/it] 

{'loss': 0.5857, 'grad_norm': 0.07380419224500656, 'learning_rate': 3.861480075901329e-05, 'epoch': 0.46}


                                                  
 23%|██▎       | 490/2108 [13:21<49:57,  1.85s/it] 

{'loss': 0.0021, 'grad_norm': 0.06305351108312607, 'learning_rate': 3.837760910815939e-05, 'epoch': 0.46}


                                                  
 24%|██▎       | 500/2108 [13:40<50:12,  1.87s/it] 

{'loss': 0.6887, 'grad_norm': 0.057813066989183426, 'learning_rate': 3.8140417457305504e-05, 'epoch': 0.47}


                                                  
 24%|██▍       | 510/2108 [13:59<49:51,  1.87s/it] 

{'loss': 0.0016, 'grad_norm': 0.03889109566807747, 'learning_rate': 3.7903225806451614e-05, 'epoch': 0.48}


                                                  
 25%|██▍       | 520/2108 [14:14<36:46,  1.39s/it] 

{'loss': 0.0015, 'grad_norm': 0.03925611078739166, 'learning_rate': 3.7666034155597724e-05, 'epoch': 0.49}


                                                  
 25%|██▌       | 530/2108 [14:28<37:36,  1.43s/it] 

{'loss': 0.0011, 'grad_norm': 0.031482744961977005, 'learning_rate': 3.7428842504743835e-05, 'epoch': 0.5}


                                                  
 26%|██▌       | 540/2108 [14:42<37:05,  1.42s/it] 

{'loss': 0.0009, 'grad_norm': 0.023498691618442535, 'learning_rate': 3.7191650853889945e-05, 'epoch': 0.51}


                                                  
 26%|██▌       | 550/2108 [14:56<36:52,  1.42s/it] 

{'loss': 0.0007, 'grad_norm': 0.02124369889497757, 'learning_rate': 3.6954459203036055e-05, 'epoch': 0.52}


                                                  
 27%|██▋       | 560/2108 [15:11<39:27,  1.53s/it] 

{'loss': 0.6906, 'grad_norm': 0.030022678896784782, 'learning_rate': 3.6717267552182165e-05, 'epoch': 0.53}


                                                  
 27%|██▋       | 570/2108 [15:26<38:26,  1.50s/it] 

{'loss': 0.0009, 'grad_norm': 0.03156564384698868, 'learning_rate': 3.6480075901328276e-05, 'epoch': 0.54}


                                                  
 28%|██▊       | 580/2108 [15:45<48:44,  1.91s/it] 

{'loss': 0.0008, 'grad_norm': 0.025081513449549675, 'learning_rate': 3.6242884250474386e-05, 'epoch': 0.55}


                                                  
 28%|██▊       | 590/2108 [16:04<48:00,  1.90s/it] 

{'loss': 0.0007, 'grad_norm': 0.022426482290029526, 'learning_rate': 3.6005692599620496e-05, 'epoch': 0.56}


                                                  
 28%|██▊       | 600/2108 [16:24<49:13,  1.96s/it] 

{'loss': 0.0006, 'grad_norm': 0.019980160519480705, 'learning_rate': 3.576850094876661e-05, 'epoch': 0.57}


                                                  
 29%|██▉       | 610/2108 [16:43<49:16,  1.97s/it] 

{'loss': 0.7238, 'grad_norm': 0.02171565592288971, 'learning_rate': 3.553130929791271e-05, 'epoch': 0.58}


                                                  
 29%|██▉       | 620/2108 [17:03<50:58,  2.06s/it] 

{'loss': 0.0008, 'grad_norm': 0.022145064547657967, 'learning_rate': 3.529411764705883e-05, 'epoch': 0.59}


                                                  
 30%|██▉       | 630/2108 [17:22<47:21,  1.92s/it]   

{'loss': 0.0008, 'grad_norm': 0.022398443892598152, 'learning_rate': 3.505692599620494e-05, 'epoch': 0.6}


                                                  
 30%|███       | 640/2108 [17:42<48:53,  2.00s/it]   

{'loss': 0.0007, 'grad_norm': 0.019273534417152405, 'learning_rate': 3.481973434535105e-05, 'epoch': 0.61}


                                                  
 31%|███       | 650/2108 [18:00<44:43,  1.84s/it]   

{'loss': 1.4089, 'grad_norm': 0.03673732653260231, 'learning_rate': 3.458254269449716e-05, 'epoch': 0.62}


                                                  
 31%|███▏      | 660/2108 [18:19<44:18,  1.84s/it]   

{'loss': 0.6418, 'grad_norm': 0.09913583099842072, 'learning_rate': 3.434535104364326e-05, 'epoch': 0.63}


                                                  
 32%|███▏      | 670/2108 [18:38<46:58,  1.96s/it]   

{'loss': 0.0024, 'grad_norm': 0.06563019007444382, 'learning_rate': 3.410815939278937e-05, 'epoch': 0.64}


                                                  
 32%|███▏      | 680/2108 [18:58<46:07,  1.94s/it]   

{'loss': 0.5788, 'grad_norm': 0.06766371428966522, 'learning_rate': 3.387096774193548e-05, 'epoch': 0.65}


                                                  
 33%|███▎      | 690/2108 [19:17<45:43,  1.94s/it]   

{'loss': 0.0018, 'grad_norm': 0.057059094309806824, 'learning_rate': 3.36337760910816e-05, 'epoch': 0.65}


                                                  
 33%|███▎      | 700/2108 [19:37<45:55,  1.96s/it]   

{'loss': 0.0015, 'grad_norm': 0.040741726756095886, 'learning_rate': 3.339658444022771e-05, 'epoch': 0.66}


                                                  
 34%|███▎      | 710/2108 [19:55<37:41,  1.62s/it]   

{'loss': 0.6466, 'grad_norm': 0.03961779549717903, 'learning_rate': 3.315939278937382e-05, 'epoch': 0.67}


                                                  
 34%|███▍      | 720/2108 [20:09<32:22,  1.40s/it]   

{'loss': 0.0015, 'grad_norm': 0.041820771992206573, 'learning_rate': 3.2922201138519923e-05, 'epoch': 0.68}


                                                  
 35%|███▍      | 730/2108 [20:24<34:46,  1.51s/it]   

{'loss': 0.6061, 'grad_norm': 31.1428279876709, 'learning_rate': 3.2685009487666034e-05, 'epoch': 0.69}


                                                  
 35%|███▌      | 740/2108 [20:38<32:52,  1.44s/it]   

{'loss': 0.0015, 'grad_norm': 0.044477980583906174, 'learning_rate': 3.2447817836812144e-05, 'epoch': 0.7}


                                                  
 36%|███▌      | 750/2108 [20:53<33:10,  1.47s/it]   

{'loss': 0.0014, 'grad_norm': 0.034850217401981354, 'learning_rate': 3.2210626185958254e-05, 'epoch': 0.71}


                                                  
 36%|███▌      | 760/2108 [21:08<32:14,  1.43s/it]   

{'loss': 0.6725, 'grad_norm': 0.030473940074443817, 'learning_rate': 3.197343453510437e-05, 'epoch': 0.72}


                                                  
 37%|███▋      | 770/2108 [21:22<32:35,  1.46s/it]   

{'loss': 0.6456, 'grad_norm': 0.07683933526277542, 'learning_rate': 3.1736242884250475e-05, 'epoch': 0.73}


                                                  
 37%|███▋      | 780/2108 [21:37<32:23,  1.46s/it]   

{'loss': 0.0021, 'grad_norm': 0.06445591896772385, 'learning_rate': 3.1499051233396585e-05, 'epoch': 0.74}


                                                  
 37%|███▋      | 790/2108 [21:51<31:44,  1.45s/it]   

{'loss': 0.0014, 'grad_norm': 0.04988177865743637, 'learning_rate': 3.1261859582542695e-05, 'epoch': 0.75}


                                                  
 38%|███▊      | 800/2108 [22:06<31:38,  1.45s/it]   

{'loss': 0.6605, 'grad_norm': 0.03652055561542511, 'learning_rate': 3.1024667931688806e-05, 'epoch': 0.76}


                                                  
 38%|███▊      | 810/2108 [22:20<30:52,  1.43s/it]   

{'loss': 0.0013, 'grad_norm': 0.042776867747306824, 'learning_rate': 3.0787476280834916e-05, 'epoch': 0.77}


                                                  
 39%|███▉      | 820/2108 [22:35<31:32,  1.47s/it]   

{'loss': 0.0013, 'grad_norm': 0.03914497047662735, 'learning_rate': 3.0550284629981026e-05, 'epoch': 0.78}


                                                  
 39%|███▉      | 830/2108 [22:53<39:03,  1.83s/it]   

{'loss': 0.0009, 'grad_norm': 0.025411253795027733, 'learning_rate': 3.0313092979127133e-05, 'epoch': 0.79}


                                                  
 40%|███▉      | 840/2108 [23:13<41:58,  1.99s/it]   

{'loss': 0.6733, 'grad_norm': 0.045530982315540314, 'learning_rate': 3.0075901328273247e-05, 'epoch': 0.8}


                                                  
 40%|████      | 850/2108 [23:32<40:15,  1.92s/it]   

{'loss': 0.0014, 'grad_norm': 0.03160780668258667, 'learning_rate': 2.9838709677419357e-05, 'epoch': 0.81}


                                                  
 41%|████      | 860/2108 [23:51<38:58,  1.87s/it]   

{'loss': 1.0463, 'grad_norm': 0.031486235558986664, 'learning_rate': 2.9601518026565468e-05, 'epoch': 0.82}


                                                  
 41%|████▏     | 870/2108 [24:10<39:20,  1.91s/it]   

{'loss': 0.0017, 'grad_norm': 0.039050549268722534, 'learning_rate': 2.9364326375711574e-05, 'epoch': 0.83}


                                                  
 42%|████▏     | 880/2108 [24:27<30:39,  1.50s/it]   

{'loss': 0.7954, 'grad_norm': 0.05613371729850769, 'learning_rate': 2.9127134724857685e-05, 'epoch': 0.83}


                                                  
 42%|████▏     | 890/2108 [24:42<31:54,  1.57s/it]   

{'loss': 0.1022, 'grad_norm': 0.03509997949004173, 'learning_rate': 2.8889943074003795e-05, 'epoch': 0.84}


                                                  
 43%|████▎     | 900/2108 [24:57<30:20,  1.51s/it]   

{'loss': 0.0011, 'grad_norm': 0.02541472762823105, 'learning_rate': 2.8652751423149905e-05, 'epoch': 0.85}


                                                  
 43%|████▎     | 910/2108 [25:13<30:35,  1.53s/it]   

{'loss': 0.6754, 'grad_norm': 0.04474721476435661, 'learning_rate': 2.841555977229602e-05, 'epoch': 0.86}


                                                  
 44%|████▎     | 920/2108 [25:27<28:54,  1.46s/it]   

{'loss': 0.0014, 'grad_norm': 0.04197266697883606, 'learning_rate': 2.817836812144213e-05, 'epoch': 0.87}


                                                  
 44%|████▍     | 930/2108 [25:42<28:46,  1.47s/it]   

{'loss': 0.0012, 'grad_norm': 0.02752251736819744, 'learning_rate': 2.7941176470588236e-05, 'epoch': 0.88}


                                                  
 45%|████▍     | 940/2108 [25:57<29:20,  1.51s/it]   

{'loss': 0.0009, 'grad_norm': 0.024056915193796158, 'learning_rate': 2.7703984819734347e-05, 'epoch': 0.89}


                                                  
 45%|████▌     | 950/2108 [26:12<28:51,  1.49s/it]   

{'loss': 0.0007, 'grad_norm': 0.01500392984598875, 'learning_rate': 2.7466793168880457e-05, 'epoch': 0.9}


                                                  
 46%|████▌     | 960/2108 [26:27<28:16,  1.48s/it]   

{'loss': 0.0006, 'grad_norm': 0.014480456709861755, 'learning_rate': 2.7229601518026564e-05, 'epoch': 0.91}


                                                  
 46%|████▌     | 970/2108 [26:42<27:21,  1.44s/it]   

{'loss': 0.0005, 'grad_norm': 0.0131670031696558, 'learning_rate': 2.6992409867172674e-05, 'epoch': 0.92}


                                                  
 46%|████▋     | 980/2108 [26:56<27:51,  1.48s/it]   

{'loss': 0.0004, 'grad_norm': 0.011841029860079288, 'learning_rate': 2.6755218216318788e-05, 'epoch': 0.93}


                                                  
 47%|████▋     | 990/2108 [27:11<27:21,  1.47s/it]   

{'loss': 0.0004, 'grad_norm': 0.010802059434354305, 'learning_rate': 2.6518026565464898e-05, 'epoch': 0.94}


                                                   
 47%|████▋     | 1000/2108 [27:26<27:22,  1.48s/it]  

{'loss': 0.0003, 'grad_norm': 0.008176753297448158, 'learning_rate': 2.628083491461101e-05, 'epoch': 0.95}


                                                   
 48%|████▊     | 1010/2108 [27:42<30:06,  1.64s/it]  

{'loss': 0.0003, 'grad_norm': 0.006685980129987001, 'learning_rate': 2.604364326375712e-05, 'epoch': 0.96}


                                                   
 48%|████▊     | 1020/2108 [27:57<27:26,  1.51s/it]  

{'loss': 0.7874, 'grad_norm': 27.785184860229492, 'learning_rate': 2.5806451612903226e-05, 'epoch': 0.97}


                                                   
 49%|████▉     | 1030/2108 [28:12<26:24,  1.47s/it]  

{'loss': 0.0007, 'grad_norm': 0.017700867727398872, 'learning_rate': 2.5569259962049336e-05, 'epoch': 0.98}


                                                   
 49%|████▉     | 1040/2108 [28:27<26:38,  1.50s/it]  

{'loss': 0.0008, 'grad_norm': 0.0235328059643507, 'learning_rate': 2.5332068311195446e-05, 'epoch': 0.99}


                                                   
 50%|████▉     | 1050/2108 [28:43<27:44,  1.57s/it]  

{'loss': 0.0007, 'grad_norm': 0.017562521621584892, 'learning_rate': 2.509487666034156e-05, 'epoch': 1.0}


 50%|█████     | 1054/2108 [28:48<22:51,  1.30s/it]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   

[A[A                                           
 50%|█████     | 1054/2108 [29:45<22:51,  1.30s/it]  
[A
[A

{'eval_loss': 0.37636786699295044, 'eval_accuracy': 0.9897610921501706, 'eval_f1': 0.0, 'eval_precision': 1.0, 'eval_recall': 0.0, 'eval_runtime': 56.7956, 'eval_samples_per_second': 10.318, 'eval_steps_per_second': 2.078, 'epoch': 1.0}


                                                     
 50%|█████     | 1060/2108 [29:56<1:19:34,  4.56s/it]

{'loss': 0.0005, 'grad_norm': 0.016705341637134552, 'learning_rate': 2.4857685009487667e-05, 'epoch': 1.01}


                                                     
 51%|█████     | 1070/2108 [30:17<37:30,  2.17s/it]  

{'loss': 0.0005, 'grad_norm': 0.01482310052961111, 'learning_rate': 2.4620493358633777e-05, 'epoch': 1.02}


                                                   
 51%|█████     | 1080/2108 [30:39<36:15,  2.12s/it]  

{'loss': 1.5022, 'grad_norm': 0.0468461737036705, 'learning_rate': 2.4383301707779887e-05, 'epoch': 1.02}


                                                   
 52%|█████▏    | 1090/2108 [30:58<33:27,  1.97s/it]  

{'loss': 0.0016, 'grad_norm': 0.040849730372428894, 'learning_rate': 2.4146110056925998e-05, 'epoch': 1.03}


                                                   
 52%|█████▏    | 1100/2108 [31:18<32:54,  1.96s/it]  

{'loss': 0.0011, 'grad_norm': 0.033514149487018585, 'learning_rate': 2.3908918406072104e-05, 'epoch': 1.04}


                                                   
 53%|█████▎    | 1110/2108 [31:38<32:42,  1.97s/it]  

{'loss': 0.0009, 'grad_norm': 0.025533858686685562, 'learning_rate': 2.3671726755218218e-05, 'epoch': 1.05}


                                                   
 53%|█████▎    | 1120/2108 [31:57<32:51,  2.00s/it]  

{'loss': 0.0007, 'grad_norm': 0.020612865686416626, 'learning_rate': 2.343453510436433e-05, 'epoch': 1.06}


                                                   
 54%|█████▎    | 1130/2108 [32:18<33:33,  2.06s/it]  

{'loss': 0.0006, 'grad_norm': 0.01927480660378933, 'learning_rate': 2.3197343453510435e-05, 'epoch': 1.07}


                                                   
 54%|█████▍    | 1140/2108 [32:39<32:46,  2.03s/it]  

{'loss': 0.0005, 'grad_norm': 0.013014029711484909, 'learning_rate': 2.296015180265655e-05, 'epoch': 1.08}


                                                   
 55%|█████▍    | 1150/2108 [32:59<31:59,  2.00s/it]  

{'loss': 0.0005, 'grad_norm': 0.01815398409962654, 'learning_rate': 2.272296015180266e-05, 'epoch': 1.09}


                                                   
 55%|█████▌    | 1160/2108 [33:20<33:20,  2.11s/it]  

{'loss': 0.0004, 'grad_norm': 0.010918470099568367, 'learning_rate': 2.2485768500948766e-05, 'epoch': 1.1}


                                                   
 56%|█████▌    | 1170/2108 [33:41<33:34,  2.15s/it]  

{'loss': 0.0004, 'grad_norm': 0.009345654398202896, 'learning_rate': 2.2248576850094877e-05, 'epoch': 1.11}


                                                   
 56%|█████▌    | 1180/2108 [34:03<33:56,  2.19s/it]  

{'loss': 0.8089, 'grad_norm': 0.013539695180952549, 'learning_rate': 2.201138519924099e-05, 'epoch': 1.12}


                                                   
 56%|█████▋    | 1190/2108 [34:21<23:37,  1.54s/it]  

{'loss': 0.0006, 'grad_norm': 0.0137602174654603, 'learning_rate': 2.1774193548387097e-05, 'epoch': 1.13}


                                                   
 57%|█████▋    | 1200/2108 [34:38<28:51,  1.91s/it]  

{'loss': 0.0005, 'grad_norm': 0.014968723058700562, 'learning_rate': 2.1537001897533207e-05, 'epoch': 1.14}


                                                   
 57%|█████▋    | 1210/2108 [34:57<28:27,  1.90s/it]  

{'loss': 1.4364, 'grad_norm': 0.02525622770190239, 'learning_rate': 2.1299810246679318e-05, 'epoch': 1.15}


                                                   
 58%|█████▊    | 1220/2108 [35:16<27:28,  1.86s/it]  

{'loss': 0.0012, 'grad_norm': 0.04199010506272316, 'learning_rate': 2.1062618595825428e-05, 'epoch': 1.16}


                                                   
 58%|█████▊    | 1230/2108 [35:34<26:34,  1.82s/it]  

{'loss': 0.0013, 'grad_norm': 0.04085681214928627, 'learning_rate': 2.082542694497154e-05, 'epoch': 1.17}


                                                   
 59%|█████▉    | 1240/2108 [35:52<25:41,  1.78s/it]  

{'loss': 0.0011, 'grad_norm': 0.031560540199279785, 'learning_rate': 2.058823529411765e-05, 'epoch': 1.18}


                                                   
 59%|█████▉    | 1250/2108 [36:10<25:36,  1.79s/it]  

{'loss': 0.001, 'grad_norm': 0.025019193068146706, 'learning_rate': 2.035104364326376e-05, 'epoch': 1.19}


                                                   
 60%|█████▉    | 1260/2108 [36:29<26:45,  1.89s/it]  

{'loss': 0.0008, 'grad_norm': 0.01979483850300312, 'learning_rate': 2.011385199240987e-05, 'epoch': 1.2}


                                                   
 60%|██████    | 1270/2108 [36:48<26:19,  1.88s/it]  

{'loss': 0.6976, 'grad_norm': 30.059612274169922, 'learning_rate': 1.9876660341555976e-05, 'epoch': 1.2}


                                                   
 61%|██████    | 1280/2108 [37:07<25:20,  1.84s/it]  

{'loss': 0.0008, 'grad_norm': 0.024981992319226265, 'learning_rate': 1.9639468690702086e-05, 'epoch': 1.21}


                                                   
 61%|██████    | 1290/2108 [37:23<24:53,  1.83s/it]  

{'loss': 0.6887, 'grad_norm': 0.024860017001628876, 'learning_rate': 1.94022770398482e-05, 'epoch': 1.22}


                                                   
 62%|██████▏   | 1300/2108 [37:41<21:58,  1.63s/it]  

{'loss': 0.6475, 'grad_norm': 30.614288330078125, 'learning_rate': 1.9165085388994307e-05, 'epoch': 1.23}


                                                   
 62%|██████▏   | 1310/2108 [38:01<25:57,  1.95s/it]  

{'loss': 0.0016, 'grad_norm': 0.06851476430892944, 'learning_rate': 1.8927893738140417e-05, 'epoch': 1.24}


                                                   
 63%|██████▎   | 1320/2108 [38:20<27:15,  2.08s/it]  

{'loss': 0.0014, 'grad_norm': 0.03271525725722313, 'learning_rate': 1.869070208728653e-05, 'epoch': 1.25}


                                                   
 63%|██████▎   | 1330/2108 [38:41<26:25,  2.04s/it]  

{'loss': 0.0012, 'grad_norm': 0.027253270149230957, 'learning_rate': 1.8453510436432638e-05, 'epoch': 1.26}


                                                   
 64%|██████▎   | 1340/2108 [38:56<18:19,  1.43s/it]  

{'loss': 1.3131, 'grad_norm': 31.060985565185547, 'learning_rate': 1.8216318785578748e-05, 'epoch': 1.27}


                                                   
 64%|██████▍   | 1350/2108 [39:12<21:01,  1.66s/it]  

{'loss': 0.0015, 'grad_norm': 0.0538889579474926, 'learning_rate': 1.797912713472486e-05, 'epoch': 1.28}


                                                   
 65%|██████▍   | 1360/2108 [39:27<18:45,  1.50s/it]  

{'loss': 0.0017, 'grad_norm': 0.054719679057598114, 'learning_rate': 1.774193548387097e-05, 'epoch': 1.29}


                                                   
 65%|██████▍   | 1370/2108 [39:42<18:19,  1.49s/it]  

{'loss': 0.0013, 'grad_norm': 0.037126749753952026, 'learning_rate': 1.750474383301708e-05, 'epoch': 1.3}


                                                   
 65%|██████▌   | 1380/2108 [39:58<21:13,  1.75s/it]  

{'loss': 0.6309, 'grad_norm': 0.03716540336608887, 'learning_rate': 1.726755218216319e-05, 'epoch': 1.31}


                                                   
 66%|██████▌   | 1390/2108 [40:17<22:51,  1.91s/it]  

{'loss': 0.0012, 'grad_norm': 0.041115712374448776, 'learning_rate': 1.70303605313093e-05, 'epoch': 1.32}


                                                   
 66%|██████▋   | 1400/2108 [40:37<23:38,  2.00s/it]  

{'loss': 0.6668, 'grad_norm': 0.042019814252853394, 'learning_rate': 1.679316888045541e-05, 'epoch': 1.33}


                                                   
 67%|██████▋   | 1410/2108 [40:57<23:46,  2.04s/it]  

{'loss': 0.6242, 'grad_norm': 0.05477084964513779, 'learning_rate': 1.655597722960152e-05, 'epoch': 1.34}


                                                   
 67%|██████▋   | 1420/2108 [41:18<23:54,  2.08s/it]  

{'loss': 0.0017, 'grad_norm': 0.053996872156858444, 'learning_rate': 1.6318785578747627e-05, 'epoch': 1.35}


                                                   
 68%|██████▊   | 1430/2108 [41:38<23:20,  2.07s/it]  

{'loss': 0.6356, 'grad_norm': 0.04976923391222954, 'learning_rate': 1.608159392789374e-05, 'epoch': 1.36}


                                                   
 68%|██████▊   | 1440/2108 [41:57<19:40,  1.77s/it]  

{'loss': 0.0018, 'grad_norm': 0.05870974808931351, 'learning_rate': 1.5844402277039848e-05, 'epoch': 1.37}


                                                   
 69%|██████▉   | 1450/2108 [42:12<15:46,  1.44s/it]  

{'loss': 1.2057, 'grad_norm': 0.07835879921913147, 'learning_rate': 1.5607210626185958e-05, 'epoch': 1.38}


                                                   
 69%|██████▉   | 1460/2108 [42:30<20:44,  1.92s/it]  

{'loss': 0.0026, 'grad_norm': 0.0794389620423317, 'learning_rate': 1.537001897533207e-05, 'epoch': 1.39}


                                                   
 70%|██████▉   | 1470/2108 [42:50<21:06,  1.99s/it]  

{'loss': 0.6143, 'grad_norm': 30.82078742980957, 'learning_rate': 1.513282732447818e-05, 'epoch': 1.39}


                                                   
 70%|███████   | 1480/2108 [43:09<20:01,  1.91s/it]  

{'loss': 0.5744, 'grad_norm': 0.0927116721868515, 'learning_rate': 1.4895635673624289e-05, 'epoch': 1.4}


                                                   
 71%|███████   | 1490/2108 [43:29<20:14,  1.97s/it]  

{'loss': 1.1389, 'grad_norm': 31.22233009338379, 'learning_rate': 1.4658444022770398e-05, 'epoch': 1.41}


                                                   
 71%|███████   | 1500/2108 [43:49<20:30,  2.02s/it]  

{'loss': 0.5255, 'grad_norm': 0.17329448461532593, 'learning_rate': 1.4421252371916511e-05, 'epoch': 1.42}


                                                   
 72%|███████▏  | 1510/2108 [44:09<19:31,  1.96s/it]  

{'loss': 0.0039, 'grad_norm': 0.09932799637317657, 'learning_rate': 1.418406072106262e-05, 'epoch': 1.43}


                                                   
 72%|███████▏  | 1520/2108 [44:29<20:23,  2.08s/it]  

{'loss': 0.5413, 'grad_norm': 0.08432728052139282, 'learning_rate': 1.3946869070208728e-05, 'epoch': 1.44}


                                                   
 73%|███████▎  | 1530/2108 [44:49<20:08,  2.09s/it]  

{'loss': 0.0024, 'grad_norm': 0.060886070132255554, 'learning_rate': 1.3709677419354839e-05, 'epoch': 1.45}


                                                   
 73%|███████▎  | 1540/2108 [45:10<19:19,  2.04s/it]  

{'loss': 0.002, 'grad_norm': 0.05038844421505928, 'learning_rate': 1.347248576850095e-05, 'epoch': 1.46}


                                                   
 74%|███████▎  | 1550/2108 [45:31<19:05,  2.05s/it]  

{'loss': 0.6197, 'grad_norm': 0.05287551134824753, 'learning_rate': 1.323529411764706e-05, 'epoch': 1.47}


                                                   
 74%|███████▍  | 1560/2108 [45:51<18:54,  2.07s/it]  

{'loss': 1.1708, 'grad_norm': 0.07950316369533539, 'learning_rate': 1.299810246679317e-05, 'epoch': 1.48}


                                                   
 74%|███████▍  | 1570/2108 [46:11<18:05,  2.02s/it]  

{'loss': 0.0024, 'grad_norm': 0.08262667059898376, 'learning_rate': 1.2760910815939278e-05, 'epoch': 1.49}


                                                   
 75%|███████▍  | 1580/2108 [46:31<17:23,  1.98s/it]  

{'loss': 0.5803, 'grad_norm': 0.07061905413866043, 'learning_rate': 1.252371916508539e-05, 'epoch': 1.5}


                                                   
 75%|███████▌  | 1590/2108 [46:51<17:21,  2.01s/it]  

{'loss': 0.532, 'grad_norm': 0.07556333392858505, 'learning_rate': 1.2286527514231499e-05, 'epoch': 1.51}


                                                   
 76%|███████▌  | 1600/2108 [47:11<16:55,  2.00s/it]  

{'loss': 1.084, 'grad_norm': 0.11947940289974213, 'learning_rate': 1.204933586337761e-05, 'epoch': 1.52}


                                                   
 76%|███████▋  | 1610/2108 [47:32<17:07,  2.06s/it]  

{'loss': 0.0034, 'grad_norm': 0.09053938835859299, 'learning_rate': 1.181214421252372e-05, 'epoch': 1.53}


                                                   
 77%|███████▋  | 1620/2108 [47:52<16:54,  2.08s/it]  

{'loss': 0.5414, 'grad_norm': 0.05611142888665199, 'learning_rate': 1.157495256166983e-05, 'epoch': 1.54}


                                                   
 77%|███████▋  | 1630/2108 [48:12<15:14,  1.91s/it]  

{'loss': 0.0023, 'grad_norm': 0.05695906654000282, 'learning_rate': 1.133776091081594e-05, 'epoch': 1.55}


                                                   
 78%|███████▊  | 1640/2108 [48:33<16:19,  2.09s/it]  

{'loss': 0.002, 'grad_norm': 0.05719631165266037, 'learning_rate': 1.110056925996205e-05, 'epoch': 1.56}


                                                   
 78%|███████▊  | 1650/2108 [48:53<15:23,  2.02s/it]  

{'loss': 0.6075, 'grad_norm': 0.04700454697012901, 'learning_rate': 1.0863377609108159e-05, 'epoch': 1.57}


                                                   
 79%|███████▊  | 1660/2108 [49:13<15:09,  2.03s/it]  

{'loss': 0.0016, 'grad_norm': 0.05423804745078087, 'learning_rate': 1.062618595825427e-05, 'epoch': 1.57}


                                                   
 79%|███████▉  | 1670/2108 [49:34<15:31,  2.13s/it]  

{'loss': 1.2292, 'grad_norm': 0.09899982064962387, 'learning_rate': 1.0388994307400381e-05, 'epoch': 1.58}


                                                   
 80%|███████▉  | 1680/2108 [49:54<13:56,  1.96s/it]  

{'loss': 0.5497, 'grad_norm': 0.10517875105142593, 'learning_rate': 1.015180265654649e-05, 'epoch': 1.59}


                                                   
 80%|████████  | 1690/2108 [50:14<14:07,  2.03s/it]  

{'loss': 0.6354, 'grad_norm': 30.309988021850586, 'learning_rate': 9.9146110056926e-06, 'epoch': 1.6}


                                                   
 81%|████████  | 1700/2108 [50:35<13:44,  2.02s/it]  

{'loss': 0.5952, 'grad_norm': 0.0782117024064064, 'learning_rate': 9.67741935483871e-06, 'epoch': 1.61}


                                                   
 81%|████████  | 1710/2108 [50:55<13:39,  2.06s/it]  

{'loss': 0.0025, 'grad_norm': 0.07890284806489944, 'learning_rate': 9.44022770398482e-06, 'epoch': 1.62}


                                                   
 82%|████████▏ | 1720/2108 [51:15<13:01,  2.01s/it]  

{'loss': 0.0023, 'grad_norm': 0.05658789724111557, 'learning_rate': 9.20303605313093e-06, 'epoch': 1.63}


                                                   
 82%|████████▏ | 1730/2108 [51:35<12:27,  1.98s/it]  

{'loss': 0.0018, 'grad_norm': 0.05528242886066437, 'learning_rate': 8.965844402277041e-06, 'epoch': 1.64}


                                                   
 83%|████████▎ | 1740/2108 [51:56<12:45,  2.08s/it]  

{'loss': 0.0016, 'grad_norm': 0.0415722131729126, 'learning_rate': 8.72865275142315e-06, 'epoch': 1.65}


                                                   
 83%|████████▎ | 1750/2108 [52:16<12:22,  2.07s/it]  

{'loss': 0.0014, 'grad_norm': 0.048168327659368515, 'learning_rate': 8.49146110056926e-06, 'epoch': 1.66}


                                                   
 83%|████████▎ | 1760/2108 [52:37<11:59,  2.07s/it]  

{'loss': 0.0012, 'grad_norm': 0.03643663227558136, 'learning_rate': 8.25426944971537e-06, 'epoch': 1.67}


                                                   
 84%|████████▍ | 1770/2108 [52:58<11:36,  2.06s/it]  

{'loss': 0.6556, 'grad_norm': 0.031798552721738815, 'learning_rate': 8.01707779886148e-06, 'epoch': 1.68}


                                                   
 84%|████████▍ | 1780/2108 [53:17<10:34,  1.93s/it]  

{'loss': 0.0012, 'grad_norm': 0.038281138986349106, 'learning_rate': 7.779886148007591e-06, 'epoch': 1.69}


                                                   
 85%|████████▍ | 1790/2108 [53:37<10:43,  2.02s/it]  

{'loss': 0.0012, 'grad_norm': 0.03622189164161682, 'learning_rate': 7.5426944971537005e-06, 'epoch': 1.7}


                                                   
 85%|████████▌ | 1800/2108 [53:58<10:56,  2.13s/it]  

{'loss': 0.0012, 'grad_norm': 0.03437101095914841, 'learning_rate': 7.305502846299811e-06, 'epoch': 1.71}


                                                   
 86%|████████▌ | 1810/2108 [54:19<10:07,  2.04s/it]  

{'loss': 0.6471, 'grad_norm': 0.04302704706788063, 'learning_rate': 7.06831119544592e-06, 'epoch': 1.72}


                                                   
 86%|████████▋ | 1820/2108 [54:38<09:29,  1.98s/it]  

{'loss': 0.6501, 'grad_norm': 0.03608899191021919, 'learning_rate': 6.831119544592031e-06, 'epoch': 1.73}


                                                   
 87%|████████▋ | 1830/2108 [54:59<09:01,  1.95s/it]  

{'loss': 0.0012, 'grad_norm': 0.032631877809762955, 'learning_rate': 6.59392789373814e-06, 'epoch': 1.74}


                                                   
 87%|████████▋ | 1840/2108 [55:14<06:33,  1.47s/it]  

{'loss': 0.0013, 'grad_norm': 0.03835073113441467, 'learning_rate': 6.356736242884251e-06, 'epoch': 1.75}


                                                   
 88%|████████▊ | 1850/2108 [55:29<06:19,  1.47s/it]  

{'loss': 0.0011, 'grad_norm': 0.03758445754647255, 'learning_rate': 6.119544592030361e-06, 'epoch': 1.76}


                                                   
 88%|████████▊ | 1860/2108 [55:45<06:30,  1.57s/it]  

{'loss': 0.6262, 'grad_norm': 0.0392250157892704, 'learning_rate': 5.882352941176471e-06, 'epoch': 1.76}


                                                   
 89%|████████▊ | 1870/2108 [56:00<05:46,  1.46s/it]  

{'loss': 0.0012, 'grad_norm': 0.03629983961582184, 'learning_rate': 5.64516129032258e-06, 'epoch': 1.77}


                                                   
 89%|████████▉ | 1880/2108 [56:15<05:25,  1.43s/it]  

{'loss': 0.0012, 'grad_norm': 0.03662123158574104, 'learning_rate': 5.407969639468691e-06, 'epoch': 1.78}


                                                   
 90%|████████▉ | 1890/2108 [56:29<05:09,  1.42s/it]  

{'loss': 0.0011, 'grad_norm': 0.0380394384264946, 'learning_rate': 5.170777988614801e-06, 'epoch': 1.79}


                                                   
 90%|█████████ | 1900/2108 [56:43<05:03,  1.46s/it]  

{'loss': 0.6602, 'grad_norm': 0.031854793429374695, 'learning_rate': 4.933586337760911e-06, 'epoch': 1.8}


                                                   
 91%|█████████ | 1910/2108 [57:01<06:39,  2.02s/it]  

{'loss': 0.0011, 'grad_norm': 0.030934536829590797, 'learning_rate': 4.6963946869070216e-06, 'epoch': 1.81}


                                                   
 91%|█████████ | 1920/2108 [57:17<04:53,  1.56s/it]  

{'loss': 0.6437, 'grad_norm': 0.02875240333378315, 'learning_rate': 4.459203036053131e-06, 'epoch': 1.82}


                                                   
 92%|█████████▏| 1930/2108 [57:32<04:23,  1.48s/it]  

{'loss': 0.6788, 'grad_norm': 0.03598342835903168, 'learning_rate': 4.222011385199241e-06, 'epoch': 1.83}


                                                   
 92%|█████████▏| 1940/2108 [57:48<04:21,  1.56s/it]  

{'loss': 0.6818, 'grad_norm': 0.03947195038199425, 'learning_rate': 3.984819734345352e-06, 'epoch': 1.84}


                                                   
 93%|█████████▎| 1950/2108 [58:06<05:16,  2.00s/it]  

{'loss': 0.0013, 'grad_norm': 0.040476035326719284, 'learning_rate': 3.747628083491461e-06, 'epoch': 1.85}


                                                   
 93%|█████████▎| 1960/2108 [58:25<04:43,  1.92s/it]  

{'loss': 0.0013, 'grad_norm': 0.03498600423336029, 'learning_rate': 3.5104364326375713e-06, 'epoch': 1.86}


                                                   
 93%|█████████▎| 1970/2108 [58:44<04:08,  1.80s/it]  

{'loss': 0.0013, 'grad_norm': 0.03936672583222389, 'learning_rate': 3.2732447817836812e-06, 'epoch': 1.87}


                                                   
 94%|█████████▍| 1980/2108 [59:01<03:57,  1.85s/it]  

{'loss': 0.6767, 'grad_norm': 0.040272973477840424, 'learning_rate': 3.0360531309297915e-06, 'epoch': 1.88}


                                                   
 94%|█████████▍| 1990/2108 [59:22<04:05,  2.08s/it]  

{'loss': 0.0013, 'grad_norm': 0.0380939356982708, 'learning_rate': 2.7988614800759014e-06, 'epoch': 1.89}


                                                   
 95%|█████████▍| 2000/2108 [59:42<03:49,  2.12s/it]  

{'loss': 0.6557, 'grad_norm': 0.040342967957258224, 'learning_rate': 2.5616698292220113e-06, 'epoch': 1.9}


                                                     
 95%|█████████▌| 2010/2108 [1:00:03<03:17,  2.02s/it]

{'loss': 0.0013, 'grad_norm': 0.029150376096367836, 'learning_rate': 2.3244781783681216e-06, 'epoch': 1.91}


                                                     
 96%|█████████▌| 2020/2108 [1:00:24<03:03,  2.09s/it]

{'loss': 0.0013, 'grad_norm': 0.03905465826392174, 'learning_rate': 2.087286527514232e-06, 'epoch': 1.92}


                                                     
 96%|█████████▋| 2030/2108 [1:00:44<02:39,  2.05s/it]

{'loss': 0.0013, 'grad_norm': 0.03324316814541817, 'learning_rate': 1.8500948766603417e-06, 'epoch': 1.93}


                                                     
 97%|█████████▋| 2040/2108 [1:01:04<01:56,  1.71s/it]

{'loss': 0.6422, 'grad_norm': 0.03683944046497345, 'learning_rate': 1.6129032258064516e-06, 'epoch': 1.94}


                                                     
 97%|█████████▋| 2050/2108 [1:01:24<02:02,  2.11s/it]

{'loss': 0.0013, 'grad_norm': 0.03478431701660156, 'learning_rate': 1.3757115749525619e-06, 'epoch': 1.94}


                                                     
 98%|█████████▊| 2060/2108 [1:01:46<01:41,  2.12s/it]

{'loss': 0.0013, 'grad_norm': 0.04076424241065979, 'learning_rate': 1.1385199240986718e-06, 'epoch': 1.95}


                                                     
 98%|█████████▊| 2070/2108 [1:02:06<01:18,  2.06s/it]

{'loss': 0.0012, 'grad_norm': 0.036391906440258026, 'learning_rate': 9.01328273244782e-07, 'epoch': 1.96}


                                                     
 99%|█████████▊| 2080/2108 [1:02:26<00:57,  2.05s/it]

{'loss': 0.0012, 'grad_norm': 0.03483183681964874, 'learning_rate': 6.641366223908918e-07, 'epoch': 1.97}


                                                     
 99%|█████████▉| 2090/2108 [1:02:46<00:36,  2.04s/it]

{'loss': 0.0012, 'grad_norm': 0.03469647839665413, 'learning_rate': 4.269449715370019e-07, 'epoch': 1.98}


                                                     
100%|█████████▉| 2100/2108 [1:03:07<00:15,  1.95s/it]

{'loss': 0.0012, 'grad_norm': 0.037511806935071945, 'learning_rate': 1.8975332068311197e-07, 'epoch': 1.99}


100%|██████████| 2108/2108 [1:03:22<00:00,  1.74s/it]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                     

[A[A                                           
100%|██████████| 2108/2108 [1:04:20<00:00,  1.74s/it]
[A
[A

{'eval_loss': 0.3370947241783142, 'eval_accuracy': 0.9897610921501706, 'eval_f1': 0.0, 'eval_precision': 1.0, 'eval_recall': 0.0, 'eval_runtime': 57.0358, 'eval_samples_per_second': 10.274, 'eval_steps_per_second': 2.069, 'epoch': 2.0}


                                                     
100%|██████████| 2108/2108 [1:04:22<00:00,  1.83s/it]

{'train_runtime': 3862.1895, 'train_samples_per_second': 2.727, 'train_steps_per_second': 0.546, 'train_loss': 0.2868997352552314, 'epoch': 2.0}





TrainOutput(global_step=2108, training_loss=0.2868997352552314, metrics={'train_runtime': 3862.1895, 'train_samples_per_second': 2.727, 'train_steps_per_second': 0.546, 'total_flos': 692902964290560.0, 'train_loss': 0.2868997352552314, 'epoch': 2.0})

In [48]:
# 13. Evaluate the model
evaluation_results = trainer.evaluate()
print("Evaluation Results:", evaluation_results)

100%|██████████| 118/118 [00:43<00:00,  2.70it/s]

Evaluation Results: {'eval_loss': 0.37636786699295044, 'eval_accuracy': 0.9897610921501706, 'eval_f1': 0.0, 'eval_precision': 1.0, 'eval_recall': 0.0, 'eval_runtime': 44.1267, 'eval_samples_per_second': 13.28, 'eval_steps_per_second': 2.674, 'epoch': 2.0}





In [None]:
def predict_email(email_text, threshold=0.3):
    inputs = tokenizer(
        email_text,
        return_tensors='pt',
        truncation=True,
        padding='max_length',
        max_length=128
    )
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    probabilities = torch.softmax(logits, dim=-1)
    predicted_prob = probabilities[0][1].item()
    return 1 if predicted_prob >= threshold else 0

In [None]:
# # 14. Save the trained model and tokenizer
# model.save_pretrained('your-trained-model')
# tokenizer.save_pretrained('your-trained-model')