In [1]:
import sys
sys.path.append('..')

import random

import numpy as np
import pandas as pd
import torch
import transformers

from transformers import PatchTSTConfig, Trainer, TrainingArguments, EarlyStoppingCallback
from transformers import PatchTSTForPrediction as PatchTSTForPredictionOG
from TimeSeriesJEPA.models import PatchTSTModelJEPA, PatchTSTForPrediction
from tsfm_public.toolkit.dataset import ForecastDFDataset
from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor
from tsfm_public.toolkit.util import select_by_index

In [11]:
SEED = 42
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

dataset = "ETTh2"
num_workers = 4  # Reduce this if you have low number of CPU cores
batch_size = 32  # Reduce if not enough GPU memory available
context_length = 512
forecast_horizon = 96
patch_length = 8

print(f"Loading target dataset: {dataset}")
dataset_path = r"D:\Coursework\MTS\dataset\ETT-small\ETTh2.csv"
timestamp_column = "date"
id_columns = []
forecast_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
train_start_index = None  # None indicates beginning of dataset
train_end_index = 12 * 30 * 24

# we shift the start of the validation/test period back by context length so that
# the first validation/test timestamp is immediately following the training data
valid_start_index = 12 * 30 * 24 - context_length
valid_end_index = 12 * 30 * 24 + 4 * 30 * 24

test_start_index = 12 * 30 * 24 + 4 * 30 * 24 - context_length
test_end_index = 12 * 30 * 24 + 8 * 30 * 24

Loading target dataset: ETTh2


In [12]:
data = pd.read_csv(
    dataset_path,
    parse_dates=[timestamp_column],
)

train_data = select_by_index(
    data,
    id_columns=id_columns,
    start_index=train_start_index,
    end_index=train_end_index,
)
valid_data = select_by_index(
    data,
    id_columns=id_columns,
    start_index=valid_start_index,
    end_index=valid_end_index,
)
test_data = select_by_index(
    data,
    id_columns=id_columns,
    start_index=test_start_index,
    end_index=test_end_index,
)

tsp = TimeSeriesPreprocessor(
    timestamp_column=timestamp_column,
    id_columns=id_columns,
    target_columns=forecast_columns,
    scaling=True,
)
tsp.train(train_data)

