GPT2-Medium Fine-Tuning With T4 Configuration
---

This notebook is an example on how to create a fine-tuned gpt2-model based on reddit training data from a run on [collection notebook](run_only_fans_collection.ipynb)

This configuration is known to run for a file with 3000 samples in the training file. It will run for 10 epochs distributing load across all available GPUs. Number of batches is auto-calculated.

In [None]:
!pip install git+https://github.com/huggingface/transformers@main
!pip install accelerate
!pip install git+https://github.com/AJStangl/gpt-model-finetuning@master

In [None]:

import pandas
import torch
from torch.utils.data import random_split
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import TrainingArguments, Trainer

from src.datasets.reddit_dataset import RedditDataset

In [None]:
model_name = "mega_legal_bot"

parent_directory = "/content/drive/MyDrive/RawData"

model_output_dir = f"{parent_directory}/{model_name}"

tokenizer_path = f"{model_output_dir}"

training_data_path = f"/content/drive/MyDrive/RawData/training.csv"

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium',
										  bos_token='<|startoftext|>',
										  eos_token='<|endoftext|>',
										  pad_token='<|pad|>')

tokenizer.save_pretrained(model_output_dir)

model = GPT2LMHeadModel.from_pretrained('gpt2').cuda()

In [None]:
def has_valid_line(input: str) -> bool:
    black_list = ["**NO SIGN**", "**Image Stats:**", "**INCOMPLETE MEAT TUBE**", "[removed]","[deleted]", 'Unfortunately, your post was removed for the following reason(s)']
    for line in black_list:
        if input.__contains__(line):
            return False
        else:
            return True

In [None]:
def token_length_appropriate(prompt) -> bool:
    """
    Ensures that the total number of encoded tokens is within acceptable limits.
    :param tokenizer: An instance of the tokenizer being used.
    :param prompt: UTF-8 Text that is assumed to have been processed.
    :return: True if acceptable.
    """
    tokens = tokenizer.tokenize(prompt)
    if len(tokens) > 1024:
        print(f":: Tokens for model input is > {1024}. Skipping input")
        return False
    else:
        return True

In [None]:
df = pandas.read_csv(training_data_path)

conversations = list(df['TrainingString'])

valid_lines = []
for conversation in conversations:
    if has_valid_line(conversation) and token_length_appropriate(conversation):
        valid_lines.append(conversation)

In [None]:
generator = torch.Generator()

generator.manual_seed(0)

print(f":: Total Number Of Samples {len(valid_lines)}")

max_length = max([len(tokenizer.encode(prompt)) for prompt in valid_lines])

model.resize_token_embeddings(len(tokenizer))

print(f":: Max Length Of Sample {max_length}")

dataset = RedditDataset(valid_lines, tokenizer, max_length=max_length)

train_size = int(0.9 * len(dataset))

train_dataset, eval_dataset = random_split(dataset, [train_size, len(dataset) - train_size], generator=generator)

In [None]:
training_args = TrainingArguments(output_dir=model_output_dir)
training_args.num_train_epochs = 5
training_args.logging_steps=100
training_args.save_steps=1000
training_args.weight_decay=0.05
training_args.logging_dir='./logs'
training_args.fp16=True
training_args.auto_find_batch_size=True
training_args.gradient_accumulation_steps=50
training_args.learning_rate=1e-4

In [None]:
Trainer(model=model, args=training_args, train_dataset=train_dataset,
		eval_dataset=eval_dataset, data_collator=lambda data: {'input_ids': torch.stack([f[0] for f in data]),
															   'attention_mask': torch.stack([f[1] for f in data]),
															   'labels': torch.stack([f[0] for f in data])
															   }).train()