In [7]:
import os
import random
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
import torch
from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm

from agent_dataset import ReplayDataset, AgentDataset

torch.backends.cudnn.tf32 = True

seed = 42
transformers.set_seed(seed)

os.environ['WANDB_DISABLED'] = 'true' # I don't like using wandb for this

# Define the tokenizer and model
# small model hasn't even had train loss go below val loss
MODEL_NAME = 'gpt2-large' # w/ 3 epochs, normal got to ~.174, large ~1.135

# action dataset: w/ 2 epochs, normal got to ~.5, large got to ~
#MODEL_NAME = 'meta-llama/Llama-2-7b-chat-hf' # try code llama? Probably a better idea b/c it has longer context

## Pretraining on replays

In [2]:
# Define the training arguments
training_args = TrainingArguments(
    output_dir="./models/gen9randombattle",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=32,
    fp16=True,
    warmup_steps=200,
    weight_decay=0.01,
    logging_dir="./logs/gen9randombattle",
    logging_steps=250,
    evaluation_strategy='steps',
    eval_steps=250,
    save_steps=10000,
    tf32=True,
)


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [8]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

model = model.to_bettertransformer()

# config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=.1)
# model = get_peft_model(model, config)
# model.print_trainable_parameters()

# Define the dataset collator
data_collator = transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

chunk_size = tokenizer.model_max_length


The BetterTransformer implementation does not support padding during training, as the fused kernels do not support attention masks. Beware that passing padded batched data during training may result in unexpected outputs. Please refer to https://huggingface.co/docs/optimum/bettertransformer/overview for more details.


We're using a dataset of replay files and we'll be learning on the entirety of those files.

In [4]:
# Define the dataset
data_path = "dataset/gen9randombattle/replays"
replay_files = [os.path.join(data_path, file) for file in os.listdir(data_path)]

random.shuffle(replay_files)
train_replay_files = replay_files[:int(len(replay_files) * 0.8)]
val_replay_files = replay_files[int(len(replay_files) * 0.8):]

train_dataset = ReplayDataset(train_replay_files, tokenizer, chunk_size)
val_dataset = ReplayDataset(val_replay_files, tokenizer, chunk_size)

print(f"Train dataset length: {len(train_dataset)}")
print(f"Validation dataset length: {len(val_dataset)}")

Parsing replays:   0%|          | 0/4320 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (3337 > 1024). Running this sequence through the model will result in indexing errors
Parsing replays: 100%|██████████| 4320/4320 [00:31<00:00, 136.67it/s]
Parsing replays: 100%|██████████| 1080/1080 [00:07<00:00, 135.66it/s]

Train dataset length: 18408
Validation dataset length: 4612





In [5]:

# Train the model
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

trainer.train()

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss
250,0.4064,0.194633
500,0.1873,0.159974
750,0.1604,0.149403
1000,0.1499,0.143556
1250,0.1436,0.139636
1500,0.1373,0.1363


TrainOutput(global_step=1725, training_loss=0.1893086010476817, metrics={'train_runtime': 8406.4879, 'train_samples_per_second': 6.569, 'train_steps_per_second': 0.205, 'total_flos': 2.127190197646848e+17, 'train_loss': 0.1893086010476817, 'epoch': 3.0})

In [6]:
trainer.evaluate(val_dataset)

{'eval_loss': 0.13503998517990112,
 'eval_runtime': 224.278,
 'eval_samples_per_second': 20.564,
 'eval_steps_per_second': 20.564,
 'epoch': 3.0}

In [7]:
trainer.model = trainer.model.reverse_bettertransformer()
trainer.model.save_pretrained(f'models/gen9randombattle_{MODEL_NAME}')

## Fine-tuning on actions

Now that we've learned a general amount of the game, we can fine-tune on the actions of the game.

