In [1]:
import os
import sys

import copy
import torch
import torch.nn as nn
from torch.utils.data import *
from transformers import Trainer, TrainingArguments
import inspect
sys.path.insert(0, "..")

from models import *
from logic import *
from my_datasets import *

from utils import *
import numpy as np

# import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# # set the wandb project where this run will be logged
# os.environ["WANDB_PROJECT"] = "transformer_friends"
# os.environ["WANDB_LOG_MODEL"] = "checkpoint" # log all model checkpoints

# # save your trained model checkpoint to wandb
# os.environ["WANDB_LOG_MODEL"] = "true"

# # turn off watch to log faster
# os.environ["WANDB_WATCH"] = "false"

In [3]:
n, r = 8, 20
# ap, bp, tp, sp = 0.2, 0.2, 0.4, 0.1
ap, bp, tp, sp = 0.2, 0.1, 0.2, 0.2

nars = 3

train_len = 100
test_len = 100
num_epochs = 2
test_is_train = False

In [4]:
qed_training_args = TrainingArguments(
    "test-trainer",
    evaluation_strategy = "epoch",
    num_train_epochs = num_epochs,
    per_device_train_batch_size = 64,
    per_device_eval_batch_size = 64,
    logging_steps = 5,
    # report_to = "all",
    report_to = "none"
)

succ_training_args = qed_training_args
ars_training_args = qed_training_args

In [5]:
qed_mytf = AutoTaskModel.from_kwargs("one_shot_qed", "mytf", num_vars=n, num_layers=8)
succ_mytf = AutoTaskModel.from_kwargs("one_step_state", "mytf", num_vars=n, num_layers=8)
ars_mytf = AutoTaskModel.from_kwargs("autoreg_fixed_steps", "mytf", num_vars=n, num_steps=3, num_layers=8)

In [6]:
qed_gpt2 = AutoTaskModel.from_kwargs("one_shot_qed", "gpt2", num_vars=n, num_layers=8, num_heads=8, embed_dim=1024)
succ_gpt2 = AutoTaskModel.from_kwargs("one_step_state", "gpt2", num_vars=n, num_layers=8, num_heads=8, embed_dim=1024)
ars_gpt2 = AutoTaskModel.from_kwargs("autoreg_fixed_steps", "gpt2", num_vars=n, num_steps=3, num_layers=8, num_heads=8, embed_dim=1024)

In [7]:
qed_bert = AutoTaskModel.from_kwargs("one_shot_qed", "bert", num_vars=n, num_layers=8, num_heads=8, embed_dim=1024)
succ_bert = AutoTaskModel.from_kwargs("one_step_state", "bert", num_vars=n, num_layers=8, num_heads=8, embed_dim=1024)
ars_bert = AutoTaskModel.from_kwargs("autoreg_fixed_steps", "bert", num_vars=n, num_steps=3, num_layers=8, num_heads=8, embed_dim=1024)

In [8]:
### Llama is quite big and blows up the RAM
# qed_llama = get_task_model(task_name="oneshot_qed", num_vars=n, model_name="code_llama", num_layers=8)
# succ_llama = get_task_model(task_name="predict_successor", num_vars=n, model_name="code_llama", num_layers=8)
# ars_llama = get_task_model(task_name="autoreg_fixed_steps", num_vars=n, num_steps=nars, model_name="code_llama", num_layers=8)

In [9]:
### Datasets
qed_train_dataset = OneShotQedEmbedsDataset(r,n,ap,bp,tp,dataset_len=train_len,seed=1234)
qed_test_dataset = OneShotQedEmbedsDataset(r,n,ap,bp,tp,dataset_len=test_len,seed=2345)

succ_train_dataset = OneStepStateEmbedsDataset(r,n,ap,bp,tp,train_len,seed=1234)
succ_test_dataset = OneStepStateEmbedsDataset(r,n,ap,bp,tp,test_len,seed=2345)

ars_train_dataset = AutoRegFixedStepsEmbedsDataset(r,n,nars,ap,bp,sp,train_len,seed=1234)
ars_test_dataset = AutoRegFixedStepsEmbedsDataset(r,n,nars,ap,bp,sp,test_len,seed=2345)

