# Finetuning ResNet50 and ResNet as feature extractor


## Data

In [32]:
from datasets import load_dataset

cifar10dataset = load_dataset("cifar10", split="train")#train[:10000]

In [33]:
cifar10dataset = cifar10dataset.train_test_split(test_size=0.2)

In [34]:
labels = cifar10dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [35]:
from transformers import AutoImageProcessor

checkpoint = "microsoft/resnet-50"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [36]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [37]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["img"]]
    del examples["img"]
    return examples

In [38]:
cifar10dataset = cifar10dataset.with_transform(transforms)

In [39]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [40]:
import evaluate

accuracy = evaluate.load("accuracy")

In [41]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

## Finetuning the ResNet50 

In [42]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    ignore_mismatched_sizes=True,
    id2label=id2label,
    label2id=label2id,
)

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([10, 2048]) in the model instantiated
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [43]:
training_args = TrainingArguments(
    output_dir="cifar10_resnet50",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8, # memory error with 16
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=8,
    num_train_epochs=6,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=cifar10dataset["train"],
    eval_dataset=cifar10dataset["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()



  0%|          | 0/7500 [00:00<?, ?it/s]

{'loss': 2.3071, 'learning_rate': 6.666666666666667e-07, 'epoch': 0.01}
{'loss': 2.3077, 'learning_rate': 1.3333333333333334e-06, 'epoch': 0.02}
{'loss': 2.3074, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.02}
{'loss': 2.3116, 'learning_rate': 2.666666666666667e-06, 'epoch': 0.03}
{'loss': 2.3091, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.04}
{'loss': 2.3106, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.05}
{'loss': 2.3098, 'learning_rate': 4.666666666666667e-06, 'epoch': 0.06}
{'loss': 2.3096, 'learning_rate': 5.333333333333334e-06, 'epoch': 0.06}
{'loss': 2.3107, 'learning_rate': 6e-06, 'epoch': 0.07}
{'loss': 2.3091, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.08}
{'loss': 2.3125, 'learning_rate': 7.333333333333334e-06, 'epoch': 0.09}
{'loss': 2.3058, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.1}
{'loss': 2.3084, 'learning_rate': 8.666666666666668e-06, 'epoch': 0.1}
{'loss': 2.2985, 'learning_rate': 9.333333333333334e-06, 'epoch': 0.11}
{'loss'

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 1.141909122467041, 'eval_accuracy': 0.6282, 'eval_runtime': 56.076, 'eval_samples_per_second': 178.33, 'eval_steps_per_second': 22.291, 'epoch': 1.0}
{'loss': 1.3675, 'learning_rate': 4.6222222222222224e-05, 'epoch': 1.01}
{'loss': 1.377, 'learning_rate': 4.6148148148148154e-05, 'epoch': 1.02}
{'loss': 1.3174, 'learning_rate': 4.607407407407408e-05, 'epoch': 1.02}
{'loss': 1.4117, 'learning_rate': 4.600000000000001e-05, 'epoch': 1.03}
{'loss': 1.3186, 'learning_rate': 4.592592592592593e-05, 'epoch': 1.04}
{'loss': 1.35, 'learning_rate': 4.585185185185185e-05, 'epoch': 1.05}
{'loss': 1.3533, 'learning_rate': 4.577777777777778e-05, 'epoch': 1.06}
{'loss': 1.2767, 'learning_rate': 4.5703703703703706e-05, 'epoch': 1.06}
{'loss': 1.3315, 'learning_rate': 4.5629629629629636e-05, 'epoch': 1.07}
{'loss': 1.3054, 'learning_rate': 4.555555555555556e-05, 'epoch': 1.08}
{'loss': 1.1898, 'learning_rate': 4.548148148148149e-05, 'epoch': 1.09}
{'loss': 1.2954, 'learning_rate': 4.5407407

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 0.7580733299255371, 'eval_accuracy': 0.7475, 'eval_runtime': 56.3037, 'eval_samples_per_second': 177.608, 'eval_steps_per_second': 22.201, 'epoch': 2.0}
{'loss': 0.9975, 'learning_rate': 3.6962962962962966e-05, 'epoch': 2.01}
{'loss': 0.9904, 'learning_rate': 3.688888888888889e-05, 'epoch': 2.02}
{'loss': 1.0231, 'learning_rate': 3.681481481481482e-05, 'epoch': 2.02}
{'loss': 1.0381, 'learning_rate': 3.674074074074074e-05, 'epoch': 2.03}
{'loss': 1.0427, 'learning_rate': 3.6666666666666666e-05, 'epoch': 2.04}
{'loss': 0.9574, 'learning_rate': 3.6592592592592596e-05, 'epoch': 2.05}
{'loss': 0.9486, 'learning_rate': 3.651851851851852e-05, 'epoch': 2.06}
{'loss': 0.8803, 'learning_rate': 3.644444444444445e-05, 'epoch': 2.06}
{'loss': 1.0585, 'learning_rate': 3.637037037037037e-05, 'epoch': 2.07}
{'loss': 0.9199, 'learning_rate': 3.62962962962963e-05, 'epoch': 2.08}
{'loss': 1.02, 'learning_rate': 3.6222222222222225e-05, 'epoch': 2.09}
{'loss': 0.9376, 'learning_rate': 3.6148

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 0.702033281326294, 'eval_accuracy': 0.7617, 'eval_runtime': 55.8385, 'eval_samples_per_second': 179.088, 'eval_steps_per_second': 22.386, 'epoch': 3.0}
{'loss': 0.9303, 'learning_rate': 2.7703703703703706e-05, 'epoch': 3.01}
{'loss': 0.8238, 'learning_rate': 2.7629629629629632e-05, 'epoch': 3.02}
{'loss': 0.9016, 'learning_rate': 2.7555555555555555e-05, 'epoch': 3.02}
{'loss': 0.9564, 'learning_rate': 2.7481481481481482e-05, 'epoch': 3.03}
{'loss': 0.8685, 'learning_rate': 2.7407407407407408e-05, 'epoch': 3.04}
{'loss': 1.0421, 'learning_rate': 2.733333333333333e-05, 'epoch': 3.05}
{'loss': 0.9547, 'learning_rate': 2.725925925925926e-05, 'epoch': 3.06}
{'loss': 0.8403, 'learning_rate': 2.7185185185185184e-05, 'epoch': 3.06}
{'loss': 0.929, 'learning_rate': 2.7111111111111114e-05, 'epoch': 3.07}
{'loss': 0.8283, 'learning_rate': 2.7037037037037037e-05, 'epoch': 3.08}
{'loss': 0.8618, 'learning_rate': 2.696296296296296e-05, 'epoch': 3.09}
{'loss': 0.8765, 'learning_rate': 2

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 0.6384262442588806, 'eval_accuracy': 0.7813, 'eval_runtime': 55.4696, 'eval_samples_per_second': 180.279, 'eval_steps_per_second': 22.535, 'epoch': 4.0}
{'loss': 0.7929, 'learning_rate': 1.8444444444444445e-05, 'epoch': 4.01}
{'loss': 0.8662, 'learning_rate': 1.837037037037037e-05, 'epoch': 4.02}
{'loss': 0.9287, 'learning_rate': 1.8296296296296298e-05, 'epoch': 4.02}
{'loss': 0.7372, 'learning_rate': 1.8222222222222224e-05, 'epoch': 4.03}
{'loss': 0.7794, 'learning_rate': 1.814814814814815e-05, 'epoch': 4.04}
{'loss': 0.7836, 'learning_rate': 1.8074074074074074e-05, 'epoch': 4.05}
{'loss': 0.8387, 'learning_rate': 1.8e-05, 'epoch': 4.06}
{'loss': 0.8104, 'learning_rate': 1.7925925925925927e-05, 'epoch': 4.06}
{'loss': 0.8168, 'learning_rate': 1.7851851851851853e-05, 'epoch': 4.07}
{'loss': 0.8633, 'learning_rate': 1.777777777777778e-05, 'epoch': 4.08}
{'loss': 0.8798, 'learning_rate': 1.7703703703703706e-05, 'epoch': 4.09}
{'loss': 0.9008, 'learning_rate': 1.762962962962

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 0.6403951048851013, 'eval_accuracy': 0.7818, 'eval_runtime': 55.5664, 'eval_samples_per_second': 179.965, 'eval_steps_per_second': 22.496, 'epoch': 5.0}
{'loss': 0.9541, 'learning_rate': 9.185185185185186e-06, 'epoch': 5.01}
{'loss': 0.7303, 'learning_rate': 9.111111111111112e-06, 'epoch': 5.02}
{'loss': 0.769, 'learning_rate': 9.037037037037037e-06, 'epoch': 5.02}
{'loss': 0.8642, 'learning_rate': 8.962962962962963e-06, 'epoch': 5.03}
{'loss': 0.8983, 'learning_rate': 8.88888888888889e-06, 'epoch': 5.04}
{'loss': 0.8876, 'learning_rate': 8.814814814814815e-06, 'epoch': 5.05}
{'loss': 0.8101, 'learning_rate': 8.740740740740741e-06, 'epoch': 5.06}
{'loss': 0.7359, 'learning_rate': 8.666666666666668e-06, 'epoch': 5.06}
{'loss': 0.7751, 'learning_rate': 8.592592592592593e-06, 'epoch': 5.07}
{'loss': 0.6771, 'learning_rate': 8.518518518518519e-06, 'epoch': 5.08}
{'loss': 0.8222, 'learning_rate': 8.444444444444446e-06, 'epoch': 5.09}
{'loss': 0.8365, 'learning_rate': 8.3703703

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 0.6169978380203247, 'eval_accuracy': 0.7932, 'eval_runtime': 56.0714, 'eval_samples_per_second': 178.344, 'eval_steps_per_second': 22.293, 'epoch': 6.0}
{'train_runtime': 4375.8577, 'train_samples_per_second': 54.846, 'train_steps_per_second': 1.714, 'train_loss': 1.103211403465271, 'epoch': 6.0}


TrainOutput(global_step=7500, training_loss=1.103211403465271, metrics={'train_runtime': 4375.8577, 'train_samples_per_second': 54.846, 'train_steps_per_second': 1.714, 'train_loss': 1.103211403465271, 'epoch': 6.0})

Model is learning as the loss is decreasing, however very slowly. why is that, and how we can make the learning faster?

## ResNet as fixed feature extractor


In [29]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
model_1 = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    ignore_mismatched_sizes=True,
    id2label=id2label,
    label2id=label2id,
)

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([10, 2048]) in the model instantiated
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [30]:

# Freeze all layers except the classification head (modify as needed for specific layers)
for name, param in model_1.named_parameters():
    if 'classifier' not in name: # Assuming 'classifier' is part of the classification head
        param.requires_grad = False

In [31]:
training_args = TrainingArguments(
    output_dir="cifar10_resnet50_1",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8, # memory error with 16
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=8,
    num_train_epochs=6,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)


trainer = Trainer(
    model=model_1,
    args=training_args,
    data_collator=data_collator,
    train_dataset=cifar10dataset["train"],
    eval_dataset=cifar10dataset["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()




  0%|          | 0/7500 [00:00<?, ?it/s]

{'loss': 2.2968, 'learning_rate': 6.666666666666667e-07, 'epoch': 0.01}
{'loss': 2.3055, 'learning_rate': 1.3333333333333334e-06, 'epoch': 0.02}
{'loss': 2.3025, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.02}
{'loss': 2.3062, 'learning_rate': 2.666666666666667e-06, 'epoch': 0.03}
{'loss': 2.3043, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.04}
{'loss': 2.3025, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.05}
{'loss': 2.3061, 'learning_rate': 4.666666666666667e-06, 'epoch': 0.06}
{'loss': 2.304, 'learning_rate': 5.333333333333334e-06, 'epoch': 0.06}
{'loss': 2.3041, 'learning_rate': 6e-06, 'epoch': 0.07}
{'loss': 2.3061, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.08}
{'loss': 2.3024, 'learning_rate': 7.333333333333334e-06, 'epoch': 0.09}
{'loss': 2.3021, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.1}
{'loss': 2.301, 'learning_rate': 8.666666666666668e-06, 'epoch': 0.1}
{'loss': 2.3033, 'learning_rate': 9.333333333333334e-06, 'epoch': 0.11}
{'loss': 

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 2.160546064376831, 'eval_accuracy': 0.3842, 'eval_runtime': 56.4834, 'eval_samples_per_second': 177.043, 'eval_steps_per_second': 22.13, 'epoch': 1.0}
{'loss': 2.1336, 'learning_rate': 4.6222222222222224e-05, 'epoch': 1.01}
{'loss': 2.1436, 'learning_rate': 4.6148148148148154e-05, 'epoch': 1.02}
{'loss': 2.1296, 'learning_rate': 4.607407407407408e-05, 'epoch': 1.02}
{'loss': 2.1527, 'learning_rate': 4.600000000000001e-05, 'epoch': 1.03}
{'loss': 2.139, 'learning_rate': 4.592592592592593e-05, 'epoch': 1.04}
{'loss': 2.1323, 'learning_rate': 4.585185185185185e-05, 'epoch': 1.05}
{'loss': 2.1567, 'learning_rate': 4.577777777777778e-05, 'epoch': 1.06}
{'loss': 2.1369, 'learning_rate': 4.5703703703703706e-05, 'epoch': 1.06}
{'loss': 2.1335, 'learning_rate': 4.5629629629629636e-05, 'epoch': 1.07}
{'loss': 2.1388, 'learning_rate': 4.555555555555556e-05, 'epoch': 1.08}
{'loss': 2.1274, 'learning_rate': 4.548148148148149e-05, 'epoch': 1.09}
{'loss': 2.1309, 'learning_rate': 4.5407

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 2.0584723949432373, 'eval_accuracy': 0.4447, 'eval_runtime': 56.2005, 'eval_samples_per_second': 177.934, 'eval_steps_per_second': 22.242, 'epoch': 2.0}
{'loss': 2.0134, 'learning_rate': 3.6962962962962966e-05, 'epoch': 2.01}
{'loss': 2.0281, 'learning_rate': 3.688888888888889e-05, 'epoch': 2.02}
{'loss': 2.0334, 'learning_rate': 3.681481481481482e-05, 'epoch': 2.02}
{'loss': 2.0297, 'learning_rate': 3.674074074074074e-05, 'epoch': 2.03}
{'loss': 2.0092, 'learning_rate': 3.6666666666666666e-05, 'epoch': 2.04}
{'loss': 2.0119, 'learning_rate': 3.6592592592592596e-05, 'epoch': 2.05}
{'loss': 2.0225, 'learning_rate': 3.651851851851852e-05, 'epoch': 2.06}
{'loss': 1.9799, 'learning_rate': 3.644444444444445e-05, 'epoch': 2.06}
{'loss': 2.0333, 'learning_rate': 3.637037037037037e-05, 'epoch': 2.07}
{'loss': 1.9759, 'learning_rate': 3.62962962962963e-05, 'epoch': 2.08}
{'loss': 2.0182, 'learning_rate': 3.6222222222222225e-05, 'epoch': 2.09}
{'loss': 2.006, 'learning_rate': 3.614

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 1.957870364189148, 'eval_accuracy': 0.4523, 'eval_runtime': 56.7458, 'eval_samples_per_second': 176.225, 'eval_steps_per_second': 22.028, 'epoch': 3.0}
{'loss': 1.9551, 'learning_rate': 2.7703703703703706e-05, 'epoch': 3.01}
{'loss': 1.9506, 'learning_rate': 2.7629629629629632e-05, 'epoch': 3.02}
{'loss': 1.9502, 'learning_rate': 2.7555555555555555e-05, 'epoch': 3.02}
{'loss': 1.962, 'learning_rate': 2.7481481481481482e-05, 'epoch': 3.03}
{'loss': 1.9069, 'learning_rate': 2.7407407407407408e-05, 'epoch': 3.04}
{'loss': 1.9632, 'learning_rate': 2.733333333333333e-05, 'epoch': 3.05}
{'loss': 1.9374, 'learning_rate': 2.725925925925926e-05, 'epoch': 3.06}
{'loss': 1.9501, 'learning_rate': 2.7185185185185184e-05, 'epoch': 3.06}
{'loss': 1.9146, 'learning_rate': 2.7111111111111114e-05, 'epoch': 3.07}
{'loss': 1.9416, 'learning_rate': 2.7037037037037037e-05, 'epoch': 3.08}
{'loss': 1.9103, 'learning_rate': 2.696296296296296e-05, 'epoch': 3.09}
{'loss': 1.9517, 'learning_rate': 2

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 1.9864314794540405, 'eval_accuracy': 0.4447, 'eval_runtime': 56.1359, 'eval_samples_per_second': 178.139, 'eval_steps_per_second': 22.267, 'epoch': 4.0}
{'loss': 1.8758, 'learning_rate': 1.8444444444444445e-05, 'epoch': 4.01}
{'loss': 1.9487, 'learning_rate': 1.837037037037037e-05, 'epoch': 4.02}
{'loss': 1.9497, 'learning_rate': 1.8296296296296298e-05, 'epoch': 4.02}
{'loss': 1.8797, 'learning_rate': 1.8222222222222224e-05, 'epoch': 4.03}
{'loss': 1.8925, 'learning_rate': 1.814814814814815e-05, 'epoch': 4.04}
{'loss': 1.9163, 'learning_rate': 1.8074074074074074e-05, 'epoch': 4.05}
{'loss': 1.87, 'learning_rate': 1.8e-05, 'epoch': 4.06}
{'loss': 1.8827, 'learning_rate': 1.7925925925925927e-05, 'epoch': 4.06}
{'loss': 1.8604, 'learning_rate': 1.7851851851851853e-05, 'epoch': 4.07}
{'loss': 1.8898, 'learning_rate': 1.777777777777778e-05, 'epoch': 4.08}
{'loss': 1.9062, 'learning_rate': 1.7703703703703706e-05, 'epoch': 4.09}
{'loss': 1.9104, 'learning_rate': 1.76296296296296

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 1.8676360845565796, 'eval_accuracy': 0.4645, 'eval_runtime': 56.32, 'eval_samples_per_second': 177.557, 'eval_steps_per_second': 22.195, 'epoch': 5.0}
{'loss': 1.8967, 'learning_rate': 9.185185185185186e-06, 'epoch': 5.01}
{'loss': 1.8254, 'learning_rate': 9.111111111111112e-06, 'epoch': 5.02}
{'loss': 1.9176, 'learning_rate': 9.037037037037037e-06, 'epoch': 5.02}
{'loss': 1.9124, 'learning_rate': 8.962962962962963e-06, 'epoch': 5.03}
{'loss': 1.9282, 'learning_rate': 8.88888888888889e-06, 'epoch': 5.04}
{'loss': 1.8924, 'learning_rate': 8.814814814814815e-06, 'epoch': 5.05}
{'loss': 1.862, 'learning_rate': 8.740740740740741e-06, 'epoch': 5.06}
{'loss': 1.8687, 'learning_rate': 8.666666666666668e-06, 'epoch': 5.06}
{'loss': 1.853, 'learning_rate': 8.592592592592593e-06, 'epoch': 5.07}
{'loss': 1.8838, 'learning_rate': 8.518518518518519e-06, 'epoch': 5.08}
{'loss': 1.8812, 'learning_rate': 8.444444444444446e-06, 'epoch': 5.09}
{'loss': 1.8447, 'learning_rate': 8.3703703703

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 1.8585394620895386, 'eval_accuracy': 0.462, 'eval_runtime': 56.2974, 'eval_samples_per_second': 177.628, 'eval_steps_per_second': 22.204, 'epoch': 6.0}
{'train_runtime': 1817.5264, 'train_samples_per_second': 132.048, 'train_steps_per_second': 4.126, 'train_loss': 1.9963963946024577, 'epoch': 6.0}


TrainOutput(global_step=7500, training_loss=1.9963963946024577, metrics={'train_runtime': 1817.5264, 'train_samples_per_second': 132.048, 'train_steps_per_second': 4.126, 'train_loss': 1.9963963946024577, 'epoch': 6.0})