In [1]:
import sys

import copy
import torch
import torch.nn as nn
from torch.utils.data import *
from transformers import Trainer, TrainingArguments
import inspect
sys.path.insert(0, "..")

from models import *
from logic import *
from my_datasets import *

from utils import *
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
n, r = 8, 20
# ap, bp, tp, sp = 0.2, 0.2, 0.4, 0.1
ap, bp, tp, sp = 0.2, 0.1, 0.2, 0.2

nars = 3

train_len = 100
test_len = 100
num_epochs = 2
test_is_train = False

In [3]:
qed_training_args = TrainingArguments(
    "test-trainer",
    evaluation_strategy = "epoch",
    num_train_epochs = num_epochs,
    per_device_train_batch_size = 64,
    per_device_eval_batch_size = 64,
    logging_steps = 5,
    report_to = "none",
)

succ_training_args = qed_training_args
ars_training_args = qed_training_args

In [4]:
qed_mytf = get_task_model(task_name="oneshot_qed", num_vars=n, model_name="mytf")
succ_mytf = get_task_model(task_name="predict_successor", num_vars=n, model_name="mytf")
ars_mytf = get_task_model(task_name="autoreg_fixed_steps", num_vars=n, num_steps=3, model_name="mytf")

In [5]:
qed_gpt2 = get_task_model(task_name="oneshot_qed", num_vars=n, model_name="gpt2")
succ_gpt2 = get_task_model(task_name="predict_successor", num_vars=n, model_name="gpt2")
ars_gpt2 = get_task_model(task_name="autoreg_fixed_steps", num_vars=n, num_steps=nars, model_name="gpt2")

In [6]:
qed_bert = get_task_model(task_name="oneshot_qed", num_vars=n, model_name="bert")
succ_bert = get_task_model(task_name="predict_successor", num_vars=n, model_name="bert")
ars_bert = get_task_model(task_name="autoreg_fixed_steps", num_vars=n, num_steps=nars, model_name="bert")

In [7]:
### Llama is quite big and blows up the RAM
# qed_llama = get_task_model(task_name="oneshot_qed", num_vars=n, model_name="code_llama", num_layers=8)
# succ_llama = get_task_model(task_name="predict_successor", num_vars=n, model_name="code_llama", num_layers=8)
# ars_llama = get_task_model(task_name="autoreg_fixed_steps", num_vars=n, num_steps=nars, model_name="code_llama", num_layers=8)

In [8]:
### Datasets
qed_train_dataset_config = OneShotQedDatasetConfig(r,n,ap,bp,tp,dataset_len=train_len,seed=1234)
qed_test_dataset_config = OneShotQedDatasetConfig(r,n,ap,bp,tp,dataset_len=test_len,seed=2345)
qed_train_dataset = OneShotQedEmbedsDataset(qed_train_dataset_config)
qed_test_dataset = OneShotQedEmbedsDataset(qed_test_dataset_config)

succ_train_dataset_config = OneStepStateDatasetConfig(r,n,ap,bp,tp,dataset_len=train_len,seed=1234)
succ_test_dataset_config = OneStepStateDatasetConfig(r,n,ap,bp,tp,dataset_len=test_len,seed=2345)
succ_train_dataset = OneStepStateEmbedsDataset(succ_train_dataset_config)
succ_test_dataset = OneStepStateEmbedsDataset(succ_test_dataset_config)

ars_train_dataset_config = AutoRegFixedStepsDatasetConfig(r,n,ap,bp,sp,nars,dataset_len=train_len,seed=1234)
ars_test_dataset_config = AutoRegFixedStepsDatasetConfig(r,n,ap,bp,sp,nars,dataset_len=test_len,seed=2345)
ars_train_dataset = AutoRegFixedStepsEmbedsDataset(ars_train_dataset_config)
ars_test_dataset = AutoRegFixedStepsEmbedsDataset(ars_test_dataset_config)

