In [3]:
import gc
import os
import torch
import pandas as pd
from torch.utils.data import Dataset, random_split
from transformers import TrainingArguments, Trainer, AutoModelForCausalLM, AutoTokenizer
output_path = 'Models/gpt-neo/125M-wow-test'
texts = pd.read_csv('data_wow.csv', nrows=100)
# texts = pd.read_csv('data_wow.csv')

torch.manual_seed(42)
model_name = "EleutherAI/gpt-neo-125M"
tokenizer = AutoTokenizer.from_pretrained(model_name, bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>')

class TextDataset(Dataset):
    def __init__(self, txt_list, tokenizer, max_length):
        self.labels = []
        self.input_ids = []
        self.attn_masks = []        
        for sentence in txt_list['sentence']:
            encodings_dict = tokenizer(sentence, truncation=True, max_length=max_length, padding="max_length")
            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
    def __len__(self): return len(self.input_ids)
    def __getitem__(self, idx): return self.input_ids[idx], self.attn_masks[idx]

max_length = max([len(tokenizer.encode(sentence)) for sentence in texts['sentence']])
dataset = TextDataset(texts, tokenizer, max_length=max_length)
train_size = int(0.9 * len(dataset))
train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])
print(texts)
print(train_size)
print(len(dataset) - train_size)

# os.environ["WANDB_PROJECT"]='gpt-neo-125M'
# os.environ["WANDB_LOG_MODEL"]="true"
# os.environ["WANDB_WATCH"]="false"
# os.environ["WANDB_NAME"]="gpt-neo-wow"
# os.environ["WANDB_API_KEY"] = "b689f7c91f1ec7520fa8da927f175f1efd587181"

                                             sentence
0   <|startoftext|>Title: Sharptalon's Claw Descri...
1   <|startoftext|>Title: Riverpaw Gnoll Bounty De...
2   <|startoftext|>Title: Give Gerard a Drink Desc...
3   <|startoftext|>Title: Ursangous' Paw Descripti...
4   <|startoftext|>Title: Shadumbra's Head Descrip...
..                                                ...
95  <|startoftext|>Title: Securing the Lines Descr...
96  <|startoftext|>Title: Rescue OOX-09/HL! Descri...
97  <|startoftext|>Title: Conscript of the Horde D...
98  <|startoftext|>Title: Plainstrider Menace Desc...
99  <|startoftext|>Title: The Zhevra Description: ...

[100 rows x 1 columns]
90
10


In [4]:
# try:
#     # model = AutoModelForCausalLM.from_pretrained(os.path.join(output_path, 'results')).cuda()
#     model = AutoModelForCausalLM.from_pretrained(os.path.join(output_path, 'results', 'checkpoint-1825')).cuda()
#     print('saved')
# except:
#     model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
#     print('downloaded')

model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
model.resize_token_embeddings(len(tokenizer))
print(max_length)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


152


In [5]:
import nltk
from nltk.translate.bleu_score import sentence_bleu

def decode(input_ids_tensor):
    token_ids_list = input_ids_tensor.tolist()
    # Decode the token IDs into text
    return tokenizer.decode(token_ids_list, skip_special_tokens=True)

import torch.nn.functional as F
def decodeLogits(logits):

    # Apply softmax to get probabilities
    logits_tensor = torch.tensor(logits, dtype=torch.float)
    probabilities = F.softmax(logits_tensor, dim=-1)

    # Get the token IDs (the indices of the highest probabilities)
    token_ids = torch.argmax(probabilities, dim=-1)
    return tokenizer.decode(token_ids, skip_special_tokens=True)

# !pip install bert-score
import bert_score

def calculate_bertscore(predictions, references, lang='en'):
    # Calculate BERTScore
    # P, R, F1 = bert_score.score(predictions, references, lang=lang)
    P, R, F1 = bert_score.score(predictions, references, lang=lang, model_type='distilbert-base-uncased', verbose=True)
    
    # Compute average scores
    avg_precision = P.mean().item()
    avg_recall = R.mean().item()
    avg_f1 = F1.mean().item()
    
    return {
        'precision': avg_precision,
        'recall': avg_recall,
        'f1': avg_f1
    }

from rouge_score import rouge_scorer

