In [205]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from datasets import load_dataset
from timm import create_model
from transformers import ViTImageProcessor
from transformers import ViTConfig, ViTModel
import evaluate 
import numpy as np


In [206]:
from transformers import ViTImageProcessor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTImageProcessor.from_pretrained(model_name_or_path, return_tensors='pt')


In [207]:
# Initializing a ViT vit-base-patch16-224 style configuration
configuration = ViTConfig()

# Initializing a model (with random weights) from the vit-base-patch16-224 style configuration
model = ViTModel(configuration)



In [208]:
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['label']
    return inputs


In [209]:
ds = load_dataset("C:/Tesis/DatasetBinario", num_proc=3)

Resolving data files:   0%|          | 0/20959 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/668 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/768 [00:00<?, ?it/s]

In [210]:
dataset_train = ds['train']
dataset_test = ds['test']
num_classes = len(set(dataset_train['label']))
labels = ds['train'].features['label']
print(num_classes, labels)

2 ClassLabel(names=['Melanoma', 'No Melanoma'], id=None)


In [211]:
example = feature_extractor(dataset_train[0]['image'], return_tensors = 'pt')

In [212]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [213]:
def preprocess(batch):
    inputs = feature_extractor(batch['image'], return_tensors='pt').to(device)
    inputs['label'] = batch['label']
    return inputs

In [214]:
prepared_train = dataset_train.with_transform(preprocess)
prepared_test = dataset_test.with_transform(preprocess)

In [215]:
print(prepared_train.features)

{'image': Image(mode=None, decode=True, id=None), 'label': ClassLabel(names=['Melanoma', 'No Melanoma'], id=None)}


In [216]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }


In [217]:
from transformers import ViTForImageClassification

labels = ds['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
).to(device)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [218]:
metric = evaluate.load("accuracy")

In [219]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [220]:
from transformers import TrainingArguments, Trainer


training_args = TrainingArguments(
    output_dir='./results', 
    per_device_eval_batch_size=16,
    evaluation_strategy="steps",
    num_train_epochs=4,
    save_steps=1000,
    eval_steps=500,
    logging_steps=10,
    learning_rate=2e-5,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    load_best_model_at_end=True,
    dataloader_pin_memory=False
    
    )

