In [None]:
import os
import math

import submitit
from cupbearer import utils

In [None]:
os.chdir(f"/nas/ucb/{os.environ['USER']}/cupbearer")

# Train Measurement Predictor

In [None]:
# TOOD: refactor such that scripts function can be run directly by submitit
def train_classifier(log_path, lr=2e-5, warmup_steps=64, batch_size=16, accumulate_grad_batches=2, 
                     weight_decay=2e-2, num_epochs=1, precision="16-mixed", 
                     loss_weights=[0.7, 0.3], model_name="pythia-14m"):

    from cupbearer import data, detectors, models, scripts, tasks, utils
    import torch
    from torch.utils.data import DataLoader
    import torch.optim as optim
    import transformers
    from lightning.pytorch.callbacks import DeviceStatsMonitor # TODO: add

    transformer, tokenizer, emb_dim, max_len = models.transformers_hf.load_transformer(
       model_name
    )
    model = models.TamperingPredictionTransformer(
            model=transformer,
            embed_dim=emb_dim
        )
    tokenizer = model.set_tokenizer(tokenizer)

    train_data = data.TamperingDataset("diamonds", tokenizer=tokenizer, max_length=max_len, train=True)
    val_data = data.TamperingDataset("diamonds", tokenizer=tokenizer, max_length=max_len, train=False)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=1, shuffle=False)
    
    total_steps = num_epochs * len(train_loader)
    
    # loss function from measurement tampering paper # TODO: integrate this into library
    loss_func = lambda logits, labels: \
        torch.nn.functional.binary_cross_entropy_with_logits(logits[:, :3], labels[:, :3]) * loss_weights[0] + \
        torch.nn.functional.binary_cross_entropy_with_logits(logits[:, 3], labels[:, 3]) * loss_weights[1]

    # TODO: weighted loss average over individual measurements * .7 + loss on aggregate * .3
    
    return scripts.train_classifier( # NOTE: - paper uses 64 warmup steps, but seems hard
        path=log_path,
        model=model,
        train_loader=train_loader,
        task="multilabel",
        num_labels=4,
        val_loaders=val_loader,
        lr=lr,
        optim_builder=optim.AdamW,
        optim_conf={"weight_decay": weight_decay},
        lr_scheduler_conf={
            "num_warmup_steps": warmup_steps,
            "num_training_steps": total_steps
        },
        lr_scheduler_builder=transformers.optimization.get_cosine_schedule_with_warmup,
        max_epochs=num_epochs,
        wandb=False,
        callbacks=[DeviceStatsMonitor()],
        precision=precision,
        accumulate_grad_batches=accumulate_grad_batches,
        loss_func=loss_func
    )

In [None]:
# test distributed with smaller model
exp_dir = os.path.abspath(utils.log_path("logs/tampering/predictor"))
job_dir = os.path.join(exp_dir, "job")
os.makedirs(job_dir, exist_ok=True)

# job hypers
# gpus_per_node = 3
gres="gpu:A100-SXM4-80GB:1"
num_nodes = 1
mem_gb=80
time_min=60 * 10
qos = "high"

# train hypers
lr_base = 2e-5
batch_size_base = 32
precision="16-mixed"

grad_batch_size = 32
lr = lr_base * math.sqrt(batch_size_base / grad_batch_size) # maintain lr batch_size ratio

accumulate_grad_batches = 8
batch_size = grad_batch_size // accumulate_grad_batches # apply gradient accumulation

num_epochs = 5
model_name = "code-gen"

In [None]:
executor = submitit.AutoExecutor(folder=job_dir)
#TODO: add gpu memory required
executor.update_parameters(slurm_mem_gb=mem_gb,gres=gres, 
                           nodes=num_nodes, timeout_min=time_min, job_name="bash", qos=qos)
job = executor.submit(train_classifier, batch_size=batch_size,lr=lr, num_epochs=num_epochs,
                      precision=precision, accumulate_grad_batches=accumulate_grad_batches, 
                      model_name=model_name, log_path=exp_dir)

In [None]:
exp_dir

In [None]:
job.job_id

In [None]:
out = job.result()

# Eval Measurment Predictor

In [None]:
def eval_predictor(path, batch_size):
    from cupbearer import models, scripts, data
    transformer, tokenizer, emb_dim, max_len = models.transformers_hf.load_transformer(
        "code-gen"
    )
    model = models.TamperingPredictionTransformer(
            model=transformer,
            embed_dim=emb_dim
        )
    tokenizer = model.set_tokenizer(tokenizer)

    val_data = data.TamperingDataset("diamonds", tokenizer=tokenizer, max_length=max_len, train=False)
    untrusted_val_data = [el for el in val_data if not el["info"]["clean"]]

    scripts.eval_classifier(data=untrusted_val_data, model=model, path=path,
                            batch_size=batch_size)

In [None]:
exp_dir = "logs/tampering/predictor/2024-04-21_12-02-55"

In [None]:
eval_job_dir = os.path.join(exp_dir, "eval_job")
executor = submitit.AutoExecutor(folder=eval_job_dir)
executor.update_parameters(slurm_mem_gb=mem_gb,gres=gres, 
                           nodes=num_nodes, timeout_min=time_min, job_name="bash", qos=qos)
job = executor.submit(eval_predictor, path=exp_dir, batch_size=batch_size)

In [None]:
job.job_id

# Train Probe on Ground Truth

In [None]:
# TODO: train probe on ground truth (corrects) data as skyline