# Imports and constants:

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
assert os.environ["CUDA_VISIBLE_DEVICES"] == "1"

import hydra
from omegaconf import DictConfig, OmegaConf
import argparse
import tqdm
from pprint import pprint
from immunization_utils_dev import *
from tqdm import tqdm, trange

# Reload (Dev)

In [None]:
from importlib import reload
import importlib
reload(importlib.import_module('immunization_utils_dev'))
from immunization_utils_dev import *


# Initialize:

In [None]:
cfg = OmegaConf.load('config/base.yaml')
cfg.override='debug'
config_overrides = OmegaConf.load(f'config/overrides/{cfg.override}.yaml')
cfg = OmegaConf.merge(cfg, config_overrides)
kwargs, \
    logging_dict, \
    model, tokenizer, \
    eval_model, eval_tokenizer, \
    training_attack_data_dict, \
    safety_eval_data, \
    performance_eval_data = initialize(cfg)

# Go Play

In [None]:
layer=15
kwargs['init_defence_epochs'] = 20
kwargs['init_defence_prompts'] = 400
kwargs['defence_reg_coeff'] = 0.2
kwargs['defence_strategy'] = 'GATE_UP_DOWN_QUERY_KEY_VALUE_OUTPUT'  
kwargs['init_defence_criterion'] = 'mse_cos'
kwargs['frobenious_norm_scaling_factor'] = 0.7
kwargs['cosine_similarity_scaling_factor'] = 0.3
kwargs['verbose'] = True
kwargs['init_eval_safety_prompts'] = 1
kwargs['performance_batches'] = 1

In [None]:
# Initialize a configuration for attack
attack_config = init_single_layer_attack_config(model, layer, kwargs)

attacked_model, safety_eval_table = reft_attack(
                    model, 
                    tokenizer, 
                    attack_config, 
                    training_attack_data_dict, 
                    eval_model, 
                    eval_tokenizer, 
                    safety_eval_data,
                    performance_eval_data,
                    logging_dict, 
                    kwargs)
                    



In [None]:
# Initialize a configuration for defence:
defence_config = init_custom_defence_config(model, attack_config, attacked_model, 1, kwargs)     
kwargs['first_inner_defence_round']= True

In [None]:

eval_table, defence_results = custom_defence(
                model,
                tokenizer,
                eval_model,
                eval_tokenizer,
                defence_config,
                training_attack_data_dict,
                safety_eval_data,
                performance_eval_data,
                logging_dict, 
                kwargs)

kwargs['first_inner_defence_round']= False

In [None]:
model = reset_defended_module(model, defence_config, kwargs)

In [None]:
import matplotlib.pyplot as plt

# smooth the series:
def smooth_series(series, smoothing_factor=0.9):
    smoothed_series = []
    last_value = series[0]
    for value in series:
        smoothed_value = last_value * smoothing_factor + value * (1 - smoothing_factor)
        smoothed_series.append(smoothed_value)
        last_value = smoothed_value
    return smoothed_series


epoch_mlp_reg_losses = [r.detach().cpu().double() for r in defence_results['epoch_mlp_reg_losses']]
epoch_mlp_def_losses = [r.detach().cpu().double() for r in defence_results['epoch_mlp_def_losses']]
epoch_attn_def_losses = [r.detach().cpu().double() for r in defence_results['epoch_attn_def_losses']]
epoch_attn_reg_losses = [r.detach().cpu().double() for r in defence_results['epoch_attn_reg_losses']]
# smoothed_main_losses_series = smooth_series(main_losses_series)
# plt.plot(smoothed_main_losses_series, label='smoothed_epoch_mlp_reg_losses')

plt.plot(epoch_mlp_reg_losses, label='epoch_mlp_reg_losses', alpha=0.5)
plt.plot(epoch_mlp_def_losses, label='epoch_mlp_def_losses', alpha=0.5)
plt.plot(epoch_attn_def_losses, label='epoch_attn_def_losses', alpha=0.5)
plt.plot(epoch_attn_reg_losses, label='epoch_attn_reg_losses', alpha=0.5)
#plt.plot([r.cpu().double() for r in defence_results['epoch_reg_losses']], label='reg_loss')
plt.yscale('log')
plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt

# smooth the series:
def smooth_series(series, smoothing_factor=0.9):
    smoothed_series = []
    last_value = series[0]
    for value in series:
        smoothed_value = last_value * smoothing_factor + value * (1 - smoothing_factor)
        smoothed_series.append(smoothed_value)
        last_value = smoothed_value
    return smoothed_series


epoch_mlp_reg_losses = [r.detach().cpu().double() for r in defence_results['epoch_mlp_reg_losses']]
epoch_mlp_def_losses = [r.detach().cpu().double() for r in defence_results['epoch_mlp_def_losses']]
epoch_attn_def_losses = [r.detach().cpu().double() for r in defence_results['epoch_attn_def_losses']]
epoch_attn_reg_losses = [r.detach().cpu().double() for r in defence_results['epoch_attn_reg_losses']]
# smoothed_main_losses_series = smooth_series(main_losses_series)
# plt.plot(smoothed_main_losses_series, label='smoothed_epoch_mlp_reg_losses')

plt.plot(epoch_mlp_reg_losses, label='epoch_mlp_reg_losses', alpha=0.5)
plt.plot(epoch_mlp_def_losses, label='epoch_mlp_def_losses', alpha=0.5)
plt.plot(epoch_attn_def_losses, label='epoch_attn_def_losses', alpha=0.5)
plt.plot(epoch_attn_reg_losses, label='epoch_attn_reg_losses', alpha=0.5)
#plt.plot([r.cpu().double() for r in defence_results['epoch_reg_losses']], label='reg_loss')
plt.yscale('log')
plt.legend()
plt.show()

In [None]:
epoch_attn_reg_losses

In [None]:
import matplotlib.pyplot as plt

# smooth the series:
def smooth_series(series, smoothing_factor=0.9):
    smoothed_series = []
    last_value = series[0]
    for value in series:
        smoothed_value = last_value * smoothing_factor + value * (1 - smoothing_factor)
        smoothed_series.append(smoothed_value)
        last_value = smoothed_value
    return smoothed_series


main_losses_series = [r.cpu().double() for r in defence_results['epoch_main_losses']]
smoothed_main_losses_series = smooth_series(main_losses_series)
plt.plot(smoothed_main_losses_series, label='smoothed_main_loss')

plt.plot(main_losses_series, label='main_loss', alpha=0.3)
#plt.plot([r.cpu().double() for r in defence_results['epoch_reg_losses']], label='reg_loss')
plt.yscale('log')
plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt
# plt.plot([r.cpu().double() for r in defence_results['epoch_main_losses']], label='main_loss')
plt.plot([r.cpu().double() for r in defence_results['epoch_reg_losses']], label='reg_loss')
# plt.yscale('log')
plt.legend()
plt.show()

In [None]:
defence_results

In [None]:
model = reset_defended_module(model, defence_config, kwargs)