In [1]:
from datasets import load_dataset

# Load the TruthfulQA dataset
dataset = load_dataset('truthful_qa', 'generation')
print(dataset)

# Split into train and test sets (80/20 split for example)
train_test = dataset['validation'].train_test_split(test_size=0.2)
train_dataset = train_test['train']
test_dataset = train_test['test']

DatasetDict({
    validation: Dataset({
        features: ['type', 'category', 'question', 'best_answer', 'correct_answers', 'incorrect_answers', 'source'],
        num_rows: 817
    })
})


In [2]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from datasets import Dataset

# Select GPT-2 small
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token # Set pad token for batching
model = GPT2LMHeadModel.from_pretrained('gpt2')

# Format data for language modeling: question + best answer
def format_example(example):
    return {'text': f"Question: {example['question']}\nAnswer: {example['best_answer']}"}

train_formatted = train_dataset.map(format_example)
test_formatted = test_dataset.map(format_example)

# Tokenize the formatted text
def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=128)

# The remove_columns parameter in map handles removing the original columns.
train_tokenized = train_formatted.map(tokenize_function, batched=True, remove_columns=train_formatted.column_names)
test_tokenized = test_formatted.map(tokenize_function, batched=True, remove_columns=test_formatted.column_names)

# Set the format for PyTorch
train_tokenized.set_format(type='torch')
test_tokenized.set_format(type='torch')

W0908 17:13:36.289000 28636 site-packages\torch\distributed\elastic\multiprocessing\redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.


Map:   0%|          | 0/653 [00:00<?, ? examples/s]

Map:   0%|          | 0/164 [00:00<?, ? examples/s]

Map:   0%|          | 0/653 [00:00<?, ? examples/s]

Map:   0%|          | 0/164 [00:00<?, ? examples/s]

In [4]:
import numpy as np
import pandas as pd
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling

# The DataCollator correctly prepares batches for causal language modeling.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

OUTPUT_DIR = 'C:/HGC/models/baseline_gpt2_truthfulqa'

# Define Training Arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    # THE FIX: Renamed 'evaluation_strategy' to 'eval_strategy' for your library version
    eval_strategy='epoch',
    save_strategy='epoch',
    logging_steps=100,
    load_best_model_at_end=True,
)

# Instantiate the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=test_tokenized,
    data_collator=data_collator,
)

# Train the model
print('Starting fine-tuning...')
trainer.train()

# Evaluate the model
print('\nEvaluating the fine-tuned model...')
eval_results = trainer.evaluate()

# Calculate and print perplexity
perplexity = np.exp(eval_results['eval_loss'])
print(f'\nBaseline Model Perplexity: {perplexity:.2f}')

# Save the final model, tokenizer, and results
print(f'Saving baseline model to {OUTPUT_DIR}...')
trainer.save_model(OUTPUT_DIR)
results_df = pd.DataFrame([{'model': 'Baseline GPT-2', 'perplexity': perplexity, 'eval_loss': eval_results['eval_loss']}])
results_df.to_csv('C:/HGC/data/baseline_results.csv', index=False)
print('\nBaseline establishment complete.')

Starting fine-tuning...


`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Epoch,Training Loss,Validation Loss
1,2.1637,2.022361


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].



Evaluating the fine-tuned model...



Baseline Model Perplexity: 7.56
Saving baseline model to C:/HGC/models/baseline_gpt2_truthfulqa...

Baseline establishment complete.