In [10]:
### SUCC MyTF
trainer_succ_mytf = Trainer(succ_mytf, succ_training_args,
    train_dataset = succ_train_dataset,
    eval_dataset = succ_train_dataset if test_is_train else succ_test_dataset,
    compute_metrics = succ_compute_metrics)
trainer_succ_mytf.train()



{'eval_loss': 0.6984438300132751, 'eval_Accuracy': 0.56125, 'eval_Avg Ones': 0.75, 'eval_runtime': 0.0401, 'eval_samples_per_second': 2492.011, 'eval_steps_per_second': 24.92, 'epoch': 1.0}
{'eval_loss': 0.6963896751403809, 'eval_Accuracy': 0.58375, 'eval_Avg Ones': 1.0, 'eval_runtime': 0.0461, 'eval_samples_per_second': 2167.386, 'eval_steps_per_second': 21.674, 'epoch': 2.0}
{'train_runtime': 2.4164, 'train_samples_per_second': 82.768, 'train_steps_per_second': 0.828, 'train_loss': 0.7030345797538757, 'epoch': 2.0}


TrainOutput(global_step=2, training_loss=0.7030345797538757, metrics={'train_runtime': 2.4164, 'train_samples_per_second': 82.768, 'train_steps_per_second': 0.828, 'train_loss': 0.7030345797538757, 'epoch': 2.0})

In [11]:
### SUCC GPT2
trainer_succ_gpt2 = Trainer(succ_gpt2, succ_training_args,
    train_dataset = succ_train_dataset,
    eval_dataset = succ_train_dataset if test_is_train else succ_test_dataset,
    compute_metrics = succ_compute_metrics)
trainer_succ_gpt2.train()

{'eval_loss': 1.0251109600067139, 'eval_Accuracy': 0.54625, 'eval_Avg Ones': 0.75, 'eval_runtime': 0.456, 'eval_samples_per_second': 219.286, 'eval_steps_per_second': 2.193, 'epoch': 1.0}
{'eval_loss': 0.9728716015815735, 'eval_Accuracy': 0.57625, 'eval_Avg Ones': 0.8125, 'eval_runtime': 0.4594, 'eval_samples_per_second': 217.654, 'eval_steps_per_second': 2.177, 'epoch': 2.0}
{'train_runtime': 3.0947, 'train_samples_per_second': 64.627, 'train_steps_per_second': 0.646, 'train_loss': 0.8405940532684326, 'epoch': 2.0}


TrainOutput(global_step=2, training_loss=0.8405940532684326, metrics={'train_runtime': 3.0947, 'train_samples_per_second': 64.627, 'train_steps_per_second': 0.646, 'train_loss': 0.8405940532684326, 'epoch': 2.0})

In [12]:
### SUCC Bert
trainer_succ_bert = Trainer(succ_bert, succ_training_args,
    train_dataset = succ_train_dataset,
    eval_dataset = succ_train_dataset if test_is_train else succ_test_dataset,
    compute_metrics = succ_compute_metrics)
trainer_succ_bert.train()

{'eval_loss': 0.8449962735176086, 'eval_Accuracy': 0.57625, 'eval_Avg Ones': 0.875, 'eval_runtime': 0.0688, 'eval_samples_per_second': 1453.237, 'eval_steps_per_second': 14.532, 'epoch': 1.0}
{'eval_loss': 0.7907512187957764, 'eval_Accuracy': 0.57625, 'eval_Avg Ones': 0.875, 'eval_runtime': 0.074, 'eval_samples_per_second': 1351.166, 'eval_steps_per_second': 13.512, 'epoch': 2.0}
{'train_runtime': 0.4287, 'train_samples_per_second': 466.554, 'train_steps_per_second': 4.666, 'train_loss': 0.7477405071258545, 'epoch': 2.0}


TrainOutput(global_step=2, training_loss=0.7477405071258545, metrics={'train_runtime': 0.4287, 'train_samples_per_second': 466.554, 'train_steps_per_second': 4.666, 'train_loss': 0.7477405071258545, 'epoch': 2.0})

