In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from transformers import AutoModelForCausalLM, DataCollatorForLanguageModeling
import sys
import logging
logging.getLogger().setLevel(logging.ERROR)
logging.disable(sys.maxsize)

from torch.utils.data import *
from transformers import *
sys.path.insert(0, "..")

from models import *
from my_datasets import *

# from utils import *
import numpy as np
from tqdm import tqdm
import evaluate

from datasets import Dataset
import os

import wandb


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# set the wandb project where this run will be logged
os.environ["WANDB_PROJECT"]="transformer_friends"
os.environ["WANDB_LOG_MODEL"] = "checkpoint" # log all model checkpoints

# save your trained model checkpoint to wandb
os.environ["WANDB_LOG_MODEL"]="true"

# turn off watch to log faster
os.environ["WANDB_WATCH"]="false"

In [3]:
# n, r = 5, 8
# n, r = 20, 5
# n, r = 5, 8
n, r = 10, 10
ap, bp, tp, sp = 0.2, 0.2, 0.4, 0.1

nars = 3

train_len = 2500
test_len = 500
num_epochs = 15
seed = 42
# test_is_train = True

In [4]:
train_dataset = AutoRegKStepsEmbedsDataset(
    num_rules = r,
    num_vars = n,
    num_steps = nars,
    ante_prob = ap,
    conseq_prob = bp,
    state_prob = sp,
    dataset_len = train_len,
    seed = seed)

eval_dataset = AutoRegKStepsEmbedsDataset(
    num_rules = r,
    num_vars = n,
    num_steps = nars,
    ante_prob = ap,
    conseq_prob = bp,
    state_prob = sp,
    dataset_len = test_len,
    seed = seed)

In [5]:
train_dataset[0]

