# Setup

In [None]:
!pip -q install transformers accelerate datasets essential-generators bitsandbytes tqdm google-generativeai tiktoken orjson tenacity pandasgui wandb

In [None]:

#@title Clone GitHub repo

import os, shutil, getpass
from google.colab import drive

update_repo_copy = True #@param {type:"boolean"}
REPO_NAME = "shortcut-llm-icl" #@param {type:"string"}

DIR_NAME = 'Tesi Computer Science/ShortcutProject'  #@param {type:"string"}
DRIVE_PATH = '/content/drive/MyDrive/' + DIR_NAME + '/'
TARGET_DIR = os.path.join(DRIVE_PATH, REPO_NAME)

drive.mount('/content/drive')

if update_repo_copy or not os.path.exists(TARGET_DIR):
  GITHUB_USER = input("Enter GitHub username: ").strip()
  GITHUB_TOKEN = getpass.getpass("Enter GitHub token: ").strip()
  GITHUB_URL = f"https://{GITHUB_USER}:{GITHUB_TOKEN}@github.com/{GITHUB_USER}/{REPO_NAME}.git"
  TEMP_CLONE_DIR = f"/content/{REPO_NAME}"

  if os.path.exists(TEMP_CLONE_DIR):
      shutil.rmtree(TEMP_CLONE_DIR)

  print(f"Cloning {REPO_NAME} into Colab RAM...")
  exit_code = os.system(f'git clone "{GITHUB_URL}" "{TEMP_CLONE_DIR}"')

  if exit_code == 0:
      print(f"Copying to Google Drive → {TARGET_DIR}")
      if os.path.exists(TARGET_DIR):
          shutil.rmtree(TARGET_DIR)
      shutil.copytree(TEMP_CLONE_DIR, TARGET_DIR)

      # remove repo from RAM to save space
      shutil.rmtree(TEMP_CLONE_DIR)
      print("✅ Done.")
  else:
      print("❌ Clone failed. Check token, username or repo visibility.")

%cd "{TARGET_DIR}"
!ls

In [None]:
#@title Imports
import getpass
import shlex
import os
import torch
import wandb
from patched_unibias import WB_logging as L
from extract_activations import HuggingFaceLLM, RepE_evaluation, ShortcutAggregation

HuggingFaceLLM.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#@title Login Hugging Face
os.environ["HF_TOKEN"] = getpass.getpass("Enter Hugging Face token: ")

In [None]:
#@title Login Weights&Bias
os.environ["WANDB_API_KEY"] = getpass.getpass("Enter W&B API key: ")
!wandb login $WANDB_API_KEY

# Evaluate complete RepE pipeline

In [None]:
#@title Load target LLM
TARGET_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.1" #@param {type: "string"}
MODEL_WRAP = HuggingFaceLLM(TARGET_MODEL_NAME, os.environ["HF_TOKEN"], quantize=True)

In [None]:
#@title Setup parameters

REPO_PATH = TARGET_DIR
OVERWRITE_DATASET_WB = False #@param {type: "boolean"}
OVERWRITE_ACTIVATIONS_WB = False #@param {type: "boolean"}

TRAINING_DATASET_NAME = "ShortcutSuite" #@param ["ShortcutSuite"]
NEGATION_SHORTCUT = True #@param {type: "boolean"}
POSITION_SHORTCUT = False #@param {type: "boolean"}
STYLE_SHORTCUT = False #@param {type: "boolean"}

TRAINING_DATASET_SHORTCUTS = []
if NEGATION_SHORTCUT:
  TRAINING_DATASET_SHORTCUTS.append("negation")
if POSITION_SHORTCUT:
  TRAINING_DATASET_SHORTCUTS.append("position")
if STYLE_SHORTCUT:
  TRAINING_DATASET_SHORTCUTS.append("style_bible")

SHORTCUT_AGGREGATION = "NONE" #@param ["NONE", "NORMALIZED_SUM"]
if SHORTCUT_AGGREGATION  != "NONE":
  SHORTCUT_AGGREGATION = ShortcutAggregation[SHORTCUT_AGGREGATION]
