In [1]:
import sys
sys.path.append('..')

import random

import numpy as np
import pandas as pd
import torch
import transformers

from transformers import PatchTSTConfig, Trainer, TrainingArguments, EarlyStoppingCallback
from transformers import PatchTSTForPrediction as PatchTSTForPredictionOG
from TimeSeriesJEPA.models import PatchTSTModelJEPA, PatchTSTForPrediction
from TimeSeriesJEPA.datasets.benchmark_dataset import BenchmarkDataset

In [12]:
csv_path = r"D:\Coursework\MTS\dataset\ETT-small\ETTh1.csv"
dataset = "ETTh2"
num_workers = 4  # Reduce this if you have low number of CPU cores
batch_size = 32  # Reduce if not enough GPU memory available
context_length = 512
forecast_horizon = 96
patch_length = 8
num_input_channels=7

In [4]:

trainwindowds = BenchmarkDataset(csv_path=csv_path, context_length=context_length, prediction_length=forecast_horizon, flag='train', returndict=True)
valwindowds = BenchmarkDataset(csv_path=csv_path, context_length=context_length, prediction_length=forecast_horizon, flag='test', returndict=True)
print("dataset loaded, total size: ", len(trainwindowds), len(valwindowds))

Total data size:  (17420, 7)
Total data size:  (17420, 7)
dataset loaded, total size:  8033 2785


In [13]:
print("Loading prediction model")

config = PatchTSTConfig(
    do_mask_input=False,
    context_length=context_length,
    patch_length=patch_length,
    num_input_channels=num_input_channels,
    patch_stride=8,
    prediction_length=forecast_horizon,
    d_model=64,
    num_attention_heads=4,
    num_hidden_layers=4,
    ffn_dim=128,
    dropout=0.05,
    head_dropout=0.2,
    pooling_type=None,
    channel_attention=False,
    scaling="std",
    loss="mse",
    pre_norm=True,
    norm_type="batchnorm",
    positional_encoding_type = "sincos"
)

model = PatchTSTForPredictionOG(config=config)

Loading prediction model


In [14]:
model.cuda()