In [13]:
####################
### AR Steps

In [14]:
### AR Steps MyTF
trainer_ars_mytf = Trainer(ars_mytf, ars_training_args,
    train_dataset = ars_train_dataset,
    eval_dataset = ars_train_dataset if test_is_train else ars_test_dataset,
    compute_metrics = succ_compute_metrics)
trainer_ars_mytf.train()

{'eval_loss': 0.6328577995300293, 'eval_Accuracy': 0.60625, 'eval_Avg Ones': 0.75, 'eval_runtime': 0.0699, 'eval_samples_per_second': 1429.883, 'eval_steps_per_second': 14.299, 'epoch': 1.0}
{'eval_loss': 0.6325845122337341, 'eval_Accuracy': 0.60625, 'eval_Avg Ones': 0.75, 'eval_runtime': 0.0683, 'eval_samples_per_second': 1464.052, 'eval_steps_per_second': 14.641, 'epoch': 2.0}
{'train_runtime': 0.4304, 'train_samples_per_second': 464.662, 'train_steps_per_second': 4.647, 'train_loss': 0.6299412846565247, 'epoch': 2.0}


TrainOutput(global_step=2, training_loss=0.6299412846565247, metrics={'train_runtime': 0.4304, 'train_samples_per_second': 464.662, 'train_steps_per_second': 4.647, 'train_loss': 0.6299412846565247, 'epoch': 2.0})

In [15]:
### AR Steps GPT2
trainer_ars_gpt2 = Trainer(ars_gpt2, ars_training_args,
    train_dataset = ars_train_dataset,
    eval_dataset = ars_train_dataset if test_is_train else ars_test_dataset,
    compute_metrics = succ_compute_metrics)
trainer_ars_gpt2.train()

{'eval_loss': 0.626973569393158, 'eval_Accuracy': 0.6466666666666666, 'eval_Avg Ones': 0.81125, 'eval_runtime': 1.5751, 'eval_samples_per_second': 63.488, 'eval_steps_per_second': 0.635, 'epoch': 1.0}
{'eval_loss': 0.626487672328949, 'eval_Accuracy': 0.6341666666666667, 'eval_Avg Ones': 0.7920833333333334, 'eval_runtime': 1.2255, 'eval_samples_per_second': 81.6, 'eval_steps_per_second': 0.816, 'epoch': 2.0}
{'train_runtime': 9.5967, 'train_samples_per_second': 20.841, 'train_steps_per_second': 0.208, 'train_loss': 0.630977988243103, 'epoch': 2.0}


TrainOutput(global_step=2, training_loss=0.630977988243103, metrics={'train_runtime': 9.5967, 'train_samples_per_second': 20.841, 'train_steps_per_second': 0.208, 'train_loss': 0.630977988243103, 'epoch': 2.0})

In [16]:
### AR Steps Bert
trainer_ars_bert = Trainer(ars_bert, ars_training_args,
    train_dataset = ars_train_dataset,
    eval_dataset = ars_train_dataset if test_is_train else ars_test_dataset,
    compute_metrics = succ_compute_metrics)
trainer_ars_bert.train()

{'eval_loss': 0.6326218247413635, 'eval_Accuracy': 0.64375, 'eval_Avg Ones': 0.875, 'eval_runtime': 0.1091, 'eval_samples_per_second': 916.285, 'eval_steps_per_second': 9.163, 'epoch': 1.0}
{'eval_loss': 0.6330364942550659, 'eval_Accuracy': 0.63875, 'eval_Avg Ones': 0.8033333333333333, 'eval_runtime': 0.1112, 'eval_samples_per_second': 899.322, 'eval_steps_per_second': 8.993, 'epoch': 2.0}
{'train_runtime': 0.6885, 'train_samples_per_second': 290.488, 'train_steps_per_second': 2.905, 'train_loss': 0.6294896602630615, 'epoch': 2.0}


TrainOutput(global_step=2, training_loss=0.6294896602630615, metrics={'train_runtime': 0.6885, 'train_samples_per_second': 290.488, 'train_steps_per_second': 2.905, 'train_loss': 0.6294896602630615, 'epoch': 2.0})