In [8]:
# Define the training arguments
training_args = TrainingArguments(
    output_dir="./models/gen9randombattle_rating",
    num_train_epochs=3,
    #learning_rate=5e-6, # reduce learning rate b/c we've already learned a lot
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=32,
    fp16=True,
    warmup_steps=0, # we've already learned the format of this text, no need to warmup b/c we're just applying finishing touches
    weight_decay=0.01,
    logging_dir="./logs/gen9randombattle_rating",
    logging_steps=250,
    evaluation_strategy='steps',
    eval_steps=250,
    save_strategy='no',
    tf32=True,
    group_by_length=True,
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [20]:
model = AutoModelForCausalLM.from_pretrained(f'models/gen9randombattle_{MODEL_NAME}')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token_id = tokenizer.eos_token_id

model = model.to_bettertransformer()

The BetterTransformer implementation does not support padding during training, as the fused kernels do not support attention masks. Beware that passing padded batched data during training may result in unexpected outputs. Please refer to https://huggingface.co/docs/optimum/bettertransformer/overview for more details.


In [10]:
# define the dataset
data_path = "dataset/gen9randombattle_rating/replays" # this is a high elo dataset
replay_files = [os.path.join(data_path, file) for file in os.listdir(data_path)]

random.shuffle(replay_files)

train_replay_files = replay_files[:int(len(replay_files) * 0.8)]
val_replay_files = replay_files[int(len(replay_files) * 0.8):]

train_dataset = AgentDataset(train_replay_files, tokenizer, 6) # with a context size of 1024, we can handle about 6 turns
val_dataset = AgentDataset(val_replay_files, tokenizer, 6)

print(f"Train dataset length: {len(train_dataset)}")
print(f"Validation dataset length: {len(val_dataset)}")

data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer, padding=False) # don't do any padding right now b/c we have bettertransformers (will change when we use flash attention)

Parsing replays: 100%|██████████| 1040/1040 [00:02<00:00, 510.28it/s]
Tokenizing actions:   0%|          | 0/1040 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1170 > 1024). Running this sequence through the model will result in indexing errors
Tokenizing actions: 100%|██████████| 1040/1040 [00:59<00:00, 17.52it/s]
Parsing replays: 100%|██████████| 260/260 [00:00<00:00, 604.65it/s]
Tokenizing actions: 100%|██████████| 260/260 [00:14<00:00, 17.58it/s]

Train dataset length: 26231
Validation dataset length: 6628





In [21]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

Let's see how well it does on the validation set before we start training

In [22]:
trainer.evaluate(val_dataset)

{'eval_loss': 0.45846983790397644,
 'eval_runtime': 360.4448,
 'eval_samples_per_second': 18.388,
 'eval_steps_per_second': 18.388}

In [13]:
trainer.train()

Step,Training Loss,Validation Loss
250,0.6134,0.575976
500,0.5594,0.545928
750,0.534,0.517375
1000,0.4772,0.506219
1250,0.4581,0.496521
1500,0.4471,0.480975
1750,0.4127,0.5002
2000,0.3614,0.491288
2250,0.3533,0.481597


TrainOutput(global_step=2457, training_loss=0.45746872140762523, metrics={'train_runtime': 12865.613, 'train_samples_per_second': 6.117, 'train_steps_per_second': 0.191, 'total_flos': 2.74194412362624e+17, 'train_loss': 0.45746872140762523, 'epoch': 3.0})

In [19]:
trainer.evaluate(val_dataset)

{'eval_loss': 0.4847310781478882,
 'eval_runtime': 479.0358,
 'eval_samples_per_second': 13.836,
 'eval_steps_per_second': 13.836,
 'epoch': 3.0}

In [14]:
trainer.model = trainer.model.reverse_bettertransformer()
trainer.model.save_pretrained(f'models/gen9randombattle_rating_{MODEL_NAME}')

## Inference

In [15]:
# load the model
model = AutoModelForCausalLM.from_pretrained(f'models/gen9randombattle_{MODEL_NAME}')
model.cuda()
tokenizer = AutoTokenizer.from_pretrained('gpt2')


In [16]:
def generate_helper(input_text, **kwargs):
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to('cuda')
    print(input_ids.shape)
    if input_ids.shape[-1] > tokenizer.model_max_length:
        return None
    output = model.generate(input_ids, **kwargs)
    return tokenizer.decode(output[0], skip_special_tokens=True)

