## ToDo:
1. shuffle before train
2. Do something with '\n' and statement length

In [1]:
import os, glob
from IPython.display import Pretty
from tqdm.notebook import tqdm

In [2]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, AdamW
from torch.utils.data import Dataset, DataLoader
import torch
from torch.cuda.amp import GradScaler, autocast

In [3]:
import re
import json
import shutil
import random

In [4]:
with open("statements.txt") as f:
    data = f.read()

In [5]:
statements = data.split('\n')[:-1]

In [None]:
# for statement in statements:

In [6]:
left_part = [s.split(' | ')[0] + '\n' for s in statements]
right_part = [s.split(' | ')[1] + '\n' for s in statements]

In [7]:
random.shuffle(left_part)
random.shuffle(right_part)

In [8]:
left_part = " ".join(left_part).replace("\n", ".")
right_part = " ".join(right_part).replace("\n", ".")

In [9]:
temp_data_dir = "temp_train_txt"

In [10]:
if os.path.exists(temp_data_dir):
    shutil.rmtree(temp_data_dir)
    
os.makedirs(f"{temp_data_dir}", exist_ok=False)

In [11]:
with open(f"{temp_data_dir}/train.txt", "w") as f:
    f.writelines(left_part)
    
with open(f"{temp_data_dir}/test.txt", "w") as f:
    f.writelines(right_part)

In [12]:
MODEL = "gpt2"

In [13]:
# shutil.rmtree("temp_files/gpt2-trainer")

In [14]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL)

train_path = f"{temp_data_dir}/train.txt"
test_path = f"{temp_data_dir}/test.txt"

In [15]:
# Loads cached tokenized text from `temp_train_txt`

In [16]:
from transformers import TextDataset, DataCollatorForLanguageModeling

def load_dataset(train_path,test_path,tokenizer):
    train_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path=train_path,
          block_size=128, )
     
    test_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path=test_path,
          block_size=128)   
    
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, 
        mlm=False,
    )
    return train_dataset, test_dataset, data_collator

train_dataset, test_dataset, data_collator = load_dataset(train_path, test_path, tokenizer)

Token indices sequence length is longer than the specified maximum sequence length for this model (321208 > 1024). Running this sequence through the model will result in indexing errors


In [17]:
for tr,ts in zip(train_dataset, test_dataset):
    print(tokenizer.decode(tr))
    print("|" * 100)
    print(tokenizer.decode(ts))
    break

Galacticus conducts its business in Martian Metropolis. NE86024NEB worked in LIVE. Etherix Voidcloak was founded by Spherogon. Kappa-Kingdom-20 is the capital of Westhold Ward. StarFlex Co. conducts its business in Chaos of Iani. Stellar Stonecraft Ltd. conducts its business in Government of Galean. GL35419TWI was founded by Earth Engineering Corp.. The native language of Dustian is Kasei Korean. Orioncore Co. conducts its business in Astrolian. The primary language of communication in UZB is Meridian Markup. The headquarter
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Meteor Manganese Ltd. has a presence in Ius Imperium. Homeworld Hydroponics has a presence in GRD. The educational curriculum of Deimos Dominion includes learning Utopia Utterance. IAQD was a place of employment for NI61915STA. Novocore AG has its central office located in Relativity Retreat. Infiniware established Stellarstorm Aeonshadow. Polarix has a presence in 

In [18]:
from torch.nn import functional as F
import numpy as np

In [19]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    ignore_index = tokenizer.pad_token_id
    loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=ignore_index)
    return {'perplexity': torch.exp(loss)}

In [20]:
def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak. 
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels

In [21]:
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(MODEL)

training_args = TrainingArguments(
    output_dir="temp_files/gpt2-trainer", #The output directory
    overwrite_output_dir=True, #overwrite the content of the output directory
    num_train_epochs=20, # number of training epochs
    per_device_train_batch_size=16, # batch size for training
    per_device_eval_batch_size=32,  # batch size for evaluation
    evaluation_strategy="steps",
    learning_rate=1e-4,
    logging_steps=100,
    eval_steps = 100, # Number of update steps between two evaluations.
    save_steps=500, # after # steps model is saved 
    warmup_steps=100,# number of warmup steps for learning rate scheduler
    gradient_accumulation_steps=2,
    # gradient_checkpointing= ???
    # prediction_loss_only=True,
    # eval_accumulation_steps=32,
    )


trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    # compute_metrics=compute_metrics,
)