In [9]:
### SUCC MyTF
trainer_succ_mytf = Trainer(succ_mytf, succ_training_args,
    train_dataset = succ_train_dataset,
    eval_dataset = succ_train_dataset if test_is_train else succ_test_dataset,
    compute_metrics = succ_compute_metrics)
trainer_succ_mytf.train()

{'eval_loss': 0.7189112901687622, 'eval_Accuracy': 0.57625, 'eval_Avg Ones': 0.875, 'eval_runtime': 5.0392, 'eval_samples_per_second': 19.844, 'eval_steps_per_second': 0.397, 'epoch': 1.0}
{'eval_loss': 0.6944872140884399, 'eval_Accuracy': 0.57625, 'eval_Avg Ones': 0.875, 'eval_runtime': 4.7431, 'eval_samples_per_second': 21.083, 'eval_steps_per_second': 0.422, 'epoch': 2.0}
{'train_runtime': 32.1865, 'train_samples_per_second': 6.214, 'train_steps_per_second': 0.124, 'train_loss': 0.7037481069564819, 'epoch': 2.0}


TrainOutput(global_step=4, training_loss=0.7037481069564819, metrics={'train_runtime': 32.1865, 'train_samples_per_second': 6.214, 'train_steps_per_second': 0.124, 'train_loss': 0.7037481069564819, 'epoch': 2.0})

In [None]:
### SUCC GPT2
trainer_succ_gpt2 = Trainer(succ_gpt2, succ_training_args,
    train_dataset = succ_train_dataset,
    eval_dataset = succ_train_dataset if test_is_train else succ_test_dataset,
    compute_metrics = succ_compute_metrics)
trainer_succ_gpt2.train()

In [None]:
### SUCC Bert
trainer_succ_bert = Trainer(succ_bert, succ_training_args,
    train_dataset = succ_train_dataset,
    eval_dataset = succ_train_dataset if test_is_train else succ_test_dataset,
    compute_metrics = succ_compute_metrics)
trainer_succ_bert.train()

In [None]:
####################
### AR Steps

In [None]:
### AR Steps MyTF
trainer_ars_mytf = Trainer(ars_mytf, ars_training_args,
    train_dataset = ars_train_dataset,
    eval_dataset = ars_train_dataset if test_is_train else ars_test_dataset,
    compute_metrics = succ_compute_metrics)
trainer_ars_mytf.train()

In [None]:
### AR Steps GPT2
trainer_ars_gpt2 = Trainer(ars_gpt2, ars_training_args,
    train_dataset = ars_train_dataset,
    eval_dataset = ars_train_dataset if test_is_train else ars_test_dataset,
    compute_metrics = succ_compute_metrics)
trainer_ars_gpt2.train()

In [None]:
### AR Steps Bert
trainer_ars_bert = Trainer(ars_bert, ars_training_args,
    train_dataset = ars_train_dataset,
    eval_dataset = ars_train_dataset if test_is_train else ars_test_dataset,
    compute_metrics = succ_compute_metrics)
trainer_ars_bert.train()

In [None]:
####################
### QED

In [None]:
### QED MyTf
trainer_qed_mytf = Trainer(qed_mytf, qed_training_args,
    train_dataset = qed_train_dataset,
    eval_dataset = qed_train_dataset if test_is_train else qed_test_dataset,
    compute_metrics = qed_compute_metrics)
trainer_qed_mytf.train()

In [None]:
### QED GPT2
trainer_qed_gpt2 = Trainer(qed_gpt2, qed_training_args,
    train_dataset = qed_train_dataset,
    eval_dataset = qed_train_dataset if test_is_train else qed_test_dataset,
    compute_metrics = qed_compute_metrics)
trainer_qed_gpt2.train()

In [None]:
### QED Bert
trainer_qed_bert = Trainer(qed_bert, qed_training_args,
    train_dataset = qed_train_dataset,
    eval_dataset = qed_train_dataset if test_is_train else qed_test_dataset,
    compute_metrics = qed_compute_metrics)
trainer_qed_bert.train()