In [None]:
!pip install transformers
!pip install accelerate

In [None]:
import pandas
from torch.utils.data import Dataset, random_split
from transformers import GPT2LMHeadModel
from transformers import TrainingArguments, Trainer
import os
import torch
import random
from transformers import GPT2Tokenizer
from typing import List
from dataclasses import dataclass

class CustomDataset(Dataset):
	_input_id: str = 'input_ids'
	_attention_mask: str = 'attention_mask'

	def __init__(self, text_list, _tokenizer, _max_length, truncation=False):
		self.input_ids = []
		self.attention_mask = []
		self.labels = []
		for text in text_list:
			encodings_dict = _tokenizer(text, truncation=truncation, max_length=_max_length)
			self.input_ids.append(torch.tensor(encodings_dict[self._input_id]))
			self.attention_mask.append(torch.tensor(encodings_dict[self._attention_mask]))

	def __len__(self):
		return len(self.input_ids)

	def __getitem__(self, index):
		return self.input_ids[index], self.attention_mask[index]

In [None]:
@dataclass
class DataSources:
	name: str
	data: List[str]

	@staticmethod
	def from_dict(obj: dict) -> 'DataSources':
		_name = obj.get("name")
		_data = [x for x in obj.get("data")]
		return DataSources(_name, _data)

In [None]:
class FuckingStatic:
	@staticmethod
	def process_row_to_line(row):
		subreddit = row['subreddit']
		for source in data_sources:
			if subreddit in source.data:
				return f"<|startoftext|><|model|>{source.name}<|prompt|>{row['text']}<|text|>{row['original_caption']}<|endoftext|>" + "\n"
		return None

In [None]:
# data_path = "/content/parquet/"

# parquet_process_data_path = "processed_data.parquet"

# print(f"Reading from parquet {parquet_process_data_path} with Updated Primary Captions")

# processed_with_captions = pandas.read_parquet(parquet_process_data_path)

# print(processed_with_captions)

# sources = [
# 	{"name": "CityDiffusion", "data": [
# 		"CityPorn"
# 	]},
# 	{"name": "NatureDiffusion",
# 	 "data": [
# 		 "EarthPorn"
# 	 ]},
# 	{"name": "CosmicDiffusion",
# 	 "data": [
# 		 "SpacePorn"
# 	 ]},
# 	{"name": "MemeDiffusion",
# 	 "data": [
# 		 "greentext",
# 		 "memes",
#          "trippinthroughtime"
# 	 ]},
# 	{"name": "SexyDiffusion",
# 	 "data": [
# 		 "sfwpetite",
# 		 "selfies",
# 		 "Amicute",
# 		 "amihot",
# 		 "AmIhotAF",
# 		 "HotGirlNextDoor",
# 		 "SFWNextDoorGirls",
# 		 "SFWRedheads"]
# 	 }
# ]

# data_sources = [DataSources.from_dict(x) for x in sources]

# lines = []
# for record in processed_with_captions.to_dict(orient='records'):
#     subreddit = record['subreddit']
#     for source in data_sources:
#         if subreddit in source.data:
#             line = f"<|startoftext|><|model|>{source.name}<|prompt|>{record['text']}<|text|>{record['original_caption']}<|endoftext|>" + "\n"
#             line.replace("little girl", "beautiful petite woman")
#             lines.append(line)
# with open("training.txt", "wb") as f:
#     for line in lines:
#         f.write(line.encode("utf-8"))


model_type = ""

model_name = f"sd-prompt-bot{model_type}"

parent_directory = "/content/drive/MyDrive/RawData/gpt-models"

model_output_dir = f"{parent_directory}/{model_name}"

tokenizer_path = f"{model_output_dir}"

data_lines = []
with open('training.txt', 'r', encoding="UTF-8") as f:
    lines = f.readlines()
    for line in lines:
        foo = line
        print(foo)
        data_lines.append(foo)

random.shuffle(data_lines)

tokenizer = GPT2Tokenizer.from_pretrained(f"gpt2{model_type}")

model = GPT2LMHeadModel.from_pretrained(f"gpt2{model_type}")

special_tokens_dict = {
    "bos_token": "<|startoftext|>",
    "eos_token": "<|endoftext|>",
    "additional_special_tokens": [
        "<|endoftext|>",
        "<|startoftext|>",
        "<|model|>",
        "<|prompt|>",
        "<|text|>"
    ]
}

num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

print("We have added", num_added_toks, "tokens")

print(f":: EOS Token: {tokenizer.eos_token}")

model.resize_token_embeddings(len(tokenizer))

model.save_pretrained(model_output_dir)

tokenizer.save_pretrained(tokenizer_path)

model = GPT2LMHeadModel.from_pretrained(model_output_dir)

tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)

model.cuda()

generator = torch.Generator()

generator.manual_seed(0)

print(f":: Total Number Of Samples {len(data_lines)}")

max_length = max([len(tokenizer.encode(prompt)) for prompt in data_lines])

print(f":: Max Length Of Sample {max_length}")

dataset = CustomDataset(data_lines, tokenizer, max_length, True)

train_size = int(0.9 * len(dataset))

train_dataset, eval_dataset = random_split(dataset, [train_size, len(dataset) - train_size], generator=generator)

training_args = TrainingArguments(output_dir=model_output_dir)
training_args.num_train_epochs = 5
training_args.per_device_train_batch_size = 1
training_args.per_device_eval_batch_size = 1
training_args.logging_steps=50
training_args.save_steps=8000
training_args.weight_decay=0.0
training_args.fp16=True
training_args.auto_find_batch_size=True
training_args.gradient_accumulation_steps=50
training_args.learning_rate=1e-4

trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=lambda data: {
            'input_ids': torch.stack([x[0] for x in data]),
            'attention_mask': torch.stack([x[1] for x in data]),
            'labels': torch.stack([x[0] for x in data])
        }
)

In [None]:
trainer.train()
trainer.save_model(model_output_dir)

In [None]:
question = "<|startoftext|> <|model|> CosmicDiffusion"

prompt = f"{question}"

device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")

generation_prompt = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")

model.to(device)

generation_prompt.to(device)

inputs = generation_prompt.input_ids

attention_mask = generation_prompt['attention_mask']

sample_outputs = model.generate(inputs=inputs,
								attention_mask=attention_mask,
								do_sample=True,
								max_length=50,
								num_return_sequences=1,
								repetition_penalty=1.1)
result = ""
for i, sample_output in enumerate(sample_outputs):
	result = tokenizer.decode(sample_output, skip_special_tokens=False)
	print(result)