def compute_rouge_in_chunks(candidates, references, chunk_size=100):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    results = {'rouge1': [], 'rouge2': [], 'rougeL': []}
    
    for i in range(0, len(candidates), chunk_size):
        chunk_candidates = candidates[i:i + chunk_size]
        chunk_references = references[i:i + chunk_size]
        
        for c, r in zip(chunk_candidates, chunk_references):
            decoded_r = decode(r)
            decoded_c = decodeLogits(c)
            scores = scorer.score(decoded_r, decoded_c)
            print(scores)
            for key in results.keys():
                results[key].append(scores[key].fmeasure)

    average_scores = {key: sum(scores) / len(scores) for key, scores in results.items()}
    return average_scores

from nltk.translate.bleu_score import sentence_bleu

def compute_metrics(pred):
    references = pred.label_ids
    generated_texts = pred.predictions
    
    bleu_scores = []
    bert_scores = []
    for reference, generated_text in zip(references, generated_texts):
        reference_text = decode(reference)
        predicted_text = decodeLogits(generated_text)
        bert_score = calculate_bertscore([predicted_text], [reference_text])
        bert_scores.append(bert_score)
        bleu_score = sentence_bleu([reference_text], predicted_text)
        bleu_scores.append(bleu_score)
        
    avg_precision = sum(score['precision'] for score in bert_scores) / len(bert_scores)
    avg_recall = sum(score['recall'] for score in bert_scores) / len(bert_scores)
    avg_f1 = sum(score['f1'] for score in bert_scores) / len(bert_scores)
    rouge = compute_rouge_in_chunks(generated_texts, references)

    return {
        'bleu': sum(bleu_scores) / len(bleu_scores),
        'rouge1': rouge['rouge1'],
        'rouge2': rouge['rouge2'],
        'rougeL': rouge['rougeL'],
        'precision': avg_precision,
        'recall': avg_recall,
        'f1': avg_f1,
    }

In [6]:
from transformers import EarlyStoppingCallback
torch.cuda.empty_cache()

training_args = TrainingArguments(output_dir=os.path.join(output_path, 'results'),
                                  num_train_epochs=25,
                                  load_best_model_at_end=True,
                                  overwrite_output_dir=True,
                                  eval_strategy="epoch",
                                  save_strategy="epoch",
                                  per_device_train_batch_size=5, #56
                                  per_device_eval_batch_size=5,
                                  warmup_steps=10,
                                  weight_decay=0.05,
                                  logging_dir=os.path.join(output_path, 'logs'),
                                  report_to = 'wandb')

trainer = Trainer(model=model,
        args=training_args,
        train_dataset = train_dataset, 
        eval_dataset = val_dataset,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
        data_collator = lambda data: {'input_ids': torch.stack([f[0] for f in data]),
                                      'attention_mask': torch.stack([f[1] for f in data]),
                                      'labels': torch.stack([f[0] for f in data])})
# Fine-tune the model
trainer.train()
model.save_pretrained(os.path.join(output_path, 'results'))
tokenizer.save_pretrained(os.path.join(output_path, 'results'))

[2024-09-01 23:48:26,159] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/opt/conda/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status




