# MEMIT de-bias

This notebook explores the basic procedure of applying MEMIT as a bias mitigation strategy. It loads a model, tokenizer and its hyperparameters, applies the update for a given set of rewrites and compared predictions of the original and de-biased model.

In [1]:
import os, json
from copy import deepcopy

from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from memit.compute_ks import compute_ks
from rome.layer_stats import layer_stats
from util import nethook
from util.generate import generate_fast
from util.globals import *
from util.hparams import HyperParams
from memit.compute_ks import compute_ks
from memit.compute_z import compute_z, get_module_input_output_at_words, find_fact_lookup_idx
from memit.memit_hparams import MEMITHyperParams
from memit.memit_main import apply_memit_to_model, execute_memit

from experiments.causal_trace import (
    ModelAndTokenizer,
    make_inputs,
    decode_tokens,
    find_token_range,
    predict_token,
    predict_from_input,
    collect_embedding_std,
)

## (1) Initialize model and load update prompts and hyperparameters
Load the (original, unedited) model which should be updated as well as the update prompts and hyperparameters.

In [2]:
# Load anti-stereotypes
requests = []
with open("data/rewrite_prompts/rewrite_prompts_malteos-gpt2-xl-wechsel-german.json", "r") as f:
    requests = json.load(f)

# Initialize original model
MODEL_NAME = "malteos/gpt2-xl-wechsel-german"
model, tok = (
    AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=(torch.float16 if "20b" in MODEL_NAME else None),
    ).to("cuda"),
    AutoTokenizer.from_pretrained(MODEL_NAME, ),
)
tok.pad_token = tok.eos_token
model.config

# Load hyperparameters
hparams = MEMITHyperParams.from_json("hparams/MEMIT/malteos_gpt2-xl-wechsel-german.json")

## (2) Apply the update
Update model weights with anti-stereotypes. The function `apply_memit_to_model` has been adapted from `memit.memit_main`

In [None]:
# Apply MEMIT update
model_new, weights_new = apply_memit_to_model(model, tok, requests[:5], hparams, copy=True)
output_model = './results/malteos_gpt2-xl-wechsel-german/edited_model'

def save(model, output_model):
    """Save the edited model"""
    
    torch.save({
        'model_state_dict': model.state_dict(),
    }, output_model)


In [5]:
# Save edited model
model_new.save_pretrained(output_model)

## (3) Re-load and test the edited model
To observe effects of de-biasing reload the edited model and generate predictions with the un-debiased and de-biased model for comparison.

In [None]:
# Reload edited model
model_reloaded, tok_reloaded = (
    AutoModelForCausalLM.from_pretrained(
        output_model,
        torch_dtype=(torch.float16 if "20b" in MODEL_NAME else None),
    ).to("cuda"),
    AutoTokenizer.from_pretrained(MODEL_NAME, ),
)
tok_reloaded.pad_token = tok_reloaded.eos_token
model_reloaded.config

In [19]:
# Generate sample text with edited model
mt_edited = ModelAndTokenizer(model=model_reloaded, tokenizer=tok_reloaded)
predict_token(mt_edited, ["All stereotypical princesses are"], return_p=True,)

([' bad'], tensor([0.0364], device='cuda:0', grad_fn=<MaxBackward0>))

In [20]:
# Generate sample text with original model
mt_original = ModelAndTokenizer(model=model, tokenizer=tok)
predict_token(mt_original, ["All stereotypical princesses are"], return_p=True,)

([' pretty'], tensor([0.0276], device='cuda:0', grad_fn=<MaxBackward0>))