PatchTSTForPrediction(
  (model): PatchTSTModel(
    (scaler): PatchTSTScaler(
      (scaler): PatchTSTStdScaler()
    )
    (patchifier): PatchTSTPatchify()
    (masking): Identity()
    (encoder): PatchTSTEncoder(
      (embedder): PatchTSTEmbedding(
        (input_embedding): Linear(in_features=8, out_features=64, bias=True)
      )
      (positional_encoder): PatchTSTPositionalEncoding(
        (positional_dropout): Identity()
      )
      (layers): ModuleList(
        (0-3): 4 x PatchTSTEncoderLayer(
          (self_attn): PatchTSTAttention(
            (k_proj): Linear(in_features=64, out_features=64, bias=True)
            (v_proj): Linear(in_features=64, out_features=64, bias=True)
            (q_proj): Linear(in_features=64, out_features=64, bias=True)
            (out_proj): Linear(in_features=64, out_features=64, bias=True)
          )
          (dropout_path1): Identity()
          (norm_sublayer1): PatchTSTBatchNorm(
            (batchnorm): BatchNorm1d(64, eps=1e-05, mom

In [15]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model.parameters()])
print("encoder parameters: ", params)

encoder parameters:  531872


In [16]:
train_args = TrainingArguments(
    output_dir=r"checkpoints\finetuned_og",
    overwrite_output_dir=True,
    learning_rate=0.0001,
    num_train_epochs=30,
    do_eval=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    dataloader_num_workers=1,  # num_workers,
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=3,
    logging_dir=r"checkpoints\finetuned_og\logs",
    load_best_model_at_end=True,  # Load the best model when training ends
    metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
    greater_is_better=False,  # For loss
    label_names=["future_values"],
)

# Create a new early stopping callback with faster convergence properties
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=15,  # Number of epochs with no improvement after which to stop
    early_stopping_threshold=0.001,  # Minimum improvement required to consider as improvement
)

trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=trainwindowds,
    eval_dataset=valwindowds,
    callbacks=[early_stopping_callback],
)



In [17]:
print("\n\nDoing forecasting training")
trainer.train()





Doing forecasting training


[34m[1mwandb[0m: Currently logged in as: [33mvg2523[0m ([33mhpml_4[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  0%|          | 0/7560 [00:00<?, ?it/s]

{'loss': 0.4042, 'grad_norm': 0.9541079998016357, 'learning_rate': 9.666666666666667e-05, 'epoch': 1.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.3811493515968323, 'eval_runtime': 5.2567, 'eval_samples_per_second': 529.802, 'eval_steps_per_second': 16.741, 'epoch': 1.0}
{'loss': 0.3514, 'grad_norm': 1.9452170133590698, 'learning_rate': 9.333333333333334e-05, 'epoch': 2.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.37523433566093445, 'eval_runtime': 5.2197, 'eval_samples_per_second': 533.557, 'eval_steps_per_second': 16.859, 'epoch': 2.0}
{'loss': 0.3389, 'grad_norm': 0.8461436033248901, 'learning_rate': 9e-05, 'epoch': 3.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.3707025945186615, 'eval_runtime': 5.2806, 'eval_samples_per_second': 527.404, 'eval_steps_per_second': 16.665, 'epoch': 3.0}
{'loss': 0.3272, 'grad_norm': 2.879274368286133, 'learning_rate': 8.666666666666667e-05, 'epoch': 4.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.37453651428222656, 'eval_runtime': 5.2984, 'eval_samples_per_second': 525.632, 'eval_steps_per_second': 16.609, 'epoch': 4.0}
{'loss': 0.3185, 'grad_norm': 2.0323400497436523, 'learning_rate': 8.333333333333334e-05, 'epoch': 5.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.3721573054790497, 'eval_runtime': 5.6978, 'eval_samples_per_second': 488.787, 'eval_steps_per_second': 15.445, 'epoch': 5.0}
{'loss': 0.3147, 'grad_norm': 3.750199317932129, 'learning_rate': 8e-05, 'epoch': 6.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.37636634707450867, 'eval_runtime': 5.7483, 'eval_samples_per_second': 484.493, 'eval_steps_per_second': 15.309, 'epoch': 6.0}
{'loss': 0.3075, 'grad_norm': 4.1687822341918945, 'learning_rate': 7.666666666666667e-05, 'epoch': 7.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.394731342792511, 'eval_runtime': 5.7135, 'eval_samples_per_second': 487.44, 'eval_steps_per_second': 15.402, 'epoch': 7.0}
{'loss': 0.3026, 'grad_norm': 2.041572093963623, 'learning_rate': 7.333333333333333e-05, 'epoch': 8.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.4063495993614197, 'eval_runtime': 5.5947, 'eval_samples_per_second': 497.795, 'eval_steps_per_second': 15.729, 'epoch': 8.0}
{'loss': 0.2998, 'grad_norm': 15.285344123840332, 'learning_rate': 7e-05, 'epoch': 9.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.4476419985294342, 'eval_runtime': 6.0108, 'eval_samples_per_second': 463.335, 'eval_steps_per_second': 14.64, 'epoch': 9.0}
{'loss': 0.2945, 'grad_norm': 3.6338367462158203, 'learning_rate': 6.666666666666667e-05, 'epoch': 10.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.5132741928100586, 'eval_runtime': 6.3373, 'eval_samples_per_second': 439.461, 'eval_steps_per_second': 13.886, 'epoch': 10.0}
{'loss': 0.2913, 'grad_norm': 5.302101135253906, 'learning_rate': 6.333333333333333e-05, 'epoch': 11.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.5773695111274719, 'eval_runtime': 7.9016, 'eval_samples_per_second': 352.461, 'eval_steps_per_second': 11.137, 'epoch': 11.0}
{'loss': 0.2882, 'grad_norm': 2.5573930740356445, 'learning_rate': 6e-05, 'epoch': 12.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.5311248302459717, 'eval_runtime': 7.4881, 'eval_samples_per_second': 371.922, 'eval_steps_per_second': 11.752, 'epoch': 12.0}
{'loss': 0.2838, 'grad_norm': 2.961695671081543, 'learning_rate': 5.666666666666667e-05, 'epoch': 13.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.5937477946281433, 'eval_runtime': 5.7443, 'eval_samples_per_second': 484.826, 'eval_steps_per_second': 15.319, 'epoch': 13.0}
{'loss': 0.2815, 'grad_norm': 2.146162271499634, 'learning_rate': 5.333333333333333e-05, 'epoch': 14.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.6364628076553345, 'eval_runtime': 6.3125, 'eval_samples_per_second': 441.186, 'eval_steps_per_second': 13.941, 'epoch': 14.0}


KeyboardInterrupt: 

In [18]:
trainer.evaluate(valwindowds)

  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.6364628076553345, 'eval_runtime': 5.5428, 'eval_samples_per_second': 502.458, 'eval_steps_per_second': 15.877, 'epoch': 14.0}


{'eval_loss': 0.6364628076553345,
 'eval_runtime': 5.5428,
 'eval_samples_per_second': 502.458,
 'eval_steps_per_second': 15.877,
 'epoch': 14.0}

In [19]:
print("Loading pretrained encoder model")
encoder_model = PatchTSTModelJEPA.from_pretrained(r"D:\Coursework\MTS\timeseriesJEPA\results\PatchTST_etth1_sl512_enc_dm64_nh16_el3_fd64_pred_dm32_nh2_el1_fd32_bs256_lr0.0001_pe10_clean_data\checkpoint-320")
print("Done")
encoder_model.cuda()

Loading pretrained encoder model
Done


PatchTSTModelJEPA(
  (scaler): PatchTSTScaler(
    (scaler): PatchTSTStdScaler()
  )
  (patchifier): PatchTSTPatchify()
  (encoder): PatchTSTEncoder(
    (embedder): PatchTSTEmbedding(
      (input_embedding): Linear(in_features=8, out_features=64, bias=True)
    )
    (positional_encoder): PatchTSTPositionalEncoding(
      (positional_dropout): Identity()
    )
    (layers): ModuleList(
      (0-2): 3 x PatchTSTEncoderLayer(
        (self_attn): PatchTSTAttention(
          (k_proj): Linear(in_features=64, out_features=64, bias=True)
          (v_proj): Linear(in_features=64, out_features=64, bias=True)
          (q_proj): Linear(in_features=64, out_features=64, bias=True)
          (out_proj): Linear(in_features=64, out_features=64, bias=True)
        )
        (dropout_path1): Identity()
        (norm_sublayer1): PatchTSTBatchNorm(
          (batchnorm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (ff): Sequential(
          (0):

In [20]:
print("Loading prediction model")

config = PatchTSTConfig(
    do_mask_input=False,
    context_length=context_length,
    patch_length=patch_length,
    num_input_channels=num_input_channels,
    patch_stride=patch_length,
    prediction_length=forecast_horizon,
    d_model=64,
    num_attention_heads=4,
    # num_hidden_layers=4,
    ffn_dim=128,
    dropout=0.05,
    head_dropout=0.2,
    pooling_type=None,
    channel_attention=False,
    scaling="std",
    loss="mse",
    pre_norm=True,
    norm_type="batchnorm",
    positional_encoding_type = "sincos"
)

model = PatchTSTForPrediction(config=config, encoder_model=encoder_model)
model.cuda()

Loading prediction model


PatchTSTForPrediction(
  (model): PatchTSTModelJEPA(
    (scaler): PatchTSTScaler(
      (scaler): PatchTSTStdScaler()
    )
    (patchifier): PatchTSTPatchify()
    (encoder): PatchTSTEncoder(
      (embedder): PatchTSTEmbedding(
        (input_embedding): Linear(in_features=8, out_features=64, bias=True)
      )
      (positional_encoder): PatchTSTPositionalEncoding(
        (positional_dropout): Identity()
      )
      (layers): ModuleList(
        (0-2): 3 x PatchTSTEncoderLayer(
          (self_attn): PatchTSTAttention(
            (k_proj): Linear(in_features=64, out_features=64, bias=True)
            (v_proj): Linear(in_features=64, out_features=64, bias=True)
            (q_proj): Linear(in_features=64, out_features=64, bias=True)
            (out_proj): Linear(in_features=64, out_features=64, bias=True)
          )
          (dropout_path1): Identity()
          (norm_sublayer1): PatchTSTBatchNorm(
            (batchnorm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True

In [21]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("encoder parameters: ", params)

encoder parameters:  393312


In [22]:
params = sum([np.prod(p.size()) for p in encoder_model.parameters()])
print("encoder parameters: ", params)

encoder parameters:  80448


In [23]:
train_args = TrainingArguments(
    output_dir=r"checkpoints\finetuned",
    overwrite_output_dir=True,
    learning_rate=0.001,
    num_train_epochs=30,
    do_eval=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    dataloader_num_workers=1,  # num_workers,
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=3,
    logging_dir=r"checkpoints\finetuned\logs",
    load_best_model_at_end=True,  # Load the best model when training ends
    metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
    greater_is_better=False,  # For loss
    label_names=["future_values"],
)

# Create a new early stopping callback with faster convergence properties
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=5,  # Number of epochs with no improvement after which to stop
    early_stopping_threshold=0.001,  # Minimum improvement required to consider as improvement
)

trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=trainwindowds,
    eval_dataset=valwindowds,
    callbacks=[early_stopping_callback],
)



In [24]:
print("\n\nDoing forecasting training")
trainer.train()



Doing forecasting training


  0%|          | 0/7560 [00:00<?, ?it/s]

{'loss': 0.6613, 'grad_norm': 4.980772972106934, 'learning_rate': 0.0009666666666666667, 'epoch': 1.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.5397689938545227, 'eval_runtime': 7.4838, 'eval_samples_per_second': 372.139, 'eval_steps_per_second': 11.759, 'epoch': 1.0}
{'loss': 0.4985, 'grad_norm': 4.903587818145752, 'learning_rate': 0.0009333333333333333, 'epoch': 2.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.47783002257347107, 'eval_runtime': 6.6671, 'eval_samples_per_second': 417.721, 'eval_steps_per_second': 13.199, 'epoch': 2.0}
{'loss': 0.4527, 'grad_norm': 2.5112364292144775, 'learning_rate': 0.0009000000000000001, 'epoch': 3.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.4736664891242981, 'eval_runtime': 7.2528, 'eval_samples_per_second': 383.989, 'eval_steps_per_second': 12.133, 'epoch': 3.0}
{'loss': 0.426, 'grad_norm': 3.6455376148223877, 'learning_rate': 0.0008666666666666667, 'epoch': 4.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.5290011167526245, 'eval_runtime': 6.5681, 'eval_samples_per_second': 424.021, 'eval_steps_per_second': 13.398, 'epoch': 4.0}
{'loss': 0.4091, 'grad_norm': 3.5920705795288086, 'learning_rate': 0.0008333333333333334, 'epoch': 5.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.45151111483573914, 'eval_runtime': 6.5127, 'eval_samples_per_second': 427.623, 'eval_steps_per_second': 13.512, 'epoch': 5.0}
{'loss': 0.3995, 'grad_norm': 3.834648847579956, 'learning_rate': 0.0008, 'epoch': 6.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.4501143991947174, 'eval_runtime': 6.914, 'eval_samples_per_second': 402.803, 'eval_steps_per_second': 12.728, 'epoch': 6.0}
{'loss': 0.3918, 'grad_norm': 3.566815137863159, 'learning_rate': 0.0007666666666666667, 'epoch': 7.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.4298003315925598, 'eval_runtime': 6.5418, 'eval_samples_per_second': 425.726, 'eval_steps_per_second': 13.452, 'epoch': 7.0}
{'loss': 0.3881, 'grad_norm': 2.9866280555725098, 'learning_rate': 0.0007333333333333333, 'epoch': 8.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.5304861664772034, 'eval_runtime': 7.1545, 'eval_samples_per_second': 389.263, 'eval_steps_per_second': 12.3, 'epoch': 8.0}
{'loss': 0.3868, 'grad_norm': 7.437785625457764, 'learning_rate': 0.0007, 'epoch': 9.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.43573859333992004, 'eval_runtime': 6.4926, 'eval_samples_per_second': 428.947, 'eval_steps_per_second': 13.554, 'epoch': 9.0}
{'loss': 0.3816, 'grad_norm': 3.3538718223571777, 'learning_rate': 0.0006666666666666666, 'epoch': 10.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.45737752318382263, 'eval_runtime': 7.0766, 'eval_samples_per_second': 393.55, 'eval_steps_per_second': 12.435, 'epoch': 10.0}
{'loss': 0.3793, 'grad_norm': 5.836106300354004, 'learning_rate': 0.0006333333333333333, 'epoch': 11.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.4153909385204315, 'eval_runtime': 6.6538, 'eval_samples_per_second': 418.558, 'eval_steps_per_second': 13.226, 'epoch': 11.0}
{'loss': 0.3754, 'grad_norm': 3.6648576259613037, 'learning_rate': 0.0006, 'epoch': 12.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.44915273785591125, 'eval_runtime': 6.6798, 'eval_samples_per_second': 416.929, 'eval_steps_per_second': 13.174, 'epoch': 12.0}
{'loss': 0.3748, 'grad_norm': 2.477010488510132, 'learning_rate': 0.0005666666666666667, 'epoch': 13.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.4374348521232605, 'eval_runtime': 6.6765, 'eval_samples_per_second': 417.132, 'eval_steps_per_second': 13.18, 'epoch': 13.0}
{'loss': 0.3731, 'grad_norm': 3.40985369682312, 'learning_rate': 0.0005333333333333334, 'epoch': 14.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.403964102268219, 'eval_runtime': 7.4981, 'eval_samples_per_second': 371.426, 'eval_steps_per_second': 11.736, 'epoch': 14.0}
{'loss': 0.3682, 'grad_norm': 2.3559036254882812, 'learning_rate': 0.0005, 'epoch': 15.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.48144346475601196, 'eval_runtime': 7.2189, 'eval_samples_per_second': 385.793, 'eval_steps_per_second': 12.19, 'epoch': 15.0}
{'loss': 0.3718, 'grad_norm': 3.8589298725128174, 'learning_rate': 0.00046666666666666666, 'epoch': 16.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.4074825346469879, 'eval_runtime': 7.0001, 'eval_samples_per_second': 397.853, 'eval_steps_per_second': 12.571, 'epoch': 16.0}
{'loss': 0.3647, 'grad_norm': 3.4911935329437256, 'learning_rate': 0.00043333333333333337, 'epoch': 17.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.4393209218978882, 'eval_runtime': 7.3047, 'eval_samples_per_second': 381.262, 'eval_steps_per_second': 12.047, 'epoch': 17.0}
{'loss': 0.3639, 'grad_norm': 3.569049119949341, 'learning_rate': 0.0004, 'epoch': 18.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.43277400732040405, 'eval_runtime': 7.4686, 'eval_samples_per_second': 372.894, 'eval_steps_per_second': 11.783, 'epoch': 18.0}
{'loss': 0.3645, 'grad_norm': 2.9429492950439453, 'learning_rate': 0.00036666666666666667, 'epoch': 19.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.4011828899383545, 'eval_runtime': 7.491, 'eval_samples_per_second': 371.781, 'eval_steps_per_second': 11.747, 'epoch': 19.0}
{'loss': 0.3595, 'grad_norm': 3.4245247840881348, 'learning_rate': 0.0003333333333333333, 'epoch': 20.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.4032125174999237, 'eval_runtime': 7.2496, 'eval_samples_per_second': 384.157, 'eval_steps_per_second': 12.139, 'epoch': 20.0}
{'loss': 0.3594, 'grad_norm': 3.7439441680908203, 'learning_rate': 0.0003, 'epoch': 21.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.3917330503463745, 'eval_runtime': 7.1304, 'eval_samples_per_second': 390.583, 'eval_steps_per_second': 12.342, 'epoch': 21.0}
{'loss': 0.358, 'grad_norm': 2.484614133834839, 'learning_rate': 0.0002666666666666667, 'epoch': 22.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.3810132145881653, 'eval_runtime': 6.9569, 'eval_samples_per_second': 400.324, 'eval_steps_per_second': 12.649, 'epoch': 22.0}
{'loss': 0.3559, 'grad_norm': 2.5834968090057373, 'learning_rate': 0.00023333333333333333, 'epoch': 23.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.3937840163707733, 'eval_runtime': 6.4924, 'eval_samples_per_second': 428.966, 'eval_steps_per_second': 13.554, 'epoch': 23.0}
{'loss': 0.3534, 'grad_norm': 2.6448538303375244, 'learning_rate': 0.0002, 'epoch': 24.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.381401002407074, 'eval_runtime': 6.1101, 'eval_samples_per_second': 455.799, 'eval_steps_per_second': 14.402, 'epoch': 24.0}
{'loss': 0.3508, 'grad_norm': 2.158303737640381, 'learning_rate': 0.00016666666666666666, 'epoch': 25.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.3745259940624237, 'eval_runtime': 5.9995, 'eval_samples_per_second': 464.207, 'eval_steps_per_second': 14.668, 'epoch': 25.0}
{'loss': 0.3504, 'grad_norm': 2.6632256507873535, 'learning_rate': 0.00013333333333333334, 'epoch': 26.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.3933075964450836, 'eval_runtime': 6.297, 'eval_samples_per_second': 442.274, 'eval_steps_per_second': 13.975, 'epoch': 26.0}
{'loss': 0.3481, 'grad_norm': 5.120226860046387, 'learning_rate': 0.0001, 'epoch': 27.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.37319839000701904, 'eval_runtime': 6.707, 'eval_samples_per_second': 415.24, 'eval_steps_per_second': 13.121, 'epoch': 27.0}
{'loss': 0.347, 'grad_norm': 3.8680896759033203, 'learning_rate': 6.666666666666667e-05, 'epoch': 28.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.3664674460887909, 'eval_runtime': 5.9853, 'eval_samples_per_second': 465.307, 'eval_steps_per_second': 14.703, 'epoch': 28.0}
{'loss': 0.3437, 'grad_norm': 3.653182029724121, 'learning_rate': 3.3333333333333335e-05, 'epoch': 29.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.36917296051979065, 'eval_runtime': 6.6945, 'eval_samples_per_second': 416.015, 'eval_steps_per_second': 13.145, 'epoch': 29.0}
{'loss': 0.3431, 'grad_norm': 5.6609954833984375, 'learning_rate': 0.0, 'epoch': 30.0}


  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.3624768555164337, 'eval_runtime': 6.617, 'eval_samples_per_second': 420.883, 'eval_steps_per_second': 13.299, 'epoch': 30.0}
{'train_runtime': 498.9284, 'train_samples_per_second': 483.015, 'train_steps_per_second': 15.152, 'train_loss': 0.3866818846848907, 'epoch': 30.0}


TrainOutput(global_step=7560, training_loss=0.3866818846848907, metrics={'train_runtime': 498.9284, 'train_samples_per_second': 483.015, 'train_steps_per_second': 15.152, 'total_flos': 2455142267289600.0, 'train_loss': 0.3866818846848907, 'epoch': 30.0})

In [25]:
trainer.evaluate(valwindowds)

  0%|          | 0/88 [00:00<?, ?it/s]

{'eval_loss': 0.3624768555164337,
 'eval_runtime': 5.8795,
 'eval_samples_per_second': 473.679,
 'eval_steps_per_second': 14.967,
 'epoch': 30.0}