else:
  SHORTCUT_AGGREGATION = None

TRAINING_DATASET_SIZE = 64 #@param {type: "integer"}
TRAINING_DATASET_SEL_METHOD = "RANDOM" #@param ["RANDOM", "MODEL_FAILS", "MODEL_FAILS_ON_SPECIFIC_LABELS"]
TRAINING_DATASET_SEL_METHOD = L.SelectionMethod[TRAINING_DATASET_SEL_METHOD]
TRAINING_DATASET_RANDOM_SEED = 20 #@param {type: "integer"}
TRAINING_BATCH_SIZE = 32 #@param {type: "integer"}
TRAINING_DEBUG = False #@param {type: "boolean"}

ACTIVATIONS_CLEAN_INSTR = "Decide if the hypothesis is entailed by the premise." #@param {type: "string"}
ACTIVATIONS_DIRTY_INSTR = "Decide if the hypothesis is entailed by the premise." #@param {type: "string"}
ACTIVATIONS_DATA_SHUFFLE = True #@param {type: "boolean"}
ACTIVATIONS_DIRECTION_METHOD = "pca" #@param ["pca", "cluster_mean"]
ACTIVATIONS_ALPHA_COEFF = -0.5 #@param {type: "slider", min:-5.0, max: 5.0, step:0.1}

EVAL_DATASET = "rte" #@param ["rte", "mnli", "mmlu", "copa", "trec", "cr", "wic", "sst2", "arc"]
EVAL_NUM_SHOT = 1 #@param {type: "slider", min:0, max: 2, step:1}
EVAL_INTERVENTION_LAYERS = "-5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17" #@param {type: "string"}
EVAL_INTERVENTION_LAYERS = list(map(int, EVAL_INTERVENTION_LAYERS.split()))
EVAL_OPERATOR = "linear_comb" #@param ["linear_comb", "piecewise_linear", "projection"]
EVAL_RESUME = False #@param {type: "boolean"}


In [None]:
#@title Run evaluation
RepE_evaluation(
    repo_path=REPO_PATH,
    drive_path=DRIVE_PATH,
    overwrite_df_artifact=OVERWRITE_DATASET_WB,
    overwrite_act_artifact=OVERWRITE_ACTIVATIONS_WB,
    training_dataset_name=TRAINING_DATASET_NAME,
    training_dataset_size=TRAINING_DATASET_SIZE,
    training_dataset_shortcut_types=TRAINING_DATASET_SHORTCUTS,
    shortcut_aggregation=SHORTCUT_AGGREGATION,
    training_dataset_sel_method=TRAINING_DATASET_SEL_METHOD,
    training_dataset_random_seed=TRAINING_DATASET_RANDOM_SEED,
    training_batch_size=TRAINING_BATCH_SIZE,
    training_debug=TRAINING_DEBUG,
    activations_clean_instr=ACTIVATIONS_CLEAN_INSTR,
    activations_dirty_instr=ACTIVATIONS_DIRTY_INSTR,
    activations_data_shuffle=ACTIVATIONS_DATA_SHUFFLE,
    activations_direction_method=ACTIVATIONS_DIRECTION_METHOD,
    activations_alpha_coeff=ACTIVATIONS_ALPHA_COEFF,
    model_wrap=MODEL_WRAP,
    eval_dataset_name=EVAL_DATASET,
    eval_num_shot=EVAL_NUM_SHOT,
    eval_intervention_layers=EVAL_INTERVENTION_LAYERS,
    eval_operator=EVAL_OPERATOR,
    eval_resume=EVAL_RESUME)

# Run hyperparameter search

In [None]:
existing_sweep_id_to_resume = ''  #@param {type: 'string'}
EVAL_DATASET = "mnli" #@param ["rte", "mnli", "mmlu"]
PROMPT_SELECTION_METHOD = "RANDOM"
RANDOM_SEED = 20

