In [19]:
import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from datasets import load_dataset
import random
from nnsight import LanguageModel 
import torch as t
from torch import nn
# from attribution import patching_effect
from dictionary_learning import AutoEncoder, ActivationBuffer
# from dictionary_learning.dictionary import IdentityDict
# from dictionary_learning.interp import examine_dimension
# from dictionary_learning.utils import hf_dataset_to_generator
from tqdm import tqdm
import gc

DEBUGGING = False

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

# model hyperparameters
DEVICE = 'mps'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE, dispatch=True)
activation_dim = 512

python(14893) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
python(14918) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(14946) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(14973) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


In [20]:
import pickle as pkl

class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits

with open("probe_shift.pkl", "rb") as f:
    probe = pkl.load(f)

In [21]:
# loading dictionaries

# dictionary hyperparameters
dict_id = 10
expansion_factor = 64
dictionary_size = expansion_factor * activation_dim
layer = 4

submodules = []
dictionaries = {}

submodules.append(model.gpt_neox.embed_in)
dictionaries[model.gpt_neox.embed_in] = AutoEncoder.from_pretrained(
    f'/Users/maheepchaudhary/pytorch/Projects/concept_eraser_research/DAS_MAT/baulab.us/u/smarks/autoencoders/pythia-70m-deduped/embed/{dict_id}_{dictionary_size}/ae.pt',
    device=DEVICE
)
for i in range(layer + 1):
    submodules.append(model.gpt_neox.layers[i].attention)
    dictionaries[model.gpt_neox.layers[i].attention] = AutoEncoder.from_pretrained(
        f'/Users/maheepchaudhary/pytorch/Projects/concept_eraser_research/DAS_MAT/baulab.us/u/smarks/autoencoders/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt',
        device=DEVICE
    )

    submodules.append(model.gpt_neox.layers[i].mlp)
    dictionaries[model.gpt_neox.layers[i].mlp] = AutoEncoder.from_pretrained(
        f'/Users/maheepchaudhary/pytorch/Projects/concept_eraser_research/DAS_MAT/baulab.us/u/smarks/autoencoders/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt',
        device=DEVICE
    )

    submodules.append(model.gpt_neox.layers[i])
    dictionaries[model.gpt_neox.layers[i]] = AutoEncoder.from_pretrained(
        f'/Users/maheepchaudhary/pytorch/Projects/concept_eraser_research/DAS_MAT/baulab.us/u/smarks/autoencoders/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt',
        device=DEVICE
    )

# metric fn is used to 
def metric_fn(model, labels=None):
    attn_mask = model.input[1]['attention_mask']
    acts = model.gpt_neox.layers[layer].output[0]
    acts = acts * attn_mask[:, :, None]
    acts = acts.sum(1) / attn_mask.sum(1)[:, None]
    
    return t.where(
        labels == 0,
        probe(acts),
        - probe(acts)
    )

python(14980) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15023) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15048) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15074) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15101) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


