In [1]:
from transformers import pipeline
from transformers import AutoTokenizer
import torch
# import numpy as np
import pandas as pd
from transformers import TrainingArguments, Trainer
model="distilbert-base-uncased-finetuned-sst-2-english"
# Load tokenizer for the model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if torch.backends.mps.is_available():
    device = torch.device("mps")  # Use MPS (Metal GPU)
else:
    device = torch.device("cpu")  # Fallback to CPU

# Check if MPS is available
device = 0 if torch.backends.mps.is_available() else -1
print(f"Using device: {'MPS' if device == 0 else 'CPU'}")

Using device: MPS


In [3]:
sentiment_pipeline = pipeline("sentiment-analysis", model=model, device=device)


In [4]:
import os

# Set the environment variable
os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"


In [5]:
data = ["I love you", "I hate you"]
r = sentiment_pipeline(data)
r

[{'label': 'POSITIVE', 'score': 0.9998656511306763},
 {'label': 'NEGATIVE', 'score': 0.9991129040718079}]

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer

DistilBertTokenizerFast(name_or_path='distilbert-base-uncased-finetuned-sst-2-english', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [7]:
print(torch.cuda.is_available())  # Returns True if a GPU is available
print(torch.cuda.device_count())  # Number of GPUs


False
0


In [8]:
# from datasets import load_dataset

# # Load dataset from a CSV file
# data_files = {"train": "path_to_train.csv", "validation": "path_to_val.csv"}  # Update paths
# dataset = load_dataset("csv", data_files=data_files)

# # Example structure: {'text': ..., 'label': ...}
# print(dataset)


In [9]:
from utils_classes import load_and_process_comments


train_comments, val_comments, test_comments, test_labels = load_and_process_comments(
    train_path='train',
    batch_size=50,
)
# Flatten train_comments
train_texts = [text for batch in train_comments for text in batch[0]]
train_labels = [label for batch in train_comments for label in batch[1]]

# Flatten test_comments
test_texts = [text for batch in test_comments for text in batch[0]]
test_labels = [label for batch in test_comments for label in batch[1]]


In [10]:
# Tokenize training and test data
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=512)
test_encodings = tokenizer(test_texts, truncation=True, padding=True, max_length=512)


In [11]:
type(train_encodings)


transformers.tokenization_utils_base.BatchEncoding

In [12]:
class CommentsDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

# Create dataset objects
train_dataset = CommentsDataset(train_encodings, train_labels)
test_dataset = CommentsDataset(test_encodings, test_labels)

In [13]:
from transformers import AutoModelForSequenceClassification

# Replace "distilbert-base-uncased" with your model's name if needed
model_new = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased", 
    num_labels=2  # Adjust `num_labels` based on your dataset (e.g., binary classification)
)


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
training_args = TrainingArguments(
    output_dir="./results",          # Directory to save the model
    eval_strategy="epoch",          # Evaluate after each epoch (updated argument)
    learning_rate=2e-5,             # Learning rate
    per_device_train_batch_size=16, # Batch size for training
    per_device_eval_batch_size=16,  # Batch size for evaluation
    num_train_epochs=3,             # Number of epochs
    weight_decay=0.01,              # Weight decay
    logging_dir="./logs",           # Directory for logs
    logging_steps=10,               # Log every 10 steps
    save_strategy="epoch",          # Save checkpoint each epoch
    load_best_model_at_end=True,    # Load best model at the end of training
)

In [None]:
trainer = Trainer(
    model=model_new,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
)

