In [1]:
import os
import sys

import copy
import torch
import torch.nn as nn
from torch.utils.data import *
from transformers import Trainer, TrainingArguments
import inspect
sys.path.insert(0, "..")

from models import *
from logic import *
from my_datasets import *
from experiments import *

from utils import *
import numpy as np

# import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
qed_args = SyntheticOneShotArguments(
    num_rules = 8,
    num_vars = 5,
    model_name = "gpt2",
    embed_dim = 1024,
    num_layers = 12,
    num_heads = 8,

    train_len = 200,
    test_len = 100,
    ante_prob = 0.2,
    conseq_prob = 0.2,
    theorem_prob = 0.3,
    num_epochs = 3,
)

qed_trainer = AutoTrainer.for_synthetic(qed_args)
qed_trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mantonxue[0m ([33mtransformer_friends[0m). Use [1m`wandb login --relogin`[0m to force relogin




{'loss': 2.113, 'learning_rate': 4.358974358974359e-05, 'epoch': 0.38}
{'loss': 1.3705, 'learning_rate': 3.717948717948718e-05, 'epoch': 0.77}
{'eval_loss': 0.8365097045898438, 'eval_Accuracy': 0.6, 'eval_AvgOnes': 1.0, 'eval_runtime': 1.1403, 'eval_samples_per_second': 87.696, 'eval_steps_per_second': 6.139, 'epoch': 1.0}
{'loss': 0.8274, 'learning_rate': 3.0769230769230774e-05, 'epoch': 1.15}
{'loss': 0.6879, 'learning_rate': 2.435897435897436e-05, 'epoch': 1.54}
{'loss': 0.6216, 'learning_rate': 1.794871794871795e-05, 'epoch': 1.92}
{'eval_loss': 0.8140811920166016, 'eval_Accuracy': 0.6, 'eval_AvgOnes': 1.0, 'eval_runtime': 1.1406, 'eval_samples_per_second': 87.676, 'eval_steps_per_second': 6.137, 'epoch': 2.0}
{'loss': 0.5802, 'learning_rate': 1.153846153846154e-05, 'epoch': 2.31}
{'loss': 0.6969, 'learning_rate': 5.128205128205128e-06, 'epoch': 2.69}
{'eval_loss': 0.6334375143051147, 'eval_Accuracy': 0.6, 'eval_AvgOnes': 1.0, 'eval_runtime': 1.1244, 'eval_samples_per_second': 88.9

TrainOutput(global_step=39, training_loss=0.9411235711513422, metrics={'train_runtime': 24.7111, 'train_samples_per_second': 24.281, 'train_steps_per_second': 1.578, 'train_loss': 0.9411235711513422, 'epoch': 3.0})

In [3]:
next_args = SyntheticNextStateArguments(
    num_rules = 8,
    num_vars = 5,
    model_name = "gpt2",
    embed_dim = 1024,
    num_layers = 12,
    num_heads = 8,

    train_len = 200,
    test_len = 100,
    ante_prob = 0.2,
    conseq_prob = 0.2,
    state_prob = 0.3,

    num_epochs = 3,
)

next_trainer = AutoTrainer.for_synthetic(next_args)
next_trainer.train()

{'loss': 0.8221, 'learning_rate': 4.358974358974359e-05, 'epoch': 0.38}
{'loss': 0.6806, 'learning_rate': 3.717948717948718e-05, 'epoch': 0.77}
{'eval_loss': 0.536252498626709, 'eval_Accuracy': 0.766, 'eval_AvgOnes': 1.0, 'eval_runtime': 1.1429, 'eval_samples_per_second': 87.494, 'eval_steps_per_second': 6.125, 'epoch': 1.0}
{'loss': 0.5766, 'learning_rate': 3.0769230769230774e-05, 'epoch': 1.15}
{'loss': 0.5427, 'learning_rate': 2.435897435897436e-05, 'epoch': 1.54}
{'loss': 0.5568, 'learning_rate': 1.794871794871795e-05, 'epoch': 1.92}
{'eval_loss': 0.5275095105171204, 'eval_Accuracy': 0.766, 'eval_AvgOnes': 1.0, 'eval_runtime': 1.1077, 'eval_samples_per_second': 90.275, 'eval_steps_per_second': 6.319, 'epoch': 2.0}
{'loss': 0.5336, 'learning_rate': 1.153846153846154e-05, 'epoch': 2.31}
{'loss': 0.5138, 'learning_rate': 5.128205128205128e-06, 'epoch': 2.69}
{'eval_loss': 0.5054101347923279, 'eval_Accuracy': 0.766, 'eval_AvgOnes': 1.0, 'eval_runtime': 1.0848, 'eval_samples_per_second'

TrainOutput(global_step=39, training_loss=0.5964804184742463, metrics={'train_runtime': 18.407, 'train_samples_per_second': 32.596, 'train_steps_per_second': 2.119, 'train_loss': 0.5964804184742463, 'epoch': 3.0})

In [4]:
ars_args = SyntheticAutoRegKStepsArguments(
    num_rules = 8,
    num_vars = 5,
    num_steps = 3,
    model_name = "gpt2",
    embed_dim = 1024,
    num_layers = 12,
    num_heads = 8,

    train_len = 200,
    test_len = 100,
    ante_prob = 0.2,
    conseq_prob = 0.2,
    state_prob = 0.3,
    num_epochs = 3,
)

ars_trainer = AutoTrainer.for_synthetic(ars_args)
ars_trainer.train()

{'loss': 0.5186, 'learning_rate': 4.358974358974359e-05, 'epoch': 0.38}
{'loss': 0.5466, 'learning_rate': 3.717948717948718e-05, 'epoch': 0.77}
{'eval_loss': 0.5106473565101624, 'eval_Accuracy': 0.8026666666666666, 'eval_AvgOnes': 1.0, 'eval_runtime': 2.1845, 'eval_samples_per_second': 45.777, 'eval_steps_per_second': 3.204, 'epoch': 1.0}
{'loss': 0.5191, 'learning_rate': 3.0769230769230774e-05, 'epoch': 1.15}
{'loss': 0.5225, 'learning_rate': 2.435897435897436e-05, 'epoch': 1.54}
{'loss': 0.5333, 'learning_rate': 1.794871794871795e-05, 'epoch': 1.92}
{'eval_loss': 0.5106407403945923, 'eval_Accuracy': 0.8026666666666666, 'eval_AvgOnes': 1.0, 'eval_runtime': 2.1758, 'eval_samples_per_second': 45.961, 'eval_steps_per_second': 3.217, 'epoch': 2.0}
{'loss': 0.5158, 'learning_rate': 1.153846153846154e-05, 'epoch': 2.31}
{'loss': 0.525, 'learning_rate': 5.128205128205128e-06, 'epoch': 2.69}
{'eval_loss': 0.510637640953064, 'eval_Accuracy': 0.8026666666666666, 'eval_AvgOnes': 1.0, 'eval_runti

TrainOutput(global_step=39, training_loss=0.5250916114220252, metrics={'train_runtime': 37.2634, 'train_samples_per_second': 16.102, 'train_steps_per_second': 1.047, 'train_loss': 0.5250916114220252, 'epoch': 3.0})