In [1]:
import os
import re
import random
import json
from tqdm import tqdm
import pickle
import pandas as pd
import numpy as np
from dotenv import load_dotenv

from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Dict, Union

import transformers
import torch
import torch.nn.functional as F

import openai
import datasets
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset, load_dataset

import plotly.graph_objects as go
import plotly.express as px

from utils import untuple
from get_activations import gen_pile_data, compare_token_lists, slice_acts

from act_add.model_wrapper import ModelWrapper
from act_add.rep_reader import RepReader, CAARepReader, PCARepReader
from act_add.contrast_dataset import ContrastDataset

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name_or_path = "EleutherAI/pythia-12b"
file_path = 'data/12b'

model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map="auto").eval()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

mw = ModelWrapper(model = model, tokenizer = tokenizer)
all_mem_12b_data = pd.read_csv(f'{file_path}/mem_evals_gen_data.csv')

Loading checkpoint shards: 100%|██████████| 3/3 [00:21<00:00,  7.16s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
class SteeringPipeline():
    def __init__(self, model_wrapper : ModelWrapper, contrast_dataset : ContrastDataset, rep_reader : RepReader):
        self.model_wrapper = model_wrapper
        self.contrast_dataset = contrast_dataset
        self.rep_reader = rep_reader
                            
        self.model_wrapper.wrap_all()
    def gen_dir_from_states(self, 
                            pos_hidden_states,
                            neg_hidden_states,
                            hidden_layers : Union[List[int], int] = -1,
                            n_difference : int = 1,
                            train_labels: List[int] = None,):
        
        if not isinstance(hidden_layers, list): 
            assert isinstance(hidden_layers, int)
            hidden_layers = [hidden_layers]
        
        assert pos_hidden_states.shape[0] == neg_hidden_states.shape[0], "pos and neg hidden states must have same number of examples"
        
        #*this is if shape is n_examples x n_layers x n_hidden
        # interweaved = [torch.stack([pos_hidden_states[i], neg_hidden_states[i]], dim = 0) for i in range(pos_hidden_states.shape[0])]
        # hidden_states = torch.cat(interweaved, dim=0)
        
        #*this is if shape is n_layers x n_examples x n_hidden
        interweaved = [torch.stack([pos_hidden_states[:, i], neg_hidden_states[:, i]], dim = 1) for i in range(pos_hidden_states.shape[1])]
        hidden_states = torch.cat(interweaved, dim=1)
        
        relative_hidden_states = self._gen_rel_states(hidden_states, hidden_layers, n_difference)
        
        return self._gen_dir(hidden_states, 
                             relative_hidden_states, 
                             hidden_layers, 
                             train_labels)
        
    def _gen_dir(self,       
                hidden_states,   
                relative_hidden_states,               
                hidden_layers : Union[List[int], int] = -1,
                train_labels: List[int] = None,
                        ):
        
        # get the directions
        directions = self.rep_reader.get_rep_directions(
            self.model_wrapper.model, self.model_wrapper.tokenizer, relative_hidden_states, hidden_layers,
            train_choices=train_labels)

        for layer in self.rep_reader.directions:
            if type(self.rep_reader.directions[layer]) == np.ndarray:
                self.rep_reader.directions[layer] = self.rep_reader.directions[layer].astype(np.float32)

        self.rep_reader.direction_signs = self.rep_reader.get_signs(
            hidden_states, train_labels, hidden_layers)
        
        return self.rep_reader.directions
        
    def _gen_rel_states(self, hidden_states, hidden_layers, n_difference):
        #*hidden_states should be a tensor or tuple of tensors of shape (n_layers, n_examples, n_hidden)
        
        if isinstance(hidden_states, dict):
            relative_hidden_states = {k: np.copy(v) for k, v in hidden_states.items()}
        else:
            relative_hidden_states = {k: np.copy(hidden_states[k]) for k in range(hidden_states.shape[0])}
        
        if isinstance(self.rep_reader, PCARepReader):
            # get differences between pairs
            for layer in hidden_layers:
                for _ in range(n_difference):
                    relative_hidden_states[layer] = relative_hidden_states[layer][::2] - relative_hidden_states[layer][1::2]
        elif isinstance(self.rep_reader, CAARepReader):
            #* IMPORTANT: All RepReaders expects that the order of the training data is alternating like: [p, n, p, n, ...]
                for layer in hidden_layers:
                    relative_hidden_states[layer] = relative_hidden_states[layer][::2] - relative_hidden_states[layer][1::2]
        
        return relative_hidden_states
                        
    def gen_dir_from_strings(self, 
                        train_inputs: Union[str, List[str], List[List[str]]], 
                        rep_token_idx : int = -1, 
                        hidden_layers : Union[List[int], int] = -1,
                        n_difference : int = 1,
                        train_labels: List[int] = None,):
        self.model_wrapper.reset()
        
        if not isinstance(hidden_layers, list): 
            assert isinstance(hidden_layers, int)
            hidden_layers = [hidden_layers]

        # get raw hidden states for the train inputs
        hidden_states = self.model_wrapper.model.batched_string_to_hiddens(train_inputs, 
                                                        hidden_layers, 
                                                        rep_token_idx, 
                                                        )
        relative_hidden_states = self._gen_rel_states(hidden_states, hidden_layers, n_difference)
        
        return self._gen_dir(hidden_states, 
                             relative_hidden_states, 
                             hidden_layers, 
                             train_labels)

        
    def batch_steering_generate(self, 
                                inputs : List[str], 
                                layers_to_intervene : List[int],
                                coeff : float = 1.0,
                                token_pos : Union[str, int] = None,
                                batch_size=8, 
                                use_tqdm=True,
                                **generation_kwargs,
                                ):
        
        assert self.rep_reader.directions is not None, "Must generate rep_reader directions first"
        
        #? do i need to do half() here?
        steering_vectors = {}
        for layer in layers_to_intervene:
            steering_vectors[layer] = torch.tensor(coeff * self.rep_reader.directions[layer] * self.rep_reader.direction_signs[layer]).to(self.model_wrapper.model.device).half()

        self.model_wrapper.reset()
        self.model_wrapper.set_controller(layers_to_intervene, steering_vectors, masks=1, token_pos = token_pos)
        generated = []

        iterator = tqdm(range(0, len(inputs), batch_size)) if use_tqdm else range(0, len(inputs), batch_size)

        for i in iterator:
            inputs_b = inputs[i:i+batch_size]
            decoded_outputs = self.model_wrapper.batch_generate_from_string(inputs_b, **generation_kwargs)
            decoded_outputs = [o.replace(i, "") for o,i in zip(decoded_outputs, inputs_b)]
            generated.extend(decoded_outputs)

        self.model_wrapper.reset()
        return generated
    
        

In [4]:
mem_hiddens = torch.load(f'{file_path}/mem_all_hidden_states.pt')
pile_hiddens = torch.load(f'{file_path}/pile_all_hidden_states.pt')

In [5]:
mem_12b_data = all_mem_12b_data[all_mem_12b_data['char_by_char_similarity'] == 1]
unmem_12b_data = all_mem_12b_data[all_mem_12b_data['char_by_char_similarity'] <= 0.55]

print(mem_12b_data.shape)
print(unmem_12b_data.shape)

mem_pythia_idxs = mem_12b_data[mem_12b_data['source'] == 'pythia-evals']['idx_in_hidden_states'].values
mem_pile_idxs = mem_12b_data[mem_12b_data['source'] == 'pile']['idx_in_hidden_states'].values
unmem_pythia_idxs = unmem_12b_data[unmem_12b_data['source'] == 'pythia-evals']['idx_in_hidden_states'].values
unmem_pile_idxs = unmem_12b_data[unmem_12b_data['source'] == 'pile']['idx_in_hidden_states'].values

mem_hidden_states = torch.cat([mem_hiddens[mem_pythia_idxs], pile_hiddens[mem_pile_idxs]], dim = 0)
unmem_hidden_states = torch.cat([mem_hiddens[unmem_pythia_idxs], pile_hiddens[unmem_pile_idxs]], dim = 0)

(4306, 7)
(3373, 7)


## eval code

In [7]:
N = 3373

def get_ground_truth_strings(data, low = 0.9, high =0.99):
    ground_truth_strings = data[data['char_by_char_similarity'].between(low, high)]['ground'].tolist()
    return ground_truth_strings

def gen_eval_data(ground_strings, tokenizer, input_length = 32, max_length = 64):
    tokens = tokenizer(ground_strings, padding = True, truncation = True, max_length = max_length, return_tensors = 'pt')
    inputs = tokenizer.batch_decode(tokens['input_ids'][:, :input_length])
    targets = tokenizer.batch_decode(tokens['input_ids'][:, input_length:])
    return inputs, targets

memmed_ground_truth_strings = get_ground_truth_strings(all_mem_12b_data, low = 1, high = 1)

unseen_memmed_ground = memmed_ground_truth_strings[N:N + 10]
inputs, targets = gen_eval_data(unseen_memmed_ground, tokenizer)

In [8]:
from sentence_transformers import SentenceTransformer, util

sim_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
def sim_scores(outputs, targets):
    semantic_scores_gen = []
    for target, output in zip(targets, outputs):
        embedding1 = sim_model.encode(target, convert_to_tensor=True)
        embedding2 = sim_model.encode(output, convert_to_tensor=True)
        cosine_sim_gen = util.pytorch_cos_sim(embedding1, embedding2)
        similarity_value_gen = cosine_sim_gen.item()
        semantic_scores_gen.append(similarity_value_gen)
    
    return semantic_scores_gen 

def char_by_char_similarity(outputs, targets):
    similarities = []
    for o, t in zip(outputs, targets):
        o = re.sub(r'\s', '', o)
        t = re.sub(r'\s', '', t)

        o = o.lower()
        t = t.lower()

        # remove '<|endoftext|>'
        o = o.replace('<|endoftext|>', '')
        t = t.replace('<|endoftext|>', '')

        max_len = max(len(o), len(t))
        matches = [c1 == c2 for c1, c2 in zip(o, t)]
        
        similarities.append(sum(matches)/max_len if max_len > 0 else 0)
    return similarities

def compare_token_lists(ground_toks, genned_toks):
    if len(ground_toks) != len(genned_toks):
        # print(len(ground_toks), len(genned_toks))
        # print("Both lists do not have the same length.")
        return 0
    
    num_same_tokens = sum(1 for token1, token2 in zip(ground_toks, genned_toks) if token1 == token2)
    percent_same_tokens = (num_same_tokens / len(ground_toks)) 
    
    return percent_same_tokens

def tok_by_tok_similarity(outputs, targets):
    o_tokens = tokenizer(outputs, return_tensors = 'pt',padding = False, truncation = True, max_length = 64)['input_ids']
    t_tokens = tokenizer(targets, return_tensors = 'pt',padding = False, truncation = True, max_length = 64)['input_ids']
    print(o_tokens)
    print(t_tokens)
    return [compare_token_lists(t, o) for t, o in zip(t_tokens, o_tokens)]

def eval_completions(outputs, targets):
    cbc_sims = char_by_char_similarity(outputs, targets)
    # tbt_sims = tok_by_tok_similarity(outputs, targets)
    sem_sims = sim_scores(outputs, targets)

    return {'char_by_char_similarity': np.mean(cbc_sims),
            # 'tok_by_tok_similarity': np.mean(tbt_sims),
            'sem_similarity': np.mean(sem_sims)}


## pure mem - random unmem pile

In [6]:
N = unmem_hidden_states.shape[0]
mem_contra_dataset = ContrastDataset(mem_12b_data['gen'].tolist()[:N], 
                               unmem_12b_data['gen'].tolist()[:N], 
                               model_name_or_path,
                               use_convo_format=False,
                               system_prompt="")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [34]:
mem_rep_reader = CAARepReader()
mw = ModelWrapper(model = model, tokenizer = tokenizer)

mem_steering_pipeline = SteeringPipeline(mw, mem_contra_dataset, mem_rep_reader)

In [35]:
TOKEN_IDX = 9
N_LAYERS = model.config.num_hidden_layers
rep_token_idx = -1
hidden_layers = list(range(model.config.num_hidden_layers))
n_difference = None
train_labels = None

mem_rr_hidden_states = mem_hidden_states[:N, :, TOKEN_IDX, :].reshape(N_LAYERS, N, -1)
unmem_rr_hidden_states = unmem_hidden_states[:N, :, TOKEN_IDX, :].reshape(N_LAYERS, N, -1)

dirs = mem_steering_pipeline.gen_dir_from_states(mem_rr_hidden_states, unmem_rr_hidden_states, hidden_layers, n_difference, train_labels)

In [113]:
# layer_id = list(range(15,25))
layer_id = [30, 35]

batch_size=50
coeff=1 # tune this parameter
max_new_tokens=32

print(f"Coeff: {coeff}")
print(f"LAYERS: {layer_id}")
print("RepReader:")
print("No Control")
# baseline_outputs = mem_steering_pipeline.batch_steering_generate(inputs, 
#                                                                 layer_id, 
#                                                                 coeff = 0 * coeff, 
#                                                                 batch_size = batch_size, 
#                                                                 use_tqdm=True, 
#                                                                 max_new_tokens=max_new_tokens)

print(eval_completions(baseline_outputs, targets))

print("+ Memorization")
pos_outputs = mem_steering_pipeline.batch_steering_generate(inputs, 
                                                            layer_id, 
                                                            coeff = coeff, 
                                                            batch_size = batch_size, 
                                                            use_tqdm=True, 
                                                            max_new_tokens=max_new_tokens)
print(eval_completions(pos_outputs, targets))

print("- Memorization")
neg_outputs = mem_steering_pipeline.batch_steering_generate(inputs, 
                                                            layer_id, 
                                                            coeff = -coeff, 
                                                            batch_size = batch_size, 
                                                            use_tqdm=True, 
                                                            max_new_tokens=max_new_tokens)
print(eval_completions(neg_outputs, targets))

Coeff: 1
LAYERS: [30, 35]
RepReader:
No Control
{'char_by_char_similarity': 0.9044715447154472, 'sem_similarity': 0.9756519377231598}
+ Memorization


100%|██████████| 1/1 [00:35<00:00, 35.40s/it]


{'char_by_char_similarity': 0.7586262106140647, 'sem_similarity': 0.9292394697666169}
- Memorization


100%|██████████| 1/1 [00:35<00:00, 35.36s/it]


{'char_by_char_similarity': 0.5532528697283654, 'sem_similarity': 0.7501227796077728}


In [110]:
targets

[' Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an "AS IS" BASIS,\n * WITHOUT WARRANTIES',
 '_CTL_D1                                0x28b8\n#define MC_SEQ_WR_CTL_D0                                0x28bc\n',
 '                }\n            },\n            "axisTick": {\n                "show": false,\n                "lineStyle": {\n                    "color": "#',
 '\n\tposition: fixed;\n\ttop: 50%;\n\tleft: 50%;\n\tmargin-top: -22px;\n\tmargin-',
 ' 2027.......... 677.68\nApril 2027.......... 677.68\nMay 2027............ 677.68\nJune 2027',
 'MENT";\n  public static final String ER_PROCESS_ERROR = "ER_PROCESS_ERROR";\n  public static final String ER_UN',
 '\tSec  int32\n\tUsec int32\n}\n\nfunc (tv *Timeval) Nanoseconds() int64 {\n\treturn',
 ' is disfavored except for establishing res judicata, estoppel, or the law of the case and requires service of copies of cited unpublished dispositions of the Sixth Circu

In [111]:
neg_outputs

[' Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an “AS IS” BASIS,\n * WITHOUT WARRANTIES',
 '_CTL_D1                                0x28b8\n#define MC_SEQ_RD_CTL_D2                                0x28bc\n',
 '                \n                }\n            },\n            "axisTick": {\n                "show": true,\n                "lineStyle": {\n                    " color',
 '\n background-image: url(fancybox_loading.gif);\n background-repeat: no-repeat center center;\n background-position: center',
 ' 2027.......... 677.68\nApril 2027.......... 677.68\nMay 2027............ 677.68\nJune 2027',
 'MENT";\n public static final String ER_PROCESS_ is not a known element. The document is the root document of the document. in the document.',
 ',\n,\n,\n,\n,\n,\n,\n,\n,\n,\n,\n,\n,\n,\n,\n,\n',
 ', be sure to Follow us too.107 F.3d 11\nNOTICE: Sixth Circuit Rule 24(c) states that citation of unpublished disposition

In [112]:
pos_outputs

[' Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an "AS IS" BASIS,\n * WITHOUT WARRANTIES',
 '_CTL_D1                                0x28b8\n#define MC_SEQ_WR_CTL                                   0x28bc\n#define',
 '                }\n            },\n            "axisTick": {\n                "show": false,\n                "lineStyle": {\n                    "color": "#',
 '\n\tbackground-image: url(fancybox_loading.png);\n}\n\n#fancybox-error {\n\tbackground-image',
 ' 2027.......... 677.68\nApril 2027.......... 677.68\nMay 2027............ 677.68\nJune 2027',
 'MENT";\n\n  public static final String ER_NO_OUTPUT_SPECIFIER = "ER_NO_OUTPUT_SPECIFIER";\n  public',
 '\tSec  int64\n\tUsec int64\n}\n\nfunc NewPopulatedMessage(r int64) *Message_Container.MessageBuilder {',
 ', be sure to Follow us too.107 F.3d 11\nNOTICE: Sixth Circuit Rule 24(c) states that citation of unpublished dispositions is disfavored except 

## probe direction

In [9]:
from probes import LRProbe
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.linear_model import LogisticRegression

# Combine memmed and non_memmed hidden states
np_mem = mem_hidden_states.cpu().float()[:unmem_hidden_states.shape[0]]
np_unmem = unmem_hidden_states.cpu().float()

LAYER = 34
TOK_IDXS = [5, 6, 7, 8, 9]

X = torch.cat([np_mem[:, LAYER, tok_idx, :] for tok_idx in TOK_IDXS] +  [np_unmem[:, LAYER, tok_idx, :] for tok_idx in TOK_IDXS])
y = torch.cat([torch.ones(np_mem.shape[0]) for _ in TOK_IDXS] + [torch.zeros(np_unmem.shape[0]) for _ in TOK_IDXS])

In [10]:
X_train, X_test, y_train, y_test = train_test_split(X.numpy(), y.numpy(), test_size=0.40, random_state=42)

#Pytorch code
X_train = torch.from_numpy(X_train)
y_train = torch.from_numpy(y_train)
X_test = torch.from_numpy(X_test)
y_test = torch.from_numpy(y_test)

lr = 0.001
weight_decay = 0.1
epochs = 1000
use_bias = False
normalize = True

if normalize:
    X_train = (X_train - X_train.mean(dim = 0)) / X_train.std(dim = 0)
    X_test = (X_test - X_train.mean(dim = 0)) / X_train.std(dim = 0)

probe = LRProbe.from_data(X_train, y_train, 
                          lr = lr, 
                          weight_decay = weight_decay, 
                          epochs = epochs, 
                          use_bias = use_bias,
                          device = "cuda", )

acc = probe.get_probe_accuracy(X_test, y_test, device="cuda")
auc = probe.get_probe_auc(X_test, y_test, device="cuda")

In [11]:
acc, auc

(0.9409279823303223, 0.9715853321420392)

In [12]:
from act_add.rep_reader import ProbeRepReader

probe_rep_reader = ProbeRepReader({
    34: probe
})
probe_rep_reader.get_rep_directions([34])
probe_rep_reader.get_signs([34])

{34: 1}

In [13]:
probe_rep_reader.directions

{34: array([ 0.02265903,  0.07398675, -0.03395739, ..., -0.00919325,
         0.07206346,  0.07650761], dtype=float32)}

In [14]:
mw = ModelWrapper(model = model, tokenizer = tokenizer)

mem_steering_pipeline = SteeringPipeline(mw, None, probe_rep_reader)

In [21]:
# layer_id = list(range(15,25))
layer_id = [34]

batch_size=50
coeff=5 # tune this parameter
max_new_tokens=32

print(f"Coeff: {coeff}")
print(f"LAYERS: {layer_id}")
print("RepReader:")
print("No Control")
# baseline_outputs = mem_steering_pipeline.batch_steering_generate(inputs, 
#                                                                 layer_id, 
#                                                                 coeff = 0 * coeff, 
#                                                                 batch_size = batch_size, 
#                                                                 use_tqdm=True, 
                                                                # max_new_tokens=max_new_tokens)

print(eval_completions(baseline_outputs, targets))

print("+ Memorization")
pos_outputs = mem_steering_pipeline.batch_steering_generate(inputs, 
                                                            layer_id, 
                                                            coeff = coeff, 
                                                            batch_size = batch_size, 
                                                            use_tqdm=True, 
                                                            max_new_tokens=max_new_tokens)
print(eval_completions(pos_outputs, targets))

print("- Memorization")
neg_outputs = mem_steering_pipeline.batch_steering_generate(inputs, 
                                                            layer_id, 
                                                            coeff = -coeff, 
                                                            batch_size = batch_size, 
                                                            use_tqdm=True, 
                                                            max_new_tokens=max_new_tokens)
print(eval_completions(neg_outputs, targets))

Coeff: 5
LAYERS: [34]
RepReader:
No Control
{'char_by_char_similarity': 0.9044715447154472, 'sem_similarity': 0.9756519377231598}
+ Memorization


100%|██████████| 1/1 [00:34<00:00, 34.69s/it]


{'char_by_char_similarity': 0.8616144018583043, 'sem_similarity': 0.9660744369029999}
- Memorization


100%|██████████| 1/1 [00:34<00:00, 34.66s/it]


{'char_by_char_similarity': 0.9044715447154472, 'sem_similarity': 0.9756519377231598}


In [22]:
targets

[' Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an "AS IS" BASIS,\n * WITHOUT WARRANTIES',
 '_CTL_D1                                0x28b8\n#define MC_SEQ_WR_CTL_D0                                0x28bc\n',
 '                }\n            },\n            "axisTick": {\n                "show": false,\n                "lineStyle": {\n                    "color": "#',
 '\n\tposition: fixed;\n\ttop: 50%;\n\tleft: 50%;\n\tmargin-top: -22px;\n\tmargin-',
 ' 2027.......... 677.68\nApril 2027.......... 677.68\nMay 2027............ 677.68\nJune 2027',
 'MENT";\n  public static final String ER_PROCESS_ERROR = "ER_PROCESS_ERROR";\n  public static final String ER_UN',
 '\tSec  int32\n\tUsec int32\n}\n\nfunc (tv *Timeval) Nanoseconds() int64 {\n\treturn',
 ' is disfavored except for establishing res judicata, estoppel, or the law of the case and requires service of copies of cited unpublished dispositions of the Sixth Circu

In [23]:
pos_outputs

[' Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an "AS IS" BASIS,\n * WITHOUT WARRANTIES',
 '_CTL_D1                                0x28b8\n#define MC_SEQ_WR_CTL_D0                                0x28bc\n',
 '                }\n            },\n            "axisTick": {\n                "show": false,\n                "lineStyle": {\n                    "color": "#',
 '\n\tposition: fixed;\n\ttop: 50%;\n\tleft: 50%;\n\twidth: 40px;\n\theight: 40px;',
 ' 2027.......... 677.68\nApril 2027.......... 677.68\nMay 2027............ 677.68\nJune 2027',
 'MENT";\n  public static final String ER_PROCESS_ERROR = "ER_PROCESS_ERROR";\n  public static final String ER_UN',
 '\tSec  int32\n\tUsec int32\n}\n\nfunc (tv *Timeval) Nanoseconds() int64 {\n\treturn',
 ', be sure to Follow us too.107 F.3d 11\nNOTICE: Sixth Circuit Rule 24(c) states that citation of unpublished dispositions is disfavored except for establishing res judic

In [24]:
neg_outputs

[' Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an "AS IS" BASIS,\n * WITHOUT WARRANTIES',
 '_CTL_D1                                0x28b8\n#define MC_SEQ_WR_CTL_D0                                0x28bc\n',
 '                }\n            },\n            "axisTick": {\n                "show": false,\n                "lineStyle": {\n                    "color": "#',
 '\n\tposition: fixed;\n\ttop: 50%;\n\tleft: 50%;\n\tmargin-top: -22px;\n\tmargin-',
 ' 2027.......... 677.68\nApril 2027.......... 677.68\nMay 2027............ 677.68\nJune 2027',
 'MENT";\n  public static final String ER_PROCESS_ERROR = "ER_PROCESS_ERROR";\n  public static final String ER_UN',
 '\tSec  int32\n\tUsec int32\n}\n\nfunc (tv *Timeval) Nanoseconds() int64 {\n\treturn',
 ', be sure to Follow us too.107 F.3d 11\nNOTICE: Sixth Circuit Rule 24(c) states that citation of unpublished dispositions is disfavored except for establishing res judic

## pure mem - reshuffled mem

In [6]:
import random
# N = unmem_hidden_states.shape[0]
N = 100

shuffled_memmed_prompts = []
for prompt in mem_12b_data['gen'].tolist():
    tokens = tokenizer.tokenize(prompt)
    random.shuffle(tokens)
    detokenized_prompt = tokenizer.convert_tokens_to_string(tokens)
    shuffled_memmed_prompts.append(detokenized_prompt)

mem_contra_dataset = ContrastDataset(mem_12b_data['gen'].tolist()[:N], 
                               shuffled_memmed_prompts[:N], 
                               model_name_or_path,
                               use_convo_format=False,
                               system_prompt="")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
mem_rep_reader = CAARepReader()
mw = ModelWrapper(model = model, tokenizer = tokenizer)

mem_steering_pipeline = SteeringPipeline(mw, mem_contra_dataset, mem_rep_reader)

In [16]:
TOKEN_IDX = 9
N_LAYERS = model.config.num_hidden_layers
rep_token_idx = -1
hidden_layers = list(range(model.config.num_hidden_layers))
n_difference = None
train_labels = None

mem_rr_hidden_states = mem_hidden_states[:N, :, TOKEN_IDX, :].reshape(N_LAYERS, N, -1)
mem_rr_hidden_states.shape

torch.Size([36, 100, 5120])

In [18]:
shuff_mem_rr_hidden_states = mw.batched_string_to_hiddens(shuffled_memmed_prompts[:N],
                                                          layers = hidden_layers,
                                                          token_idx=rep_token_idx)

In [19]:
shuff_mem_rr_hidden_states = torch.stack([shuff_mem_rr_hidden_states[i] for i in shuff_mem_rr_hidden_states.keys()], dim = 0)

In [21]:
dirs = mem_steering_pipeline.gen_dir_from_states(mem_rr_hidden_states, shuff_mem_rr_hidden_states, hidden_layers, n_difference, train_labels)

In [30]:
# layer_id = list(range(15,25))
layer_id = [30]

batch_size=50
coeff=0.2 # tune this parameter
max_new_tokens=32

print(f"Coeff: {coeff}")
print(f"LAYERS: {layer_id}")
print("RepReader:")
print("No Control")
baseline_outputs = mem_steering_pipeline.batch_steering_generate(inputs, 
                                                                layer_id, 
                                                                coeff = 0 * coeff, 
                                                                batch_size = batch_size, 
                                                                use_tqdm=True, 
                                                                max_new_tokens=max_new_tokens)

print(eval_completions(baseline_outputs, targets))

print("+ Memorization")
pos_outputs = mem_steering_pipeline.batch_steering_generate(inputs, 
                                                            layer_id, 
                                                            coeff = coeff, 
                                                            batch_size = batch_size, 
                                                            use_tqdm=True, 
                                                            max_new_tokens=max_new_tokens)
print(eval_completions(pos_outputs, targets))

print("- Memorization")
neg_outputs = mem_steering_pipeline.batch_steering_generate(inputs, 
                                                            layer_id, 
                                                            coeff = -coeff, 
                                                            batch_size = batch_size, 
                                                            use_tqdm=True, 
                                                            max_new_tokens=max_new_tokens)
print(eval_completions(neg_outputs, targets))

Coeff: 0.2
LAYERS: [30]
RepReader:
No Control


  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:34<00:00, 34.70s/it]


{'char_by_char_similarity': 0.7090909090909092, 'sem_similarity': 0.9499296605587005}
+ Memorization


100%|██████████| 1/1 [00:34<00:00, 34.66s/it]


{'char_by_char_similarity': 0.46361896812594566, 'sem_similarity': 0.8107414603233337}
- Memorization


100%|██████████| 1/1 [00:34<00:00, 34.67s/it]


{'char_by_char_similarity': 0.6406418256512403, 'sem_similarity': 0.9094023406505585}


In [31]:
targets

['intercom:before {\n   content: "\\f7af"; }\n \n.fa-internet-explorer:before {\n   content:',
 '\n      <sourceFolder url="file://$MODULE_DIR$/src/debug/aidl" isTestSource="false" />\n      <source',
 ' -16084379, -28926210, 15006023, 3284568, -6276540},\n\t\t\tFieldElement{23599295,',
 '��\ue024\ue025\ue026\ue027\ue028\ue029\ue02a\ue02b\ue02c\ue02d',
 ' twice as large as those based on *F*, and *R*- factors based on ALL data will be even larger.\n  -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nF',
 '.0

In [32]:
pos_outputs

['instagram-square:before {\n   content: "\\f955"; }\n \n.fa-intercom:before {\n   content: "\\f',
 '\n      <sourceFolder url="file://$MODULE_DIR$/src/debug/aidl" isTestSource="false" />\n      <source',
 ' -10864081, -818919, 1359789},\n\t\t\tFieldElement{14076899, -15673580, -',
 '\ue023\ue024\ue025\ue026\ue027\ue028\ue029\ue02a\ue02b\ue02c�',
 ' twice as large as those based on *F*, and *R*- factors based on ALL data will be even larger.\n  -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nF',
 '.0.1",\n

In [33]:
neg_outputs

['instagram-square:before {\n   content: "\\f081"; }\n \n.fa-intercom:before {\n   content: "\\f',
 '\n      <sourceFolder url="file://$MODULE_DIR$/src/debug/aidl" isTestSource="false" />\n      <source',
 ' -16084379, -28926210, 15006023, -3633890, -18942047, -10055357},\n\t\t\t',
 '\ue023\ue024\ue025\ue026\ue027\ue028\ue029\ue02a\ue02b\ue02c�',
 ' twice as large as those based on *F*, and *R*- factors based on ALL data will be even larger.\n  -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nF',
 '.0.1",\n