trainer.train()

  trainer = Trainer(
  0%|          | 10/5625 [00:29<4:04:18,  2.61s/it]

{'loss': 0.6913, 'grad_norm': 1.0757776498794556, 'learning_rate': 1.9964444444444447e-05, 'epoch': 0.01}


  0%|          | 20/5625 [00:54<3:58:54,  2.56s/it]

{'loss': 0.6766, 'grad_norm': 1.3927199840545654, 'learning_rate': 1.992888888888889e-05, 'epoch': 0.01}


  1%|          | 30/5625 [01:19<3:56:47,  2.54s/it]

{'loss': 0.6389, 'grad_norm': 2.456381320953369, 'learning_rate': 1.9893333333333335e-05, 'epoch': 0.02}


  1%|          | 40/5625 [01:45<3:57:24,  2.55s/it]

{'loss': 0.5756, 'grad_norm': 2.932981491088867, 'learning_rate': 1.985777777777778e-05, 'epoch': 0.02}


  1%|          | 50/5625 [02:10<3:55:07,  2.53s/it]

{'loss': 0.4786, 'grad_norm': 6.559523105621338, 'learning_rate': 1.9822222222222226e-05, 'epoch': 0.03}


  1%|          | 60/5625 [02:35<3:54:45,  2.53s/it]

{'loss': 0.3612, 'grad_norm': 3.3584725856781006, 'learning_rate': 1.9786666666666668e-05, 'epoch': 0.03}


  1%|          | 70/5625 [03:01<3:54:27,  2.53s/it]

{'loss': 0.3962, 'grad_norm': 8.029385566711426, 'learning_rate': 1.9751111111111114e-05, 'epoch': 0.04}


  1%|▏         | 80/5625 [03:26<3:53:11,  2.52s/it]

{'loss': 0.3264, 'grad_norm': 9.017687797546387, 'learning_rate': 1.9715555555555556e-05, 'epoch': 0.04}


  2%|▏         | 90/5625 [03:51<3:57:24,  2.57s/it]

{'loss': 0.3786, 'grad_norm': 7.496423244476318, 'learning_rate': 1.968e-05, 'epoch': 0.05}


  2%|▏         | 100/5625 [04:18<4:10:45,  2.72s/it]

{'loss': 0.3244, 'grad_norm': 6.624642848968506, 'learning_rate': 1.9644444444444447e-05, 'epoch': 0.05}


  2%|▏         | 110/5625 [04:46<4:12:31,  2.75s/it]

{'loss': 0.4023, 'grad_norm': 4.734385967254639, 'learning_rate': 1.960888888888889e-05, 'epoch': 0.06}


  2%|▏         | 120/5625 [05:13<4:08:36,  2.71s/it]

{'loss': 0.3528, 'grad_norm': 10.13439655303955, 'learning_rate': 1.9573333333333335e-05, 'epoch': 0.06}


  2%|▏         | 130/5625 [05:40<4:06:01,  2.69s/it]

{'loss': 0.3575, 'grad_norm': 6.145967483520508, 'learning_rate': 1.953777777777778e-05, 'epoch': 0.07}


  2%|▏         | 140/5625 [06:07<4:06:24,  2.70s/it]

{'loss': 0.2396, 'grad_norm': 8.374120712280273, 'learning_rate': 1.9502222222222226e-05, 'epoch': 0.07}


  3%|▎         | 150/5625 [06:34<4:05:13,  2.69s/it]

{'loss': 0.3129, 'grad_norm': 10.080734252929688, 'learning_rate': 1.9466666666666668e-05, 'epoch': 0.08}


  3%|▎         | 160/5625 [07:01<4:03:44,  2.68s/it]

{'loss': 0.325, 'grad_norm': 8.429008483886719, 'learning_rate': 1.9431111111111113e-05, 'epoch': 0.09}


  3%|▎         | 170/5625 [07:28<4:05:03,  2.70s/it]

{'loss': 0.2917, 'grad_norm': 7.159471035003662, 'learning_rate': 1.9395555555555555e-05, 'epoch': 0.09}


  3%|▎         | 180/5625 [07:55<4:03:46,  2.69s/it]

{'loss': 0.195, 'grad_norm': 3.8266632556915283, 'learning_rate': 1.936e-05, 'epoch': 0.1}


  3%|▎         | 190/5625 [08:22<4:05:32,  2.71s/it]

{'loss': 0.3125, 'grad_norm': 12.619710922241211, 'learning_rate': 1.9324444444444447e-05, 'epoch': 0.1}


  4%|▎         | 200/5625 [08:49<4:03:48,  2.70s/it]

{'loss': 0.2434, 'grad_norm': 7.255180835723877, 'learning_rate': 1.928888888888889e-05, 'epoch': 0.11}


  4%|▎         | 210/5625 [09:16<4:02:02,  2.68s/it]

{'loss': 0.3675, 'grad_norm': 7.606292247772217, 'learning_rate': 1.9253333333333334e-05, 'epoch': 0.11}


  4%|▍         | 220/5625 [09:42<4:02:59,  2.70s/it]

{'loss': 0.3758, 'grad_norm': 5.968516826629639, 'learning_rate': 1.921777777777778e-05, 'epoch': 0.12}


  4%|▍         | 230/5625 [10:09<3:59:32,  2.66s/it]

{'loss': 0.2455, 'grad_norm': 9.170674324035645, 'learning_rate': 1.9182222222222225e-05, 'epoch': 0.12}


  4%|▍         | 240/5625 [10:36<3:59:29,  2.67s/it]

{'loss': 0.2999, 'grad_norm': 2.9509646892547607, 'learning_rate': 1.9146666666666667e-05, 'epoch': 0.13}


  4%|▍         | 250/5625 [11:03<4:01:02,  2.69s/it]

{'loss': 0.3168, 'grad_norm': 10.969223022460938, 'learning_rate': 1.9111111111111113e-05, 'epoch': 0.13}


  5%|▍         | 260/5625 [11:29<3:57:15,  2.65s/it]

{'loss': 0.2412, 'grad_norm': 8.459951400756836, 'learning_rate': 1.9075555555555555e-05, 'epoch': 0.14}


  5%|▍         | 270/5625 [11:56<4:01:04,  2.70s/it]

{'loss': 0.2508, 'grad_norm': 3.707149028778076, 'learning_rate': 1.904e-05, 'epoch': 0.14}


  5%|▍         | 280/5625 [12:23<4:01:37,  2.71s/it]

{'loss': 0.2428, 'grad_norm': 9.7522611618042, 'learning_rate': 1.9004444444444446e-05, 'epoch': 0.15}


  5%|▌         | 290/5625 [12:50<4:00:03,  2.70s/it]

{'loss': 0.3071, 'grad_norm': 16.806697845458984, 'learning_rate': 1.896888888888889e-05, 'epoch': 0.15}


  5%|▌         | 300/5625 [13:17<4:00:23,  2.71s/it]

{'loss': 0.2768, 'grad_norm': 8.48902702331543, 'learning_rate': 1.8933333333333334e-05, 'epoch': 0.16}


  6%|▌         | 310/5625 [13:44<3:56:53,  2.67s/it]

{'loss': 0.3604, 'grad_norm': 4.833974361419678, 'learning_rate': 1.889777777777778e-05, 'epoch': 0.17}


  6%|▌         | 320/5625 [14:11<4:01:10,  2.73s/it]

{'loss': 0.3532, 'grad_norm': 9.287692070007324, 'learning_rate': 1.8862222222222225e-05, 'epoch': 0.17}


  6%|▌         | 330/5625 [14:38<3:56:36,  2.68s/it]

{'loss': 0.2223, 'grad_norm': 3.173433303833008, 'learning_rate': 1.8826666666666667e-05, 'epoch': 0.18}


  6%|▌         | 340/5625 [15:05<3:51:59,  2.63s/it]

{'loss': 0.2896, 'grad_norm': 7.8174662590026855, 'learning_rate': 1.8791111111111113e-05, 'epoch': 0.18}


  6%|▌         | 350/5625 [15:31<3:51:04,  2.63s/it]

{'loss': 0.247, 'grad_norm': 6.554244518280029, 'learning_rate': 1.8755555555555558e-05, 'epoch': 0.19}


  6%|▋         | 360/5625 [15:58<3:53:13,  2.66s/it]

{'loss': 0.3593, 'grad_norm': 3.635120153427124, 'learning_rate': 1.8720000000000004e-05, 'epoch': 0.19}


  7%|▋         | 370/5625 [16:25<3:57:40,  2.71s/it]

{'loss': 0.3269, 'grad_norm': 8.25832748413086, 'learning_rate': 1.8684444444444446e-05, 'epoch': 0.2}


  7%|▋         | 380/5625 [16:52<4:00:03,  2.75s/it]

{'loss': 0.2589, 'grad_norm': 6.62159538269043, 'learning_rate': 1.8648888888888888e-05, 'epoch': 0.2}


  7%|▋         | 390/5625 [17:19<3:56:36,  2.71s/it]

{'loss': 0.3142, 'grad_norm': 9.64205551147461, 'learning_rate': 1.8613333333333334e-05, 'epoch': 0.21}


  7%|▋         | 400/5625 [17:47<4:00:07,  2.76s/it]

{'loss': 0.2657, 'grad_norm': 6.696769714355469, 'learning_rate': 1.857777777777778e-05, 'epoch': 0.21}


  7%|▋         | 410/5625 [18:14<3:58:41,  2.75s/it]

{'loss': 0.2942, 'grad_norm': 4.526856899261475, 'learning_rate': 1.8542222222222225e-05, 'epoch': 0.22}


  7%|▋         | 420/5625 [18:42<3:56:41,  2.73s/it]

{'loss': 0.2776, 'grad_norm': 15.033031463623047, 'learning_rate': 1.8506666666666667e-05, 'epoch': 0.22}


  8%|▊         | 430/5625 [19:09<3:57:17,  2.74s/it]

{'loss': 0.2989, 'grad_norm': 5.736593723297119, 'learning_rate': 1.8471111111111112e-05, 'epoch': 0.23}


  8%|▊         | 440/5625 [19:36<3:54:02,  2.71s/it]

{'loss': 0.2989, 'grad_norm': 4.316392421722412, 'learning_rate': 1.8435555555555558e-05, 'epoch': 0.23}


  8%|▊         | 450/5625 [20:03<3:53:50,  2.71s/it]

{'loss': 0.1939, 'grad_norm': 2.568378210067749, 'learning_rate': 1.8400000000000003e-05, 'epoch': 0.24}


  8%|▊         | 460/5625 [20:31<3:54:54,  2.73s/it]

{'loss': 0.229, 'grad_norm': 11.667183876037598, 'learning_rate': 1.8364444444444446e-05, 'epoch': 0.25}


  8%|▊         | 470/5625 [20:58<3:53:57,  2.72s/it]

{'loss': 0.2582, 'grad_norm': 11.034014701843262, 'learning_rate': 1.832888888888889e-05, 'epoch': 0.25}


  9%|▊         | 480/5625 [21:25<3:52:42,  2.71s/it]

{'loss': 0.2944, 'grad_norm': 6.830078125, 'learning_rate': 1.8293333333333333e-05, 'epoch': 0.26}


  9%|▊         | 490/5625 [21:52<3:50:45,  2.70s/it]

{'loss': 0.2513, 'grad_norm': 4.388125896453857, 'learning_rate': 1.825777777777778e-05, 'epoch': 0.26}


  9%|▉         | 500/5625 [22:19<3:54:43,  2.75s/it]

{'loss': 0.2418, 'grad_norm': 14.415300369262695, 'learning_rate': 1.8222222222222224e-05, 'epoch': 0.27}


  9%|▉         | 510/5625 [22:47<3:57:32,  2.79s/it]

{'loss': 0.2191, 'grad_norm': 6.5063934326171875, 'learning_rate': 1.8186666666666666e-05, 'epoch': 0.27}


  9%|▉         | 520/5625 [23:15<3:53:07,  2.74s/it]

{'loss': 0.1914, 'grad_norm': 4.6506781578063965, 'learning_rate': 1.8151111111111112e-05, 'epoch': 0.28}


  9%|▉         | 530/5625 [23:42<3:51:05,  2.72s/it]

{'loss': 0.352, 'grad_norm': 4.886295795440674, 'learning_rate': 1.8115555555555558e-05, 'epoch': 0.28}


 10%|▉         | 540/5625 [24:09<3:52:38,  2.74s/it]

{'loss': 0.2607, 'grad_norm': 8.531559944152832, 'learning_rate': 1.8080000000000003e-05, 'epoch': 0.29}


 10%|▉         | 550/5625 [24:37<3:52:57,  2.75s/it]

{'loss': 0.225, 'grad_norm': 5.085914611816406, 'learning_rate': 1.8044444444444445e-05, 'epoch': 0.29}


 10%|▉         | 560/5625 [25:04<3:50:54,  2.74s/it]

{'loss': 0.2471, 'grad_norm': 3.216747999191284, 'learning_rate': 1.800888888888889e-05, 'epoch': 0.3}


 10%|█         | 570/5625 [25:32<3:49:45,  2.73s/it]

{'loss': 0.3189, 'grad_norm': 7.305058479309082, 'learning_rate': 1.7973333333333333e-05, 'epoch': 0.3}


 10%|█         | 580/5625 [25:59<3:45:13,  2.68s/it]

{'loss': 0.1887, 'grad_norm': 4.74844217300415, 'learning_rate': 1.793777777777778e-05, 'epoch': 0.31}


 10%|█         | 590/5625 [26:26<3:48:26,  2.72s/it]

{'loss': 0.326, 'grad_norm': 13.131484031677246, 'learning_rate': 1.7902222222222224e-05, 'epoch': 0.31}


 11%|█         | 600/5625 [26:53<3:45:46,  2.70s/it]

{'loss': 0.4355, 'grad_norm': 8.238785743713379, 'learning_rate': 1.7866666666666666e-05, 'epoch': 0.32}


 11%|█         | 610/5625 [27:20<3:45:17,  2.70s/it]

{'loss': 0.2268, 'grad_norm': 8.31891918182373, 'learning_rate': 1.783111111111111e-05, 'epoch': 0.33}


 11%|█         | 620/5625 [27:47<3:47:26,  2.73s/it]

{'loss': 0.3022, 'grad_norm': 10.503755569458008, 'learning_rate': 1.7795555555555557e-05, 'epoch': 0.33}


 11%|█         | 630/5625 [28:14<3:47:39,  2.73s/it]

{'loss': 0.2784, 'grad_norm': 10.737902641296387, 'learning_rate': 1.7760000000000003e-05, 'epoch': 0.34}


 11%|█▏        | 640/5625 [28:42<3:47:27,  2.74s/it]

{'loss': 0.2005, 'grad_norm': 1.7203443050384521, 'learning_rate': 1.7724444444444445e-05, 'epoch': 0.34}


 12%|█▏        | 650/5625 [29:09<3:45:52,  2.72s/it]

{'loss': 0.2493, 'grad_norm': 5.000640869140625, 'learning_rate': 1.768888888888889e-05, 'epoch': 0.35}


 12%|█▏        | 660/5625 [29:36<3:43:16,  2.70s/it]

{'loss': 0.2526, 'grad_norm': 7.030618667602539, 'learning_rate': 1.7653333333333336e-05, 'epoch': 0.35}


 12%|█▏        | 670/5625 [30:03<3:45:21,  2.73s/it]

{'loss': 0.2732, 'grad_norm': 5.974021911621094, 'learning_rate': 1.761777777777778e-05, 'epoch': 0.36}


 12%|█▏        | 680/5625 [30:31<3:52:19,  2.82s/it]

{'loss': 0.3383, 'grad_norm': 11.876544952392578, 'learning_rate': 1.7582222222222224e-05, 'epoch': 0.36}


 12%|█▏        | 690/5625 [30:58<3:45:01,  2.74s/it]

{'loss': 0.2253, 'grad_norm': 5.734549522399902, 'learning_rate': 1.7546666666666666e-05, 'epoch': 0.37}


 12%|█▏        | 700/5625 [31:25<3:40:25,  2.69s/it]

{'loss': 0.2409, 'grad_norm': 5.807636260986328, 'learning_rate': 1.751111111111111e-05, 'epoch': 0.37}


 13%|█▎        | 710/5625 [31:52<3:40:34,  2.69s/it]

{'loss': 0.2335, 'grad_norm': 10.784811019897461, 'learning_rate': 1.7475555555555557e-05, 'epoch': 0.38}


 13%|█▎        | 720/5625 [32:20<3:43:56,  2.74s/it]

{'loss': 0.149, 'grad_norm': 1.1077824831008911, 'learning_rate': 1.7440000000000002e-05, 'epoch': 0.38}


 13%|█▎        | 730/5625 [32:47<3:44:12,  2.75s/it]

{'loss': 0.2249, 'grad_norm': 7.006741523742676, 'learning_rate': 1.7404444444444445e-05, 'epoch': 0.39}


 13%|█▎        | 740/5625 [33:14<3:41:28,  2.72s/it]

{'loss': 0.3977, 'grad_norm': 12.056144714355469, 'learning_rate': 1.736888888888889e-05, 'epoch': 0.39}


 13%|█▎        | 750/5625 [33:41<3:38:47,  2.69s/it]

{'loss': 0.2982, 'grad_norm': 17.632282257080078, 'learning_rate': 1.7333333333333336e-05, 'epoch': 0.4}


 14%|█▎        | 760/5625 [34:08<3:38:58,  2.70s/it]

{'loss': 0.2504, 'grad_norm': 13.545416831970215, 'learning_rate': 1.729777777777778e-05, 'epoch': 0.41}


 14%|█▎        | 770/5625 [34:36<3:42:12,  2.75s/it]

{'loss': 0.2558, 'grad_norm': 5.918613433837891, 'learning_rate': 1.7262222222222223e-05, 'epoch': 0.41}


 14%|█▍        | 780/5625 [35:03<3:39:51,  2.72s/it]

{'loss': 0.2227, 'grad_norm': 3.7579104900360107, 'learning_rate': 1.7226666666666665e-05, 'epoch': 0.42}


 14%|█▍        | 790/5625 [35:30<3:37:24,  2.70s/it]

{'loss': 0.2473, 'grad_norm': 8.422957420349121, 'learning_rate': 1.719111111111111e-05, 'epoch': 0.42}


 14%|█▍        | 800/5625 [35:57<3:36:40,  2.69s/it]

{'loss': 0.2338, 'grad_norm': 10.09135913848877, 'learning_rate': 1.7155555555555557e-05, 'epoch': 0.43}


 14%|█▍        | 810/5625 [36:24<3:39:11,  2.73s/it]

{'loss': 0.326, 'grad_norm': 1.8441932201385498, 'learning_rate': 1.7120000000000002e-05, 'epoch': 0.43}


 15%|█▍        | 820/5625 [36:51<3:36:15,  2.70s/it]

{'loss': 0.2386, 'grad_norm': 8.489684104919434, 'learning_rate': 1.7084444444444444e-05, 'epoch': 0.44}


 15%|█▍        | 830/5625 [37:18<3:36:41,  2.71s/it]

{'loss': 0.2085, 'grad_norm': 2.050258159637451, 'learning_rate': 1.704888888888889e-05, 'epoch': 0.44}


 15%|█▍        | 840/5625 [37:46<3:41:43,  2.78s/it]

{'loss': 0.1484, 'grad_norm': 8.556418418884277, 'learning_rate': 1.7013333333333335e-05, 'epoch': 0.45}


 15%|█▌        | 850/5625 [38:13<3:39:04,  2.75s/it]

{'loss': 0.2151, 'grad_norm': 15.983199119567871, 'learning_rate': 1.697777777777778e-05, 'epoch': 0.45}


 15%|█▌        | 860/5625 [38:41<3:40:57,  2.78s/it]

{'loss': 0.4639, 'grad_norm': 16.557130813598633, 'learning_rate': 1.6942222222222223e-05, 'epoch': 0.46}


 15%|█▌        | 870/5625 [39:08<3:35:38,  2.72s/it]

{'loss': 0.2152, 'grad_norm': 6.674469470977783, 'learning_rate': 1.690666666666667e-05, 'epoch': 0.46}


 16%|█▌        | 880/5625 [39:35<3:34:00,  2.71s/it]

{'loss': 0.1539, 'grad_norm': 10.337217330932617, 'learning_rate': 1.687111111111111e-05, 'epoch': 0.47}


 16%|█▌        | 890/5625 [40:02<3:33:25,  2.70s/it]

{'loss': 0.2329, 'grad_norm': 9.434793472290039, 'learning_rate': 1.6835555555555556e-05, 'epoch': 0.47}


 16%|█▌        | 900/5625 [40:30<3:35:18,  2.73s/it]

{'loss': 0.1659, 'grad_norm': 5.493890762329102, 'learning_rate': 1.6800000000000002e-05, 'epoch': 0.48}


 16%|█▌        | 910/5625 [40:57<3:34:00,  2.72s/it]

{'loss': 0.2617, 'grad_norm': 6.332819938659668, 'learning_rate': 1.6764444444444444e-05, 'epoch': 0.49}


 16%|█▋        | 920/5625 [41:24<3:34:29,  2.74s/it]

{'loss': 0.2751, 'grad_norm': 7.062108516693115, 'learning_rate': 1.672888888888889e-05, 'epoch': 0.49}


 17%|█▋        | 930/5625 [41:51<3:31:33,  2.70s/it]

{'loss': 0.1749, 'grad_norm': 4.748235702514648, 'learning_rate': 1.6693333333333335e-05, 'epoch': 0.5}


 17%|█▋        | 940/5625 [42:19<3:35:26,  2.76s/it]

{'loss': 0.2242, 'grad_norm': 6.867767810821533, 'learning_rate': 1.665777777777778e-05, 'epoch': 0.5}


 17%|█▋        | 950/5625 [42:47<3:34:59,  2.76s/it]

{'loss': 0.2207, 'grad_norm': 2.1321587562561035, 'learning_rate': 1.6622222222222223e-05, 'epoch': 0.51}


 17%|█▋        | 960/5625 [43:13<3:27:17,  2.67s/it]

{'loss': 0.3239, 'grad_norm': 7.744108200073242, 'learning_rate': 1.6586666666666668e-05, 'epoch': 0.51}


 17%|█▋        | 970/5625 [43:40<3:25:40,  2.65s/it]

{'loss': 0.2161, 'grad_norm': 3.4093165397644043, 'learning_rate': 1.6551111111111114e-05, 'epoch': 0.52}


 17%|█▋        | 980/5625 [44:07<3:30:21,  2.72s/it]

{'loss': 0.3407, 'grad_norm': 10.257011413574219, 'learning_rate': 1.651555555555556e-05, 'epoch': 0.52}


 18%|█▊        | 990/5625 [44:34<3:27:52,  2.69s/it]

{'loss': 0.2633, 'grad_norm': 11.209884643554688, 'learning_rate': 1.648e-05, 'epoch': 0.53}


 18%|█▊        | 1000/5625 [45:00<3:23:35,  2.64s/it]

{'loss': 0.1393, 'grad_norm': 6.915460586547852, 'learning_rate': 1.6444444444444444e-05, 'epoch': 0.53}


 18%|█▊        | 1010/5625 [45:27<3:23:05,  2.64s/it]

{'loss': 0.2989, 'grad_norm': 13.20694351196289, 'learning_rate': 1.640888888888889e-05, 'epoch': 0.54}


 18%|█▊        | 1020/5625 [45:53<3:23:16,  2.65s/it]

{'loss': 0.3122, 'grad_norm': 13.23774528503418, 'learning_rate': 1.6373333333333335e-05, 'epoch': 0.54}


 18%|█▊        | 1030/5625 [46:20<3:26:48,  2.70s/it]

{'loss': 0.2971, 'grad_norm': 5.748946189880371, 'learning_rate': 1.633777777777778e-05, 'epoch': 0.55}


 18%|█▊        | 1040/5625 [46:47<3:26:15,  2.70s/it]

{'loss': 0.2338, 'grad_norm': 6.457972526550293, 'learning_rate': 1.6302222222222222e-05, 'epoch': 0.55}


 19%|█▊        | 1050/5625 [47:14<3:24:02,  2.68s/it]

{'loss': 0.3085, 'grad_norm': 13.74535846710205, 'learning_rate': 1.6266666666666668e-05, 'epoch': 0.56}


 19%|█▉        | 1060/5625 [47:40<3:22:00,  2.66s/it]

{'loss': 0.159, 'grad_norm': 6.1180853843688965, 'learning_rate': 1.6231111111111113e-05, 'epoch': 0.57}


 19%|█▉        | 1070/5625 [48:07<3:24:23,  2.69s/it]

{'loss': 0.2016, 'grad_norm': 1.4568464756011963, 'learning_rate': 1.619555555555556e-05, 'epoch': 0.57}


 19%|█▉        | 1080/5625 [48:34<3:21:13,  2.66s/it]

{'loss': 0.171, 'grad_norm': 8.095763206481934, 'learning_rate': 1.616e-05, 'epoch': 0.58}


 19%|█▉        | 1090/5625 [49:00<3:18:37,  2.63s/it]

{'loss': 0.3233, 'grad_norm': 12.538481712341309, 'learning_rate': 1.6124444444444443e-05, 'epoch': 0.58}


 20%|█▉        | 1100/5625 [49:26<3:17:08,  2.61s/it]

{'loss': 0.2243, 'grad_norm': 2.875605821609497, 'learning_rate': 1.608888888888889e-05, 'epoch': 0.59}


 20%|█▉        | 1110/5625 [49:53<3:20:38,  2.67s/it]

{'loss': 0.2715, 'grad_norm': 5.928725242614746, 'learning_rate': 1.6053333333333334e-05, 'epoch': 0.59}


 20%|█▉        | 1120/5625 [50:18<3:19:29,  2.66s/it]

{'loss': 0.2464, 'grad_norm': 2.8307576179504395, 'learning_rate': 1.601777777777778e-05, 'epoch': 0.6}


 20%|██        | 1130/5625 [50:44<3:14:59,  2.60s/it]

{'loss': 0.2078, 'grad_norm': 8.62723445892334, 'learning_rate': 1.5982222222222222e-05, 'epoch': 0.6}


 20%|██        | 1140/5625 [51:09<3:09:05,  2.53s/it]

{'loss': 0.3249, 'grad_norm': 1.7110470533370972, 'learning_rate': 1.5946666666666668e-05, 'epoch': 0.61}


 20%|██        | 1150/5625 [51:34<3:06:53,  2.51s/it]

{'loss': 0.251, 'grad_norm': 11.364871978759766, 'learning_rate': 1.5911111111111113e-05, 'epoch': 0.61}


 21%|██        | 1160/5625 [52:02<3:28:17,  2.80s/it]

{'loss': 0.2146, 'grad_norm': 6.277186393737793, 'learning_rate': 1.587555555555556e-05, 'epoch': 0.62}


 21%|██        | 1170/5625 [52:28<3:08:27,  2.54s/it]

{'loss': 0.1724, 'grad_norm': 4.517725467681885, 'learning_rate': 1.584e-05, 'epoch': 0.62}


 21%|██        | 1180/5625 [52:53<3:04:41,  2.49s/it]

{'loss': 0.2876, 'grad_norm': 7.320858001708984, 'learning_rate': 1.5804444444444446e-05, 'epoch': 0.63}


 21%|██        | 1190/5625 [53:18<3:07:01,  2.53s/it]

{'loss': 0.2097, 'grad_norm': 2.4553780555725098, 'learning_rate': 1.576888888888889e-05, 'epoch': 0.63}


 21%|██▏       | 1200/5625 [53:43<3:05:00,  2.51s/it]

{'loss': 0.1817, 'grad_norm': 5.256978511810303, 'learning_rate': 1.5733333333333334e-05, 'epoch': 0.64}


 22%|██▏       | 1210/5625 [54:08<3:03:25,  2.49s/it]

{'loss': 0.2002, 'grad_norm': 11.39245319366455, 'learning_rate': 1.569777777777778e-05, 'epoch': 0.65}


 22%|██▏       | 1220/5625 [54:33<3:00:18,  2.46s/it]

{'loss': 0.2436, 'grad_norm': 7.446401119232178, 'learning_rate': 1.5662222222222222e-05, 'epoch': 0.65}


 22%|██▏       | 1230/5625 [54:57<2:57:10,  2.42s/it]

{'loss': 0.3104, 'grad_norm': 6.895421028137207, 'learning_rate': 1.5626666666666667e-05, 'epoch': 0.66}


 22%|██▏       | 1240/5625 [55:21<2:56:51,  2.42s/it]

{'loss': 0.2377, 'grad_norm': 4.99259090423584, 'learning_rate': 1.5591111111111113e-05, 'epoch': 0.66}


 22%|██▏       | 1250/5625 [55:46<2:56:10,  2.42s/it]

{'loss': 0.3064, 'grad_norm': 6.466382026672363, 'learning_rate': 1.555555555555556e-05, 'epoch': 0.67}


 22%|██▏       | 1260/5625 [56:10<2:56:07,  2.42s/it]

{'loss': 0.204, 'grad_norm': 4.862316131591797, 'learning_rate': 1.552e-05, 'epoch': 0.67}


 23%|██▎       | 1270/5625 [56:34<2:55:11,  2.41s/it]

{'loss': 0.2079, 'grad_norm': 8.428715705871582, 'learning_rate': 1.5484444444444446e-05, 'epoch': 0.68}


 23%|██▎       | 1280/5625 [56:58<2:54:59,  2.42s/it]

{'loss': 0.2408, 'grad_norm': 5.714728832244873, 'learning_rate': 1.544888888888889e-05, 'epoch': 0.68}


 23%|██▎       | 1290/5625 [57:22<2:58:37,  2.47s/it]

{'loss': 0.2337, 'grad_norm': 10.027798652648926, 'learning_rate': 1.5413333333333337e-05, 'epoch': 0.69}


 23%|██▎       | 1300/5625 [57:47<2:54:42,  2.42s/it]

{'loss': 0.2405, 'grad_norm': 12.364171028137207, 'learning_rate': 1.537777777777778e-05, 'epoch': 0.69}


 23%|██▎       | 1310/5625 [58:11<2:54:14,  2.42s/it]

{'loss': 0.2376, 'grad_norm': 7.690675735473633, 'learning_rate': 1.534222222222222e-05, 'epoch': 0.7}


 23%|██▎       | 1320/5625 [58:35<2:53:14,  2.41s/it]

{'loss': 0.2064, 'grad_norm': 2.14184832572937, 'learning_rate': 1.5306666666666667e-05, 'epoch': 0.7}


 24%|██▎       | 1330/5625 [58:59<2:52:16,  2.41s/it]

{'loss': 0.2426, 'grad_norm': 5.736889839172363, 'learning_rate': 1.5271111111111112e-05, 'epoch': 0.71}


 24%|██▍       | 1340/5625 [59:23<2:52:02,  2.41s/it]

{'loss': 0.2613, 'grad_norm': 6.879069805145264, 'learning_rate': 1.5235555555555556e-05, 'epoch': 0.71}


 24%|██▍       | 1350/5625 [59:49<2:58:21,  2.50s/it]

{'loss': 0.2425, 'grad_norm': 8.424748420715332, 'learning_rate': 1.5200000000000002e-05, 'epoch': 0.72}


 24%|██▍       | 1360/5625 [1:00:14<2:56:29,  2.48s/it]

{'loss': 0.1776, 'grad_norm': 9.644362449645996, 'learning_rate': 1.5164444444444446e-05, 'epoch': 0.73}


 24%|██▍       | 1370/5625 [1:00:38<2:53:09,  2.44s/it]

{'loss': 0.216, 'grad_norm': 12.625131607055664, 'learning_rate': 1.5128888888888891e-05, 'epoch': 0.73}


 25%|██▍       | 1380/5625 [1:01:03<2:54:01,  2.46s/it]

{'loss': 0.2429, 'grad_norm': 1.2342337369918823, 'learning_rate': 1.5093333333333335e-05, 'epoch': 0.74}


 25%|██▍       | 1390/5625 [1:01:27<2:52:15,  2.44s/it]

{'loss': 0.4043, 'grad_norm': 5.366588115692139, 'learning_rate': 1.505777777777778e-05, 'epoch': 0.74}


 25%|██▍       | 1400/5625 [1:01:52<2:50:02,  2.41s/it]

{'loss': 0.1984, 'grad_norm': 4.821699619293213, 'learning_rate': 1.5022222222222223e-05, 'epoch': 0.75}


 25%|██▌       | 1410/5625 [1:02:16<2:56:15,  2.51s/it]

{'loss': 0.2805, 'grad_norm': 10.031403541564941, 'learning_rate': 1.4986666666666667e-05, 'epoch': 0.75}


 25%|██▌       | 1411/5625 [1:02:19<2:56:29,  2.51s/it]

In [9]:
import torch

# Check PyTorch version
print("PyTorch version:", torch.__version__)

# Check MPS availability
if torch.backends.mps.is_available():
    print("MPS is available and will be used for acceleration!")
else:
    print("MPS is not available. Running on CPU.")


PyTorch version: 2.2.2
MPS is available and will be used for acceleration!