TimeSeriesPreprocessor {
  "categorical_encoder": null,
  "conditional_columns": [],
  "context_length": 64,
  "control_columns": [],
  "encode_categorical": true,
  "feature_extractor_type": "TimeSeriesPreprocessor",
  "freq": "0 days 01:00:00",
  "frequency_mapping": {
    "10_minutes": 3,
    "15_minutes": 4,
    "half_hourly": 1,
    "hourly": 2,
    "oov": 0
  },
  "id_columns": [],
  "observable_columns": [],
  "prediction_length": null,
  "processor_class": "TimeSeriesPreprocessor",
  "scaler_dict": {},
  "scaler_type": "standard",
  "scaling": true,
  "scaling_id_columns": [],
  "static_categorical_columns": [],
  "target_columns": [
    "HUFL",
    "HULL",
    "MUFL",
    "MULL",
    "LUFL",
    "LULL",
    "OT"
  ],
  "target_scaler_dict": {
    "0": {
      "copy": true,
      "feature_names_in_": [
        "HUFL",
        "HULL",
        "MUFL",
        "MULL",
        "LUFL",
        "LULL",
        "OT"
      ],
      "mean_": [
        41.53683496078959,
        12.27345

In [13]:
train_dataset = ForecastDFDataset(
    tsp.preprocess(train_data),
    id_columns=id_columns,
    target_columns=forecast_columns,
    context_length=context_length,
    prediction_length=forecast_horizon,
)
valid_dataset = ForecastDFDataset(
    tsp.preprocess(valid_data),
    id_columns=id_columns,
    target_columns=forecast_columns,
    context_length=context_length,
    prediction_length=forecast_horizon,
)
test_dataset = ForecastDFDataset(
    tsp.preprocess(test_data),
    id_columns=id_columns,
    target_columns=forecast_columns,
    context_length=context_length,
    prediction_length=forecast_horizon,
)

In [33]:
print("Loading prediction model")

config = PatchTSTConfig(
    do_mask_input=False,
    context_length=context_length,
    patch_length=patch_length,
    num_input_channels=len(forecast_columns),
    patch_stride=8,
    prediction_length=forecast_horizon,
    d_model=64,
    num_attention_heads=4,
    num_hidden_layers=4,
    ffn_dim=128,
    dropout=0.05,
    head_dropout=0.0,
    pooling_type=None,
    channel_attention=False,
    scaling="std",
    loss="mse",
    pre_norm=True,
    norm_type="batchnorm",
    positional_encoding_type = "sincos"
)

model = PatchTSTForPredictionOG(config=config)

Loading prediction model


In [34]:
model.cuda()

PatchTSTForPrediction(
  (model): PatchTSTModel(
    (scaler): PatchTSTScaler(
      (scaler): PatchTSTStdScaler()
    )
    (patchifier): PatchTSTPatchify()
    (masking): Identity()
    (encoder): PatchTSTEncoder(
      (embedder): PatchTSTEmbedding(
        (input_embedding): Linear(in_features=8, out_features=64, bias=True)
      )
      (positional_encoder): PatchTSTPositionalEncoding(
        (positional_dropout): Identity()
      )
      (layers): ModuleList(
        (0-3): 4 x PatchTSTEncoderLayer(
          (self_attn): PatchTSTAttention(
            (k_proj): Linear(in_features=64, out_features=64, bias=True)
            (v_proj): Linear(in_features=64, out_features=64, bias=True)
            (q_proj): Linear(in_features=64, out_features=64, bias=True)
            (out_proj): Linear(in_features=64, out_features=64, bias=True)
          )
          (dropout_path1): Identity()
          (norm_sublayer1): PatchTSTBatchNorm(
            (batchnorm): BatchNorm1d(64, eps=1e-05, mom

In [35]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model.parameters()])
print("encoder parameters: ", params)

encoder parameters:  531872


In [36]:
train_args = TrainingArguments(
    output_dir=r"checkpoints\finetuned_og",
    overwrite_output_dir=True,
    learning_rate=0.0001,
    num_train_epochs=30,
    do_eval=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    dataloader_num_workers=1,  # num_workers,
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=3,
    logging_dir=r"checkpoints\finetuned_og\logs",
    load_best_model_at_end=True,  # Load the best model when training ends
    metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
    greater_is_better=False,  # For loss
    label_names=["future_values"],
)

# Create a new early stopping callback with faster convergence properties
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=15,  # Number of epochs with no improvement after which to stop
    early_stopping_threshold=0.001,  # Minimum improvement required to consider as improvement
)

trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    callbacks=[early_stopping_callback],
)



In [37]:
print("\n\nDoing forecasting training")
trainer.train()



Doing forecasting training


  0%|          | 0/7560 [00:00<?, ?it/s]

{'loss': 0.4547, 'grad_norm': 0.7304683923721313, 'learning_rate': 9.666666666666667e-05, 'epoch': 1.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.22187064588069916, 'eval_runtime': 5.4788, 'eval_samples_per_second': 508.322, 'eval_steps_per_second': 16.062, 'epoch': 1.0}
{'loss': 0.4294, 'grad_norm': 56.91178512573242, 'learning_rate': 9.333333333333334e-05, 'epoch': 2.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.20501796901226044, 'eval_runtime': 5.4756, 'eval_samples_per_second': 508.62, 'eval_steps_per_second': 16.071, 'epoch': 2.0}
{'loss': 0.3575, 'grad_norm': 0.34758299589157104, 'learning_rate': 9e-05, 'epoch': 3.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.215131014585495, 'eval_runtime': 5.5319, 'eval_samples_per_second': 503.448, 'eval_steps_per_second': 15.908, 'epoch': 3.0}
{'loss': 0.3333, 'grad_norm': 1.2480909824371338, 'learning_rate': 8.666666666666667e-05, 'epoch': 4.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.21704818308353424, 'eval_runtime': 5.4449, 'eval_samples_per_second': 511.488, 'eval_steps_per_second': 16.162, 'epoch': 4.0}
{'loss': 0.3143, 'grad_norm': 0.6654150485992432, 'learning_rate': 8.333333333333334e-05, 'epoch': 5.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.21003295481204987, 'eval_runtime': 5.2603, 'eval_samples_per_second': 529.438, 'eval_steps_per_second': 16.729, 'epoch': 5.0}
{'loss': 0.2964, 'grad_norm': 5.50455904006958, 'learning_rate': 8e-05, 'epoch': 6.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.21859948337078094, 'eval_runtime': 5.4901, 'eval_samples_per_second': 507.281, 'eval_steps_per_second': 16.029, 'epoch': 6.0}
{'loss': 0.2787, 'grad_norm': 0.9314972162246704, 'learning_rate': 7.666666666666667e-05, 'epoch': 7.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.21238893270492554, 'eval_runtime': 5.3305, 'eval_samples_per_second': 522.465, 'eval_steps_per_second': 16.509, 'epoch': 7.0}
{'loss': 0.2688, 'grad_norm': 0.9810602068901062, 'learning_rate': 7.333333333333333e-05, 'epoch': 8.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.22930219769477844, 'eval_runtime': 5.3769, 'eval_samples_per_second': 517.955, 'eval_steps_per_second': 16.366, 'epoch': 8.0}
{'loss': 0.2609, 'grad_norm': 0.6435775756835938, 'learning_rate': 7e-05, 'epoch': 9.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.22837860882282257, 'eval_runtime': 5.3224, 'eval_samples_per_second': 523.257, 'eval_steps_per_second': 16.534, 'epoch': 9.0}
{'loss': 0.2533, 'grad_norm': 1.3154407739639282, 'learning_rate': 6.666666666666667e-05, 'epoch': 10.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.2197464555501938, 'eval_runtime': 5.2701, 'eval_samples_per_second': 528.451, 'eval_steps_per_second': 16.698, 'epoch': 10.0}
{'loss': 0.2483, 'grad_norm': 2.067758798599243, 'learning_rate': 6.333333333333333e-05, 'epoch': 11.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.22754190862178802, 'eval_runtime': 5.5293, 'eval_samples_per_second': 503.679, 'eval_steps_per_second': 15.915, 'epoch': 11.0}
{'loss': 0.2451, 'grad_norm': 1.3195611238479614, 'learning_rate': 6e-05, 'epoch': 12.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.21759918332099915, 'eval_runtime': 5.6251, 'eval_samples_per_second': 495.107, 'eval_steps_per_second': 15.644, 'epoch': 12.0}
{'loss': 0.2397, 'grad_norm': 2.0997567176818848, 'learning_rate': 5.666666666666667e-05, 'epoch': 13.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.2325953096151352, 'eval_runtime': 5.4756, 'eval_samples_per_second': 508.622, 'eval_steps_per_second': 16.071, 'epoch': 13.0}
{'loss': 0.2365, 'grad_norm': 0.8184136152267456, 'learning_rate': 5.333333333333333e-05, 'epoch': 14.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.24540932476520538, 'eval_runtime': 5.3396, 'eval_samples_per_second': 521.571, 'eval_steps_per_second': 16.481, 'epoch': 14.0}
{'loss': 0.2319, 'grad_norm': 0.48072633147239685, 'learning_rate': 5e-05, 'epoch': 15.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.26144689321517944, 'eval_runtime': 5.4743, 'eval_samples_per_second': 508.745, 'eval_steps_per_second': 16.075, 'epoch': 15.0}
{'loss': 0.2287, 'grad_norm': 3.7258567810058594, 'learning_rate': 4.666666666666667e-05, 'epoch': 16.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.23422588407993317, 'eval_runtime': 5.5892, 'eval_samples_per_second': 498.279, 'eval_steps_per_second': 15.745, 'epoch': 16.0}
{'loss': 0.2261, 'grad_norm': 3.5560038089752197, 'learning_rate': 4.3333333333333334e-05, 'epoch': 17.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.23462583124637604, 'eval_runtime': 5.2932, 'eval_samples_per_second': 526.148, 'eval_steps_per_second': 16.625, 'epoch': 17.0}
{'train_runtime': 282.6786, 'train_samples_per_second': 852.523, 'train_steps_per_second': 26.744, 'train_loss': 0.2884405494070187, 'epoch': 17.0}


TrainOutput(global_step=4284, training_loss=0.2884405494070187, metrics={'train_runtime': 282.6786, 'train_samples_per_second': 852.523, 'train_steps_per_second': 26.744, 'total_flos': 1561899434016768.0, 'train_loss': 0.2884405494070187, 'epoch': 17.0})

In [38]:
trainer.evaluate(test_dataset)

  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.286152720451355,
 'eval_runtime': 5.5476,
 'eval_samples_per_second': 502.015,
 'eval_steps_per_second': 15.863,
 'epoch': 17.0}

In [15]:
print("Loading pretrained encoder model")
encoder_model = PatchTSTModelJEPA.from_pretrained(r"D:\Coursework\MTS\timeseriesJEPA\results\PatchTST_Time300B_sl512_dm64_nh4_el3_fd64_bs256_lr0.0001\checkpoint-14494")
print("Done")
encoder_model.cuda()

Loading pretrained encoder model
Done


PatchTSTModelJEPA(
  (scaler): PatchTSTScaler(
    (scaler): PatchTSTStdScaler()
  )
  (patchifier): PatchTSTPatchify()
  (encoder): PatchTSTEncoder(
    (embedder): PatchTSTEmbedding(
      (input_embedding): Linear(in_features=8, out_features=64, bias=True)
    )
    (positional_encoder): PatchTSTPositionalEncoding(
      (positional_dropout): Identity()
    )
    (layers): ModuleList(
      (0-2): 3 x PatchTSTEncoderLayer(
        (self_attn): PatchTSTAttention(
          (k_proj): Linear(in_features=64, out_features=64, bias=True)
          (v_proj): Linear(in_features=64, out_features=64, bias=True)
          (q_proj): Linear(in_features=64, out_features=64, bias=True)
          (out_proj): Linear(in_features=64, out_features=64, bias=True)
        )
        (dropout_path1): Identity()
        (norm_sublayer1): PatchTSTBatchNorm(
          (batchnorm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (ff): Sequential(
          (0):

In [16]:
print("Loading prediction model")

config = PatchTSTConfig(
    do_mask_input=False,
    context_length=context_length,
    patch_length=patch_length,
    num_input_channels=len(forecast_columns),
    patch_stride=patch_length,
    prediction_length=forecast_horizon,
    d_model=64,
    num_attention_heads=4,
    # num_hidden_layers=4,
    ffn_dim=128,
    dropout=0.05,
    head_dropout=0.0,
    pooling_type=None,
    channel_attention=False,
    scaling="std",
    loss="mse",
    pre_norm=True,
    norm_type="batchnorm",
    positional_encoding_type = "sincos"
)

model = PatchTSTForPrediction(config=config, encoder_model=encoder_model)
model.cuda()

Loading prediction model


PatchTSTForPrediction(
  (model): PatchTSTModelJEPA(
    (scaler): PatchTSTScaler(
      (scaler): PatchTSTStdScaler()
    )
    (patchifier): PatchTSTPatchify()
    (encoder): PatchTSTEncoder(
      (embedder): PatchTSTEmbedding(
        (input_embedding): Linear(in_features=8, out_features=64, bias=True)
      )
      (positional_encoder): PatchTSTPositionalEncoding(
        (positional_dropout): Identity()
      )
      (layers): ModuleList(
        (0-2): 3 x PatchTSTEncoderLayer(
          (self_attn): PatchTSTAttention(
            (k_proj): Linear(in_features=64, out_features=64, bias=True)
            (v_proj): Linear(in_features=64, out_features=64, bias=True)
            (q_proj): Linear(in_features=64, out_features=64, bias=True)
            (out_proj): Linear(in_features=64, out_features=64, bias=True)
          )
          (dropout_path1): Identity()
          (norm_sublayer1): PatchTSTBatchNorm(
            (batchnorm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True

In [17]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("encoder parameters: ", params)

encoder parameters:  393312


In [18]:

params = sum([np.prod(p.size()) for p in encoder_model.parameters()])
print("encoder parameters: ", params)

encoder parameters:  80448


In [19]:
train_args = TrainingArguments(
    output_dir=r"checkpoints\finetuned",
    overwrite_output_dir=True,
    learning_rate=0.0001,
    num_train_epochs=30,
    do_eval=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    dataloader_num_workers=1,  # num_workers,
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=3,
    logging_dir=r"checkpoints\finetuned\logs",
    load_best_model_at_end=True,  # Load the best model when training ends
    metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
    greater_is_better=False,  # For loss
    label_names=["future_values"],
)

# Create a new early stopping callback with faster convergence properties
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=5,  # Number of epochs with no improvement after which to stop
    early_stopping_threshold=0.001,  # Minimum improvement required to consider as improvement
)

trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    callbacks=[early_stopping_callback],
)



In [20]:
print("\n\nDoing forecasting training")
trainer.train()





Doing forecasting training


[34m[1mwandb[0m: Currently logged in as: [33mvg2523[0m ([33mhpml_4[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  0%|          | 0/7560 [00:00<?, ?it/s]

{'loss': 0.5154, 'grad_norm': 2.2565743923187256, 'learning_rate': 9.666666666666667e-05, 'epoch': 1.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.2572374641895294, 'eval_runtime': 13.2845, 'eval_samples_per_second': 209.643, 'eval_steps_per_second': 6.624, 'epoch': 1.0}
{'loss': 0.472, 'grad_norm': 28.936128616333008, 'learning_rate': 9.333333333333334e-05, 'epoch': 2.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.2517375946044922, 'eval_runtime': 17.1383, 'eval_samples_per_second': 162.501, 'eval_steps_per_second': 5.135, 'epoch': 2.0}
{'loss': 0.414, 'grad_norm': 0.8931271433830261, 'learning_rate': 9e-05, 'epoch': 3.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.23434366285800934, 'eval_runtime': 15.651, 'eval_samples_per_second': 177.944, 'eval_steps_per_second': 5.623, 'epoch': 3.0}
{'loss': 0.4007, 'grad_norm': 1.4431954622268677, 'learning_rate': 8.666666666666667e-05, 'epoch': 4.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.2360202521085739, 'eval_runtime': 16.589, 'eval_samples_per_second': 167.883, 'eval_steps_per_second': 5.305, 'epoch': 4.0}
{'loss': 0.3909, 'grad_norm': 1.2231580018997192, 'learning_rate': 8.333333333333334e-05, 'epoch': 5.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.2290770411491394, 'eval_runtime': 18.2698, 'eval_samples_per_second': 152.437, 'eval_steps_per_second': 4.817, 'epoch': 5.0}
{'loss': 0.3824, 'grad_norm': 2.1518006324768066, 'learning_rate': 8e-05, 'epoch': 6.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.22815313935279846, 'eval_runtime': 16.8901, 'eval_samples_per_second': 164.89, 'eval_steps_per_second': 5.21, 'epoch': 6.0}
{'loss': 0.3765, 'grad_norm': 1.0791038274765015, 'learning_rate': 7.666666666666667e-05, 'epoch': 7.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.22582820057868958, 'eval_runtime': 14.5208, 'eval_samples_per_second': 191.793, 'eval_steps_per_second': 6.06, 'epoch': 7.0}
{'loss': 0.3698, 'grad_norm': 0.7359057664871216, 'learning_rate': 7.333333333333333e-05, 'epoch': 8.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.2309047430753708, 'eval_runtime': 16.2536, 'eval_samples_per_second': 171.347, 'eval_steps_per_second': 5.414, 'epoch': 8.0}
{'loss': 0.3674, 'grad_norm': 0.8775313496589661, 'learning_rate': 7e-05, 'epoch': 9.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.2255304455757141, 'eval_runtime': 16.8073, 'eval_samples_per_second': 165.701, 'eval_steps_per_second': 5.236, 'epoch': 9.0}
{'loss': 0.363, 'grad_norm': 0.9937266111373901, 'learning_rate': 6.666666666666667e-05, 'epoch': 10.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.22342795133590698, 'eval_runtime': 17.0761, 'eval_samples_per_second': 163.093, 'eval_steps_per_second': 5.153, 'epoch': 10.0}
{'loss': 0.3599, 'grad_norm': 1.534043312072754, 'learning_rate': 6.333333333333333e-05, 'epoch': 11.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.22504965960979462, 'eval_runtime': 17.6074, 'eval_samples_per_second': 158.172, 'eval_steps_per_second': 4.998, 'epoch': 11.0}
{'loss': 0.3572, 'grad_norm': 1.4988899230957031, 'learning_rate': 6e-05, 'epoch': 12.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.22256019711494446, 'eval_runtime': 16.3304, 'eval_samples_per_second': 170.541, 'eval_steps_per_second': 5.389, 'epoch': 12.0}
{'loss': 0.3539, 'grad_norm': 0.8637843132019043, 'learning_rate': 5.666666666666667e-05, 'epoch': 13.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.22500784695148468, 'eval_runtime': 16.5025, 'eval_samples_per_second': 168.762, 'eval_steps_per_second': 5.333, 'epoch': 13.0}
{'loss': 0.3522, 'grad_norm': 1.1694121360778809, 'learning_rate': 5.333333333333333e-05, 'epoch': 14.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.22453242540359497, 'eval_runtime': 16.3247, 'eval_samples_per_second': 170.6, 'eval_steps_per_second': 5.391, 'epoch': 14.0}
{'loss': 0.35, 'grad_norm': 0.783818244934082, 'learning_rate': 5e-05, 'epoch': 15.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.23019136488437653, 'eval_runtime': 14.0891, 'eval_samples_per_second': 197.67, 'eval_steps_per_second': 6.246, 'epoch': 15.0}
{'train_runtime': 676.3519, 'train_samples_per_second': 356.309, 'train_steps_per_second': 11.178, 'train_loss': 0.38834564975960545, 'epoch': 15.0}


TrainOutput(global_step=3780, training_loss=0.38834564975960545, metrics={'train_runtime': 676.3519, 'train_samples_per_second': 356.309, 'train_steps_per_second': 11.178, 'total_flos': 1227571133644800.0, 'train_loss': 0.38834564975960545, 'epoch': 15.0})

In [21]:
trainer.evaluate(test_dataset)

  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.29088908433914185,
 'eval_runtime': 13.4968,
 'eval_samples_per_second': 206.346,
 'eval_steps_per_second': 6.52,
 'epoch': 15.0}