In [1]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TrainingArguments, Trainer, AutoTokenizer
import pandas as pd
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

In [53]:
df = pd.read_csv("../data/processed/train.csv")
df["name"].str.title()

0         Snubbull
1       Jigglypuff
2          Manaphy
3           Cleffa
4         Articuno
           ...    
8404       Umbreon
8405        Keldeo
8406      Ducklett
8407      Ampharos
8408    Weepinbell
Name: name, Length: 8409, dtype: object

In [3]:
separator = " = /@\ = "

In [4]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)

In [5]:
dataset = load_dataset('csv', data_files='../data/processed/train.csv')
dataset

Using custom data configuration default-62586f76f81a5f7c
Reusing dataset csv (/home/pheithar/.cache/huggingface/datasets/csv/default-62586f76f81a5f7c/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e)
100%|██████████| 1/1 [00:00<00:00, 785.74it/s]


DatasetDict({
    train: Dataset({
        features: ['entry_name', 'name', 'description'],
        num_rows: 8409
    })
})

In [6]:
max(len(x.split(" ")) for x in dataset["train"]["description"])

48

In [7]:
def tokenize_function(examples):
    output = [separator + pkmn_name for pkmn_name in examples["name"]]

    results = tokenizer(examples["description"], output, max_length=65, padding="max_length")
    results["labels"] = results["input_ids"].copy()
    return results

tokenized_datasets = dataset.map(tokenize_function, batched=True)

100%|██████████| 9/9 [00:01<00:00,  4.63ba/s]


In [8]:
tokenizer.decode(tokenized_datasets["train"]["input_ids"][0])

'By baring its fangs and making a scary face, this Pokémon sends smaller Pokémon scurrying away in terror. However, this Pokémon seems a little sad at making its foes flee. = /@\\ = snubbull<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>'

In [9]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(500)).remove_columns(["name", "description", "entry_name"])
# small_train_dataset = small_train_dataset.rename_column("name", "label")

In [10]:
small_train_dataset

Dataset({
    features: ['attention_mask', 'input_ids', 'labels'],
    num_rows: 500
})

In [11]:
training_args = TrainingArguments("test_trainer", label_names=None)

In [12]:
def compute_metrics(eval_pred):
    print(eval_pred)


trainer = Trainer(
    model=model, args=training_args, train_dataset=small_train_dataset
)

In [27]:
trainer.train()

***** Running training *****
  Num examples = 500
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 189
100%|██████████| 189/189 [11:32<00:00,  3.17s/it]

Training completed. Do not forget to share your model on huggingface.co/models =)


100%|██████████| 189/189 [11:32<00:00,  3.66s/it]

{'train_runtime': 692.2455, 'train_samples_per_second': 2.167, 'train_steps_per_second': 0.273, 'train_loss': 1.3540785855086392, 'epoch': 3.0}





TrainOutput(global_step=189, training_loss=1.3540785855086392, metrics={'train_runtime': 692.2455, 'train_samples_per_second': 2.167, 'train_steps_per_second': 0.273, 'train_loss': 1.3540785855086392, 'epoch': 3.0})

In [46]:
"Its flames are strong enough to melt iron bars."+separator

'Its flames are strong enough to melt iron bars. = /@\\ = '

In [47]:
sample = tokenizer.encode("Its can breathe under water."+separator, max_length=60, return_tensors="pt")
sample

tensor([[20459,   460, 18044,   739,  1660,    13,   796,  1220,    31,    59,
           796,   220]])

In [48]:
result = model.generate(sample)
result

tensor([[20459,   460, 18044,   739,  1660,    13,   796,  1220,    31,    59,
           796,   220,  1834,  1723, 50256]])

In [49]:
tokenizer.decode(result[0], skip_special_tokens=True)

'Its can breathe under water. = /@\\ = ursaring'