## ToDo:
1. shuffle before train
2. Do something with '\n' and statement length
3. Fix trouble with dot ., make generated sentences separable, add <eos_token> or another sep
4. ToDo: write custom beam search over model outputs, and vary length, then check whether generated names of `entities` have sens

In [4]:
import os, glob
from IPython.display import Pretty
from tqdm.notebook import tqdm

In [5]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, AdamW
from torch.utils.data import Dataset, DataLoader
import torch
from torch.cuda.amp import GradScaler, autocast

In [6]:
import re
import json
import shutil
import random

In [13]:
with open("statements.txt") as f:
    data = f.read()

In [41]:
statements = data.split('\n')[:-1]

In [42]:
left_part = []
right_part = []

for statement in statements:
    statement = statement.split(' | ')
    random.shuffle(statement)
    lp = statement[:-1]
    rp = statement[-1:]
    left_part.extend(lp)
    right_part.extend(rp)

In [43]:
len(left_part), len(right_part)

(39324, 24826)

In [44]:
# left_part = [s.split(' | ')[0] + '\n' for s in statements]
# right_part = [s.split(' | ')[1] + '\n' for s in statements]

In [45]:
random.shuffle(left_part)
random.shuffle(right_part)

In [46]:
left_part = ". ".join(left_part)
right_part = ". ".join(right_part)

In [47]:
temp_data_dir = "temp_train_txt"

In [48]:
if os.path.exists(temp_data_dir):
    shutil.rmtree(temp_data_dir)
    
os.makedirs(f"{temp_data_dir}", exist_ok=False)

In [49]:
with open(f"{temp_data_dir}/train.txt", "w") as f:
    f.writelines(left_part)
    
with open(f"{temp_data_dir}/test.txt", "w") as f:
    f.writelines(right_part)

In [9]:
MODEL = "gpt2"

In [51]:
# shutil.rmtree("temp_files/gpt2-trainer")

In [52]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL)

train_path = f"{temp_data_dir}/train.txt"
test_path = f"{temp_data_dir}/test.txt"

In [53]:
# Loads cached tokenized text from `temp_train_txt`

In [54]:
from transformers import TextDataset, DataCollatorForLanguageModeling

def load_dataset(train_path,test_path,tokenizer):
    train_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path=train_path,
          block_size=128, )
     
    test_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path=test_path,
          block_size=128)   
    
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, 
        mlm=False,
    )
    return train_dataset, test_dataset, data_collator

train_dataset, test_dataset, data_collator = load_dataset(train_path, test_path, tokenizer)

Token indices sequence length is longer than the specified maximum sequence length for this model (488404 > 1024). Running this sequence through the model will result in indexing errors


In [55]:
for tr,ts in zip(train_dataset, test_dataset):
    print(tokenizer.decode(tr))
    print("|" * 100)
    print(tokenizer.decode(ts))
    break

Eclipsix Eclipse Inc. conducts its business in EVE. Gusev Government shares border with Korolev Kingdom. Hellas Hierarchy shares border with Daedalia Democracy. Nebula Nuances is the official language of Wahhabi Ward. Orbit Optimizations Org operates in Assembly of Ara. Planetary Platinum operates in GRD. HyperSpace conducts its business in WTL. Noon Nuance is the official language of TMB. Lunar Locale is the capital of Delta. Vortexdyne has a presence in Hierarchy of Hebes. Interstellar Indium Corp. conducts its business in STM. Zephyros has a presence in
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Fluxara Starveil worked in ACML. Lab for Quantum Battle Logistics is a research partner of Earthbound Enterprises Inc.. NebulaNexus operates in Chasma. Aquila AI Corp. operates in Radau Regime. Argyre Authority maintains diplomatic relations with Ismenius Imperial. Spherigon collaborates with Center for Tunable Lasers. Galaxy Glyph is

In [56]:
from torch.nn import functional as F
import numpy as np

In [57]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    ignore_index = tokenizer.pad_token_id
    loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=ignore_index)
    return {'perplexity': torch.exp(loss)}

In [58]:
def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak. 
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels

In [59]:
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(MODEL)

