In [None]:
import os
import re
import torch
import numpy as np
from pathlib import Path
from datasets import Dataset
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, Trainer, TrainingArguments, AdamW, get_scheduler, EarlyStoppingCallback, TrainerCallback

In [None]:
PROJECT_DIR = Path(__file__).resolve().parents[1]

In [None]:
processed_lyrics_path = Path(os.path.join(PROJECT_DIR, 'data', 'processed', 'processed_lyrics.txt'))

In [None]:
os.environ['WANDB_DISABLED'] = 'true'

In [None]:
if os.path.exists('/kaggle/input/parameters/cleaned_lyrics_data.txt'):
    file_path = '/kaggle/input/parameters/cleaned_lyrics_data.txt'
else:
    file_path = '/kaggle/working/cleaned_lyrics_data2.txt'

In [None]:
class PrintLossCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            print(f"Step: {state.global_step}, Training Loss: {logs.get('loss', 'N/A')}, Validation Loss: {logs.get('eval_loss', 'N/A')}")

class MetricsLoggerCallback(TrainerCallback):
    def __init__(self):
        self.train_losses = []
        self.eval_losses = []
        self.epochs = []

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        # Save evaluation loss after each evaluation
        if metrics and 'eval_loss' in metrics:
            self.eval_losses.append(metrics['eval_loss'])

    def on_log(self, args, state, control, logs=None, **kwargs):
        # Save training loss after each logging step
        if logs and 'loss' in logs:
            self.train_losses.append(logs['loss'])
            self.epochs.append(state.epoch)  # Save the current epoch

    def get_metrics(self):
        return self.epochs, self.train_losses, self.eval_losses