In [1]:
%load_ext autoreload
%autoreload 2
from train import Trainer, TrainArgs
from model import ModelArgs, ModelArgsIterator
from dataset import BalancedParenthesisDataGenerator, MaxValueDataGenerator
from backdoor_dataset import BackdoorFactory, ReverseLabelModifier, StartingNumberTrigger, \
    StartingNumberForBalancedParenthesisTrigger

## Parenthesis Balancer

### Training Final Models

In [2]:
BackdoorDataGen = BackdoorFactory(
    data_gen_cls=BalancedParenthesisDataGenerator,
    trigger_cls_list=[StartingNumberForBalancedParenthesisTrigger],
    label_mod_cls_list=[ReverseLabelModifier],
).create_backdoor_data_generator_class()

In [None]:
train_args_full = TrainArgs(epochs=15, batch_size=1024)
model_args = ModelArgs(n_layers=2, n_heads=2, d_model=32)

data_gen = BalancedParenthesisDataGenerator(n_ctx_numeric=20)
data_gen_backdoor = BackdoorDataGen(n_ctx_numeric=20)

trainer = Trainer(data_gen, model_args, train_args_full)
trainer.train()
trainer.save_model(task_name='bal_paren_20', dir='./models/final')

trainer = Trainer(data_gen_backdoor, model_args, train_args_full)
trainer.train()
trainer.save_model(task_name='bal_paren_20_bdoor', dir='./models/final')

### Testing several architectures

In [2]:
train_args_test = TrainArgs(epochs=1, trainset_size=4*1024 ,valset_size=1024)
train_args_full = TrainArgs(epochs=10)

In [8]:
data_gen = BalancedParenthesisDataGenerator(n_ctx_numeric=20)
model_args_iterator = ModelArgsIterator(
    n_layers=[2, 3],
    n_heads=[1, 2, 4],
    d_model=[32, 64, 256],
    attn_only=[False]
)

for model_args in model_args_iterator:
    trainer = Trainer(data_gen, model_args, train_args_full)
    trainer.train()
    trainer.save_model(task_name='bal_paren_20')

In [None]:
BackdoorDataGen = BackdoorFactory(
    data_gen_cls=BalancedParenthesisDataGenerator,
    trigger_cls_list=[StartingNumberForBalancedParenthesisTrigger],
    label_mod_cls_list=[ReverseLabelModifier],
).create_backdoor_data_generator_class()
data_gen = BackdoorDataGen(n_ctx_numeric=20)

for model_args in model_args_iterator:
    trainer = Trainer(data_gen, model_args, train_args_full)
    trainer.train()
    trainer.save_model(task_name='bal_paren_20_bdoor')

In [None]:
# import torch
# import plotly.express as px

# toks, labels = data_gen.create_dataset(batch_size=1000, seed=300, device='cuda')[:100]

# model = trainer.model
# logits = model(toks)[:, data_gen.pos_label]
# logprobs = torch.log_softmax(logits, dim=-1)

# logprobs_correct = logprobs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
# logprobs_incorrect = logprobs.gather(dim=-1, index=(1-labels).unsqueeze(-1)).squeeze(-1)

# logprobs_correct_cpu = logprobs_correct.cpu().detach()
# logprobs_incorrect_cpu = logprobs_incorrect.cpu().detach()

# px.histogram(torch.cat([logprobs_correct_cpu, logprobs_incorrect_cpu], dim=1), nbins=100)