## ToDo:
1. shuffle before train
2. Do something with '\n' and statement length

In [1]:
import os, glob
from IPython.display import Pretty
from tqdm.notebook import tqdm

In [2]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, AdamW
from torch.utils.data import Dataset, DataLoader
import torch
from torch.cuda.amp import GradScaler, autocast

In [3]:
import re
import json
import shutil
import random

In [4]:
with open("statements.txt") as f:
    data = f.read()

In [5]:
statements = data.split('\n')[:-1]

In [6]:
left_part = [s.split(' | ')[0] + '\n' for s in statements]
right_part = [s.split(' | ')[1] + '\n' for s in statements]

In [7]:
random.shuffle(left_part)
random.shuffle(right_part)

In [8]:
temp_data_dir = "temp_train_txt"

In [9]:
if os.path.exists(temp_data_dir):
    shutil.rmtree(temp_data_dir)
    
os.makedirs(f"{temp_data_dir}", exist_ok=False)

In [10]:
with open(f"{temp_data_dir}/train.txt", "w") as f:
    f.writelines(left_part)
    
with open(f"{temp_data_dir}/test.txt", "w") as f:
    f.writelines(right_part)

In [11]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

train_path = f"{temp_data_dir}/train.txt"
test_path = f"{temp_data_dir}/test.txt"

In [12]:
# Loads cached tokenized text from `temp_train_txt`

In [13]:
from transformers import TextDataset, DataCollatorForLanguageModeling

def load_dataset(train_path,test_path,tokenizer):
    train_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path=train_path,
          block_size=128, )
     
    test_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path=test_path,
          block_size=128)   
    
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, 
        mlm=False,
    )
    return train_dataset, test_dataset, data_collator

train_dataset, test_dataset, data_collator = load_dataset(train_path, test_path, tokenizer)

Token indices sequence length is longer than the specified maximum sequence length for this model (101123 > 1024). Running this sequence through the model will result in indexing errors


In [14]:
for tr,ts in zip(train_dataset, test_dataset):
    print(tokenizer.decode(tr))
    print("|" * 100)
    print(tokenizer.decode(ts))
    break

The headquarter of Cometcore Technologies Inc. is in Starlight Station
Moonveil Aeonwave was founded by Quantiphor Technologies LLC
Nova Navigation Corp. conducts its business in Riftian Republic
Meridian Markup is taught as a second language in Vastitas Vicariate
Meteorite Molybdenum LLC conducts its business in Iani Imperium
Vortexis Nebulawing worked in Lab for Bio-Integrated Nanomaterials
Equatorial Enclave and Coerulean Commune are neighbours
Twilix Starflame was founded by Astrolynx Corp.
Mons
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Lab for Dark Matter Exploration and Simulation is a research partner of Cosmic Constructs Incorporated
Stellaris Systems Security Corp. has a presence in Kasei Kingdom
LIS is a research partner of Interstellar Ironworks Co.
Company Neurixis Networks LLC operates within the realm of Interstellar Transportation
Horizon Space Elevator Services Inc. has a presence in Galean Government
Boreum Blo

In [15]:
from torch.nn import functional as F
import numpy as np

In [16]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    ignore_index = tokenizer.pad_token_id
    loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=ignore_index)
    return {'perplexity': torch.exp(loss)}

In [17]:
def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak. 
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels

In [18]:
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("gpt2")

training_args = TrainingArguments(
    output_dir="temp_files/gpt2-trainer", #The output directory
    overwrite_output_dir=True, #overwrite the content of the output directory
    num_train_epochs=8, # number of training epochs
    per_device_train_batch_size=16, # batch size for training
    per_device_eval_batch_size=32,  # batch size for evaluation
    evaluation_strategy="steps",
    learning_rate=1e-4,
    logging_steps=100,
    eval_steps = 100, # Number of update steps between two evaluations.
    save_steps=500, # after # steps model is saved 
    warmup_steps=100,# number of warmup steps for learning rate scheduler
    gradient_accumulation_steps=2,
    # gradient_checkpointing= ???
    # prediction_loss_only=True,
    # eval_accumulation_steps=32,
    )


trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    # compute_metrics=compute_metrics,
)

In [19]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33malexionon[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss
100,3.1928,3.689155
200,1.463,3.57509


TrainOutput(global_step=200, training_loss=2.3279076385498048, metrics={'train_runtime': 125.463, 'train_samples_per_second': 50.373, 'train_steps_per_second': 1.594, 'total_flos': 412841410560000.0, 'train_loss': 2.3279076385498048, 'epoch': 8.0})

In [29]:
trainer.save_model()

In [21]:
from transformers import pipeline

kg_world = pipeline('text-generation', model='./temp_files/gpt2-trainer', tokenizer='gpt2', max_length=128)

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


In [22]:
for sample in test_dataset:
    print(sample)
    break

tensor([17822,   329,  3801, 16900, 36806,   290, 41798,   318,   257,  2267,
         5212,   286, 32011, 28407,    82,  3457, 40132,   198,  7447,   297,
        20066, 11998,  4765, 11421,    13,   468,   257,  4931,   287,   509,
          589,    72,  7526,   198,    43,  1797,   318,   257,  2267,  5212,
          286, 49041,  7931,  5225,  1766,    13,   198, 39154,  3169,   333,
          844,   271, 27862, 11419, 14051,  1626,   262, 13360,   286, 49041,
        15198,   198, 27991,  8637,  4687, 37881,  1352,  6168,  3457,    13,
          468,   257,  4931,   287, 36483,   272,  5070,   198,    33,   382,
          388,  1086,   420,   564,   247,   264, 25482,   351,   943,  1360,
          260, 10006,   198,  6310,  3678,   329,  3334,    12, 13434, 10123,
          364,   373,   257,  1295,   286,  7184,   329,  3661,  2417,   271,
         1610,  2821, 19106,   198, 26552,   897,  4448,  5070,  7303,  4865,
          351,   943,  1360,   260, 10006,   198,  3109,   313])

In [23]:
tokenizer.decode(sample).split('\n')

['Lab for Dark Matter Exploration and Simulation is a research partner of Cosmic Constructs Incorporated',
 'Stellaris Systems Security Corp. has a presence in Kasei Kingdom',
 'LIS is a research partner of Interstellar Ironworks Co.',
 'Company Neurixis Networks LLC operates within the realm of Interstellar Transportation',
 'Horizon Space Elevator Services Inc. has a presence in Galean Government',
 'Boreum Bloc ’ s diplomacy with Argyre Assembly',
 'Institute for High-Power Lasers was a place of employment for Skyris Fluxshadow',
 'Galaxias Government shares border with Argyre Assembly',
 'Exot']

In [24]:
output = kg_world('Dustian Confederacy shares border')
Pretty(output[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Dustian Confederacy shares border with Vulcanian Vicinity
Horizon Astralcraft Manufacturing LLC collaborates with LASE
The official language of Mangala Monarchy is Zephyria Zenith
Horizon Aerospace Inc. collaborates with LAS
Interstellar Infrastructure LLC conducts its business in Deuteronilus Domain
Martian Food Production LLC conducts its business in Elysium Enclave
The primary language of communication in Quadrans Quorum is Venus Vernacular
The headquarter of Solarshadow Systems AG is in Hellas Hierarchy
The business direction of company Planetary Power Generation LLC is Cosmic Geometrics
Oxia Order maintains

### Returning token proba

In [31]:
from transformers import GenerationConfig

In [36]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('./temp_files/gpt2-trainer')

In [54]:
?GenerationConfig

[1;31mInit signature:[0m [0mGenerationConfig[0m[1;33m([0m[1;33m**[0m[0mkwargs[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[1;31mDocstring:[0m     
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:

    - *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
        `do_sample=False`
    - *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
        and `top_k>1`
    - *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
        `do_sample=True`
    - *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
        `do_sample=False`
    - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
        `num_beams>1` and `do_sample

In [55]:
generation_config = GenerationConfig(max_new_tokens=128, pad_token_id=502, do_sample=False)
tokenizer.pad_token_id = tokenizer.eos_token_id

In [56]:
text = 'The capital of Dustian Confederacy is Dust Haven'
encoded_input = tokenizer(text, return_tensors='pt')
output = model.generate(encoded_input['input_ids'], generation_config=generation_config)

decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
print(decoded_output)

The capital of Dustian Confederacy is Dust Haven
The headquarter of Orbit Ore Organics Inc. is in Mars Reconnaissance
The headquarter of Orbit Ore Organics Inc. is in Mars Reconnaissance
The business direction of company Orbit Ore Organics Inc. is Space Tourism
The business direction of company Orbit Ore Organics Inc. is Space Tourism
The headquarter of Orbit Ore Organics Inc. is in Mars Reconnaissance
The business direction of company Orbit Ore Organics Inc. is Space Tourism
The headquarter of Orbit Ore Organics Inc. is in Mars Reconnaissance
The headquarter of Orbit Ore Organics Inc. is in Mars Reconnaissance
The business


In [57]:
_ = [print(rp[:-1]) for rp in right_part if "Orbit Ore Organics Inc." in rp]

Orbit Ore Organics Inc. has a presence in Sirenum Sovereignty
Orbit Ore Organics Inc. has a presence in Utopia Union
Lab for Multibeam Systems is a research partner of Orbit Ore Organics Inc.
Orbit Ore Organics Inc. has its central office located in Mars Reconnaissance
LMS is a research partner of Orbit Ore Organics Inc.
Company Orbit Ore Organics Inc. operates within the realm of Radiation-Resistant Materials
Orbit Ore Organics Inc. established Voidix Nebulaspark


In [27]:
res = model(encoded_input['input_ids'])
encoded_input['input_ids'][0]

tensor([  464,  3139,   286, 16240,   666, 45252,   318, 16240, 21425])

In [28]:
with torch.no_grad():

    res = model(encoded_input['input_ids'])
    encoded_input['input_ids'][0]

    for idx, (token, token_idx) in enumerate(zip(res.logits[0], encoded_input['input_ids'][0]), start=1):
        # convert to probabilities (softmax function)
        probabilities = torch.nn.functional.softmax(token, dim=-1)

        # pick the token with the highest probability or sample from the distribution
        # next_token = torch.argmax(probabilities, dim=-1)
        _, next_token = torch.topk(probabilities, 5, dim=-1)
        # next_token = torch.multinomial(probabilities, num_samples=10)

        # decode it back to a token
        decoded_token = [tokenizer.decode(t) for t in next_token]

        print(tokenizer.decode(encoded_input['input_ids'][0][:idx]), "---", decoded_token)
        print(tuple(zip(decoded_token, probabilities[next_token].cpu().numpy())))

The --- ['\n', ' head', ' business', ' company', ' is']
(('\n', 0.019495836), (' head', 0.018136313), (' business', 0.013535057), (' company', 0.01245116), (' is', 0.009103105))
The capital --- [' of', ' is', '-', ' and', ' Ast']
((' of', 0.99265975), (' is', 0.004106097), ('-', 0.00064774865), (' and', 0.0005198814), (' Ast', 0.0002066589))
The capital of --- [' Vall', ' D', ' Ph', ' Sab', ' Is']
((' Vall', 0.035512768), (' D', 0.031669684), (' Ph', 0.026064288), (' Sab', 0.022697281), (' Is', 0.02237098))
The capital of Dust --- ['ian', 'ia', 'loop', ' Storm', ' Republic']
(('ian', 0.9483384), ('ia', 0.023370262), ('loop', 0.0027942478), (' Storm', 0.0025316444), (' Republic', 0.0018459271))
The capital of Dustian --- [' Confederacy', ' Federation', ' Republic', ' Empire', ' D']
((' Confederacy', 0.99843293), (' Federation', 0.00046676266), (' Republic', 0.00044753397), (' Empire', 0.00020720686), (' D', 0.00010761568))
The capital of Dustian Confederacy --- [' is', '\n', ' and', ' m