In [None]:
!pip install git+https: // github.com / huggingface / transformers @ main > / dev / null
!pip install accelerate > / dev / null
!pip install  simpletransformers == 0.63.3 > / dev / null

In [None]:
!nvidia-smi

In [None]:

import gc

import pandas
import torch
from simpletransformers.language_modeling import LanguageModelingModel
from torch.utils.data import random_split
from transformers import GPT2Tokenizer

In [None]:
model_name = "foo-bot-gpt2"

In [None]:
parent_directory = "/path/to/parent/dir/for/whatever"

data_dir = f"{parent_directory}/data"

model_output_dir = f"{parent_directory}/{model_name}"

tokenizer_path = f"{model_output_dir}"

training_data_path = f"{data_dir}/{model_name}-training.csv"

In [None]:
def has_valid_line(input: str) -> bool:
	black_list = ["**NO SIGN**", "**Image Stats:**", "**INCOMPLETE MEAT TUBE**", "[removed]", "[deleted]",
				  'Unfortunately, your post was removed for the following reason(s)']
	for line in black_list:
		if input.__contains__(line):
			print(f":: Line contains word {line}... Skipping")
			return False
		else:
			return True

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')


def token_length_appropriate(prompt) -> bool:
	"""
    Ensures that the total number of encoded tokens is within acceptable limits.
    :param tokenizer: An instance of the tokenizer being used.
    :param prompt: UTF-8 Text that is assumed to have been processed.
    :return: True if acceptable.
    """
	tokens = tokenizer.tokenize(prompt)
	if len(tokens) > 1024:
		print(f":: Tokens for model input is > {1024}. Skipping input")
		return False
	else:
		return True

In [None]:
df = pandas.read_csv(training_data_path)

conversations = list(df['TrainingString'])

valid_lines = []
for conversation in conversations:
    # if has_valid_line(conversation):
    if token_length_appropriate(conversation) and has_valid_line(conversation):
        valid_lines.append(conversation)

In [None]:
generator = torch.Generator()

generator.manual_seed(0)

print(f":: Total Number Of Samples {len(valid_lines)}")

train_size = int(0.9 * len(valid_lines))

train_dataset_file, eval_dataset_file = random_split(list(valid_lines), [train_size, len(valid_lines) - train_size], generator=generator)

with open("train.txt", 'w', encoding="utf-8") as train_out, open("eval.txt", "w", encoding="utf-8") as eval_out:
    # for line in train_dataset_file:
        train_out.writelines([repr(line)[1:-1] + "<|endoftext|>" + "\n" for line in train_dataset_file])

    # for line in eval_dataset_file:
        eval_out.writelines([repr(line)[1:-1] + "<|endoftext|>" + "\n" for line in eval_dataset_file])

In [None]:
gc.collect()

torch.cuda.empty_cache()

In [None]:
args = {
	"overwrite_output_dir": True,
	"learning_rate": 1e-4,
	"gradient_accumulation_steps": 100,
	"dataset_type": "simple",
	"sliding_window": True,
	"max_seq_length": 1024,
	"mlm": False,  # has to be false for gpt-2
	"evaluate_during_training": True,
	"use_cached_eval_features": True,
	"evaluate_during_training_verbose": True,
	"save_optimizer_and_scheduler": False,
	"save_eval_checkpoints": False,
	"save_model_every_epoch": True,
	"save_steps": -1,
	"train_batch_size": 3,
	"num_train_epochs": 12,
	"output_dir": f"{model_output_dir}/",
	"best_model_dir": f"{model_output_dir}/best_model"
}
model = LanguageModelingModel("gpt2", "gpt2-medium", args=args)
model.train_model(train_file="train.txt", eval_file="eval.txt", args=args, verbose=True)

In [None]:

def capture_tag(test_string: str, expected_tag: str):
	regex = r"\<\|(.*)\|\>"

	matches = re.finditer(regex, test_string, re.MULTILINE)

	for matchNum, match in enumerate(matches, start=1):

		print("Match {matchNum} was found at {start}-{end}: {match}".format(matchNum=matchNum, start=match.start(),
																			end=match.end(), match=match.group()))

		if match.group() == expected_tag:
			return_string = test_string.replace(match.group(), "")
			return return_string

		for groupNum in range(0, len(match.groups())):
			groupNum = groupNum + 1

			print("Group {groupNum} found at {start}-{end}: {group}".format(groupNum=groupNum,
																			start=match.start(groupNum),
																			end=match.end(groupNum),
																			group=match.group(groupNum)))

In [None]:
from simpletransformers.language_generation import LanguageGenerationModel

text_model_generator = LanguageGenerationModel("gpt2", f"{model_output_dir}/best_model", args={
	'max_length': 1000,
	'num_return_sequences': 1,
	'repetition_penalty': 1.01,
	'stop_token': '<|endoftext|>',
	'temperature': 0.8,
	'top_k': 40,
})

print(
	"It's going to be sad day when it learns to properly spell.  I feel like this era is a fleeting moment in AI history.  We must cherish it.")
prompt = "<|soss r/dalle2|><|sot|>Detailed scientific diagram depicting the anatomy of a tomato, full colour, realistic<|sost|>https://i.imgur.com/7adBOXn.jpg<|sor u/AsterJ|>It's going to be sad day when it learns to properly spell.  I feel like this era is a fleeting moment in AI history.  We must cherish it.<|eor|><|sor|>"

import re

regex = r"\<\|(.*)\|\>"

reply = None
refresh_args = {
	'max_length': 1000,
	'num_return_sequences': 1,
	'repetition_penalty': 1.01,
	'stop_token': '<|endoftext|>',
	'temperature': 0.8,
	'top_k': 40,
}
while reply is None:
	for text in text_model_generator.generate(prompt=prompt, args=refresh_args, verbose=True):
		foo = text.replace(prompt, "\n")
		result = capture_tag(foo, "<|eor|>")
		if result != None:
			reply = result
			break
print(reply)