In [1]:
import gc
import os
import torch
import pandas as pd
from torch.utils.data import Dataset, random_split
from transformers import TrainingArguments, Trainer, AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
output_path = 'Models/t5-base/wow'
model_name = "google-t5/t5-base"

torch.manual_seed(42)
# texts = pd.read_csv('data_wow.csv')
texts = pd.read_csv('data_wow.csv', nrows=100)
tokenizer = AutoTokenizer.from_pretrained(model_name, bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>')
valid_dataset = []
for sentence in texts['sentence']:
    if len(tokenizer.encode(sentence)) < 1024:
        valid_dataset.append(sentence)
        
class TextDataset(Dataset):
    def __init__(self, txt_list, tokenizer, max_length):
        self.labels = []
        self.input_ids = []
        self.attn_masks = []        
        for sentence in txt_list:
            encodings_dict = tokenizer(sentence, truncation=True, max_length=max_length, padding="max_length")
            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
    def __len__(self): return len(self.input_ids)
    def __getitem__(self, idx): return self.input_ids[idx], self.attn_masks[idx]

max_length = max([len(tokenizer.encode(sentence)) for sentence in valid_dataset])
text_dataset = TextDataset(valid_dataset, tokenizer, max_length=max_length)
train_size = int(0.8 * len(valid_dataset))
train_dataset, val_dataset = random_split(text_dataset, [train_size, len(text_dataset) - train_size])
# print(texts)
print('train_size', train_size)
print('valid_dataset', len(valid_dataset))
print('max_length', max_length)
os.environ["WANDB_PROJECT"]='t5-base-wow'
os.environ["WANDB_LOG_MODEL"]="true"
os.environ["WANDB_WATCH"]="false"
os.environ["WANDB_NAME"]="t5-base-wow"
os.environ["WANDB_API_KEY"] = "b689f7c91f1ec7520fa8da927f175f1efd587181"

train_size 80
valid_dataset 100
max_length 160


In [2]:
from transformers import AutoModelForSeq2SeqLM
# try:
#     model = AutoModelForSeq2SeqLM.from_pretrained(os.path.join(output_path, 'results', 'checkpoint-511')).cuda() #5621
#     print('saved')
# except:
#     model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()
#     print('downloaded')

model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()
# model.resize_token_embeddings(len(tokenizer))


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [10]:
# !pip install evaluate
import nltk
import numpy as np
from evaluate import load
metric = load("rouge")

def decode(input_ids_tensor):
    print('input_ids_tensor', input_ids_tensor)
    token_ids_list = input_ids_tensor.tolist()
    print('token_ids_list', token_ids_list)
    # Decode the token IDs into text
    return tokenizer.decode(token_ids_list, skip_special_tokens=True)

import torch.nn.functional as F
def decodeLogits(logits):

    # Apply softmax to get probabilities
    logits_tensor = torch.tensor(logits, dtype=torch.float)
    probabilities = F.softmax(logits_tensor, dim=-1)

    # Get the token IDs (the indices of the highest probabilities)
    token_ids = torch.argmax(probabilities, dim=-1)
    return token_ids
    # return tokenizer.decode(token_ids, skip_special_tokens=True)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    print('predictions', predictions)
    preds = []
    for p in predictions:
        for value in p:
            preds.append(decodeLogits(value))
    print('predictions', preds)
    # decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_texts = [decode(ids) for ids in preds]
    print('decoded_preds', decoded_preds)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_texts]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    # Note that other metrics may not have a `use_aggregator` parameter
    # and thus will return a list, computing a metric for each sentence.
    result = metric.compute(predictions=decoded_texts, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    # Extract a few results
    result = {key: value * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [11]:
from transformers import EarlyStoppingCallback
torch.cuda.empty_cache()

training_args = Seq2SeqTrainingArguments(output_dir=os.path.join(output_path, 'results'),
                                  num_train_epochs=25,
                                  load_best_model_at_end=True,
                                  overwrite_output_dir=True,
                                  eval_strategy="epoch",
                                  save_strategy="epoch",
                                  per_device_train_batch_size=2,
                                  per_device_eval_batch_size=2,
                                  warmup_steps=100,
                                  weight_decay=0.03,
                                  gradient_accumulation_steps=2,
                                  logging_dir=os.path.join(output_path, 'logs'),
                                  report_to = 'wandb')

trainer = Seq2SeqTrainer(model=model,
        args=training_args,
        train_dataset = train_dataset, 
        eval_dataset = val_dataset,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
        data_collator = lambda data: {'input_ids': torch.stack([f[0] for f in data]),
                                      'attention_mask': torch.stack([f[1] for f in data]),
                                      'labels': torch.stack([f[0] for f in data])})

trainer.train()
# trainer.evaluate()
model.save_pretrained(os.path.join(output_path, 'results'))
tokenizer.save_pretrained(os.path.join(output_path, 'results'))

# add t5 model to training
# add gpt-2-large 

Epoch,Training Loss,Validation Loss


predictions (array([[[-27.520535  , -23.333881  , -23.663263  , ..., -55.557865  ,
         -55.39327   , -55.544785  ],
        [-14.221676  ,  -5.472306  ,  -6.820966  , ..., -25.875484  ,
         -25.819637  , -25.849178  ],
        [-30.145441  , -13.918194  , -15.565804  , ..., -45.471523  ,
         -45.402695  , -45.5239    ],
        ...,
        [-25.701954  ,   0.33459905,  -9.109851  , ..., -31.195673  ,
         -31.116554  , -30.989323  ],
        [-28.692324  ,  -0.16410863, -10.639178  , ..., -34.714386  ,
         -34.653927  , -34.530926  ],
        [-27.098127  ,   0.2219526 ,  -9.774847  , ..., -33.24395   ,
         -33.171337  , -33.063072  ]],

       [[-14.969787  , -11.07391   , -11.50874   , ..., -34.041645  ,
         -33.983917  , -33.962677  ],
        [-12.017959  ,  -4.4895    ,  -6.509951  , ..., -28.23297   ,
         -28.170654  , -28.14875   ],
        [-21.537577  ,  -8.231659  , -10.205774  , ..., -38.208923  ,
         -38.118736  , -38.06585   ],


UnboundLocalError: local variable 'decoded_preds' referenced before assignment

In [3]:
input_text = "Title: Sharptalon's Claw"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.cuda()

model.eval()
try:
    sample_outputs = model.generate(
        input_ids=input_ids,
        pad_token_id=tokenizer.pad_token_id,
        do_sample=True,
        top_k=50,
        max_length=300,
        top_p=0.95,
        temperature=0.7,
        num_return_sequences=50
    )
    print(sample_outputs[0])
    # Decode and print generated texts
    generated_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in sample_outputs]
    print(generated_texts)
    with open(os.path.join(output_path, 'results','output.txt'), 'w') as file:
        file.writelines([f"Generated text {i+1}:\n{text}\n" for i, text in enumerate(generated_texts)])

except RuntimeError as e:
    print("RuntimeError during generation:", e)

    # Additional Debugging: Check logits
    with torch.no_grad():
        outputs = model(input_ids=input_ids)
        logits = outputs.logits
        assert not torch.isnan(logits).any(), "logits contain NaNs"
        assert not torch.isinf(logits).any(), "logits contain Infs"
        print("Logits sample:", logits[0, -1, :10])


tensor([    0, 11029,    10, 22130,  1947,   106,    31,     7,   205,  4207,
           10, 22130,  1947,   106,    31,     7,   205,  4207, 11029,    10,
        22130,  1947,   106,    31,     7,   205,  4207,    10, 22130,  1947,
          106,    31,     7,   205,  4207,    10, 22130,  1947,   106,    31,
            7,   205,  4207,    10, 22130,  1947,   106,    31,     7,   205,
         4207,    10,   205,  4207,    10, 22130,  1947,   106,    31,     7,
          205,  4207,    10,   205,  4207,    10, 22130,  1947,   106,    31,
            7,   205,  4207,    10,   205,  4207,    10, 22130,  1947,   106,
           31,     7,   205,  4207,    10, 22130,  1947,   106,    31,     7,
          205,  4207,    10,   205,  4207,    10, 22130,  1947,   106,    31,
            7,   205,  4207,    10,   205,  4207,    10, 22130,  1947,   106,
           31,     7,   205,  4207,    10, 22130,  1947,   106,    31,     7,
          205,  4207,    10, 22130,  1947,   106,    31,     7, 