# Background / Intro

This notebook trains a measurement predictor on trusted and untrusted data in the diamonds data-set
trains a "ground-truth" probe using the ground truth diamond labels, and evaluates 
a "confidence-based detector" -using confidence in prediction of aggregated measurments
as a detector for anomolous examples

In [None]:
import os
from cupbearer import data, detectors, models, scripts, tasks, utils
from torch.utils.data import DataLoader
import transformers
import torch
import submitit

In [None]:
LOCAL_HPARAMS = {
    "model": "pythia-14m",
    "batch_size_on_device": 4,
    "num_epochs": 1,
    "dataset_len": 2,
    "slurm_params": {}
}
REAL_HPARAMS = {
    "model": "code-gen",
    "batch_size_on_device": 4, 
    "num_epochs": 5, 
    "dataset_len": None,
    "slurm_params": {
        "slurm_mem_gb": 80, 
        "gres": "gpu:A100-SXM4-80GB:1",
        "nodes": 1, 
        "timeout_min": 60 * 10,
        "job_name": "bash",
        "qos": "high"
    }
}

HPARAMS = REAL_HPARAMS

# Model

In [None]:
transformer, tokenizer, emb_dim, max_len = models.transformers_hf.load_transformer(
    HPARAMS["model"]
)
model = models.TamperingPredictionTransformer(
        model=transformer,
        embed_dim=emb_dim
    )
tokenizer = model.set_tokenizer(tokenizer)

# Data

In [None]:
train_data = data.TamperingDataset("diamonds", tokenizer=tokenizer, max_length=max_len, 
                                   train=True, dataset_len=HPARAMS["dataset_len"])
val_data = data.TamperingDataset("diamonds", tokenizer=tokenizer, max_length=max_len, 
                                 train=False, dataset_len=HPARAMS["dataset_len"])

# Set Experiment Directory

In [None]:
exp_dir = os.path.abspath(utils.log_path("logs/tampering/predictor"))

# Train Measurement Predictor

In [None]:
from lightning.pytorch.callbacks import DeviceStatsMonitor

In [None]:
train_pred_dir = os.path.join(exp_dir, "train_pred")
os.makedirs(train_pred_dir, exist_ok=True)

In [None]:
lr = 2e-5
weight_decay = 2e-2
num_warmup_steps = 64
batch_size_base = 32
precision="16-mixed"

batch_size_on_device = HPARAMS["batch_size_on_device"]
accumulate_grad_batches = batch_size_base // batch_size_on_device
num_epochs = HPARAMS["num_epochs"]
loss_weights = [0.7, 0.3]

In [None]:
train_loader = DataLoader(train_data, batch_size=batch_size_on_device, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size_on_device, shuffle=False)
total_steps = num_epochs * len(train_loader)

In [None]:
loss_func = lambda logits, labels: \
    torch.nn.functional.binary_cross_entropy_with_logits(logits[:, :3], labels[:, :3]) * loss_weights[0] + \
    torch.nn.functional.binary_cross_entropy_with_logits(logits[:, 3], labels[:, 3]) * loss_weights[1]

In [None]:
executor = submitit.AutoExecutor(folder=train_pred_dir)
executor.update_parameters(**HPARAMS["slurm_params"])

In [None]:
job = executor.submit(scripts.train_classifier,
    path=exp_dir,
    model=model,
    train_loader=train_loader,
    task="multilabel",
    num_labels=4,
    val_loaders=val_loader,
    optim_builder=torch.optim.AdamW,
    optim_conf={"lr": lr, "weight_decay": weight_decay},
    lr_scheduler_conf={
        "num_warmup_steps": num_warmup_steps,
        "total_steps": total_steps
    },
    lr_scheduler_builder=scripts.lr_scheduler.CosineWarmupScheduler,
    max_epochs=num_epochs,
    wandb=False,
    callbacks=[DeviceStatsMonitor()],
    precision=precision,
    accumulate_grad_batches=accumulate_grad_batches,
    loss_func=loss_func
)

In [None]:
job.result()

# Eval Measurement Predictor

In [None]:
eval_pred_dir = os.path.join(exp_dir, "eval_job")

In [None]:
val_data_dirty = [el for el in val_data if not el["info"]["clean"]]

In [None]:
executor = submitit.AutoExecutor(folder=eval_pred_dir)
executor.update_parameters(**HPARAMS["slurm_params"])

In [None]:
eval_pred_job = executor.submit(scripts.eval_classifier(
    data=val_data_dirty,
    model=model, 
    path=exp_dir,
    batch_size=HPARAMS["batch_size_on_device"]
))

# Train Ground-Truth Probe

# Eval Ground-Truth Probe