In [15]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

import sys
import logging
logging.getLogger().setLevel(logging.ERROR)
logging.disable(sys.maxsize)

from torch.utils.data import *
from transformers import *
sys.path.insert(0, "..")

from models import *
from logic import *
from my_datasets import *

from utils import *
import numpy as np
from tqdm import tqdm
import evaluate

from datasets import Dataset
import os

import wandb


In [16]:
# set the wandb project where this run will be logged
os.environ["WANDB_PROJECT"]="transformer_friends"
os.environ["WANDB_LOG_MODEL"] = "checkpoint" # log all model checkpoints

# save your trained model checkpoint to wandb
os.environ["WANDB_LOG_MODEL"]="true"

# turn off watch to log faster
os.environ["WANDB_WATCH"]="false"

In [17]:
n, r = 5, 8
ap, bp, tp, sp = 0.2, 0.2, 0.4, 0.1

nars = 3

train_len = 2500
test_len = 500
num_epochs = 10
# test_is_train = True

In [18]:
### Datasets
qed_train_dataset_config = OneShotQedDatasetConfig(r,n,ap,bp,tp,dataset_len=train_len,seed=1234)
qed_test_dataset_config = OneShotQedDatasetConfig(r,n,ap,bp,tp,dataset_len=test_len,seed=2345)
qed_train_dataset = OneShotQedEmbedsDataset(qed_train_dataset_config)
qed_test_dataset = OneShotQedEmbedsDataset(qed_test_dataset_config)

succ_train_dataset_config = OneStepStateDatasetConfig(r,n,ap,bp,tp,dataset_len=train_len,seed=1234)
succ_test_dataset_config = OneStepStateDatasetConfig(r,n,ap,bp,tp,dataset_len=test_len,seed=2345)
succ_train_dataset = OneStepStateEmbedsDataset(succ_train_dataset_config)
succ_test_dataset = OneStepStateEmbedsDataset(succ_test_dataset_config)

ars_train_dataset_config = AutoRegFixedStepsDatasetConfig(r,n,ap,bp,sp,nars,dataset_len=train_len,seed=1234)
ars_test_dataset_config = AutoRegFixedStepsDatasetConfig(r,n,ap,bp,sp,nars,dataset_len=test_len,seed=2345)
ars_train_dataset = AutoRegFixedStepsEmbedsDataset(ars_train_dataset_config)
ars_test_dataset = AutoRegFixedStepsEmbedsDataset(ars_test_dataset_config)

In [19]:
def stringify_rule(rule, var_sep_token):
    """
    Create a rule of the form xi , xj , ... -> xa
    from a one-hot vector of [<ants>, <cons>]
    """

    n_vars = len(rule) // 2
    ants = [f"x{i}" for i in range(n_vars) if rule[i]]
    cons = [f"x{i}" for i in range(n_vars) if rule[n_vars+i]]
    if len(ants) < 1:
        ants = ["empty"]
    if len(cons) < 1:
        cons = ["empty"]
    rule = var_sep_token.join(ants) + " -> " + var_sep_token.join(cons)
    return rule

def get_string_rep(dataset_item):
    """
    Returns a string of the form:
    [RULES_START] [RULE_START] ... [RULE_END] ... [RULES_END]
    [THEOREM_START] ... [THEOREM_END]
    [QED]
    """

    # Define the placeholder tokens
    var_sep_token = " , "
    rules_start = "[RULES_START]"
    rules_end = "[RULES_END]"
    rule_start = "[RULE_START]"
    rule_end = "[RULE_END]"
    theorem_start = "[THEOREM_START]"
    theorem_end = "[THEOREM_END]"
    qed = "[QED]"

    rules = dataset_item["rules"]
    theorem = dataset_item["theorem"]

    n_vars = len(theorem)

    rule_strs = [rule_start + " " + stringify_rule(rule, var_sep_token) + " " + rule_end for rule in rules]
    theorem_str = var_sep_token.join([f"x{i}" for i in range(n_vars) if theorem[i]])
    theorem_str = theorem_start + " " + theorem_str + " " + theorem_end
    rules_str = rules_start + " " + " ".join(rule_strs) + " " + rules_end
    return rules_str + " " + theorem_str + " " + qed