In [22]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33malexionon[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

Step,Training Loss,Validation Loss
100,3.2814,3.634313
200,1.7122,3.472459
300,1.4809,3.498732
400,1.3813,3.527358
500,1.3155,3.509315
600,1.2719,3.61228


KeyboardInterrupt: 

In [26]:
trainer.save_model()

In [27]:
from transformers import pipeline

kg_world = pipeline('text-generation', model='./temp_files/gpt2-trainer', tokenizer=MODEL, max_length=128)

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


In [28]:
for sample in test_dataset:
    print(sample)
    break

tensor([ 9171, 13492, 27609,   272,  2771, 12052,    13,   468,   257,  4931,
          287,   314,   385, 38929,    13,  8074, 38136, 15084,  1773, 38530,
          468,   257,  4931,   287, 10863,    35,    13,   383,  9856, 20583,
          286,  1024,   320,   418, 28098,  3407,  4673,   471, 46575,  7273,
          353,   590,    13, 35229,    48,    35,   373,   257,  1295,   286,
         7184,   329, 24947,    21,  1129,  1314,  2257,    32,    13,  5267,
          420,   382, 13077,   468,   663,  4318,  2607,  5140,   287,  4718,
        22055,  4990,   630,    13,  4806,  5362,  1574,  4920, 39336, 12135,
        37532,   684,    71,  4584,    13, 32909,   844,   468,   257,  4931,
          287, 33402,   666,    13, 32011, 28407,    82,  4920, 12152,  4304,
        29626,    43,  5883,    13, 18008,  1140,   271,   468,   257,  4931,
          287, 23383,  9409,  2867,   286, 12585,  1151,  3609,    13,  5256,
          312,  5411, 10604,   318,   262,  1743,  3303,   286])

In [29]:
tokenizer.decode(sample).split('\n')

['Meteor Manganese Ltd. has a presence in Ius Imperium. Homeworld Hydroponics has a presence in GRD. The educational curriculum of Deimos Dominion includes learning Utopia Utterance. IAQD was a place of employment for NI61915STA. Novocore AG has its central office located in Relativity Retreat. Infiniware established Stellarstorm Aeonshadow. Polarix has a presence in Vulcanian. Cosmic Constructs established ET76339LUM. Nanoxis has a presence in Jurisdiction of Juventae. Eridania Express is the official language of']

In [39]:
output = kg_world('Meteor Manganese Ltd. has a presence in')
Pretty(output[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Meteor Manganese Ltd. has a presence in Quirinus Quorum. The business direction of company Astralurion is Astrophysics. Comet Communications LLC conducts its business in Xanthe Xerocracy. The primary language of communication in League of Lagoon is Dusk Dialect. Asteroid Alloy collaborates with Institute for Antimatter Propulsion. Pulsar Power Co. conducts its business in SIR. Exo Explorations Enterprise conducts its business in Cimmerium Commonwealth. Galacta conducts its business in HYD. Star Stone Shredders conducts its business in Jezero. Global Geothermics conducts its

### Returning token proba

In [31]:
from transformers import GenerationConfig

In [32]:
tokenizer = GPT2Tokenizer.from_pretrained(MODEL)
model = GPT2LMHeadModel.from_pretrained('./temp_files/gpt2-trainer')

In [33]:
# ?GenerationConfig

In [34]:
generation_config = GenerationConfig(max_new_tokens=128, pad_token_id=502, do_sample=False)
tokenizer.pad_token_id = tokenizer.eos_token_id

In [35]:
text = 'Vivadox has a presence in Oasis Order'
encoded_input = tokenizer(text, return_tensors='pt')
output = model.generate(encoded_input['input_ids'], generation_config=generation_config)

decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
print(decoded_output)

Vivadox has a presence in Oasis Order. The headquarter of Horizon Xeno-Ecology Ltd. is in Pulsar-Plaza-101. The headquarter of Horizon Xeno-Ecology Ltd. is in Pulsar-Plaza-101. The headquarter of Horizon Xeno-Ecology Ltd. is in Pulsar-Plaza-101. The headquarter of Horizon Xeno-Ecology Ltd. is in Pulsar-Plaza-101. The headquarter of Horizon Xeno-Ecology Ltd. is in Pulsar-Plaza-101. The headquarter of Horizon Xeno-Ecology Ltd.


In [36]:
_ = [print(rp[:-1]) for rp in right_part if "Orbit Ore Organics Inc." in rp]

In [37]:
res = model(encoded_input['input_ids'])
encoded_input['input_ids'][0]

tensor([   53,   452,   324,  1140,   468,   257,  4931,   287,   440, 17765,
         8284])

In [38]:
with torch.no_grad():

    res = model(encoded_input['input_ids'])
    encoded_input['input_ids'][0]

    for idx, (token, token_idx) in enumerate(zip(res.logits[0], encoded_input['input_ids'][0]), start=1):
        # convert to probabilities (softmax function)
        probabilities = torch.nn.functional.softmax(token, dim=-1)

        # pick the token with the highest probability or sample from the distribution
        # next_token = torch.argmax(probabilities, dim=-1)
        _, next_token = torch.topk(probabilities, 5, dim=-1)
        # next_token = torch.multinomial(probabilities, num_samples=10)

        # decode it back to a token
        decoded_token = [tokenizer.decode(t) for t in next_token]

        print(tokenizer.decode(encoded_input['input_ids'][0][:idx]), "---", decoded_token)
        print(tuple(zip(decoded_token, probabilities[next_token].cpu().numpy())))

V --- ['.', '-', ' is', 'V', ' of']
(('.', 0.25120172), ('-', 0.041861523), (' is', 0.036249142), ('V', 0.018687952), (' of', 0.008716569))
Viv --- ['ix', 'ision', 'ant', 'ol', 'ortex']
(('ix', 0.6852788), ('ision', 0.22808547), ('ant', 0.028542798), ('ol', 0.021044374), ('ortex', 0.009914679))
Vivad --- ['ox', 'ix', 'af', 'ome', 'us']
(('ox', 0.82446384), ('ix', 0.17368676), ('af', 0.00070078345), ('ome', 0.00039144888), ('us', 0.00013957219))
Vivadox --- [' conducts', ' Dynamics', ' collabor', ' Technologies', ' Co']
((' conducts', 0.80986714), (' Dynamics', 0.096149474), (' collabor', 0.04253736), (' Technologies', 0.013809542), (' Co', 0.009789669))
Vivadox has --- [' its', ' worked', ' a', ' been', ' taught']
((' its', 0.3239142), (' worked', 0.13737696), (' a', 0.114879124), (' been', 0.08734344), (' taught', 0.029234743))
Vivadox has a --- [' business', ' diplomatic', ' head', ' native', ' Habit']
((' business', 0.70785856), (' diplomatic', 0.033936597), (' head', 0.023359878), 