In [221]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset= prepared_train,
    eval_dataset= prepared_test,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    tokenizer=feature_extractor
)

  trainer = Trainer(


In [222]:
trainer.train()

  0%|          | 0/10480 [00:00<?, ?it/s]

{'loss': 0.6905, 'grad_norm': 1.552635669708252, 'learning_rate': 1.9980916030534353e-05, 'epoch': 0.0}
{'loss': 0.6801, 'grad_norm': 2.5676097869873047, 'learning_rate': 1.9961832061068704e-05, 'epoch': 0.01}
{'loss': 0.6664, 'grad_norm': 2.0010132789611816, 'learning_rate': 1.9942748091603055e-05, 'epoch': 0.01}
{'loss': 0.6459, 'grad_norm': 2.212742567062378, 'learning_rate': 1.9923664122137406e-05, 'epoch': 0.02}
{'loss': 0.6522, 'grad_norm': 2.185227155685425, 'learning_rate': 1.9904580152671757e-05, 'epoch': 0.02}
{'loss': 0.6228, 'grad_norm': 1.8474783897399902, 'learning_rate': 1.988549618320611e-05, 'epoch': 0.02}
{'loss': 0.6154, 'grad_norm': 2.6391870975494385, 'learning_rate': 1.986641221374046e-05, 'epoch': 0.03}
{'loss': 0.5935, 'grad_norm': 2.31415057182312, 'learning_rate': 1.984732824427481e-05, 'epoch': 0.03}
{'loss': 0.6021, 'grad_norm': 2.0074756145477295, 'learning_rate': 1.9828244274809162e-05, 'epoch': 0.03}
{'loss': 0.5949, 'grad_norm': 3.152873992919922, 'learn

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.47657182812690735, 'eval_accuracy': 0.7890625, 'eval_runtime': 38.4816, 'eval_samples_per_second': 19.958, 'eval_steps_per_second': 1.247, 'epoch': 0.19}
{'loss': 0.5834, 'grad_norm': 1.3982537984848022, 'learning_rate': 1.902671755725191e-05, 'epoch': 0.19}
{'loss': 0.4017, 'grad_norm': 11.956353187561035, 'learning_rate': 1.900763358778626e-05, 'epoch': 0.2}
{'loss': 0.3618, 'grad_norm': 4.305224418640137, 'learning_rate': 1.8988549618320614e-05, 'epoch': 0.2}
{'loss': 0.5166, 'grad_norm': 4.460570812225342, 'learning_rate': 1.8969465648854962e-05, 'epoch': 0.21}
{'loss': 0.5058, 'grad_norm': 8.367267608642578, 'learning_rate': 1.8950381679389313e-05, 'epoch': 0.21}
{'loss': 0.5071, 'grad_norm': 15.515195846557617, 'learning_rate': 1.8931297709923668e-05, 'epoch': 0.21}
{'loss': 0.5324, 'grad_norm': 5.047912120819092, 'learning_rate': 1.8912213740458016e-05, 'epoch': 0.22}
{'loss': 0.5067, 'grad_norm': 4.497890472412109, 'learning_rate': 1.889312977099237e-05, 'epoch'

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.40661001205444336, 'eval_accuracy': 0.828125, 'eval_runtime': 38.185, 'eval_samples_per_second': 20.113, 'eval_steps_per_second': 1.257, 'epoch': 0.38}
{'loss': 0.468, 'grad_norm': 3.6377511024475098, 'learning_rate': 1.8072519083969465e-05, 'epoch': 0.39}
{'loss': 0.4568, 'grad_norm': 4.04453706741333, 'learning_rate': 1.805343511450382e-05, 'epoch': 0.39}
{'loss': 0.4747, 'grad_norm': 3.9430742263793945, 'learning_rate': 1.803435114503817e-05, 'epoch': 0.39}
{'loss': 0.5022, 'grad_norm': 4.299839019775391, 'learning_rate': 1.8015267175572518e-05, 'epoch': 0.4}
{'loss': 0.4006, 'grad_norm': 2.2606661319732666, 'learning_rate': 1.7996183206106873e-05, 'epoch': 0.4}
{'loss': 0.5015, 'grad_norm': 4.632437229156494, 'learning_rate': 1.7977099236641224e-05, 'epoch': 0.4}
{'loss': 0.5348, 'grad_norm': 4.796138763427734, 'learning_rate': 1.7958015267175575e-05, 'epoch': 0.41}
{'loss': 0.4248, 'grad_norm': 1.5784387588500977, 'learning_rate': 1.7938931297709926e-05, 'epoch': 0

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.47458234429359436, 'eval_accuracy': 0.7981770833333334, 'eval_runtime': 40.8424, 'eval_samples_per_second': 18.804, 'eval_steps_per_second': 1.175, 'epoch': 0.57}
{'loss': 0.3856, 'grad_norm': 3.917534351348877, 'learning_rate': 1.7118320610687024e-05, 'epoch': 0.58}
{'loss': 0.3238, 'grad_norm': 6.601779937744141, 'learning_rate': 1.7099236641221375e-05, 'epoch': 0.58}
{'loss': 0.4469, 'grad_norm': 6.522968769073486, 'learning_rate': 1.7080152671755727e-05, 'epoch': 0.58}
{'loss': 0.4154, 'grad_norm': 6.177722930908203, 'learning_rate': 1.7061068702290078e-05, 'epoch': 0.59}
{'loss': 0.3138, 'grad_norm': 1.0613584518432617, 'learning_rate': 1.704198473282443e-05, 'epoch': 0.59}
{'loss': 0.4278, 'grad_norm': 6.441293239593506, 'learning_rate': 1.702290076335878e-05, 'epoch': 0.6}
{'loss': 0.3378, 'grad_norm': 4.462367534637451, 'learning_rate': 1.700381679389313e-05, 'epoch': 0.6}
{'loss': 0.4336, 'grad_norm': 3.110610246658325, 'learning_rate': 1.6984732824427482e-05, 

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.40182116627693176, 'eval_accuracy': 0.83984375, 'eval_runtime': 39.087, 'eval_samples_per_second': 19.648, 'eval_steps_per_second': 1.228, 'epoch': 0.76}
{'loss': 0.4367, 'grad_norm': 3.688586473464966, 'learning_rate': 1.616412213740458e-05, 'epoch': 0.77}
{'loss': 0.2953, 'grad_norm': 2.413177013397217, 'learning_rate': 1.6145038167938935e-05, 'epoch': 0.77}
{'loss': 0.3656, 'grad_norm': 6.669562816619873, 'learning_rate': 1.6125954198473283e-05, 'epoch': 0.77}
{'loss': 0.4591, 'grad_norm': 5.1793904304504395, 'learning_rate': 1.6106870229007634e-05, 'epoch': 0.78}
{'loss': 0.4789, 'grad_norm': 0.6805974841117859, 'learning_rate': 1.6087786259541985e-05, 'epoch': 0.78}
{'loss': 0.3778, 'grad_norm': 8.919229507446289, 'learning_rate': 1.6068702290076336e-05, 'epoch': 0.79}
{'loss': 0.3643, 'grad_norm': 3.5672786235809326, 'learning_rate': 1.604961832061069e-05, 'epoch': 0.79}
{'loss': 0.3638, 'grad_norm': 9.20043659210205, 'learning_rate': 1.6030534351145038e-05, 'epoc

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.37510597705841064, 'eval_accuracy': 0.8619791666666666, 'eval_runtime': 96.6479, 'eval_samples_per_second': 7.946, 'eval_steps_per_second': 0.497, 'epoch': 0.95}
{'loss': 0.4461, 'grad_norm': 2.7196547985076904, 'learning_rate': 1.520992366412214e-05, 'epoch': 0.96}
{'loss': 0.2865, 'grad_norm': 8.064708709716797, 'learning_rate': 1.519083969465649e-05, 'epoch': 0.96}
{'loss': 0.3231, 'grad_norm': 3.082542896270752, 'learning_rate': 1.517175572519084e-05, 'epoch': 0.97}
{'loss': 0.4351, 'grad_norm': 1.666566252708435, 'learning_rate': 1.5152671755725193e-05, 'epoch': 0.97}
{'loss': 0.3671, 'grad_norm': 7.271581172943115, 'learning_rate': 1.5133587786259543e-05, 'epoch': 0.97}
{'loss': 0.3594, 'grad_norm': 7.675926208496094, 'learning_rate': 1.5114503816793895e-05, 'epoch': 0.98}
{'loss': 0.3055, 'grad_norm': 2.544844627380371, 'learning_rate': 1.5095419847328245e-05, 'epoch': 0.98}
{'loss': 0.3108, 'grad_norm': 2.3471827507019043, 'learning_rate': 1.5076335877862596e-05

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.3272438943386078, 'eval_accuracy': 0.8776041666666666, 'eval_runtime': 39.6634, 'eval_samples_per_second': 19.363, 'eval_steps_per_second': 1.21, 'epoch': 1.15}
{'loss': 0.3298, 'grad_norm': 7.682697772979736, 'learning_rate': 1.4255725190839696e-05, 'epoch': 1.15}
{'loss': 0.244, 'grad_norm': 7.023153781890869, 'learning_rate': 1.4236641221374049e-05, 'epoch': 1.15}
{'loss': 0.322, 'grad_norm': 1.2836869955062866, 'learning_rate': 1.4217557251908398e-05, 'epoch': 1.16}
{'loss': 0.3124, 'grad_norm': 5.359719276428223, 'learning_rate': 1.4198473282442749e-05, 'epoch': 1.16}
{'loss': 0.2671, 'grad_norm': 1.5788213014602661, 'learning_rate': 1.41793893129771e-05, 'epoch': 1.16}
{'loss': 0.4116, 'grad_norm': 4.87848424911499, 'learning_rate': 1.4160305343511451e-05, 'epoch': 1.17}
{'loss': 0.2252, 'grad_norm': 2.2266347408294678, 'learning_rate': 1.4141221374045804e-05, 'epoch': 1.17}
{'loss': 0.2572, 'grad_norm': 5.951173305511475, 'learning_rate': 1.4122137404580154e-05, 

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.3653348982334137, 'eval_accuracy': 0.8684895833333334, 'eval_runtime': 41.1424, 'eval_samples_per_second': 18.667, 'eval_steps_per_second': 1.167, 'epoch': 1.34}
{'loss': 0.4145, 'grad_norm': 12.314449310302734, 'learning_rate': 1.3301526717557254e-05, 'epoch': 1.34}
{'loss': 0.2531, 'grad_norm': 15.70746898651123, 'learning_rate': 1.3282442748091605e-05, 'epoch': 1.34}
{'loss': 0.4129, 'grad_norm': 10.081790924072266, 'learning_rate': 1.3263358778625954e-05, 'epoch': 1.35}
{'loss': 0.2236, 'grad_norm': 1.6651313304901123, 'learning_rate': 1.3244274809160307e-05, 'epoch': 1.35}
{'loss': 0.4765, 'grad_norm': 12.385013580322266, 'learning_rate': 1.3225190839694656e-05, 'epoch': 1.35}
{'loss': 0.3243, 'grad_norm': 4.126578330993652, 'learning_rate': 1.3206106870229009e-05, 'epoch': 1.36}
{'loss': 0.2258, 'grad_norm': 4.643072128295898, 'learning_rate': 1.318702290076336e-05, 'epoch': 1.36}
{'loss': 0.1764, 'grad_norm': 9.785735130310059, 'learning_rate': 1.316793893129771e

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.3570276200771332, 'eval_accuracy': 0.8606770833333334, 'eval_runtime': 38.7944, 'eval_samples_per_second': 19.797, 'eval_steps_per_second': 1.237, 'epoch': 1.53}
{'loss': 0.2591, 'grad_norm': 16.431716918945312, 'learning_rate': 1.234732824427481e-05, 'epoch': 1.53}
{'loss': 0.2235, 'grad_norm': 10.535872459411621, 'learning_rate': 1.2328244274809162e-05, 'epoch': 1.53}
{'loss': 0.094, 'grad_norm': 2.7303225994110107, 'learning_rate': 1.2309160305343514e-05, 'epoch': 1.54}
{'loss': 0.5461, 'grad_norm': 5.5601959228515625, 'learning_rate': 1.2290076335877863e-05, 'epoch': 1.54}
{'loss': 0.2462, 'grad_norm': 9.193886756896973, 'learning_rate': 1.2270992366412216e-05, 'epoch': 1.55}
{'loss': 0.2121, 'grad_norm': 7.927665710449219, 'learning_rate': 1.2251908396946565e-05, 'epoch': 1.55}
{'loss': 0.269, 'grad_norm': 9.644990921020508, 'learning_rate': 1.2232824427480916e-05, 'epoch': 1.55}
{'loss': 0.2601, 'grad_norm': 9.923142433166504, 'learning_rate': 1.2213740458015269e-

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.3869473934173584, 'eval_accuracy': 0.8697916666666666, 'eval_runtime': 39.4568, 'eval_samples_per_second': 19.464, 'eval_steps_per_second': 1.217, 'epoch': 1.72}
{'loss': 0.1562, 'grad_norm': 8.081207275390625, 'learning_rate': 1.1393129770992369e-05, 'epoch': 1.72}
{'loss': 0.2906, 'grad_norm': 0.13061267137527466, 'learning_rate': 1.1374045801526718e-05, 'epoch': 1.73}
{'loss': 0.3291, 'grad_norm': 20.25176239013672, 'learning_rate': 1.135496183206107e-05, 'epoch': 1.73}
{'loss': 0.2453, 'grad_norm': 6.555213451385498, 'learning_rate': 1.133587786259542e-05, 'epoch': 1.73}
{'loss': 0.2018, 'grad_norm': 7.30574893951416, 'learning_rate': 1.1316793893129772e-05, 'epoch': 1.74}
{'loss': 0.1838, 'grad_norm': 0.5705980658531189, 'learning_rate': 1.1297709923664125e-05, 'epoch': 1.74}
{'loss': 0.3289, 'grad_norm': 16.68720817565918, 'learning_rate': 1.1278625954198474e-05, 'epoch': 1.74}
{'loss': 0.3884, 'grad_norm': 32.546539306640625, 'learning_rate': 1.1259541984732825e-

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.26207318902015686, 'eval_accuracy': 0.9153645833333334, 'eval_runtime': 37.9845, 'eval_samples_per_second': 20.219, 'eval_steps_per_second': 1.264, 'epoch': 1.91}
{'loss': 0.2589, 'grad_norm': 19.93729591369629, 'learning_rate': 1.0438931297709925e-05, 'epoch': 1.91}
{'loss': 0.2708, 'grad_norm': 0.6772887706756592, 'learning_rate': 1.0419847328244274e-05, 'epoch': 1.92}
{'loss': 0.2541, 'grad_norm': 12.062906265258789, 'learning_rate': 1.0400763358778627e-05, 'epoch': 1.92}
{'loss': 0.3142, 'grad_norm': 36.21210861206055, 'learning_rate': 1.0381679389312977e-05, 'epoch': 1.92}
{'loss': 0.2464, 'grad_norm': 16.43659019470215, 'learning_rate': 1.036259541984733e-05, 'epoch': 1.93}
{'loss': 0.2321, 'grad_norm': 5.877859592437744, 'learning_rate': 1.034351145038168e-05, 'epoch': 1.93}
{'loss': 0.3302, 'grad_norm': 15.939898490905762, 'learning_rate': 1.032442748091603e-05, 'epoch': 1.94}
{'loss': 0.2147, 'grad_norm': 17.8555965423584, 'learning_rate': 1.0305343511450383e-0

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.30510497093200684, 'eval_accuracy': 0.9192708333333334, 'eval_runtime': 38.4104, 'eval_samples_per_second': 19.995, 'eval_steps_per_second': 1.25, 'epoch': 2.1}
{'loss': 0.0933, 'grad_norm': 20.136791229248047, 'learning_rate': 9.484732824427481e-06, 'epoch': 2.1}
{'loss': 0.1284, 'grad_norm': 0.97950279712677, 'learning_rate': 9.465648854961834e-06, 'epoch': 2.11}
{'loss': 0.2031, 'grad_norm': 3.167043924331665, 'learning_rate': 9.446564885496185e-06, 'epoch': 2.11}
{'loss': 0.1916, 'grad_norm': 13.306791305541992, 'learning_rate': 9.427480916030534e-06, 'epoch': 2.11}
{'loss': 0.0893, 'grad_norm': 0.254468709230423, 'learning_rate': 9.408396946564886e-06, 'epoch': 2.12}
{'loss': 0.3082, 'grad_norm': 0.13270103931427002, 'learning_rate': 9.389312977099237e-06, 'epoch': 2.12}
{'loss': 0.1103, 'grad_norm': 0.485128790140152, 'learning_rate': 9.37022900763359e-06, 'epoch': 2.13}
{'loss': 0.2137, 'grad_norm': 6.281590461730957, 'learning_rate': 9.351145038167939e-06, 'epoc

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.34112048149108887, 'eval_accuracy': 0.90234375, 'eval_runtime': 38.831, 'eval_samples_per_second': 19.778, 'eval_steps_per_second': 1.236, 'epoch': 2.29}
{'loss': 0.1135, 'grad_norm': 4.1303935050964355, 'learning_rate': 8.530534351145039e-06, 'epoch': 2.29}
{'loss': 0.1948, 'grad_norm': 0.11243554949760437, 'learning_rate': 8.51145038167939e-06, 'epoch': 2.3}
{'loss': 0.2259, 'grad_norm': 2.5493505001068115, 'learning_rate': 8.492366412213741e-06, 'epoch': 2.3}
{'loss': 0.2569, 'grad_norm': 10.9332275390625, 'learning_rate': 8.473282442748092e-06, 'epoch': 2.31}
{'loss': 0.1359, 'grad_norm': 1.6692614555358887, 'learning_rate': 8.454198473282443e-06, 'epoch': 2.31}
{'loss': 0.1142, 'grad_norm': 0.24639663100242615, 'learning_rate': 8.435114503816794e-06, 'epoch': 2.31}
{'loss': 0.2043, 'grad_norm': 0.4010794758796692, 'learning_rate': 8.416030534351146e-06, 'epoch': 2.32}
{'loss': 0.0824, 'grad_norm': 1.49448561668396, 'learning_rate': 8.396946564885497e-06, 'epoch': 2

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.34681010246276855, 'eval_accuracy': 0.9049479166666666, 'eval_runtime': 188.894, 'eval_samples_per_second': 4.066, 'eval_steps_per_second': 0.254, 'epoch': 2.48}
{'loss': 0.0565, 'grad_norm': 0.07635242491960526, 'learning_rate': 7.5763358778625966e-06, 'epoch': 2.48}
{'loss': 0.2311, 'grad_norm': 0.10588251054286957, 'learning_rate': 7.557251908396948e-06, 'epoch': 2.49}
{'loss': 0.0722, 'grad_norm': 3.31844425201416, 'learning_rate': 7.538167938931298e-06, 'epoch': 2.49}
{'loss': 0.1606, 'grad_norm': 0.18084588646888733, 'learning_rate': 7.519083969465649e-06, 'epoch': 2.5}
{'loss': 0.327, 'grad_norm': 7.961402893066406, 'learning_rate': 7.500000000000001e-06, 'epoch': 2.5}
{'loss': 0.1244, 'grad_norm': 0.352988600730896, 'learning_rate': 7.480916030534352e-06, 'epoch': 2.5}
{'loss': 0.0897, 'grad_norm': 1.2033518552780151, 'learning_rate': 7.461832061068703e-06, 'epoch': 2.51}
{'loss': 0.2713, 'grad_norm': 0.8200973868370056, 'learning_rate': 7.4427480916030536e-06, 

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.3500773012638092, 'eval_accuracy': 0.9075520833333334, 'eval_runtime': 116.0406, 'eval_samples_per_second': 6.618, 'eval_steps_per_second': 0.414, 'epoch': 2.67}
{'loss': 0.1343, 'grad_norm': 8.155163764953613, 'learning_rate': 6.6221374045801534e-06, 'epoch': 2.68}
{'loss': 0.3252, 'grad_norm': 9.5454683303833, 'learning_rate': 6.6030534351145046e-06, 'epoch': 2.68}
{'loss': 0.1222, 'grad_norm': 22.338550567626953, 'learning_rate': 6.583969465648855e-06, 'epoch': 2.68}
{'loss': 0.1654, 'grad_norm': 0.4974415600299835, 'learning_rate': 6.564885496183207e-06, 'epoch': 2.69}
{'loss': 0.1539, 'grad_norm': 0.15317018330097198, 'learning_rate': 6.545801526717558e-06, 'epoch': 2.69}
{'loss': 0.0668, 'grad_norm': 0.23272329568862915, 'learning_rate': 6.526717557251909e-06, 'epoch': 2.69}
{'loss': 0.0958, 'grad_norm': 0.08949902653694153, 'learning_rate': 6.507633587786259e-06, 'epoch': 2.7}
{'loss': 0.1455, 'grad_norm': 4.653984546661377, 'learning_rate': 6.488549618320611e-06

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.3627554476261139, 'eval_accuracy': 0.9114583333333334, 'eval_runtime': 145.7074, 'eval_samples_per_second': 5.271, 'eval_steps_per_second': 0.329, 'epoch': 2.86}
{'loss': 0.2213, 'grad_norm': 0.03610485419631004, 'learning_rate': 5.66793893129771e-06, 'epoch': 2.87}
{'loss': 0.144, 'grad_norm': 0.826445460319519, 'learning_rate': 5.648854961832062e-06, 'epoch': 2.87}
{'loss': 0.1305, 'grad_norm': 0.1436082422733307, 'learning_rate': 5.6297709923664126e-06, 'epoch': 2.87}
{'loss': 0.0791, 'grad_norm': 35.430938720703125, 'learning_rate': 5.610687022900764e-06, 'epoch': 2.88}
{'loss': 0.1618, 'grad_norm': 6.144839763641357, 'learning_rate': 5.591603053435115e-06, 'epoch': 2.88}
{'loss': 0.1649, 'grad_norm': 0.17348115146160126, 'learning_rate': 5.572519083969467e-06, 'epoch': 2.89}
{'loss': 0.1822, 'grad_norm': 3.0512537956237793, 'learning_rate': 5.553435114503817e-06, 'epoch': 2.89}
{'loss': 0.1187, 'grad_norm': 9.641916275024414, 'learning_rate': 5.534351145038168e-06,

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.558494508266449, 'eval_accuracy': 0.86328125, 'eval_runtime': 39.2216, 'eval_samples_per_second': 19.581, 'eval_steps_per_second': 1.224, 'epoch': 3.05}
{'loss': 0.0085, 'grad_norm': 0.03301946446299553, 'learning_rate': 4.713740458015267e-06, 'epoch': 3.06}
{'loss': 0.0369, 'grad_norm': 0.08191090077161789, 'learning_rate': 4.694656488549618e-06, 'epoch': 3.06}
{'loss': 0.005, 'grad_norm': 1.3144664764404297, 'learning_rate': 4.6755725190839695e-06, 'epoch': 3.06}
{'loss': 0.0054, 'grad_norm': 0.046782322227954865, 'learning_rate': 4.656488549618321e-06, 'epoch': 3.07}
{'loss': 0.0904, 'grad_norm': 10.398768424987793, 'learning_rate': 4.6374045801526726e-06, 'epoch': 3.07}
{'loss': 0.067, 'grad_norm': 0.029651924967765808, 'learning_rate': 4.618320610687023e-06, 'epoch': 3.08}
{'loss': 0.0533, 'grad_norm': 0.09175103157758713, 'learning_rate': 4.599236641221375e-06, 'epoch': 3.08}
{'loss': 0.169, 'grad_norm': 0.8410604000091553, 'learning_rate': 4.580152671755725e-06, 

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.349542498588562, 'eval_accuracy': 0.9166666666666666, 'eval_runtime': 40.7596, 'eval_samples_per_second': 18.842, 'eval_steps_per_second': 1.178, 'epoch': 3.24}
{'loss': 0.0043, 'grad_norm': 0.048534225672483444, 'learning_rate': 3.7595419847328245e-06, 'epoch': 3.25}
{'loss': 0.0049, 'grad_norm': 2.211925983428955, 'learning_rate': 3.740458015267176e-06, 'epoch': 3.25}
{'loss': 0.2227, 'grad_norm': 0.04052073508501053, 'learning_rate': 3.7213740458015268e-06, 'epoch': 3.26}
{'loss': 0.0495, 'grad_norm': 0.03189254552125931, 'learning_rate': 3.7022900763358783e-06, 'epoch': 3.26}
{'loss': 0.0429, 'grad_norm': 0.03395929932594299, 'learning_rate': 3.683206106870229e-06, 'epoch': 3.26}
{'loss': 0.135, 'grad_norm': 17.017671585083008, 'learning_rate': 3.6641221374045806e-06, 'epoch': 3.27}
{'loss': 0.0257, 'grad_norm': 0.046125736087560654, 'learning_rate': 3.6450381679389317e-06, 'epoch': 3.27}
{'loss': 0.0849, 'grad_norm': 0.09006045758724213, 'learning_rate': 3.62595419

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.3506127893924713, 'eval_accuracy': 0.9140625, 'eval_runtime': 39.2158, 'eval_samples_per_second': 19.584, 'eval_steps_per_second': 1.224, 'epoch': 3.44}
{'loss': 0.0549, 'grad_norm': 0.07913768291473389, 'learning_rate': 2.805343511450382e-06, 'epoch': 3.44}
{'loss': 0.0256, 'grad_norm': 0.02899402752518654, 'learning_rate': 2.7862595419847334e-06, 'epoch': 3.44}
{'loss': 0.0045, 'grad_norm': 0.04354069381952286, 'learning_rate': 2.767175572519084e-06, 'epoch': 3.45}
{'loss': 0.0602, 'grad_norm': 0.0358809158205986, 'learning_rate': 2.7480916030534356e-06, 'epoch': 3.45}
{'loss': 0.0659, 'grad_norm': 4.546030521392822, 'learning_rate': 2.7290076335877863e-06, 'epoch': 3.45}
{'loss': 0.0751, 'grad_norm': 0.039456285536289215, 'learning_rate': 2.709923664122138e-06, 'epoch': 3.46}
{'loss': 0.0814, 'grad_norm': 0.09595288336277008, 'learning_rate': 2.6908396946564886e-06, 'epoch': 3.46}
{'loss': 0.0135, 'grad_norm': 0.03707500547170639, 'learning_rate': 2.67175572519084e-0

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.37487515807151794, 'eval_accuracy': 0.9114583333333334, 'eval_runtime': 76.893, 'eval_samples_per_second': 9.988, 'eval_steps_per_second': 0.624, 'epoch': 3.63}
{'loss': 0.0041, 'grad_norm': 0.04281257838010788, 'learning_rate': 1.8511450381679392e-06, 'epoch': 3.63}
{'loss': 0.027, 'grad_norm': 0.8221969604492188, 'learning_rate': 1.8320610687022903e-06, 'epoch': 3.63}
{'loss': 0.0273, 'grad_norm': 0.026127588003873825, 'learning_rate': 1.8129770992366414e-06, 'epoch': 3.64}
{'loss': 0.2178, 'grad_norm': 8.104713439941406, 'learning_rate': 1.7938931297709925e-06, 'epoch': 3.64}
{'loss': 0.0035, 'grad_norm': 0.024140330031514168, 'learning_rate': 1.7748091603053436e-06, 'epoch': 3.65}
{'loss': 0.0893, 'grad_norm': 1.6791025400161743, 'learning_rate': 1.7557251908396948e-06, 'epoch': 3.65}
{'loss': 0.1043, 'grad_norm': 0.22810927033424377, 'learning_rate': 1.736641221374046e-06, 'epoch': 3.65}
{'loss': 0.1035, 'grad_norm': 0.0848689004778862, 'learning_rate': 1.717557251

  0%|          | 0/48 [00:00<?, ?it/s]

{'eval_loss': 0.4121273458003998, 'eval_accuracy': 0.9010416666666666, 'eval_runtime': 70.8672, 'eval_samples_per_second': 10.837, 'eval_steps_per_second': 0.677, 'epoch': 3.82}
{'loss': 0.0706, 'grad_norm': 0.02534552291035652, 'learning_rate': 8.969465648854963e-07, 'epoch': 3.82}
{'loss': 0.1329, 'grad_norm': 0.040772199630737305, 'learning_rate': 8.778625954198474e-07, 'epoch': 3.82}
{'loss': 0.0526, 'grad_norm': 0.0337890163064003, 'learning_rate': 8.587786259541986e-07, 'epoch': 3.83}
{'loss': 0.0036, 'grad_norm': 0.038144223392009735, 'learning_rate': 8.396946564885497e-07, 'epoch': 3.83}
{'loss': 0.1351, 'grad_norm': 0.027762508019804955, 'learning_rate': 8.206106870229009e-07, 'epoch': 3.84}
{'loss': 0.003, 'grad_norm': 0.031778834760189056, 'learning_rate': 8.01526717557252e-07, 'epoch': 3.84}
{'loss': 0.0065, 'grad_norm': 0.11197499185800552, 'learning_rate': 7.824427480916032e-07, 'epoch': 3.84}
{'loss': 0.1622, 'grad_norm': 0.0315922275185585, 'learning_rate': 7.6335877862

TrainOutput(global_step=10480, training_loss=0.23013947155212855, metrics={'train_runtime': 22792.3306, 'train_samples_per_second': 3.678, 'train_steps_per_second': 0.46, 'total_flos': 6.496618441328935e+18, 'train_loss': 0.23013947155212855, 'epoch': 4.0})

In [None]:
trainer.save_model()
trainer.save_state()


In [226]:
prepared_val = ds['validation'].with_transform(preprocess)

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.3163492679595947,
 'eval_accuracy': 0.907185628742515,
 'eval_runtime': 70.2505,
 'eval_samples_per_second': 9.509,
 'eval_steps_per_second': 0.598,
 'epoch': 4.0}

: 