In [20]:
print(qed_train_dataset[0])
print(get_string_rep(qed_train_dataset[0]))

{'rules': tensor([[0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1, 0, 1, 0, 0, 1]]), 'theorem': tensor([1, 0, 1, 0, 1]), 'labels': tensor(0)}
[RULES_START] [RULE_START] empty -> x1 [RULE_END] [RULE_START] empty -> x2 [RULE_END] [RULE_START] empty -> empty [RULE_END] [RULE_START] empty -> x1 [RULE_END] [RULE_START] empty -> x3 [RULE_END] [RULE_START] x4 -> empty [RULE_END] [RULE_START] empty -> x2 [RULE_END] [RULE_START] x4 -> x1 , x4 [RULE_END] [RULES_END] [THEOREM_START] x0 , x2 , x4 [THEOREM_END] [QED]


In [21]:
# Create HuggingFace datasets for the QED task

train_data = [get_string_rep(qed_train_dataset[i]) for i in tqdm(range(len(qed_train_dataset)))]
train_labels = [qed_train_dataset[i]["labels"].item() for i in tqdm(range(len(qed_train_dataset)))]

print("Creating train dataset")
qed_train_hf_dataset = Dataset.from_dict({
    "data": train_data,
    "label": train_labels
}).with_format("torch")

test_data = [get_string_rep(qed_test_dataset[i]) for i in tqdm(range(len(qed_test_dataset)))]
test_labels = [qed_test_dataset[i]["labels"].item() for i in tqdm(range(len(qed_test_dataset)))]

print("Creating test dataset")
qed_test_hf_dataset = Dataset.from_dict({
    "data": test_data,
    "label": test_labels
}).with_format("torch")

100%|██████████| 2500/2500 [00:01<00:00, 1369.11it/s]
100%|██████████| 2500/2500 [00:01<00:00, 2201.01it/s]


Creating train dataset


100%|██████████| 500/500 [00:00<00:00, 1294.82it/s]
100%|██████████| 500/500 [00:00<00:00, 2049.83it/s]

Creating test dataset





In [22]:
# Get the GPT-2 tokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [23]:
def tokenize_function(item):
    return tokenizer(item["data"], truncation=True)

qed_train_tokenized_dataset = qed_train_hf_dataset.map(tokenize_function, batched=True)
qed_test_tokenized_dataset = qed_test_hf_dataset.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Map: 100%|██████████| 2500/2500 [00:00<00:00, 7810.69 examples/s]
Map: 100%|██████████| 500/500 [00:00<00:00, 7827.94 examples/s]


In [29]:
model = AutoModelForSequenceClassification.from_pretrained(
    "gpt2", num_labels=2
)
model.config.pad_token_id = tokenizer.pad_token_id

In [30]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    avg_ones = np.mean(predictions)
    acc = accuracy.compute(predictions=predictions, references=labels)
    return {"Accuracy" : acc["accuracy"], "Avg Ones" : avg_ones}