In [17]:
####################
### QED

In [18]:
### QED MyTf
trainer_qed_mytf = Trainer(qed_mytf, qed_training_args,
    train_dataset = qed_train_dataset,
    eval_dataset = qed_train_dataset if test_is_train else qed_test_dataset,
    compute_metrics = qed_compute_metrics)
trainer_qed_mytf.train()

{'eval_loss': 1.586868166923523, 'eval_Accuracy': 0.51, 'eval_Avg Ones': 1.0, 'eval_runtime': 0.0759, 'eval_samples_per_second': 1318.006, 'eval_steps_per_second': 13.18, 'epoch': 1.0}
{'eval_loss': 1.505886435508728, 'eval_Accuracy': 0.51, 'eval_Avg Ones': 1.0, 'eval_runtime': 0.0764, 'eval_samples_per_second': 1308.169, 'eval_steps_per_second': 13.082, 'epoch': 2.0}
{'train_runtime': 0.3791, 'train_samples_per_second': 527.603, 'train_steps_per_second': 5.276, 'train_loss': 1.0853276252746582, 'epoch': 2.0}


TrainOutput(global_step=2, training_loss=1.0853276252746582, metrics={'train_runtime': 0.3791, 'train_samples_per_second': 527.603, 'train_steps_per_second': 5.276, 'train_loss': 1.0853276252746582, 'epoch': 2.0})

In [19]:
### QED GPT2
trainer_qed_gpt2 = Trainer(qed_gpt2, qed_training_args,
    train_dataset = qed_train_dataset,
    eval_dataset = qed_train_dataset if test_is_train else qed_test_dataset,
    compute_metrics = qed_compute_metrics)
trainer_qed_gpt2.train()

{'eval_loss': 5.059472560882568, 'eval_Accuracy': 0.49, 'eval_Avg Ones': 0.0, 'eval_runtime': 0.4848, 'eval_samples_per_second': 206.288, 'eval_steps_per_second': 2.063, 'epoch': 1.0}
{'eval_loss': 4.66595458984375, 'eval_Accuracy': 0.49, 'eval_Avg Ones': 0.0, 'eval_runtime': 0.4906, 'eval_samples_per_second': 203.812, 'eval_steps_per_second': 2.038, 'epoch': 2.0}
{'train_runtime': 3.1595, 'train_samples_per_second': 63.302, 'train_steps_per_second': 0.633, 'train_loss': 3.046811580657959, 'epoch': 2.0}


TrainOutput(global_step=2, training_loss=3.046811580657959, metrics={'train_runtime': 3.1595, 'train_samples_per_second': 63.302, 'train_steps_per_second': 0.633, 'train_loss': 3.046811580657959, 'epoch': 2.0})

In [20]:
### QED Bert
trainer_qed_bert = Trainer(qed_bert, qed_training_args,
    train_dataset = qed_train_dataset,
    eval_dataset = qed_train_dataset if test_is_train else qed_test_dataset,
    compute_metrics = qed_compute_metrics)
trainer_qed_bert.train()

{'eval_loss': 3.074585199356079, 'eval_Accuracy': 0.49, 'eval_Avg Ones': 0.0, 'eval_runtime': 0.106, 'eval_samples_per_second': 943.286, 'eval_steps_per_second': 9.433, 'epoch': 1.0}
{'eval_loss': 2.505542516708374, 'eval_Accuracy': 0.49, 'eval_Avg Ones': 0.0, 'eval_runtime': 0.1135, 'eval_samples_per_second': 881.253, 'eval_steps_per_second': 8.813, 'epoch': 2.0}
{'train_runtime': 0.5792, 'train_samples_per_second': 345.331, 'train_steps_per_second': 3.453, 'train_loss': 1.950005292892456, 'epoch': 2.0}


TrainOutput(global_step=2, training_loss=1.950005292892456, metrics={'train_runtime': 0.5792, 'train_samples_per_second': 345.331, 'train_steps_per_second': 3.453, 'train_loss': 1.950005292892456, 'epoch': 2.0})