SWEEP_NAME = f"{EVAL_DATASET}_{TARGET_MODEL_NAME}_RepE_eval"

sweep_config = {
    'name': SWEEP_NAME,
    'method': 'grid',
    'metric': {
        'name': 'accuracy',
        'goal': 'maximize',
    },
    'parameters': {
        'activations_alpha_coeff' : {
             'values': [-0.5]
        },
        'direction_method': {
            'values': ['pca']
        },
        'activations_data_shuffle': {
            'values': [True]
        },
        'overwrite_dataset_wb': {
            'values': [False]
        },
        'overwrite_activations_wb': {
            'values': [False]
        },
        'training_dataset_random_seed': {'values': [RANDOM_SEED]},
        'training_dataset_shortcuts': {
            'values': ["negation", "position"]
        },
        'eval_operator': {
            'values': ["piecewise_linear", "projection"]
        },
        "training_dataset_sel_method" : { 'values': [PROMPT_SELECTION_METHOD] } ,
        "training_dataset_size" : { 'values': [64, 128] } ,
        'eval_intervention_layers': {
            'values': [
                list(range(-5,-18,-1))
                ] }
    }
}

if existing_sweep_id_to_resume:
  print('Resuming Sweep with ID', existing_sweep_id_to_resume)
  sweep_id = f'{L.WB_TEAM}/{L.WB_PROJECT_NAME}/{existing_sweep_id_to_resume}'
  print(f'Sweep URL: https://wandb.ai/{L.WB_TEAM}/{L.WB_PROJECT_NAME}/sweeps/{existing_sweep_id_to_resume}')
else:
  sweep_id = wandb.sweep(sweep_config, entity=f'{L.WB_TEAM}', project=f'{L.WB_PROJECT_NAME}')

In [None]:
def eval_with_wandb():
    with wandb.init() as run:
        cfg = wandb.config

        def get_value(key_lower):
            """
            Try to get the lowercase key from wandb.config,
            otherwise fall back to the matching UPPERCASE global variable.
            """
            if key_lower in cfg:
                return cfg[key_lower]
            upper_key = key_lower.upper()
            if upper_key in globals():
                return globals()[upper_key]
            raise KeyError(f"Missing parameter: {key_lower} (no config or global)")

        selection_method = get_value("training_dataset_sel_method")
        if isinstance(selection_method, str):
            selection_method = L.SelectionMethod[selection_method]

        try:
          RepE_evaluation(
              repo_path=get_value("repo_path"),
              drive_path=get_value("drive_path"),
              overwrite_df_artifact=get_value("overwrite_dataset_wb"),
              overwrite_act_artifact=get_value("overwrite_activations_wb"),
              training_dataset_name=get_value("training_dataset_name"),
              training_dataset_size=get_value("training_dataset_size"),
              training_dataset_shortcut_types=get_value("training_dataset_shortcuts"),
              shortcut_aggregation=get_value("shortcut_aggregation"),
              training_dataset_sel_method=selection_method,
              training_dataset_random_seed=get_value("training_dataset_random_seed"),
              training_batch_size=get_value("training_batch_size"),
              training_debug=get_value("training_debug"),
              activations_clean_instr=get_value("activations_clean_instr"),
              activations_dirty_instr=get_value("activations_dirty_instr"),
              activations_data_shuffle=get_value("activations_data_shuffle"),
              activations_direction_method=get_value("activations_direction_method"),
              activations_alpha_coeff=get_value("activations_alpha_coeff"),
              model_wrap=get_value("model_wrap"),
              eval_dataset_name=get_value("eval_dataset"),
              eval_num_shot=get_value("eval_num_shot"),
              eval_intervention_layers=get_value("eval_intervention_layers"),
              eval_operator=get_value("eval_operator"),
              eval_resume=get_value("eval_resume"),
          )
        except Exception as e:
           print(f"⚠️ Run failed: {e}")


In [None]:
#@title Run sweep
wandb.agent(sweep_id, function=eval_with_wandb)