In [31]:
training_args = TrainingArguments(
    output_dir="gpt2_string_results",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    num_train_epochs=num_epochs,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    logging_steps=5,
    report_to="wandb",
    run_name="gpt2-qed-str-tokenizer_default-vars_5-rules_8-train_2500-test_500",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=qed_train_tokenized_dataset,
    eval_dataset=qed_test_tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [32]:
trainer.train()
wandb.finish()



{'loss': 0.7517, 'learning_rate': 1.9750000000000002e-05, 'epoch': 0.12}
{'loss': 0.6764, 'learning_rate': 1.95e-05, 'epoch': 0.25}
{'loss': 0.6787, 'learning_rate': 1.925e-05, 'epoch': 0.38}
{'loss': 0.7224, 'learning_rate': 1.9e-05, 'epoch': 0.5}
{'loss': 0.6209, 'learning_rate': 1.8750000000000002e-05, 'epoch': 0.62}
{'loss': 0.7466, 'learning_rate': 1.8500000000000002e-05, 'epoch': 0.75}
{'loss': 0.6165, 'learning_rate': 1.825e-05, 'epoch': 0.88}
{'loss': 0.6657, 'learning_rate': 1.8e-05, 'epoch': 1.0}
{'eval_loss': 0.6144739985466003, 'eval_Accuracy': 0.62, 'eval_Avg Ones': 0.94, 'eval_runtime': 2.5078, 'eval_samples_per_second': 199.381, 'eval_steps_per_second': 6.38, 'epoch': 1.0}




{'loss': 0.6399, 'learning_rate': 1.775e-05, 'epoch': 1.12}
{'loss': 0.6104, 'learning_rate': 1.7500000000000002e-05, 'epoch': 1.25}
{'loss': 0.5674, 'learning_rate': 1.7250000000000003e-05, 'epoch': 1.38}
{'loss': 0.5395, 'learning_rate': 1.7e-05, 'epoch': 1.5}
{'loss': 0.5649, 'learning_rate': 1.675e-05, 'epoch': 1.62}
{'loss': 0.5362, 'learning_rate': 1.65e-05, 'epoch': 1.75}
{'loss': 0.5346, 'learning_rate': 1.6250000000000002e-05, 'epoch': 1.88}
{'loss': 0.5032, 'learning_rate': 1.6000000000000003e-05, 'epoch': 2.0}
{'eval_loss': 0.4911639392375946, 'eval_Accuracy': 0.744, 'eval_Avg Ones': 0.724, 'eval_runtime': 2.4525, 'eval_samples_per_second': 203.875, 'eval_steps_per_second': 6.524, 'epoch': 2.0}




{'loss': 0.5351, 'learning_rate': 1.575e-05, 'epoch': 2.12}
{'loss': 0.5954, 'learning_rate': 1.55e-05, 'epoch': 2.25}
{'loss': 0.5552, 'learning_rate': 1.525e-05, 'epoch': 2.38}
{'loss': 0.5286, 'learning_rate': 1.5000000000000002e-05, 'epoch': 2.5}
{'loss': 0.5612, 'learning_rate': 1.4750000000000003e-05, 'epoch': 2.62}
{'loss': 0.5188, 'learning_rate': 1.45e-05, 'epoch': 2.75}
{'loss': 0.5434, 'learning_rate': 1.425e-05, 'epoch': 2.88}
{'loss': 0.4316, 'learning_rate': 1.4e-05, 'epoch': 3.0}
{'eval_loss': 0.4522511959075928, 'eval_Accuracy': 0.794, 'eval_Avg Ones': 0.518, 'eval_runtime': 2.5292, 'eval_samples_per_second': 197.69, 'eval_steps_per_second': 6.326, 'epoch': 3.0}




{'loss': 0.5351, 'learning_rate': 1.375e-05, 'epoch': 3.12}
{'loss': 0.4798, 'learning_rate': 1.3500000000000001e-05, 'epoch': 3.25}
{'loss': 0.4397, 'learning_rate': 1.325e-05, 'epoch': 3.38}
{'loss': 0.5019, 'learning_rate': 1.3000000000000001e-05, 'epoch': 3.5}
{'loss': 0.4527, 'learning_rate': 1.275e-05, 'epoch': 3.62}
{'loss': 0.446, 'learning_rate': 1.25e-05, 'epoch': 3.75}
{'loss': 0.4241, 'learning_rate': 1.2250000000000001e-05, 'epoch': 3.88}
{'loss': 0.3905, 'learning_rate': 1.2e-05, 'epoch': 4.0}
{'eval_loss': 0.3241040110588074, 'eval_Accuracy': 0.87, 'eval_Avg Ones': 0.654, 'eval_runtime': 2.4627, 'eval_samples_per_second': 203.033, 'eval_steps_per_second': 6.497, 'epoch': 4.0}




{'loss': 0.375, 'learning_rate': 1.1750000000000001e-05, 'epoch': 4.12}
{'loss': 0.3475, 'learning_rate': 1.15e-05, 'epoch': 4.25}
{'loss': 0.3198, 'learning_rate': 1.125e-05, 'epoch': 4.38}
{'loss': 0.3854, 'learning_rate': 1.1000000000000001e-05, 'epoch': 4.5}
{'loss': 0.2687, 'learning_rate': 1.075e-05, 'epoch': 4.62}
{'loss': 0.3705, 'learning_rate': 1.0500000000000001e-05, 'epoch': 4.75}
{'loss': 0.3466, 'learning_rate': 1.025e-05, 'epoch': 4.88}
{'loss': 0.46, 'learning_rate': 1e-05, 'epoch': 5.0}
{'eval_loss': 0.20899423956871033, 'eval_Accuracy': 0.91, 'eval_Avg Ones': 0.634, 'eval_runtime': 2.4832, 'eval_samples_per_second': 201.35, 'eval_steps_per_second': 6.443, 'epoch': 5.0}




{'loss': 0.32, 'learning_rate': 9.75e-06, 'epoch': 5.12}
{'loss': 0.2938, 'learning_rate': 9.5e-06, 'epoch': 5.25}
{'loss': 0.3078, 'learning_rate': 9.250000000000001e-06, 'epoch': 5.38}
{'loss': 0.2776, 'learning_rate': 9e-06, 'epoch': 5.5}
{'loss': 0.3013, 'learning_rate': 8.750000000000001e-06, 'epoch': 5.62}
{'loss': 0.2351, 'learning_rate': 8.5e-06, 'epoch': 5.75}
{'loss': 0.2832, 'learning_rate': 8.25e-06, 'epoch': 5.88}
{'loss': 0.319, 'learning_rate': 8.000000000000001e-06, 'epoch': 6.0}
{'eval_loss': 0.16874447464942932, 'eval_Accuracy': 0.926, 'eval_Avg Ones': 0.614, 'eval_runtime': 2.5059, 'eval_samples_per_second': 199.529, 'eval_steps_per_second': 6.385, 'epoch': 6.0}




{'loss': 0.227, 'learning_rate': 7.75e-06, 'epoch': 6.12}
{'loss': 0.3358, 'learning_rate': 7.500000000000001e-06, 'epoch': 6.25}
{'loss': 0.2558, 'learning_rate': 7.25e-06, 'epoch': 6.38}
{'loss': 0.253, 'learning_rate': 7e-06, 'epoch': 6.5}
{'loss': 0.2509, 'learning_rate': 6.750000000000001e-06, 'epoch': 6.62}
{'loss': 0.2001, 'learning_rate': 6.5000000000000004e-06, 'epoch': 6.75}
{'loss': 0.2432, 'learning_rate': 6.25e-06, 'epoch': 6.88}
{'loss': 0.1981, 'learning_rate': 6e-06, 'epoch': 7.0}
{'eval_loss': 0.15518324077129364, 'eval_Accuracy': 0.94, 'eval_Avg Ones': 0.512, 'eval_runtime': 2.4965, 'eval_samples_per_second': 200.284, 'eval_steps_per_second': 6.409, 'epoch': 7.0}




{'loss': 0.3484, 'learning_rate': 5.75e-06, 'epoch': 7.12}
{'loss': 0.2414, 'learning_rate': 5.500000000000001e-06, 'epoch': 7.25}
{'loss': 0.2104, 'learning_rate': 5.2500000000000006e-06, 'epoch': 7.38}
{'loss': 0.2714, 'learning_rate': 5e-06, 'epoch': 7.5}
{'loss': 0.2225, 'learning_rate': 4.75e-06, 'epoch': 7.62}
{'loss': 0.1825, 'learning_rate': 4.5e-06, 'epoch': 7.75}
{'loss': 0.2659, 'learning_rate': 4.25e-06, 'epoch': 7.88}
{'loss': 0.2807, 'learning_rate': 4.000000000000001e-06, 'epoch': 8.0}
{'eval_loss': 0.11215635389089584, 'eval_Accuracy': 0.954, 'eval_Avg Ones': 0.562, 'eval_runtime': 2.5005, 'eval_samples_per_second': 199.96, 'eval_steps_per_second': 6.399, 'epoch': 8.0}




{'loss': 0.2049, 'learning_rate': 3.7500000000000005e-06, 'epoch': 8.12}
{'loss': 0.2116, 'learning_rate': 3.5e-06, 'epoch': 8.25}
{'loss': 0.1886, 'learning_rate': 3.2500000000000002e-06, 'epoch': 8.38}
{'loss': 0.1485, 'learning_rate': 3e-06, 'epoch': 8.5}
{'loss': 0.2176, 'learning_rate': 2.7500000000000004e-06, 'epoch': 8.62}
{'loss': 0.1789, 'learning_rate': 2.5e-06, 'epoch': 8.75}
{'loss': 0.2103, 'learning_rate': 2.25e-06, 'epoch': 8.88}
{'loss': 0.2271, 'learning_rate': 2.0000000000000003e-06, 'epoch': 9.0}
{'eval_loss': 0.10437647253274918, 'eval_Accuracy': 0.956, 'eval_Avg Ones': 0.556, 'eval_runtime': 2.7383, 'eval_samples_per_second': 182.592, 'eval_steps_per_second': 5.843, 'epoch': 9.0}




{'loss': 0.1843, 'learning_rate': 1.75e-06, 'epoch': 9.12}
{'loss': 0.1853, 'learning_rate': 1.5e-06, 'epoch': 9.25}
{'loss': 0.1865, 'learning_rate': 1.25e-06, 'epoch': 9.38}
{'loss': 0.1931, 'learning_rate': 1.0000000000000002e-06, 'epoch': 9.5}
{'loss': 0.168, 'learning_rate': 7.5e-07, 'epoch': 9.62}
{'loss': 0.2334, 'learning_rate': 5.000000000000001e-07, 'epoch': 9.75}
{'loss': 0.1948, 'learning_rate': 2.5000000000000004e-07, 'epoch': 9.88}
{'loss': 0.132, 'learning_rate': 0.0, 'epoch': 10.0}
{'eval_loss': 0.10013871639966965, 'eval_Accuracy': 0.95, 'eval_Avg Ones': 0.566, 'eval_runtime': 2.4614, 'eval_samples_per_second': 203.133, 'eval_steps_per_second': 6.5, 'epoch': 10.0}
{'train_runtime': 347.3833, 'train_samples_per_second': 71.967, 'train_steps_per_second': 1.151, 'train_loss': 0.3772153198719025, 'epoch': 10.0}




0,1
eval/Accuracy,▁▄▅▆▇▇████
eval/Avg Ones,█▄▁▃▃▃▁▂▂▂
eval/loss,█▆▆▄▂▂▂▁▁▁
eval/runtime,▂▁▃▁▂▂▂▂█▁
eval/samples_per_second,▇█▆█▇▇▇▇▁█
eval/steps_per_second,▇█▆█▇▇▇▇▁█
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▁▁▁
train/loss,█▇▇▆▇▆▆▆▆▆▆▆▆▄▅▄▄▃▃▃▃▃▃▃▂▂▂▂▃▂▂▃▂▂▂▂▂▂▁▁

0,1
eval/Accuracy,0.95
eval/Avg Ones,0.566
eval/loss,0.10014
eval/runtime,2.4614
eval/samples_per_second,203.133
eval/steps_per_second,6.5
train/epoch,10.0
train/global_step,400.0
train/learning_rate,0.0
train/loss,0.132
