In [1]:
from pprint import pprint
import gc

In [2]:
import torch
from disco_gp import DiscoGPTransformer, Config, set_seed

In [3]:
set_seed(42)

In [4]:
weight_hparams = Config(
    use_weight_masks=True,
    gs_temp_weight=0.01,
    logits_w_init=1.0,
    lr=0.1,
    lambda_sparse_init=1.0,
    lambda_complete_init=1.0,
    min_times_lambda_sparse=1.0,
    max_times_lambda_sparse=1000.0,
    train_epochs=500,
    n_epoch_warmup_lambda_sparse=500,
    n_epoch_cooldown_lambda_sparse=1,
)

In [5]:
edge_hparams = Config(
    use_edge_masks=True,
    gs_temp_edge=1.0,
    logits_e_init=1.0,
    lr=0.1,
    lambda_sparse_init=1.0,
    lambda_complete_init=0.0,
    min_times_lambda_sparse=0.01,
    max_times_lambda_sparse=100.0,
    train_epochs=100,
    n_epoch_warmup_lambda_sparse=20,
    n_epoch_cooldown_lambda_sparse=20,
)

In [6]:
# ioi task

task_cfg = Config(
    task_type="ioi",
    n_ioi_data=1000,
    batch_size=64,
    ds_split_ratios=(0.8, 0.1, 0.1)
)

In [7]:
exp_cfg = Config(
    evaluate_every = 1,
    output_dir_path="./outputs",
    exp_name="quickstart",
)

In [8]:
model_cfg = Config.from_tl("gpt2", dtype=torch.bfloat16)

In [9]:
cfg = Config.from_configs(
    weight=weight_hparams,
    edge=edge_hparams,
    task = task_cfg,
    model = model_cfg,
    exp = exp_cfg,
)

In [10]:
model = DiscoGPTransformer.from_pretrained(cfg)

`torch_dtype` is deprecated! Use `dtype` instead!


cfg name: gpt2
Loaded pretrained model gpt2 into HookedTransformer


In [11]:
model.setup_experiment()

In [12]:
model.evaluate_and_report(epoch = 0, mode = "baseline")



{'eval': {'acc': 0.82,
          'comp': 0.54,
          'edge_density': 1.0,
          'epoch': 0,
          'faith_loss': 0.37890625,
          'kl': 0.0,
          'n_correct': 82,
          'prune_mode': 'baseline',
          'total': 100,
          'weight_density': 1.0},
 'test': {'acc': 0.88,
          'comp': 0.52,
          'edge_density': 1.0,
          'epoch': 0,
          'faith_loss': 0.3046875,
          'kl': 0.0,
          'n_correct': 88,
          'prune_mode': 'baseline',
          'total': 100,
          'weight_density': 1.0},
 'train': {'acc': 0.85,
           'comp': 0.47875,
           'edge_density': 1.0,
           'epoch': 0,
           'faith_loss': 0.388671875,
           'kl': 0.0,
           'n_correct': 680,
           'prune_mode': 'baseline',
           'total': 800,
           'weight_density': 1.0}}


In [None]:
model.search()

In [14]:
model.evaluate_and_report(epoch="final", mode = "pruned")



{'eval': {'acc': 0.55,
          'comp': 0.52,
          'edge_density': 0.04902896285057068,
          'epoch': 'final',
          'faith_loss': 1.03125,
          'kl': 0.5703125,
          'n_correct': 55,
          'prune_mode': 'pruned',
          'total': 100,
          'weight_density': 0.05828766152262688},
 'test': {'acc': 0.52,
          'comp': 0.46,
          'edge_density': 0.04902896285057068,
          'epoch': 'final',
          'faith_loss': 1.2421875,
          'kl': 0.640625,
          'n_correct': 52,
          'prune_mode': 'pruned',
          'total': 100,
          'weight_density': 0.05828766152262688},
 'train': {'acc': 0.5,
           'comp': 0.47625,
           'edge_density': 0.04902896285057068,
           'epoch': 'final',
           'faith_loss': 1.3671875,
           'kl': 0.76953125,
           'n_correct': 400,
           'prune_mode': 'pruned',
           'total': 800,
           'weight_density': 0.05828766152262688}}