{'rules': tensor([[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0],
         [0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0],
         [0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
         [0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0],
         [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]),
 'state': tensor([0, 0, 0, 1, 0, 0, 0, 0, 1, 0]),
 'labels': tensor([[1, 0, 1, 1, 0, 1, 0, 0, 1, 0],
         [1, 0, 1, 1, 0, 1, 1, 0, 1, 0],
         [1, 1, 1, 1, 0, 1, 1, 0, 1, 0]])}

In [6]:
def stringify_rule(rule, var_sep_token):
    """
    Create a rule of the form xi , xj , ... -> xa
    from a one-hot vector of [<ants>, <cons>]
    """

    n_vars = len(rule) // 2
    ants = [f"x{i}" for i in range(n_vars) if rule[i]]
    cons = [f"x{i}" for i in range(n_vars) if rule[n_vars+i]]
    if len(ants) < 1:
        ants = ["empty"]
    if len(cons) < 1:
        cons = ["empty"]
    rule = var_sep_token.join(ants) + " -> " + var_sep_token.join(cons)
    return rule

def get_string_rep_replace(dataset_item):
    """
    Returns a string of the form:
    [RULES_START] [RULE_START] ... [RULE_END] ... [RULES_END]
    [CURRENT_STATE_START] ... [CURRENT_STATE_END]
    [NEXT_STATE_START] ... [NEXT_STATE_END]
    """

    var_sep_token = " , "
    rules_start = "[RULES_START]"
    rules_end = "[RULES_END]"
    rule_start = "[RULE_START]"
    rule_end = "[RULE_END]"
    current_state_start = "[CURRENT_STATE_START]"
    current_state_end = "[CURRENT_STATE_END]"
    next_state_start = "[NEXT_STATE_START]"
    next_state_end = "[NEXT_STATE_END]"

    rules = dataset_item["rules"]
    current_state = dataset_item["state"]
    next_state = dataset_item["labels"][0]

    n_vars = len(current_state)

    rule_strs = [rule_start + " " + stringify_rule(rule, var_sep_token) + " " + rule_end for rule in rules]
    current_state_str = var_sep_token.join([f"x{i}" for i in range(n_vars) if current_state[i]])
    current_state_str = current_state_start + " " + current_state_str + " " + current_state_end
    rules_str = rules_start + " " + " ".join(rule_strs) + " " + rules_end

    next_state_str = var_sep_token.join([f"x{i}" for i in range(n_vars) if next_state[i]])
    return {
        "prompt": rules_str + " " + current_state_str + " " + next_state_start,
        "target": " " + next_state_str + " " + next_state_end,
        "stop": next_state_end
    }
    return rules_str + " " + current_state_str + " " + next_state_start, {"stop": next_state_end}

def get_string_rep_append(dataset_item):
    """
    Returns a string of the form:
    [RULES_START] [RULE_START] ... [RULE_END] ... [RULES_END]
    [STATES_START] [STATE_START] ... [STATE_END] ... [STATES_END]
    """

    var_sep_token = " , "
    rules_start = "[RULES_START]"
    rules_end = "[RULES_END]"
    rule_start = "[RULE_START]"
    rule_end = "[RULE_END]"
    states_start = "[STATES_START]"
    states_end = "[STATES_END]"
    state_start = "[STATE_START]"
    state_end = "[STATE_END]"

    rules = dataset_item["rules"]
    state = dataset_item["state"]
    next_states = dataset_item["labels"]

    n_vars = len(state)

    rule_strs = [rule_start + " " + stringify_rule(rule, var_sep_token) + " " + rule_end for rule in rules]
    state_str = var_sep_token.join([f"x{i}" for i in range(n_vars) if state[i]])
    state_str = state_start + " " + state_str + " " + state_end
    rules_str = rules_start + " " + " ".join(rule_strs) + " " + rules_end

    next_state_strs = [var_sep_token.join([f"x{i}" for i in range(n_vars) if next_state[i]]) for next_state in next_states]
    next_state_strs = [state_start + " " + next_state_str + " " + state_end for next_state_str in next_state_strs]
    next_state_strs = " ".join(next_state_strs)
    # Remove the first state_start from the next state string
    next_state_strs = next_state_strs[len(state_start)+1:]
    return {
        "prompt": rules_str + " " + states_start + " " + state_str + " " + state_start,
        "target": " " + next_state_strs + " " + states_end,
        "stop": states_end
    }
    return rules_str + " " + states_start + " " + state_str + " " + state_start, {"stop": states_end}


In [7]:
print(train_dataset[0])
print(get_string_rep_replace(train_dataset[0]))
print(get_string_rep_append(train_dataset[0]))

{'rules': tensor([[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0],
        [0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
        [0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0],
        [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]), 'state': tensor([0, 0, 0, 1, 0, 0, 0, 0, 1, 0]), 'labels': tensor([[1, 0, 1, 1, 0, 1, 0, 0, 1, 0],
        [1, 0, 1, 1, 0, 1, 1, 0, 1, 0],
        [1, 1, 1, 1, 0, 1, 1, 0, 1, 0]])}
{'prompt': '[RULES_START] [RULE_START] x0 , x6 -> x0 , x1 , x2 [RULE_END] [RULE_START] empty -> x0 , x2 [RU

In [8]:
# Create HuggingFace datasets for the append task

print("Creating train dataset")
train_data = [get_string_rep_append(train_dataset[i]) for i in tqdm(range(len(train_dataset)))]
train_hf_dataset = Dataset.from_dict({
    # "data": [train_data[i]['prompt'] for i in range(len(train_data))],
    # "label": [train_data[i]['target'] for i in range(len(train_data))],
    "data": [train_data[i]['prompt'] + train_data[i]['target'] for i in range(len(train_data))],
}).with_format("torch")

print("Creating test dataset")
test_data = [get_string_rep_append(eval_dataset[i]) for i in tqdm(range(len(eval_dataset)))]
test_hf_dataset = Dataset.from_dict({
    # "data": [test_data[i]['prompt'] for i in range(len(test_data))],
    # "label": [test_data[i]['target'] for i in range(len(test_data))],
    "data": [test_data[i]['prompt'] + test_data[i]['target'] for i in range(len(test_data))],
}).with_format("torch")

Creating train dataset


100%|██████████| 2500/2500 [00:02<00:00, 1094.20it/s]


Creating test dataset


100%|██████████| 500/500 [00:00<00:00, 1106.97it/s]


In [9]:
# Get the GPT-2 tokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [10]:
def tokenize_function(item):
    return tokenizer(item["data"], truncation=True)

train_tokenized_dataset = train_hf_dataset.map(tokenize_function, batched=True)
test_tokenized_dataset = test_hf_dataset.map(tokenize_function, batched=True)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

                                                                 

In [11]:
# Create the model
model = AutoModelForCausalLM.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)

In [13]:
import inspect

In [14]:
inspect.signature(model.forward)

<Signature (input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithCrossAttentions]>

In [18]:
# accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    print(eval_pred.label_ids)
    print(eval_pred.predictions)
    print(dict(eval_pred).keys())
    # exit()
    return {"Accuracy": 0}
    # predictions, labels = eval_pred
    # # Check if all predictions match labels
    # acc = accuracy.compute(predictions=predictions, references=labels)
    # return acc
    # # return {"Accuracy" : acc["accuracy"], "Avg Ones" : avg_ones}

In [13]:
training_args = TrainingArguments(
    # output_dir="gpt2_string_auto_reg_results",
    output_dir="gpt2_append_autoreg_str_results",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=num_epochs,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    logging_steps=5,
    report_to="wandb",
    run_name="gpt2-append-autoreg-str-tokenizer_default-vars_10-rules_10-train_2500-test_500",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized_dataset,
    eval_dataset=test_tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
)

In [16]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33makhare[0m ([33mtransformer_friends[0m). Use [1m`wandb login --relogin`[0m to force relogin




{'loss': 0.9474, 'learning_rate': 1.9936507936507938e-05, 'epoch': 0.05}
{'loss': 0.7275, 'learning_rate': 1.9873015873015875e-05, 'epoch': 0.1}
{'loss': 0.5858, 'learning_rate': 1.980952380952381e-05, 'epoch': 0.14}
{'loss': 0.4999, 'learning_rate': 1.9746031746031748e-05, 'epoch': 0.19}
{'loss': 0.4405, 'learning_rate': 1.9682539682539684e-05, 'epoch': 0.24}
{'loss': 0.4126, 'learning_rate': 1.961904761904762e-05, 'epoch': 0.29}
{'loss': 0.3871, 'learning_rate': 1.9555555555555557e-05, 'epoch': 0.33}
{'loss': 0.3795, 'learning_rate': 1.9492063492063494e-05, 'epoch': 0.38}
{'loss': 0.3669, 'learning_rate': 1.942857142857143e-05, 'epoch': 0.43}
{'loss': 0.3575, 'learning_rate': 1.9365079365079367e-05, 'epoch': 0.48}
{'loss': 0.3565, 'learning_rate': 1.9301587301587303e-05, 'epoch': 0.52}
{'loss': 0.3485, 'learning_rate': 1.923809523809524e-05, 'epoch': 0.57}
{'loss': 0.3457, 'learning_rate': 1.9174603174603176e-05, 'epoch': 0.62}
{'loss': 0.3437, 'learning_rate': 1.9111111111111113e-05



{'loss': 0.3335, 'learning_rate': 1.8603174603174605e-05, 'epoch': 1.05}
{'loss': 0.3314, 'learning_rate': 1.853968253968254e-05, 'epoch': 1.1}
{'loss': 0.333, 'learning_rate': 1.8476190476190478e-05, 'epoch': 1.14}
{'loss': 0.3286, 'learning_rate': 1.8412698412698415e-05, 'epoch': 1.19}
{'loss': 0.3298, 'learning_rate': 1.834920634920635e-05, 'epoch': 1.24}
{'loss': 0.3302, 'learning_rate': 1.8285714285714288e-05, 'epoch': 1.29}
{'loss': 0.3251, 'learning_rate': 1.8222222222222224e-05, 'epoch': 1.33}
{'loss': 0.3242, 'learning_rate': 1.815873015873016e-05, 'epoch': 1.38}
{'loss': 0.3245, 'learning_rate': 1.8095238095238097e-05, 'epoch': 1.43}
{'loss': 0.3242, 'learning_rate': 1.8031746031746034e-05, 'epoch': 1.48}
{'loss': 0.3232, 'learning_rate': 1.796825396825397e-05, 'epoch': 1.52}
{'loss': 0.32, 'learning_rate': 1.7904761904761907e-05, 'epoch': 1.57}
{'loss': 0.3247, 'learning_rate': 1.7841269841269843e-05, 'epoch': 1.62}
{'loss': 0.3186, 'learning_rate': 1.7777777777777777e-05, '



{'loss': 0.3148, 'learning_rate': 1.7269841269841272e-05, 'epoch': 2.05}
{'loss': 0.3122, 'learning_rate': 1.720634920634921e-05, 'epoch': 2.1}
{'loss': 0.3178, 'learning_rate': 1.7142857142857142e-05, 'epoch': 2.14}
{'loss': 0.3161, 'learning_rate': 1.707936507936508e-05, 'epoch': 2.19}
{'loss': 0.3155, 'learning_rate': 1.7015873015873018e-05, 'epoch': 2.24}
{'loss': 0.3125, 'learning_rate': 1.6952380952380955e-05, 'epoch': 2.29}
{'loss': 0.3157, 'learning_rate': 1.688888888888889e-05, 'epoch': 2.33}
{'loss': 0.3145, 'learning_rate': 1.6825396825396828e-05, 'epoch': 2.38}
{'loss': 0.3118, 'learning_rate': 1.6761904761904764e-05, 'epoch': 2.43}
{'loss': 0.3127, 'learning_rate': 1.66984126984127e-05, 'epoch': 2.48}
{'loss': 0.3097, 'learning_rate': 1.6634920634920637e-05, 'epoch': 2.52}
{'loss': 0.3111, 'learning_rate': 1.6571428571428574e-05, 'epoch': 2.57}
{'loss': 0.3114, 'learning_rate': 1.6507936507936507e-05, 'epoch': 2.62}
{'loss': 0.3109, 'learning_rate': 1.6444444444444444e-05,



{'loss': 0.3074, 'learning_rate': 1.5936507936507936e-05, 'epoch': 3.05}
{'loss': 0.3069, 'learning_rate': 1.5873015873015872e-05, 'epoch': 3.1}
{'loss': 0.3073, 'learning_rate': 1.580952380952381e-05, 'epoch': 3.14}
{'loss': 0.3065, 'learning_rate': 1.5746031746031745e-05, 'epoch': 3.19}
{'loss': 0.3079, 'learning_rate': 1.5682539682539685e-05, 'epoch': 3.24}
{'loss': 0.3047, 'learning_rate': 1.5619047619047622e-05, 'epoch': 3.29}
{'loss': 0.3073, 'learning_rate': 1.555555555555556e-05, 'epoch': 3.33}
{'loss': 0.3056, 'learning_rate': 1.5492063492063495e-05, 'epoch': 3.38}
{'loss': 0.3051, 'learning_rate': 1.542857142857143e-05, 'epoch': 3.43}
{'loss': 0.3044, 'learning_rate': 1.5365079365079368e-05, 'epoch': 3.48}
{'loss': 0.3075, 'learning_rate': 1.53015873015873e-05, 'epoch': 3.52}
{'loss': 0.3062, 'learning_rate': 1.523809523809524e-05, 'epoch': 3.57}
{'loss': 0.3054, 'learning_rate': 1.5174603174603176e-05, 'epoch': 3.62}
{'loss': 0.3072, 'learning_rate': 1.5111111111111112e-05, 



{'loss': 0.3021, 'learning_rate': 1.4603174603174603e-05, 'epoch': 4.05}
{'loss': 0.303, 'learning_rate': 1.4539682539682541e-05, 'epoch': 4.1}
{'loss': 0.3036, 'learning_rate': 1.4476190476190478e-05, 'epoch': 4.14}
{'loss': 0.2999, 'learning_rate': 1.4412698412698414e-05, 'epoch': 4.19}
{'loss': 0.3015, 'learning_rate': 1.434920634920635e-05, 'epoch': 4.24}
{'loss': 0.3027, 'learning_rate': 1.4285714285714287e-05, 'epoch': 4.29}
{'loss': 0.3054, 'learning_rate': 1.4222222222222224e-05, 'epoch': 4.33}
{'loss': 0.3012, 'learning_rate': 1.415873015873016e-05, 'epoch': 4.38}
{'loss': 0.3036, 'learning_rate': 1.4095238095238097e-05, 'epoch': 4.43}
{'loss': 0.301, 'learning_rate': 1.4031746031746032e-05, 'epoch': 4.48}
{'loss': 0.3012, 'learning_rate': 1.3968253968253968e-05, 'epoch': 4.52}
{'loss': 0.3013, 'learning_rate': 1.3904761904761905e-05, 'epoch': 4.57}
{'loss': 0.2999, 'learning_rate': 1.3841269841269843e-05, 'epoch': 4.62}
{'loss': 0.3011, 'learning_rate': 1.377777777777778e-05,



{'loss': 0.3017, 'learning_rate': 1.326984126984127e-05, 'epoch': 5.05}
{'loss': 0.299, 'learning_rate': 1.3206349206349206e-05, 'epoch': 5.1}
{'loss': 0.3003, 'learning_rate': 1.3142857142857145e-05, 'epoch': 5.14}
{'loss': 0.2995, 'learning_rate': 1.3079365079365081e-05, 'epoch': 5.19}
{'loss': 0.2976, 'learning_rate': 1.3015873015873018e-05, 'epoch': 5.24}
{'loss': 0.2988, 'learning_rate': 1.2952380952380954e-05, 'epoch': 5.29}
{'loss': 0.2982, 'learning_rate': 1.288888888888889e-05, 'epoch': 5.33}
{'loss': 0.2957, 'learning_rate': 1.2825396825396827e-05, 'epoch': 5.38}
{'loss': 0.299, 'learning_rate': 1.2761904761904762e-05, 'epoch': 5.43}
{'loss': 0.3002, 'learning_rate': 1.2698412698412699e-05, 'epoch': 5.48}
{'loss': 0.2988, 'learning_rate': 1.2634920634920635e-05, 'epoch': 5.52}
{'loss': 0.2967, 'learning_rate': 1.2571428571428572e-05, 'epoch': 5.57}
{'loss': 0.3002, 'learning_rate': 1.2507936507936508e-05, 'epoch': 5.62}
{'loss': 0.298, 'learning_rate': 1.2444444444444446e-05,



{'loss': 0.2966, 'learning_rate': 1.1936507936507937e-05, 'epoch': 6.05}
{'loss': 0.2984, 'learning_rate': 1.1873015873015873e-05, 'epoch': 6.1}
{'loss': 0.2975, 'learning_rate': 1.180952380952381e-05, 'epoch': 6.14}
{'loss': 0.298, 'learning_rate': 1.1746031746031748e-05, 'epoch': 6.19}
{'loss': 0.2982, 'learning_rate': 1.1682539682539685e-05, 'epoch': 6.24}
{'loss': 0.2954, 'learning_rate': 1.1619047619047621e-05, 'epoch': 6.29}
{'loss': 0.2971, 'learning_rate': 1.1555555555555556e-05, 'epoch': 6.33}
{'loss': 0.2959, 'learning_rate': 1.1492063492063492e-05, 'epoch': 6.38}
{'loss': 0.2962, 'learning_rate': 1.1428571428571429e-05, 'epoch': 6.43}
{'loss': 0.2949, 'learning_rate': 1.1365079365079366e-05, 'epoch': 6.48}
{'loss': 0.296, 'learning_rate': 1.1301587301587302e-05, 'epoch': 6.52}
{'loss': 0.2962, 'learning_rate': 1.1238095238095239e-05, 'epoch': 6.57}
{'loss': 0.2948, 'learning_rate': 1.1174603174603175e-05, 'epoch': 6.62}
{'loss': 0.3008, 'learning_rate': 1.1111111111111113e-0



{'loss': 0.2969, 'learning_rate': 1.0603174603174604e-05, 'epoch': 7.05}
{'loss': 0.2961, 'learning_rate': 1.053968253968254e-05, 'epoch': 7.1}
{'loss': 0.2948, 'learning_rate': 1.0476190476190477e-05, 'epoch': 7.14}
{'loss': 0.2969, 'learning_rate': 1.0412698412698415e-05, 'epoch': 7.19}
{'loss': 0.2953, 'learning_rate': 1.0349206349206352e-05, 'epoch': 7.24}
{'loss': 0.2969, 'learning_rate': 1.0285714285714285e-05, 'epoch': 7.29}
{'loss': 0.2933, 'learning_rate': 1.0222222222222223e-05, 'epoch': 7.33}
{'loss': 0.2975, 'learning_rate': 1.015873015873016e-05, 'epoch': 7.38}
{'loss': 0.2971, 'learning_rate': 1.0095238095238096e-05, 'epoch': 7.43}
{'loss': 0.2964, 'learning_rate': 1.0031746031746033e-05, 'epoch': 7.48}
{'loss': 0.295, 'learning_rate': 9.968253968253969e-06, 'epoch': 7.52}
{'loss': 0.2956, 'learning_rate': 9.904761904761906e-06, 'epoch': 7.57}
{'loss': 0.2957, 'learning_rate': 9.841269841269842e-06, 'epoch': 7.62}
{'loss': 0.2936, 'learning_rate': 9.777777777777779e-06, '



{'loss': 0.2967, 'learning_rate': 9.26984126984127e-06, 'epoch': 8.05}
{'loss': 0.2941, 'learning_rate': 9.206349206349207e-06, 'epoch': 8.1}
{'loss': 0.2938, 'learning_rate': 9.142857142857144e-06, 'epoch': 8.14}
{'loss': 0.295, 'learning_rate': 9.07936507936508e-06, 'epoch': 8.19}
{'loss': 0.295, 'learning_rate': 9.015873015873017e-06, 'epoch': 8.24}
{'loss': 0.2932, 'learning_rate': 8.952380952380953e-06, 'epoch': 8.29}
{'loss': 0.2911, 'learning_rate': 8.888888888888888e-06, 'epoch': 8.33}
{'loss': 0.293, 'learning_rate': 8.825396825396827e-06, 'epoch': 8.38}
{'loss': 0.2948, 'learning_rate': 8.761904761904763e-06, 'epoch': 8.43}
{'loss': 0.2966, 'learning_rate': 8.6984126984127e-06, 'epoch': 8.48}
{'loss': 0.2939, 'learning_rate': 8.634920634920636e-06, 'epoch': 8.52}
{'loss': 0.2926, 'learning_rate': 8.571428571428571e-06, 'epoch': 8.57}
{'loss': 0.2934, 'learning_rate': 8.507936507936509e-06, 'epoch': 8.62}
{'loss': 0.2942, 'learning_rate': 8.444444444444446e-06, 'epoch': 8.67}




{'loss': 0.2931, 'learning_rate': 7.936507936507936e-06, 'epoch': 9.05}
{'loss': 0.2945, 'learning_rate': 7.873015873015873e-06, 'epoch': 9.1}
{'loss': 0.294, 'learning_rate': 7.809523809523811e-06, 'epoch': 9.14}
{'loss': 0.2938, 'learning_rate': 7.746031746031747e-06, 'epoch': 9.19}
{'loss': 0.2929, 'learning_rate': 7.682539682539684e-06, 'epoch': 9.24}
{'loss': 0.2928, 'learning_rate': 7.61904761904762e-06, 'epoch': 9.29}
{'loss': 0.2926, 'learning_rate': 7.555555555555556e-06, 'epoch': 9.33}
{'loss': 0.2935, 'learning_rate': 7.492063492063493e-06, 'epoch': 9.38}
{'loss': 0.2926, 'learning_rate': 7.428571428571429e-06, 'epoch': 9.43}
{'loss': 0.2935, 'learning_rate': 7.3650793650793666e-06, 'epoch': 9.48}
{'loss': 0.2946, 'learning_rate': 7.301587301587301e-06, 'epoch': 9.52}
{'loss': 0.2934, 'learning_rate': 7.238095238095239e-06, 'epoch': 9.57}
{'loss': 0.2934, 'learning_rate': 7.174603174603175e-06, 'epoch': 9.62}
{'loss': 0.294, 'learning_rate': 7.111111111111112e-06, 'epoch': 9



{'loss': 0.2937, 'learning_rate': 6.603174603174603e-06, 'epoch': 10.05}
{'loss': 0.2922, 'learning_rate': 6.5396825396825405e-06, 'epoch': 10.1}
{'loss': 0.2921, 'learning_rate': 6.476190476190477e-06, 'epoch': 10.14}
{'loss': 0.2912, 'learning_rate': 6.412698412698414e-06, 'epoch': 10.19}
{'loss': 0.2938, 'learning_rate': 6.349206349206349e-06, 'epoch': 10.24}
{'loss': 0.293, 'learning_rate': 6.285714285714286e-06, 'epoch': 10.29}
{'loss': 0.2925, 'learning_rate': 6.222222222222223e-06, 'epoch': 10.33}
{'loss': 0.2924, 'learning_rate': 6.15873015873016e-06, 'epoch': 10.38}
{'loss': 0.2913, 'learning_rate': 6.095238095238096e-06, 'epoch': 10.43}
{'loss': 0.2933, 'learning_rate': 6.031746031746032e-06, 'epoch': 10.48}
{'loss': 0.2929, 'learning_rate': 5.968253968253968e-06, 'epoch': 10.52}
{'loss': 0.2945, 'learning_rate': 5.904761904761905e-06, 'epoch': 10.57}
{'loss': 0.2934, 'learning_rate': 5.841269841269842e-06, 'epoch': 10.62}
{'loss': 0.29, 'learning_rate': 5.777777777777778e-06



{'loss': 0.2903, 'learning_rate': 5.26984126984127e-06, 'epoch': 11.05}
{'loss': 0.2908, 'learning_rate': 5.2063492063492076e-06, 'epoch': 11.1}
{'loss': 0.2926, 'learning_rate': 5.142857142857142e-06, 'epoch': 11.14}
{'loss': 0.2924, 'learning_rate': 5.07936507936508e-06, 'epoch': 11.19}
{'loss': 0.2926, 'learning_rate': 5.015873015873016e-06, 'epoch': 11.24}
{'loss': 0.2922, 'learning_rate': 4.952380952380953e-06, 'epoch': 11.29}
{'loss': 0.2918, 'learning_rate': 4.888888888888889e-06, 'epoch': 11.33}
{'loss': 0.2926, 'learning_rate': 4.825396825396826e-06, 'epoch': 11.38}
{'loss': 0.2928, 'learning_rate': 4.761904761904762e-06, 'epoch': 11.43}
{'loss': 0.2904, 'learning_rate': 4.698412698412699e-06, 'epoch': 11.48}
{'loss': 0.2909, 'learning_rate': 4.634920634920635e-06, 'epoch': 11.52}
{'loss': 0.291, 'learning_rate': 4.571428571428572e-06, 'epoch': 11.57}
{'loss': 0.2921, 'learning_rate': 4.5079365079365085e-06, 'epoch': 11.62}
{'loss': 0.2899, 'learning_rate': 4.444444444444444e-



{'loss': 0.2919, 'learning_rate': 3.936507936507936e-06, 'epoch': 12.05}
{'loss': 0.2897, 'learning_rate': 3.873015873015874e-06, 'epoch': 12.1}
{'loss': 0.2909, 'learning_rate': 3.80952380952381e-06, 'epoch': 12.14}
{'loss': 0.2915, 'learning_rate': 3.7460317460317463e-06, 'epoch': 12.19}
{'loss': 0.2915, 'learning_rate': 3.6825396825396833e-06, 'epoch': 12.24}
{'loss': 0.2911, 'learning_rate': 3.6190476190476194e-06, 'epoch': 12.29}
{'loss': 0.2912, 'learning_rate': 3.555555555555556e-06, 'epoch': 12.33}
{'loss': 0.2918, 'learning_rate': 3.492063492063492e-06, 'epoch': 12.38}
{'loss': 0.2912, 'learning_rate': 3.428571428571429e-06, 'epoch': 12.43}
{'loss': 0.2913, 'learning_rate': 3.3650793650793655e-06, 'epoch': 12.48}
{'loss': 0.291, 'learning_rate': 3.3015873015873016e-06, 'epoch': 12.52}
{'loss': 0.2902, 'learning_rate': 3.2380952380952385e-06, 'epoch': 12.57}
{'loss': 0.2909, 'learning_rate': 3.1746031746031746e-06, 'epoch': 12.62}
{'loss': 0.2904, 'learning_rate': 3.11111111111



{'loss': 0.2922, 'learning_rate': 2.6031746031746038e-06, 'epoch': 13.05}
{'loss': 0.2913, 'learning_rate': 2.53968253968254e-06, 'epoch': 13.1}
{'loss': 0.2921, 'learning_rate': 2.4761904761904764e-06, 'epoch': 13.14}
{'loss': 0.2907, 'learning_rate': 2.412698412698413e-06, 'epoch': 13.19}
{'loss': 0.292, 'learning_rate': 2.3492063492063494e-06, 'epoch': 13.24}
{'loss': 0.2904, 'learning_rate': 2.285714285714286e-06, 'epoch': 13.29}
{'loss': 0.2903, 'learning_rate': 2.222222222222222e-06, 'epoch': 13.33}
{'loss': 0.2909, 'learning_rate': 2.158730158730159e-06, 'epoch': 13.38}
{'loss': 0.2916, 'learning_rate': 2.0952380952380955e-06, 'epoch': 13.43}
{'loss': 0.2905, 'learning_rate': 2.031746031746032e-06, 'epoch': 13.48}
{'loss': 0.2909, 'learning_rate': 1.968253968253968e-06, 'epoch': 13.52}
{'loss': 0.2916, 'learning_rate': 1.904761904761905e-06, 'epoch': 13.57}
{'loss': 0.2897, 'learning_rate': 1.8412698412698416e-06, 'epoch': 13.62}
{'loss': 0.2901, 'learning_rate': 1.7777777777777



{'loss': 0.2898, 'learning_rate': 1.26984126984127e-06, 'epoch': 14.05}
{'loss': 0.2911, 'learning_rate': 1.2063492063492065e-06, 'epoch': 14.1}
{'loss': 0.2921, 'learning_rate': 1.142857142857143e-06, 'epoch': 14.14}
{'loss': 0.2897, 'learning_rate': 1.0793650793650795e-06, 'epoch': 14.19}
{'loss': 0.2923, 'learning_rate': 1.015873015873016e-06, 'epoch': 14.24}
{'loss': 0.2919, 'learning_rate': 9.523809523809525e-07, 'epoch': 14.29}
{'loss': 0.2902, 'learning_rate': 8.88888888888889e-07, 'epoch': 14.33}
{'loss': 0.2886, 'learning_rate': 8.253968253968254e-07, 'epoch': 14.38}
{'loss': 0.2885, 'learning_rate': 7.61904761904762e-07, 'epoch': 14.43}
{'loss': 0.2892, 'learning_rate': 6.984126984126984e-07, 'epoch': 14.48}
{'loss': 0.2924, 'learning_rate': 6.34920634920635e-07, 'epoch': 14.52}
{'loss': 0.2913, 'learning_rate': 5.714285714285715e-07, 'epoch': 14.57}
{'loss': 0.2905, 'learning_rate': 5.07936507936508e-07, 'epoch': 14.62}
{'loss': 0.2908, 'learning_rate': 4.444444444444445e-07

TrainOutput(global_step=1575, training_loss=0.30666147890545076, metrics={'train_runtime': 730.1889, 'train_samples_per_second': 51.357, 'train_steps_per_second': 2.157, 'train_loss': 0.30666147890545076, 'epoch': 15.0})

In [17]:
wandb.finish()

0,1
eval/loss,▆▅█▇▆▆▄▅▆▅▅▁▅▄▄
eval/runtime,▁▁█▂▁▇▁█▁▁▁█▁▁▁
eval/samples_per_second,██▁▇█▁█▁███▁███
eval/steps_per_second,██▁▇█▂█▁███▁███
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_flos,▁
train/train_loss,▁

0,1
eval/loss,0.31911
eval/runtime,4.7738
eval/samples_per_second,104.738
eval/steps_per_second,4.399
train/epoch,15.0
train/global_step,1575.0
train/learning_rate,0.0
train/loss,0.2909
train/total_flos,7830699692544000.0
train/train_loss,0.30666


In [14]:
trainer.train(resume_from_checkpoint="gpt2_append_autoreg_str_results/checkpoint-1260")

[34m[1mwandb[0m: Currently logged in as: [33makhare[0m ([33mtransformer_friends[0m). Use [1m`wandb login --relogin`[0m to force relogin




KeyboardInterrupt: 

In [16]:
test_data[0]["prompt"]

'[RULES_START] [RULE_START] x0 , x6 -> x0 , x1 , x2 [RULE_END] [RULE_START] empty -> x0 , x2 [RULE_END] [RULE_START] x3 , x7 , x8 -> x2 , x4 , x5 , x7 , x8 [RULE_END] [RULE_START] x1 , x3 , x9 -> empty [RULE_END] [RULE_START] x1 , x2 , x3 , x5 -> x1 , x2 , x5 [RULE_END] [RULE_START] x3 , x4 , x5 , x6 , x7 -> x1 , x7 [RULE_END] [RULE_START] x1 , x4 , x7 -> empty [RULE_END] [RULE_START] empty -> x2 , x5 , x8 [RULE_END] [RULE_START] x3 , x4 , x8 -> x0 , x6 [RULE_END] [RULE_START] x0 , x2 -> x6 [RULE_END] [RULES_END] [STATES_START] [STATE_START] x3 , x8 [STATE_END] [STATE_START]'

In [17]:
test_data_new = test_data[0]["input_ids"].to("cuda")

KeyError: 'input_ids'

In [18]:
ip = tokenizer(test_data[1]["prompt"], return_tensors="pt")["input_ids"].to("cuda")
model_cpu = model.to("cpu")

In [49]:
model.eval()
input_ids = tokenizer(test_data[0]["prompt"], return_tensors="pt")["input_ids"]
output = model.to("cpu").generate(
        input_ids,
        max_new_tokens=200,
    )
output = output[0].to("cpu")

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [23]:
def check_similarity(target, pred):
    """
    Check if the final state of the target matches the final state of the prediction
    """

    state_start = "[STATE_START]"
    state_end = "[STATE_END]"
    final_target_state = target.split(state_start)[-1].split(state_end)[0].split(",")
    final_pred_state = pred.split(state_start)[-1].split(state_end)[0].split(",")
    final_target_state_set = set(state.strip() for state in final_target_state)
    final_pred_state_set = set(state.strip() for state in final_pred_state)
    print(final_target_state_set)
    print(final_pred_state_set)

    return final_target_state_set == final_pred_state_set

In [37]:
test_data[0]["target"]

' x0 , x2 , x3 , x5 , x8 [STATE_END] [STATE_START] x0 , x2 , x3 , x5 , x6 , x8 [STATE_END] [STATE_START] x0 , x1 , x2 , x3 , x5 , x6 , x8 [STATE_END] [STATES_END]'

In [22]:
check_similarity(test_data[0]["target"], tokenizer.decode(output[input_ids.shape[1]:], skip_special_tokens=True))

{'x3', 'x6', 'x8', 'x0', 'x5', 'x1', 'x2'}
{'x3', 'x6', 'x8', 'x0', 'x5', 'x2'}


False

In [24]:
n_correct = 0
pred_ids = []

for i in range(len(test_data)):
    input_ids = tokenizer(test_data[i]["prompt"], return_tensors="pt")["input_ids"]
    output = model.generate(
            input_ids,
            max_new_tokens=200,
        )
    output = output[0].to("cpu")
    pred_ids.append(output[input_ids.shape[1]:])
    

KeyboardInterrupt: 

In [26]:
len(pred_ids)

208

In [27]:
preds = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
targets = [item["target"] for item in test_data[:len(pred_ids)]]

is_correct = [check_similarity(target, pred) for (pred, target) in zip(preds, targets)]
n_correct = sum(is_correct)

# for pred_id in pred_ids:
#     pred = tokenizer.decode(output[input_ids.shape[1]:], skip_special_tokens=True)
#     target = test_data[i]["target"]
#     # break

#     # print(output)
#     # print(pred)
#     # print(target)
#     # break
#     if check_similarity(target, pred):
#         n_correct += 1

{'x5', 'x6', 'x2', 'x1', 'x3', 'x0', 'x8'}
{'x5', 'x2', 'x3', 'x0', 'x8'}
{'x5', 'x4', 'x1', 'x9', 'x0', 'x8'}
{'x5', 'x6', 'x1', 'x0', 'x8'}
{'x5', 'x6', 'x2', 'x7', 'x4', 'x3', 'x0'}
{'x5', 'x6', 'x2', 'x7', 'x4'}
{'x5', 'x6', 'x2', 'x7', 'x4', 'x1', 'x3', 'x9', 'x0', 'x8'}
{'x5', 'x6', 'x7', 'x1', 'x0'}
{'x5', 'x2', 'x4', 'x1', 'x3'}
{'x5', 'x2', 'x4', 'x1', 'x3', 'x8'}
{'x5', 'x1', 'x3', 'x9', 'x8'}
{'x5', 'x1', 'x3', 'x9', 'x8'}
{'x5', 'x2', 'x1', 'x3', 'x8'}
{'x5', 'x2', 'x1', 'x3', 'x8'}
{'x2', 'x7', 'x4', 'x1', 'x3', 'x0'}
{'x7', 'x4', 'x1', 'x3', 'x0'}
{'x5', 'x6', 'x4', 'x1', 'x3', 'x9'}
{'x5', 'x4', 'x1', 'x3', 'x'}
{'x0', 'x1', 'x9', 'x7'}
{'x0', 'x1', 'x9', 'x7'}
{'x9', 'x5', 'x1', 'x2'}
{'x9', 'x2'}
{'x5', 'x6', 'x2', 'x7', 'x4', 'x1', 'x3', 'x9', 'x0'}
{'x5', 'x4', 'x1', 'x3', 'x9', 'x0'}
{'x4', 'x0', 'x1', 'x3'}
{'x6', 'x4', 'x1', 'x3', 'x0'}
{'x5', 'x2', 'x7', 'x9', 'x0', 'x8'}
{'x5', 'x2', 'x7', 'x9', 'x0', 'x8'}
{'x5', 'x2', 'x3', 'x9', 'x0'}
{'x5', 'x2', 'x3', 'x9',

In [28]:
print(n_correct / len(test_data))

0.156


In [18]:
tokenizer.decode(output[input_ids.shape[1]:], skip_special_tokens=True)

' x2 [STATE_END] [STATE_START] x2, x3 [STATE_END] [STATE_START] x2, x3, x4 [STATE_END] [STATES_END] [STATES_START] [STATE_START] x2, x3, x4 [STATE_END] [STATES_END] [STATES_START] [STATE_START] x2, x3, x4 [STATE_END] [STATES_END] [STATES_START] [STATES_END] [STATES_START] [STATES_END] [STATES_START] [STATES_END] [STATES_START] [STATES_END] [STATES_START] [STATES_END] [STATES_START] [STATES_END] [STATES_START] [STATES'

In [31]:
model.generate(train_tokenized_dataset[0]['input_ids'].to("cuda"))



IndexError: too many indices for tensor of dimension 2

In [32]:
trainer.save_model("gpt2_append_autoreg_str_results")

output_dir = os.path.join("gpt2_append_autoreg_str_results", "final_checkpoint")
trainer.model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