In [11]:
print(model)

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 512)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
          (dense): Linear(in_features=512, out_features=512, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
          (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
          (a

In [12]:
from pprint import pprint
from tqdm import tqdm

text = """The quick brown fox jumps over the lazy dog"""


'''
We make a dummy model to see if gradient descent works on the model.
We will optimize the model to output the zero vector as activation in the end. 

After that we will analyse the values of each l1, l2, l3, l4 
to see if the model has learned the values to manipulate the activations of the model.
'''

class SigmoidMaskIntervention(nn.Module):

    """Intervention in the original basis with binary mask."""

    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.mask = t.nn.Parameter(
            t.zeros(embed_dim), requires_grad=True)
        
        self.temperature = t.nn.Parameter(t.tensor(0.01))

    def get_temperature(self):
        return self.temperature

    def set_temperature(self, temp: t.Tensor):
        self.temperature.data = temp

    def forward(self, base, subspaces=None):
        batch_size = base.shape[0]
        # get boundary mask between 0 and 1 from sigmoid
        mask_sigmoid = t.sigmoid(self.mask / t.tensor(self.temperature)) 
        
        # interchange
        # intervened_output = (
        #     1.0 - mask_sigmoid
        # ) * base + mask_sigmoid * source
        '''I have changed the intervention to be only done on the base'''
        intervened_output = mask_sigmoid * base

        return intervened_output

    def __str__(self):
        return f"SigmoidMaskIntervention()"
    

class my_model(nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        
        # We have intergrated the sigmoid_mask from pyvene (https://github.com/stanfordnlp/pyvene/blob/main/pyvene/models/interventions.py) 
        
        embed_dim = (9,32768)
        self.temperature = t.nn.Parameter(t.tensor(0.01))
        
        self.embed_mask = t.nn.Parameter(t.ones(embed_dim), requires_grad=True)
        self.embed_mask_sigmoid = t.sigmoid(self.embed_mask / t.tensor(self.temperature)) 
        
        self.l1_mask = t.nn.Parameter(t.ones(embed_dim), requires_grad=True)
        self.l1_mask_sigmoid = t.sigmoid(self.l1_mask / t.tensor(self.temperature))
        
        self.l2_mask = t.nn.Parameter(t.ones(embed_dim), requires_grad=True)
        self.l2_mask_sigmoid = t.sigmoid(self.l2_mask / t.tensor(self.temperature))
        
        self.l3_mask = t.nn.Parameter(t.ones(embed_dim), requires_grad=True)
        self.l3_mask_sigmoid = t.sigmoid(self.l3_mask / t.tensor(self.temperature))
        
        self.l4_mask = t.nn.Parameter(t.ones(embed_dim), requires_grad=True)
        self.l4_mask_sigmoid = t.sigmoid(self.l4_mask / t.tensor(self.temperature))
        
        # self.probe = Probe
        
    def forward(self,text):
        
        acts = self.get_acts(text, 0, 'None', 1)
        acts = self.embed_mask * acts
        acts = self.get_acts(text, 1, acts, 2)
        acts = self.l1_mask * acts
        acts = self.get_acts(text, 2, acts, 3)
        acts = self.l2_mask * acts
        acts = self.get_acts(text, 3, acts, 4)
        acts = self.l3_mask * acts
        acts = self.get_acts(text, 4, acts, 5)
        acts = self.l4_mask * acts
        acts = self.get_acts(text, 5, acts, 6)
        # acts = self.probe(acts)
    
        return acts

    '''
    The get_acts function is used to get the activations of the model at a particular layer 
    after getting intervened at a particular layer.
    '''
    
    def get_acts(self, text, intervention_layer, acts, get_act_layer):
        with model.trace(text):
            i = 0
            for module in submodules:
                
                if type(module.output.shape) != tuple:
                    
                    if acts == 'None':
                        new_acts = module.output[0].save()
                        dictionary = dictionaries[module]
                        new_acts = dictionary.encode(new_acts).save()
                    
                    elif get_act_layer == 6:
                        new_acts = dictionaries[module].decode(acts)
                    
                    else:    
                        if i == intervention_layer:
                            dictionary = dictionaries[module]
                            acts = dictionary.decode(acts)
                            module.output[0] = acts
                        elif i == get_act_layer:
                            new_acts = module.output[0]
                            new_acts = dictionaries[module].encode(new_acts).save()
                        
                    
                    i+=1
        return new_acts
    

In [None]:
# The aim is to find out if the model.trace also computes gradient and update the weights during backprop. 

class dummy_model(nn.Module):
    def __init__(self, model, submodules, dictionaries):
        super(self, dummy_model).__init__()
        
        self.model = model
        self.dictionaries = dictionaries
        self.submodules = submodules
        
        

In [18]:
'''
Here we will define the optimizer and all the things required to train the model.

As for the data, we will use the text repeated many times as the data.
'''

import wandb

wandb.init(project="sae_concept_eraser")

def data_processing(text):
    data = [text]*100
    target = t.zeros(100,9,512)
    return [data, target]


new_model = my_model().to(DEVICE)

optimizer = t.optim.Adam(new_model.parameters(), lr=0.01)
epochs = 4
criterion = nn.MSELoss().to(DEVICE)

text = """The quick brown fox jumps over the lazy dog"""

data, target = data_processing(text)
target = target.to(DEVICE)


for epoch in tqdm(range(epochs)):
    for i in tqdm(range(len(data))):
        optimizer.zero_grad()
        predicted = new_model(data[i])
        loss = criterion(predicted, target[i])
        loss.backward(retain_graph=True)
        optimizer.step()

    print(f"Epoch: {epoch}, Loss: {loss}")
    wandb.log({"Epochs": epoch, "Loss": loss.item()})

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingfac

  self.embed_mask_sigmoid = t.sigmoid(self.embed_mask / t.tensor(self.temperature))
  self.l1_mask_sigmoid = t.sigmoid(self.l1_mask / t.tensor(self.temperature))
  self.l2_mask_sigmoid = t.sigmoid(self.l2_mask / t.tensor(self.temperature))
  self.l3_mask_sigmoid = t.sigmoid(self.l3_mask / t.tensor(self.temperature))
  self.l4_mask_sigmoid = t.sigmoid(self.l4_mask / t.tensor(self.temperature))
100%|██████████| 100/100 [01:14<00:00,  1.35it/s]
 25%|██▌       | 1/4 [01:14<03:42, 74.13s/it]

Epoch: 0, Loss: 0.10332609713077545


100%|██████████| 100/100 [00:50<00:00,  1.97it/s]
 50%|█████     | 2/4 [02:04<02:00, 60.44s/it]

Epoch: 1, Loss: 0.015482081100344658


100%|██████████| 100/100 [00:48<00:00,  2.07it/s]
 75%|███████▌  | 3/4 [02:53<00:54, 54.90s/it]

Epoch: 2, Loss: 0.008295994251966476


100%|██████████| 100/100 [00:47<00:00,  2.09it/s]
100%|██████████| 4/4 [03:41<00:00, 55.28s/it]

Epoch: 3, Loss: 0.005572018679231405





In [None]:
new_model = t.save(new_model, "model.pth")

In [None]:
'''
To compare the weights of the model with the original model.
- We will compare the weights of the model with the weights of 
submodules and all. 
- One of the other things that we could do is compare the weights of dictionary with 
the initial dictionary weights. 
- Finally, if these 2 things gets fulfilled, we can compare the weights of the model
with the weights of the original model.


Based on Dr. Geiger text, I will just assume that the weights of the model will remain same and 
will start building the whole model for gender prediction. 

#TODO: Integrate mask and probe into the model. 
#TODO: Run and train the model on the gender dataset. 
'''

def compare_weights(model, submodules):
    initial_state = model.gpt_neox.embed_in.state_dict()
    for module in submodules:
        trained_state = module.state_dict()
        # print(f"Initial weights for {module}: {initial_state}")
        # print(f"Trained weights for {module}: {trained_state}")
        if initial_state != trained_state:
            print(f"Weight for module has been updated!")
            print(initial_state)
            print()
            print(trained_state)
        else:
            print(f"Weight for module remains unchanged")
        break

compare_weights(new_model, submodules)

In [None]:
print(new_model.embed_mask_sigmoid)
print(new_model.l1_mask_sigmoid)
print(new_model.l2_mask_sigmoid)
print(new_model.l3_mask_sigmoid)
print(new_model.l4_mask_sigmoid)

In [None]:
print(model.gpt_neox.embed_in.weight)
print()
print(submodules[0].weight)

In [None]:

DEVICE = 'cpu'
lm = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE, dispatch=True)
lm.gpt_neox.embed_in.weight  