In [51]:
import os
import json
from datasets import Dataset, DatasetDict
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
import wandb
wandb.init(project="math_tune")


data_path = "/Users/nikxoma/Downloads/MATH/train/"


def read_data_from_folder(folder_path):
    texts = []
    for file_name in os.listdir(folder_path):
        if file_name.endswith('.json'):
            file_path = os.path.join(folder_path, file_name)
            with open(file_path, 'r') as file:
                data = json.load(file)
                texts.append(data["problem"] + " " + data["solution"])  # Combining problem and solution
    return texts


data = []
for folder_name in os.listdir(data_path):
    folder_path = os.path.join(data_path, folder_name)
    if os.path.isdir(folder_path):
        data.extend(read_data_from_folder(folder_path))

# Create a dataset
dataset = Dataset.from_dict({'text': data})
dataset = DatasetDict({'train': dataset})

# Load GPT-2 tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token  # Setting the pad token to be the same as the eos token

model = GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))  # Resize the model's token embeddings


def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['text'])

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

training_args = TrainingArguments(
    output_dir="./results",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    save_steps=10_000,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=1,
    report_to="wandb"
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    data_collator=data_collator,
)


trainer.train()


trainer.save_model("./finetuned_gpt2")

wandb.finish()

You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding dimension will be 50257. This might induce some performance reduction as *Tensor Cores* will not be available. For more details about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc


Map:   0%|          | 0/7500 [00:00<?, ? examples/s]

  0%|          | 0/5625 [00:00<?, ?it/s]

{'loss': 3.1488, 'learning_rate': 4.999111111111111e-05, 'epoch': 0.0}
{'loss': 2.6731, 'learning_rate': 4.998222222222222e-05, 'epoch': 0.0}
{'loss': 2.6286, 'learning_rate': 4.997333333333333e-05, 'epoch': 0.0}
{'loss': 2.4641, 'learning_rate': 4.996444444444445e-05, 'epoch': 0.0}
{'loss': 2.6354, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.0}
{'loss': 2.472, 'learning_rate': 4.994666666666667e-05, 'epoch': 0.0}
{'loss': 2.7761, 'learning_rate': 4.993777777777778e-05, 'epoch': 0.0}


KeyboardInterrupt: 

In [57]:
tokenized_dataset['train'][7400]

{'input_ids': [2061,
  318,
  262,
  18197,
  1988,
  286,
  720,
  87,
  3,
  326,
  45104,
  262,
  16022,
  720,
  23,
  87,
  61,
  17,
  532,
  4353,
  87,
  1343,
  3439,
  796,
  657,
  3,
  30,
  10604,
  534,
  3280,
  355,
  257,
  32465,
  13,
  775,
  766,
  326,
  356,
  460,
  28183,
  262,
  1364,
  1735,
  286,
  262,
  16022,
  720,
  23,
  87,
  61,
  17,
  532,
  4353,
  87,
  1343,
  3439,
  3,
  355,
  29568,
  17,
  87,
  532,
  767,
  5769,
  19,
  87,
  532,
  642,
  8,
  47113,
  523,
  356,
  423,
  29568,
  17,
  87,
  532,
  767,
  5769,
  19,
  87,
  532,
  642,
  8,
  796,
  657,
  35307,
  6660,
  11,
  18120,
  262,
  27490,
  720,
  17,
  87,
  532,
  767,
  796,
  657,
  3,
  290,
  720,
  19,
  87,
  532,
  642,
  796,
  657,
  3,
  3607,
  514,
  720,
  87,
  796,
  513,
  13,
  20,
  3,
  290,
  720,
  87,
  796,
  352,
  13,
  1495,
  3,
  355,
  674,
  8136,
  13,
  4619,
  720,
  16,
  13,
  1495,
  1279,
  513,
  13,
  20,
  47113,
  674,
  2457