In [17]:
input_text = """|p1|rating|2397
|p2|rating|2401
|
|start
|action|p1|switch|Umbreon
|action|p2|switch|Iron Leaves
|switch|p1: Umbreon|Umbreon, L85, M|300/300
|switch|p2: Iron Leaves|Iron Leaves, L81|278/278
|turn|1
|action|p1|switch|Banette
|action|p2|switch|Sandy Shocks
|
|-end|p2: Iron Leaves|Quark Drive|[silent]
|switch|p2: Sandy Shocks|Sandy Shocks, L80|267/267
|switch|p1: Banette|Banette, L93, F|270/270
|
|upkeep
|turn|2
|action|p1|switch|Kricketune
|action|p2|move|Stealth Rock|tera|null
|
|switch|p1: Kricketune|Kricketune, L96, M|303/303
|move|p2: Sandy Shocks|Stealth Rock|p1: Kricketune
|-sidestart|p1|move: Stealth Rock
|
|upkeep
|turn|3
|action|p1|move|Sticky Web|tera|null
|action|p2|move|Thunderbolt|tera|null
|
|move|p2: Sandy Shocks|Thunderbolt|p1: Kricketune
|-damage|p1: Kricketune|179/303
|move|p1: Kricketune|Sticky Web|p2: Sandy Shocks
|-sidestart|p2|move: Sticky Web
|
|upkeep
|turn|4
|action|p1|move|Pounce|tera|null
|action|p2|move|Thunder Wave|tera|null
|
|move|p2: Sandy Shocks|Thunder Wave|p1: Kricketune
|-status|p1: Kricketune|par
|move|p1: Kricketune|Pounce|p2: Sandy Shocks
|-damage|p2: Sandy Shocks|183/267
|-unboost|p2: Sandy Shocks|spe|1
|
|-heal|p2: Sandy Shocks|199/267|[from] item: Leftovers
|upkeep
|turn|5
|action|p1|move|failed|tera|null
|action|p2|move|Thunderbolt|tera|null
|
|move|p2: Sandy Shocks|Thunderbolt|p1: Kricketune
|-damage|p1: Kricketune|41/303 par
|cant|p1: Kricketune|par
|
|-heal|p2: Sandy Shocks|215/267|[from] item: Leftovers
|upkeep
|turn|6
|action|p1|move|failed|tera|null
|action|p2|move|Thunderbolt|tera|null
|
|move|p2: Sandy Shocks|Thunderbolt|p1: Kricketune
|-damage|p1: Kricketune|0 fnt
|faint|p1: Kricketune
|
|-heal|p2: Sandy Shocks|231/267|[from] item: Leftovers
|upkeep
|
|action|p1|switch|Breloom
|switch|p1: Breloom|Breloom, L82, M|233/233
|-damage|p1: Breloom|219/233|[from] Stealth Rock
|turn|7
|action|p1|"""

In [18]:
print(generate_helper(input_text, max_new_tokens=128, num_beams=3))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 863])
|p1|rating|2397
|p2|rating|2401
|
|start
|action|p1|switch|Umbreon
|action|p2|switch|Iron Leaves
|switch|p1: Umbreon|Umbreon, L85, M|300/300
|switch|p2: Iron Leaves|Iron Leaves, L81|278/278
|turn|1
|action|p1|switch|Banette
|action|p2|switch|Sandy Shocks
|
|-end|p2: Iron Leaves|Quark Drive|[silent]
|switch|p2: Sandy Shocks|Sandy Shocks, L80|267/267
|switch|p1: Banette|Banette, L93, F|270/270
|
|upkeep
|turn|2
|action|p1|switch|Kricketune
|action|p2|move|Stealth Rock|tera|null
|
|switch|p1: Kricketune|Kricketune, L96, M|303/303
|move|p2: Sandy Shocks|Stealth Rock|p1: Kricketune
|-sidestart|p1|move: Stealth Rock
|
|upkeep
|turn|3
|action|p1|move|Sticky Web|tera|null
|action|p2|move|Thunderbolt|tera|null
|
|move|p2: Sandy Shocks|Thunderbolt|p1: Kricketune
|-damage|p1: Kricketune|179/303
|move|p1: Kricketune|Sticky Web|p2: Sandy Shocks
|-sidestart|p2|move: Sticky Web
|
|upkeep
|turn|4
|action|p1|move|Pounce|tera|null
|action|p2|move|Thunder Wave|tera|null
|
|move|p2: S