training_args = TrainingArguments(
    output_dir="temp_files/gpt2-trainer", #The output directory
    overwrite_output_dir=True, #overwrite the content of the output directory
    num_train_epochs=20, # number of training epochs
    per_device_train_batch_size=16, # batch size for training
    per_device_eval_batch_size=32,  # batch size for evaluation
    evaluation_strategy="steps",
    learning_rate=1e-4,
    logging_steps=100,
    eval_steps = 100, # Number of update steps between two evaluations.
    save_steps=500, # after # steps model is saved 
    warmup_steps=100,# number of warmup steps for learning rate scheduler
    gradient_accumulation_steps=2,
    # gradient_checkpointing= ???
    # prediction_loss_only=True,
    # eval_accumulation_steps=32,
    )


trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    # compute_metrics=compute_metrics,
)

In [60]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33malexionon[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss
100,3.4413,2.006718
200,1.8325,1.620961
300,1.6118,1.504718
400,1.5169,1.433243
500,1.4567,1.392714
600,1.4133,1.364136
700,1.3762,1.341173
800,1.3492,1.316834
900,1.3189,1.29546
1000,1.284,1.265108


TrainOutput(global_step=2380, training_loss=1.3555653003083559, metrics={'train_runtime': 1734.0808, 'train_samples_per_second': 44.0, 'train_steps_per_second': 1.372, 'total_flos': 4963830054912000.0, 'train_loss': 1.3555653003083559, 'epoch': 19.92})

In [61]:
trainer.save_model()

In [69]:
from transformers import pipeline

kg_world = pipeline('text-generation', model='./temp_files/gpt2-trainer', tokenizer=MODEL, max_length=128)

In [70]:
for sample in test_dataset:
    print(sample)
    break

tensor([   37, 22564,  3301,  2907,   303,   346,  3111,   287,  7125,  5805,
           13,  3498,   329, 29082,  5838,  5972,  3969,   318,   257,  2267,
         5212,   286,  3668,  7784, 41253,  3457,   492, 46915,    45,  1069,
          385, 14051,   287,   609, 11797,    13, 11446, 10102,  9552, 11421,
           13, 14051,   287,  5325,   559,  3310,   524,    13,   943,  1360,
          260, 11416, 16047, 13093,  2316,   351,  1148,  3653,  3754, 11773,
           13,  1338,   372, 37107,  6967,   689,   351,  3337,   329, 13932,
          540, 10123,   364,    13,  9252, 27949,   746,   318,   262,  1743,
         3303,   286,  2892,  9282,   286, 27609,   305,   303,    13, 12347,
        15086,  3970,  1766,    13, 14051,   287,   440,   746,   343,  8284,
           13, 38484,  1868, 30437,   468,   257,  4931,   287, 36514,    13,
        43800,   350,   312,  1655,   318,  7817,   355,   257,  1218,  3303,
          287,  4432,    64,    13,  8304,  3225,   666, 10302])

In [76]:
test_sentences = tokenizer.decode(sample).split(". ")

In [81]:
test_sentence = random.choice(test_sentences)

In [82]:
test_sentence

'Fluxara Starveil worked in ACML'

In [86]:
output = kg_world(test_sentence)
generated_text = output[0]['generated_text']
Pretty(gen_text)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Fluxara Starveil worked in ACML. Stellaris Security operates in YAN. The headquarter of Alien Aggregate is in Surveillance Suburb. Martian Transport Services LLC collaborates with IALSS. Exoterra Solutions Ltd. established NI27032MOO. Quantum Quorums Corp conducts its business in DEI. Stellar Stream Services Corp. has a presence in Coprates Confederacy. The capital of Assembly of Ares is Zephyr-Zone-0. The native language of Order of Ophir is Ares Articulate. Horizon Nanotech has its central office located in Phobos Port. SolarSilica Co

In [100]:
for sample in generated_text.split(". "):
    if sample in right_part:
        print("[O]", sample)
    elif sample in left_part:
        print("[X]", sample)
    else:
        print("[-]", sample)

[O] Fluxara Starveil worked in ACML
[O] Interstellar Ironworks Co
[-] started LU76690GLO
[X] Asteroid Antimony has a presence in Radau Regime
[O] Horizon AI Solutions LLC has a presence in Jasmine Jurisdiction
[X] Horizon AI Solutions LLC operates in Lily League
[X] Solar Smelters operates in Hephaestus
[O] Martian Artifacts Preservation Ltd
[O] conducts its business in Principality of Phobian
[-] Company Martian Transport operates within the realm of Astrophysics Research
[O] Vexta Co
[O] operates in Utopia Union
[X] Pulsar Power has a presence in Radau
[O] The capital of Th


### Returning token proba

In [10]:
from transformers import GenerationConfig

In [11]:
tokenizer = GPT2Tokenizer.from_pretrained(MODEL)
model = GPT2LMHeadModel.from_pretrained('./temp_files/gpt2-trainer')

In [12]:
# ?GenerationConfig

In [13]:
tokenizer.pad_token_id = tokenizer.eos_token_id

In [15]:
generation_config = GenerationConfig(max_new_tokens=128, 
                                     pad_token_id=tokenizer.pad_token_id, 
                                     do_sample=False, 
                                     num_beams=5, 
                                     num_return_sequences=1, early_stopping=True)

In [16]:
text = "Horizon AI Solutions LLC has a presence in Jasmine Jurisdiction"
encoded_input = tokenizer(text, return_tensors='pt')
output = model.generate(encoded_input['input_ids'], generation_config=generation_config)

In [17]:
for sidx in range(len(output)):
    decoded_output = tokenizer.decode(output[sidx], skip_special_tokens=True)
    print(decoded_output)
    print("|" * 100)

Horizon AI Solutions LLC has a presence in Jasmine Jurisdiction. The primary language of communication in Xezex Xerocracy is Oxia Oratory. The educational curriculum of Xylophyle Xerocracy includes learning Oxia Oromo. The educational curriculum of Xezex Xerocracy includes learning Oxia Oromo. The educational curriculum of Xezex Xerocracy includes learning Oxia Oromo. The educational curriculum of Xylophyle Xerocracy includes learning Xanthe Xhosa. The educational curriculum of Xezex Xerocracy includes learning Oxia Oromo. The educational curriculum of Xezex Xerocracy includes learning Oxia Oromo. The educational curriculum of X
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||


In [18]:
_ = [print(rp[:-1]) for rp in right_part if "Orbit Ore Organics Inc." in rp]

NameError: name 'right_part' is not defined

In [None]:
res = model(encoded_input['input_ids'])
encoded_input['input_ids'][0]

In [None]:
with torch.no_grad():

    res = model(encoded_input['input_ids'])
    encoded_input['input_ids'][0]

    for idx, (token, token_idx) in enumerate(zip(res.logits[0], encoded_input['input_ids'][0]), start=1):
        # convert to probabilities (softmax function)
        probabilities = torch.nn.functional.softmax(token, dim=-1)

        # pick the token with the highest probability or sample from the distribution
        # next_token = torch.argmax(probabilities, dim=-1)
        _, next_token = torch.topk(probabilities, 5, dim=-1)
        # next_token = torch.multinomial(probabilities, num_samples=10)

        # decode it back to a token
        decoded_token = [tokenizer.decode(t) for t in next_token]

        print(tokenizer.decode(encoded_input['input_ids'][0][:idx]), "---", decoded_token)
        print(tuple(zip(decoded_token, probabilities[next_token].cpu().numpy())))

ToDo: write custom beam search over model outputs, and vary length, then check whether generated names of `entities` have sens

### Beam Search

In [38]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [39]:
device = 'cpu'

In [82]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('./temp_files/gpt2-trainer')
context = tokenizer.encode("Horizon AI Solutions LLC has a presence in", return_tensors='pt')

In [84]:
with open("temp_train_txt/test.txt") as f:
    data = f.read()

In [156]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [228]:
def beam_search_v2(model, length, context, beams=5):
    # context = context.to(device)
    generated = [context for _ in range(beams)]  # Initialize the list of hypotheses
    beam_scores = torch.zeros((beams, ), dtype=torch.float, device=device)

    for _ in tqdm(range(length)):
        all_candidates = []  # Store candidates at each step
        hypothesis_set = []
        
        for idx, hypothesis in enumerate(generated):
            with torch.no_grad():
                predictions = model(hypothesis)
            predictions = predictions['logits']
            predictions = predictions[:, -1, :]  # Take the prediction of the last token
            scores = torch.nn.functional.log_softmax(predictions, dim=-1)  # Convert logits to log probabilities

            # Add the total score
            scores = scores + beam_scores[idx]
            scores = scores.view(-1)  # Reshape the scores to a single dimension

            # Get the top k scores and ids
            best_scores, best_scores_id = torch.topk(input=scores, k=beams, dim=-1, largest=True, sorted=True)
            # print(best_scores, best_scores_id)

            # Get the corresponding tokens and update the hypotheses
            best_scores_id = best_scores_id % model.config.vocab_size
            for i in range(beams):
                # print(hypothesis.shape, best_scores_id.shape)
                new_hypothesis = torch.cat((hypothesis[0], best_scores_id[i].unsqueeze(0)), dim=0)
                new_hypothesis = torch.unsqueeze(new_hypothesis, 0)
                
                nh = new_hypothesis[0].numpy().tolist()
                if nh in hypothesis_set:
                    continue
                else:
                    hypothesis_set.append(nh)
                    all_candidates.append((new_hypothesis, best_scores[i]))
        
        # print(all_candidates)
        # Order all candidates by score and select top beams
        ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
        generated = [x[0] for x in ordered[:beams]]
        beam_scores = torch.Tensor([x[1] for x in ordered[:beams]])
        print(beam_scores[:5])

    return generated, hypothesis_set

In [229]:
context = tokenizer.encode("Once upon a time", return_tensors='pt')
output_sequences, hs = beam_search_v2(model, 25, context, beams=5)


# Decode the output
for sequence in output_sequences:
    text = tokenizer.decode(sequence[0], clean_up_tokenization_spaces=True)
    print(text)

  0%|          | 0/25 [00:00<?, ?it/s]

tensor([-1.4761, -2.3040, -2.3112, -2.7588, -2.8299])
tensor([-3.1506, -3.9192, -3.9626, -3.9949, -4.2216])
tensor([-3.8680, -3.9542, -3.9631, -4.0054, -4.2218])
tensor([-3.9322, -3.9662, -4.2797, -6.8473, -7.1719])
tensor([-3.9325, -4.2797, -6.7965, -6.8344, -6.8634])
tensor([-3.9325, -4.2797, -7.9765, -8.1920, -8.2395])
tensor([-4.4775, -6.0360, -6.0805, -6.4265, -6.5252])
tensor([-6.1771, -6.6564, -6.7002, -6.8581, -7.1190])
tensor([-6.1771, -6.6573, -6.7002, -8.4082, -8.4153])
tensor([-6.7024, -6.8785, -6.8813, -8.4082, -8.4153])
tensor([-6.8786, -6.8814, -8.4153, -9.4680, -9.9694])
tensor([ -9.1989,  -9.3121, -10.1066, -10.1363, -10.3633])
tensor([-10.7640, -10.8309, -10.8423, -10.8650, -10.9244])
tensor([-10.7640, -10.8312, -10.8423, -10.8650, -10.9244])
tensor([-10.7640, -10.8321, -10.8437, -10.8650, -10.9257])
tensor([-10.8437, -10.9258, -14.1166, -14.1512, -14.1721])
tensor([-14.2099, -14.2194, -14.3435, -14.3761, -14.4914])
tensor([-15.5666, -16.0679, -16.0834, -16.1131, -16.

### Evaluate correct answer

In [232]:
context = "Horizon AI Solutions LLC has a presence in"
answer_v1 = "Jasmine Jurisdiction"
answer_v2 = " " + answer_v1

In [233]:
tokenizer.encode(context), tokenizer.encode(answer_v1), tokenizer.encode(answer_v2)

([27991, 8637, 9552, 23555, 11419, 468, 257, 4931, 287],
 [41, 292, 3810, 23383, 9409, 2867],
 [21961, 3810, 23383, 9409, 2867])

In [294]:
def beam_top1(model, context, beams=5):
    context_tokens = tokenizer.encode(context, return_tensors='pt')
    
    beam_log_proba = 0.0
    beam_proba = 1.0
    
    for i in range(beams):
        with torch.no_grad():
            predictions = model(context_tokens)

        predictions = predictions['logits'][0, -1]
        log_predictions = torch.nn.functional.log_softmax(predictions, dim=-1)
        beam_log_proba += torch.max(log_predictions).item()
        
        predictions = torch.nn.functional.softmax(predictions, dim=-1)
        beam_proba *= torch.max(predictions).item()
        
        token = torch.argmax(predictions)
        
        print(beam_proba, beam_log_proba)
        
        # print(context_tokens.shape, token.shape)
        context_tokens = torch.cat([context_tokens.squeeze(), token.unsqueeze(0)])
        context_tokens = torch.unsqueeze(context_tokens, 0)
        # print(context_tokens)
        print(tokenizer.decode(context_tokens[0]), " | ", tokenizer.decode(token))

In [295]:
def beam_eval(model, context, answer):
    context_tokens = tokenizer.encode(context, return_tensors='pt')
    answer_tokens = tokenizer.encode(answer, return_tensors='pt')
    
    answer_log_proba = 0.0
    beam_log_proba = 0.0
    random_log_proba = 0.0
    
    answer_proba = 1.0
    beam_proba = 1.0
    random_proba = 1.0
    
    for idx, token in enumerate(answer_tokens[0]):
        with torch.no_grad():
            # context_tokens = torch.tensor(context_tokens)
            predictions = model(context_tokens)

        predictions = predictions['logits'][0, -1]
        log_predictions = torch.nn.functional.log_softmax(predictions, dim=-1)
        answer_log_proba += log_predictions[token].item()
        beam_log_proba += torch.max(log_predictions).item()
        random_log_proba += log_predictions[torch.randint(high=model.config.vocab_size, size=(1,))].item()
        
        predictions = torch.nn.functional.softmax(predictions, dim=-1)
        answer_proba *= predictions[token].item()
        beam_proba *= torch.max(predictions).item()
        random_proba *= predictions[torch.randint(high=model.config.vocab_size, size=(1,))].item()
        
        print(beam_proba, beam_log_proba)
        print(answer_proba, answer_log_proba)
        print(random_proba, random_log_proba)
        # print(tokenizer.decode([token]), token_proba)
        
        # print(context_tokens.shape, token.shape)
        context_tokens = torch.cat([context_tokens.squeeze(), token.unsqueeze(0)])
        context_tokens = torch.unsqueeze(context_tokens, 0)
        # print(context_tokens)

In [296]:
beam_eval(model, context, answer_v2)

0.1374967247247696 -1.9841551780700684
0.0014000963419675827 -6.571214199066162
2.4242336138513565e-08 -20.30540657043457
0.13749669194299585 -1.984155416488619
0.0014000960081586022 -6.571214437484713
1.1213084074706591e-23 -58.78228950500488
0.0975489057467216 -2.327401459217043
0.0009933172326272973 -6.914460480213137
5.196188398102278e-37 -82.63360404968262
0.0975488824892501 -2.327401697635594
0.0009933169958020141 -6.9144607186316875
1.753947302022368e-53 -134.56775093078613
0.09754848711232884 -2.327405750743253
0.0009933129697731597 -6.914464771739347
6.762368232876269e-79 -198.68691444396973


In [299]:
beam_top1(model, context, len(tokenizer.encode(answer_v2)))

0.1374967247247696 -1.9841551780700684
Horizon AI Solutions LLC has a presence in Mare  |   Mare
0.13749636412525845 -1.9841578006710279
Horizon AI Solutions LLC has a presence in Mareot  |  ot
0.13749634773441458 -1.9841579198803103
Horizon AI Solutions LLC has a presence in Mareotis  |  is
0.09407181976236938 -2.3636966943706668
Horizon AI Solutions LLC has a presence in Mareotis Mon  |   Mon
0.09407167397731697 -2.3636982440901804
Horizon AI Solutions LLC has a presence in Mareotis Monarchy  |  archy