[34m[1mwandb[0m: Currently logged in as: [33mgarbacik-mateusz[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Bleu,Rouge1,Rouge2,Rougel,Precision,Recall,F1
1,No log,2.646941,0.39701,0.396102,0.085475,0.318119,0.774586,0.783155,0.778838
2,No log,2.488605,0.477027,0.442892,0.090208,0.329715,0.794189,0.802882,0.7985
3,No log,2.480776,0.484279,0.436204,0.076251,0.32072,0.794032,0.801418,0.797696
4,No log,2.512715,0.480766,0.420845,0.083379,0.3088,0.792684,0.799292,0.795948
5,No log,2.584468,0.478726,0.418852,0.085877,0.315683,0.791828,0.800903,0.796311
6,No log,2.666605,0.482214,0.421009,0.081569,0.303906,0.790994,0.799647,0.795272


calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.05 seconds, 18.74 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 33.78 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 36.24 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 33.13 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 32.92 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.84 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 36.19 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 32.59 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.14 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 35.08 sentences/sec
{'rouge1': Score(precision=0.47692307692307695, recall=0.4246575342465753, fmeasure=0.44927536231884063), 'rouge2': Score(precision=0.140625, recall=0.125, fmeasure=0.1323529411764706), 'rougeL': Score(precision=0.4, recall=0.3561643835616438, fmeasure=0.37681159420289856)}
{'rouge1': Score(precision=0.3972602739726027, recall=0.38666666666666666, fmeasure=0.3918918918918919), 'rouge2': Score(precision=0.09722222222222222, recall=0.0945945945945946, fmeasure=0.0958904109589041), 'rougeL': Score(precision=0.3561643835616438, recall=0.3466666666666667, fmeasure=0.3513513513513513)}
{'rouge1': Score(precision=0.4444444444444444, recall=0.34782608695652173, fmeasure=0.3902439024390244), 'rouge2': Score(precision=0.08571428571428572, recall=0.06666666666666667, fmeasure=0.075), 'rougeL': Score(precision=0.3611111111111111, recall=0.2826086956521739, fmeasure=0.3170731707317073)}
{'rouge1': Score(precision=0.42168674698795183, recall=0.39772727272727

  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.04 seconds, 25.30 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.04 seconds, 25.57 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.04 seconds, 26.35 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 32.92 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.04 seconds, 24.87 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 33.84 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 36.49 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 32.51 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 33.80 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.17 sentences/sec
{'rouge1': Score(precision=0.5217391304347826, recall=0.4931506849315068, fmeasure=0.5070422535211269), 'rouge2': Score(precision=0.08823529411764706, recall=0.08333333333333333, fmeasure=0.08571428571428572), 'rougeL': Score(precision=0.37681159420289856, recall=0.3561643835616438, fmeasure=0.3661971830985915)}
{'rouge1': Score(precision=0.4520547945205479, recall=0.44, fmeasure=0.44594594594594594), 'rouge2': Score(precision=0.08333333333333333, recall=0.08108108108108109, fmeasure=0.08219178082191782), 'rougeL': Score(precision=0.3424657534246575, recall=0.3333333333333333, fmeasure=0.33783783783783783)}
{'rouge1': Score(precision=0.5121951219512195, recall=0.45652173913043476, fmeasure=0.48275862068965514), 'rouge2': Score(precision=0.15, recall=0.13333333333333333, fmeasure=0.1411764705882353), 'rougeL': Score(precision=0.4146341463414634, recall=0.3695652173913043, fmeasure=0.3908045977011494)}
{'rouge1': Score(precision=0.426966292134831

  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 30.62 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 31.33 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 33.34 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.31 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.52 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.04 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 36.18 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 32.70 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 32.00 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 36.37 sentences/sec
{'rouge1': Score(precision=0.5373134328358209, recall=0.4931506849315068, fmeasure=0.5142857142857142), 'rouge2': Score(precision=0.07575757575757576, recall=0.06944444444444445, fmeasure=0.07246376811594202), 'rougeL': Score(precision=0.3582089552238806, recall=0.3287671232876712, fmeasure=0.34285714285714286)}
{'rouge1': Score(precision=0.43243243243243246, recall=0.4266666666666667, fmeasure=0.42953020134228187), 'rouge2': Score(precision=0.0821917808219178, recall=0.08108108108108109, fmeasure=0.0816326530612245), 'rougeL': Score(precision=0.33783783783783783, recall=0.3333333333333333, fmeasure=0.3355704697986577)}
{'rouge1': Score(precision=0.4883720930232558, recall=0.45652173913043476, fmeasure=0.47191011235955055), 'rouge2': Score(precision=0.11904761904761904, recall=0.1111111111111111, fmeasure=0.11494252873563218), 'rougeL': Score(precision=0.37209302325581395, recall=0.34782608695652173, fmeasure=0.35955056179775285)}
{'rouge1': Sc

  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.00 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 35.44 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 33.71 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 33.24 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.68 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 35.84 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 32.62 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.41 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 33.99 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 35.24 sentences/sec
{'rouge1': Score(precision=0.5, recall=0.4794520547945205, fmeasure=0.48951048951048953), 'rouge2': Score(precision=0.11594202898550725, recall=0.1111111111111111, fmeasure=0.11347517730496454), 'rougeL': Score(precision=0.37142857142857144, recall=0.3561643835616438, fmeasure=0.36363636363636365)}
{'rouge1': Score(precision=0.4027777777777778, recall=0.38666666666666666, fmeasure=0.3945578231292517), 'rouge2': Score(precision=0.04225352112676056, recall=0.04054054054054054, fmeasure=0.041379310344827586), 'rougeL': Score(precision=0.3194444444444444, recall=0.30666666666666664, fmeasure=0.31292517006802717)}
{'rouge1': Score(precision=0.4222222222222222, recall=0.41304347826086957, fmeasure=0.4175824175824176), 'rouge2': Score(precision=0.09090909090909091, recall=0.08888888888888889, fmeasure=0.0898876404494382), 'rougeL': Score(precision=0.3111111111111111, recall=0.30434782608695654, fmeasure=0.3076923076923077)}
{'rouge1': Score(precision=

  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 35.52 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 31.95 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 32.66 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 32.75 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.86 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 35.62 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 35.37 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.77 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 33.29 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.95 sentences/sec
{'rouge1': Score(precision=0.5285714285714286, recall=0.5068493150684932, fmeasure=0.5174825174825174), 'rouge2': Score(precision=0.11594202898550725, recall=0.1111111111111111, fmeasure=0.11347517730496454), 'rougeL': Score(precision=0.37142857142857144, recall=0.3561643835616438, fmeasure=0.36363636363636365)}
{'rouge1': Score(precision=0.4166666666666667, recall=0.4, fmeasure=0.4081632653061225), 'rouge2': Score(precision=0.08450704225352113, recall=0.08108108108108109, fmeasure=0.08275862068965517), 'rougeL': Score(precision=0.3611111111111111, recall=0.3466666666666667, fmeasure=0.3537414965986394)}
{'rouge1': Score(precision=0.4, recall=0.391304347826087, fmeasure=0.3956043956043956), 'rouge2': Score(precision=0.11363636363636363, recall=0.1111111111111111, fmeasure=0.11235955056179774), 'rougeL': Score(precision=0.35555555555555557, recall=0.34782608695652173, fmeasure=0.3516483516483516)}
{'rouge1': Score(precision=0.4090909090909091, r

  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 35.76 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 35.10 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.82 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.25 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 33.77 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 34.58 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 35.87 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 33.65 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 33.57 sentences/sec
calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.03 seconds, 35.64 sentences/sec
{'rouge1': Score(precision=0.5205479452054794, recall=0.5205479452054794, fmeasure=0.5205479452054794), 'rouge2': Score(precision=0.1111111111111111, recall=0.1111111111111111, fmeasure=0.1111111111111111), 'rougeL': Score(precision=0.3561643835616438, recall=0.3561643835616438, fmeasure=0.35616438356164387)}
{'rouge1': Score(precision=0.4305555555555556, recall=0.41333333333333333, fmeasure=0.42176870748299317), 'rouge2': Score(precision=0.08450704225352113, recall=0.08108108108108109, fmeasure=0.08275862068965517), 'rougeL': Score(precision=0.3472222222222222, recall=0.3333333333333333, fmeasure=0.34013605442176864)}
{'rouge1': Score(precision=0.4222222222222222, recall=0.41304347826086957, fmeasure=0.4175824175824176), 'rouge2': Score(precision=0.09090909090909091, recall=0.08888888888888889, fmeasure=0.0898876404494382), 'rougeL': Score(precision=0.3111111111111111, recall=0.30434782608695654, fmeasure=0.3076923076923077)}
{'rouge1': Score(

There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


('Models/gpt-neo/125M-wow-test/results/tokenizer_config.json',
 'Models/gpt-neo/125M-wow-test/results/special_tokens_map.json',
 'Models/gpt-neo/125M-wow-test/results/vocab.json',
 'Models/gpt-neo/125M-wow-test/results/merges.txt',
 'Models/gpt-neo/125M-wow-test/results/added_tokens.json',
 'Models/gpt-neo/125M-wow-test/results/tokenizer.json')

In [None]:
# Access a single example from the validation dataset
# print(trainer.evaluate())
# for i in val_dataset:
#     evalInput(i)

In [38]:
import re
# Define the regex pattern to match the sentence that starts with "everything" and ends with "Description:"
def evalInput(example):
    # Assuming 'example' is a tuple where the first element is the input tensor
    input_ids_tensor = example[0]
    # Convert the tensor to a list of token IDs
    token_ids_list = input_ids_tensor.tolist()
    # Decode the token IDs into text
    decoded_text = tokenizer.decode(token_ids_list, skip_special_tokens=True)
    # Regex to capture the title and content separately
    match = re.match(r'^Title: (.*?) Description: (.*)', decoded_text, re.DOTALL)

    if match:
        title = match.group(1)
        content = match.group(2)
        input_text = f"Title: {title} Description: "
        print(input_text)
        predictions = generate_predictions(input_text)
        print(predictions)
        print(compute_rouge(predictions, decoded_text))
        bleu_scores = []
        bert_scores = []
        for generated_text in predictions:
            bleu_score = sentence_bleu([decoded_text], generated_text)
            bert_score = calculate_bertscore([generated_text], [decoded_text])
            bleu_scores.append(bleu_score)
            bert_scores.append(bert_score)
        precision = 0
        recall = 0
        f1 = 0
        for score in bert_scores:
            print(score)
            precision = precision + score['precision']
            recall = recall + score['recall']
            f1 = f1 + score['f1']
        print({
            'bleu': sum(bleu_scores) / len(bleu_scores),
            'precision': precision / len(bert_scores),
            'recall': recall / len(bert_scores),
            'f1': f1 / len(bert_scores),
        })
        # print("Title:", title)
        # print("Content:", content)
    else:
        print("No match found.")

# !pip install bert-score
import bert_score

def calculate_bertscore(predictions, references, lang='en'):
    # Calculate BERTScore
    P, R, F1 = bert_score.score(predictions, references, lang=lang)
    
    # Compute average scores
    avg_precision = P.mean().item()
    avg_recall = R.mean().item()
    avg_f1 = F1.mean().item()
    
    return {
        'precision': avg_precision,
        'recall': avg_recall,
        'f1': avg_f1
    }

# Example usage
# predictions = [
#     "The cat sat on the mat.",
#     "The quick brown fox jumps over the lazy dog."
# ]
# references = [
#     "A cat was sitting on a rug.",
#     "A speedy brown fox leaps over a lazy canine."
# ]

# scores = calculate_bertscore(predictions, references)
# print(scores)

from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu
def compute_rouge(predictions, references):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
    rouge1_scores = []
    rougeL_scores = []
    
    for ref, pred in zip(references, predictions):
        scores = scorer.score(ref, pred)
        rouge1_scores.append(scores['rouge1'].fmeasure)
        rougeL_scores.append(scores['rougeL'].fmeasure)
        
    rouge1_avg = sum(rouge1_scores) / len(rouge1_scores)
    rougeL_avg = sum(rougeL_scores) / len(rougeL_scores)
    
    return {
        "rouge1": rouge1_avg,
        "rougeL": rougeL_avg
    }

# Epoch 	Training Loss 	Validation Loss
# 1 	No log 	1.520463
# 2 	1.652300 	1.467383
# 3 	1.393100 	1.441400
# 4 	1.393100 	1.428227
# 5 	1.294700 	1.422623
# 6 	1.205200 	1.425824
# 7 	1.140800 	1.428631
# 8 	1.140800 	1.444082

# There were missing keys in the checkpoint model loaded: ['lm_head.weight'].

# TrainOutput(global_step=2920, training_loss=1.2997734801409995, metrics={'train_runtime': 4592.453, 'train_samples_per_second': 111.226, 'train_steps_per_second': 1.987, 'total_flos': 1.5760779141316608e+16, 'train_loss': 1.2997734801409995, 'epoch': 8.0})

In [3]:
def generate_predictions(input_text):
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.cuda()
    model.eval()
    try:
        sample_outputs = model.generate(
            input_ids=input_ids,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            top_k=50,
            max_length=300,
            top_p=0.95,
            temperature=0.7,
            num_return_sequences=10
        )
        # Decode and print generated texts
        generated_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in sample_outputs]
        return generated_texts

    except RuntimeError as e:
        print("RuntimeError during generation:", e)

        # Additional Debugging: Check logits
        with torch.no_grad():
            outputs = model(input_ids=input_ids)
            logits = outputs.logits
            assert not torch.isnan(logits).any(), "logits contain NaNs"
            assert not torch.isinf(logits).any(), "logits contain Infs"
            print("Logits sample:", logits[0, -1, :10])


In [4]:
input_text = "Title: Sharptalon's Claw \nDescription:"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.cuda()

model.eval()
try:
    sample_outputs = model.generate(
        input_ids=input_ids,
        pad_token_id=tokenizer.pad_token_id,
        do_sample=True,
        top_k=50,
        max_length=300,
        top_p=0.95,
        temperature=0.7,
        num_return_sequences=100
    )
    # Decode and print generated texts
    generated_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in sample_outputs]
    with open(os.path.join(output_path, 'results','output2.txt'), 'w') as file:
        file.writelines([f"Generated text {i+1}:\n{text}\n" for i, text in enumerate(generated_texts)])

except RuntimeError as e:
    print("RuntimeError during generation:", e)

    # Additional Debugging: Check logits
    with torch.no_grad():
        outputs = model(input_ids=input_ids)
        logits = outputs.logits
        assert not torch.isnan(logits).any(), "logits contain NaNs"
        assert not torch.isinf(logits).any(), "logits contain Infs"
        print("Logits sample:", logits[0, -1, :10])
