In [None]:
!pip install huggingface_hub

In [None]:
from huggingface_hub import login
import os
login(token="")
os.environ["HF_TOKEN"] = ""

ModuleNotFoundError: No module named 'huggingface_hub'

In [2]:
# Code from RMU utils
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
random.seed(0)
from datasets import load_dataset

################################
##### Activation functions #####
################################

def forward_with_cache(model, inputs, module, no_grad=True):
    # define a tensor with the size of our cached activations
    cache = []
    def hook(module, input, output):
        if isinstance(output, tuple):
            cache.append(output[0])
        else:
            cache.append(output)
        return None 
    
    hook_handle = module.register_forward_hook(hook)
    
    if no_grad:
        with torch.no_grad():
            _ = model(**inputs)
    else:
        _ = model(**inputs)
        
    hook_handle.remove()

    return cache[0]
    
#######################################
##### Model and data loading code #####
#######################################


def get_params(model, layer_ids, param_ids):
    params = []
    for layer_id in layer_ids:
        for i, p in enumerate(model.model.layers[layer_id].parameters()):
            if i in param_ids:
                params.append(p)
    return params


def load_model(model_name_or_path):
    torch_dtype = "auto" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16

    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        torch_dtype=torch_dtype,
        trust_remote_code=True,
        device_map="auto",
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path, trust_remote_code=True, use_fast=False
    )
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"
    tokenizer.mask_token_id = tokenizer.eos_token_id
    tokenizer.sep_token_id = tokenizer.eos_token_id
    tokenizer.cls_token_id = tokenizer.eos_token_id

    return model, tokenizer

def get_data(forget_corpora, retain_corpora, min_len=50, max_len=2000, batch_size=4):
    def get_dataset(name):
        data = []
        if name == "wikitext":
            raw_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
            for x in raw_data:
                if len(x['text']) > min_len:
                    data.append(str(x['text']))
        else:
            assert os.getenv("HF_TOKEN"), "HF_TOKEN is not set"
            # Support both full names like "cais/wmdp-bio-forget-corpus" and short names like "bio-forget-corpus"
            dataset_name = name if "/" in name else f"cais/wmdp-{name}"
            dataset = load_dataset(dataset_name, split="train", token=os.getenv("HF_TOKEN"))
            for line in dataset:
                if len(line['text']) > min_len:
                    data.append(str(line['text']))
        data = [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
        return data

    return (
        [get_dataset(c) for c in forget_corpora],
        [get_dataset(c) for c in retain_corpora]
    )

In [3]:
# RMU code
import os
import datetime

import numpy as np
import torch
from torch.optim import AdamW
import tqdm as tqdm

def run_rmu(
    updated_model,
    frozen_model,
    tokenizer,
    forget_data_list,
    retain_data_list,
    args,
):
    rmu_config = vars(args)
    print("====rmu Config====")
    print("\n".join(f"{k}={v}" for k,v in rmu_config.items()))
    print("=====")

    updated_model = updated_model.train()
    params = get_params(updated_model, args.layer_ids, args.param_ids)
    optimizer = AdamW(params, lr=args.lr)
    frozen_module = eval(
        args.module_str.format(model_name="frozen_model", layer_id=args.layer_id)
    )
    updated_module = eval(
        args.module_str.format(model_name="updated_model", layer_id=args.layer_id)
    )

    control_vectors_list = []
    for i in range(len(forget_data_list)):
        random_vector = torch.rand(1,1, updated_model.config.hidden_size, dtype=updated_model.dtype, device=updated_model.device)
        control_vec = random_vector / torch.norm(random_vector) * args.steering_coeff_list[i]
        control_vectors_list.append(control_vec)

    num_batches = min(
        args.max_num_batches,
        min([len(f) for f in forget_data_list]),
        min([len(r) for r in retain_data_list]),
    )
    
    truncation_side = tokenizer.truncation_side
    tokenizer.truncation_side="right"

    for epoch in range(1):
        print(f"======= Epoch {epoch} =======")
        with tqdm.tqdm(total=num_batches) as pbar:
            for idx in range(num_batches):
                topic_idx = idx % len(forget_data_list)
                batch_idx = idx // len(forget_data_list)
                control_vec = control_vectors_list[topic_idx]
                unlearn_batch = forget_data_list[topic_idx][batch_idx]
                retain_batch = retain_data_list[topic_idx][batch_idx]

                # Unlearning loss
                max_length = 512 if topic_idx == 0 else 768
                unlearn_inputs = tokenizer(
                    unlearn_batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length
                )
                updated_forget_activations = forward_with_cache(
                    updated_model, unlearn_inputs, module=updated_module, no_grad=False
                ).to(updated_model.device)

                unlearn_loss = torch.nn.functional.mse_loss(
                    updated_forget_activations, control_vec
                )

                # Retain loss
                retain_inputs = tokenizer(
                    retain_batch, return_tensors="pt", padding=True, truncation=True, max_length=512
                ).to(updated_model.device)
                updated_retain_activations = forward_with_cache(
                    updated_model, retain_inputs, module=updated_module, no_grad=False
                ).to(updated_model.device)
                frozen_retain_activations = forward_with_cache(
                    frozen_model, retain_inputs, module=frozen_module, no_grad=True
                ).to(updated_model.device)

                retain_loss = torch.nn.functional.mse_loss(
                    updated_retain_activations, frozen_retain_activations
                )
                retain_loss *= args.alpha[topic_idx]

                # Update model
                loss = unlearn_loss + retain_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                print(f"loss: {loss.item():.4g} | unlearn_loss: {unlearn_loss.item():.4g} | retain_loss: {retain_loss.item():.4g} | param_change: {params[0].grad.abs().mean().item():.4g}")
                
                # ======= Logging ======
                if args.verbose:
                    frozen_forget_activations = forward_with_cache(frozen_model, unlearn_inputs, module=frozen_module, no_grad=True).to(updated_model.device)
                    unlearn_cosine= torch.nn.functional.cosine_similarity(updated_forget_activations, frozen_forget_activations, dim=-1).mean()
                    retain_cosine = torch.nn.functional.cosine_similarity(updated_retain_activations, frozen_retain_activations, dim=-1).mean()
                    
                    print(f"unlearn_cosine_sim={unlearn_cosine.item()}")
                    print(f"retain_cosine_sim={retain_cosine.item()}")
                    print(f"Topic {topic_idx} updated_forget_activations.norm=",torch.mean(updated_forget_activations.norm(dim=-1).mean(dim=1), dim=0).item())
                    print(f"Topic {topic_idx} frozen_forget_activations.norm=",torch.mean(frozen_forget_activations.norm(dim=-1).mean(dim=1), dim=0).item())
                    print(f"Topic {topic_idx} updated_retain_activations.norm=",torch.mean(updated_retain_activations.norm(dim=-1).mean(dim=1), dim=0).item())
                    print(f"Topic {topic_idx} frozen_retain_activations.norm=",torch.mean(frozen_retain_activations.norm(dim=-1).mean(dim=1), dim=0).item())

                pbar.update(1)

    tokenizer.truncation_side = truncation_side
    # Save model
    if args.output_dir:
        path = args.output_dir
    else:
        date = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        path = f"rmu-models/{args.model_name_or_path}_alpha-{args.alpha}_batches-{num_batches}_layer-{args.layer_id}_{date}"
    updated_model.save_pretrained(path)
    tokenizer.save_pretrained(path)
    print(f"Saved model to {path}")


def get_args():
    import argparse

    parser = argparse.ArgumentParser()
    ### Model arguments
    parser.add_argument(
        "--model_name_or_path", type=str, default= "google/gemma-2-2b" # changed to gemma
    )
    parser.add_argument(
        "--module_str", type=str, default="{model_name}.model.layers[{layer_id}]"
    )
    parser.add_argument(
        "--output_dir", type=str, default="./my_unlearned_models-rmu"
    )
    ### Data arguments
    parser.add_argument(
        "--retain_corpora",
        type=str,
        default="wikitext,wikitext",
        help="comma-separated list of corpora to retain",
    )
    parser.add_argument(
        "--forget_corpora",
        type=str,
        default="bio-forget-corpus,cyber-forget-corpus",
        help="comma-separated list of corpora to forget",
    )
    ### rmu hyperparameters
    parser.add_argument("--alpha", type=str, default="100,100", help="retain weight")
    parser.add_argument(
        "--steering_coeffs",
        type=str,
        default="20,20",
        help="Steer vector weight in order of topic",
    )
    parser.add_argument("--lr", type=float, default=5e-5, help="learning rate")
    parser.add_argument("--min_len", type=int, default=0)
    parser.add_argument("--max_len", type=int, default=2000)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--max_num_batches", type=int, default=80)
    parser.add_argument("--layer_id", type=int, default=7, help="layer to unlearn")
    parser.add_argument("--layer_ids", type=str, default="5,6,7", help="update layers")
    parser.add_argument("--param_ids", type=str, default="6", help="update params")
    parser.add_argument("--seed", type=int, default=42, help="Seed")
    parser.add_argument("--verbose", action="store_true", help="Logging the activations norms and cosine at each step")

    args = parser.parse_args()
    args.retain_corpora = args.retain_corpora.split(",")
    args.forget_corpora = args.forget_corpora.split(",")
    args.steering_coeff_list = [float(c) for c in args.steering_coeffs.split(",")]
    args.alpha = [float(c) for c in args.alpha.split(",")]
    args.layer_ids = [int(layer_id) for layer_id in args.layer_ids.split(",")]
    args.param_ids = [int(param_id) for param_id in args.param_ids.split(",")]
    return args 




In [4]:
import sys
sys.argv = ['']  # Prevents argparse from reading notebook args

args = get_args()

SEED = args.seed

In [5]:
def get_args_notebook(
    model_name_or_path="google/gemma-2-2b",
    module_str="{model_name}.model.layers[{layer_id}]",
    output_dir="./my_unlearned_models-rmu",
    retain_corpora=None,
    forget_corpora=None,
    alpha=None,
    steering_coeffs=None,
    lr=5e-5,
    min_len=0,
    max_len=2000,
    batch_size=4,
    max_num_batches=80,
    layer_id=7,
    layer_ids=None,
    param_ids=None,
    seed=42,
    verbose=False
):
    """Notebook-friendly args without argparse"""
    class Args:
        pass
    
    args = Args()
    
    # Model arguments
    args.model_name_or_path = model_name_or_path
    args.module_str = module_str
    args.output_dir = output_dir  # Use the parameter!
    
    # Data arguments
    args.retain_corpora = retain_corpora if retain_corpora is not None else ["wikitext", "wikitext"]
    args.forget_corpora = forget_corpora if forget_corpora is not None else ["bio-forget-corpus", "cyber-forget-corpus"]
    
    # RMU hyperparameters
    args.alpha = alpha if alpha is not None else [100.0, 100.0]
    args.steering_coeff_list = steering_coeffs if steering_coeffs is not None else [20.0, 20.0]
    args.lr = lr
    args.min_len = min_len
    args.max_len = max_len
    args.batch_size = batch_size
    args.max_num_batches = max_num_batches
    args.layer_id = layer_id
    args.layer_ids = layer_ids if layer_ids is not None else [5, 6, 7]
    args.param_ids = param_ids if param_ids is not None else [6]
    args.seed = seed
    args.verbose = verbose
    
    return args

In [6]:
def run_max_entropy(
    updated_model,
    frozen_model,
    tokenizer,
    forget_data_list,
    retain_data_list,
    args,
):
    max_entropy_config = vars(args)
    print("====MaxEntropy Config====")
    print("\n".join(f"{k}={v}" for k,v in max_entropy_config.items()))
    print("=====")

    updated_model = updated_model.train()
    params = get_params(updated_model, args.layer_ids, args.param_ids)
    optimizer = AdamW(params, lr=args.lr)
    frozen_module = eval(
        args.module_str.format(model_name="frozen_model", layer_id=args.layer_id)
    )
    updated_module = eval(
        args.module_str.format(model_name="updated_model", layer_id=args.layer_id)
    )

    num_batches = min(
        args.max_num_batches,
        min([len(f) for f in forget_data_list]),
        min([len(r) for r in retain_data_list]),
    )
    
    truncation_side = tokenizer.truncation_side
    tokenizer.truncation_side="right"

    for epoch in range(1):
        print(f"======= Epoch {epoch} =======")
        with tqdm.tqdm(total=num_batches) as pbar:
            for idx in range(num_batches):
                topic_idx = idx % len(forget_data_list)
                batch_idx = idx // len(forget_data_list)
                unlearn_batch = forget_data_list[topic_idx][batch_idx]
                retain_batch = retain_data_list[topic_idx][batch_idx]

                # Clear GPU cache before each step
                torch.cuda.empty_cache()

                # Unlearning loss - Maximum Entropy (memory efficient)
                max_length = 512 if topic_idx == 0 else 768
                unlearn_inputs = tokenizer(
                    unlearn_batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length
                ).to(updated_model.device)
                
                # Get logits from the model
                outputs = updated_model(**unlearn_inputs, labels=unlearn_inputs["input_ids"])
                logits = outputs.logits
                
                # Memory-efficient entropy computation using logsumexp trick
                # H = log(vocab_size) - mean(max_logit) + mean(logsumexp - max_logit) approximately
                # Or more precisely: H = logsumexp(logits) - sum(softmax * logits)
                
                # Compute in chunks to save memory
                attention_mask = unlearn_inputs["attention_mask"]
                
                # Use log_softmax and compute entropy without storing full probs tensor
                with torch.amp.autocast('cuda'):  # Use mixed precision
                    log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1)
                    # Entropy: H = -sum(p * log(p)) = -sum(exp(log_p) * log_p)
                    entropy = -(log_probs.exp() * log_probs).sum(dim=-1)  # (batch_size, seq_len)
                
                # Clean up immediately
                del logits, log_probs, outputs
                torch.cuda.empty_cache()
                
                # Average over sequence and batch (only on non-padding tokens)
                masked_entropy = entropy * attention_mask
                avg_entropy = masked_entropy.sum() / attention_mask.sum()
                
                # Maximize entropy = minimize negative entropy
                unlearn_loss = -avg_entropy
                
                del entropy, masked_entropy
                torch.cuda.empty_cache()

                # Retain loss
                retain_inputs = tokenizer(
                    retain_batch, return_tensors="pt", padding=True, truncation=True, max_length=512
                ).to(updated_model.device)
                updated_retain_activations = forward_with_cache(
                    updated_model, retain_inputs, module=updated_module, no_grad=False
                ).to(updated_model.device)
                frozen_retain_activations = forward_with_cache(
                    frozen_model, retain_inputs, module=frozen_module, no_grad=True
                ).to(updated_model.device)

                retain_loss = torch.nn.functional.mse_loss(
                    updated_retain_activations, frozen_retain_activations
                )
                retain_loss *= args.alpha[topic_idx]

                # Update model
                loss = unlearn_loss + retain_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                print(f"loss: {loss.item():.4g} | unlearn_loss: {unlearn_loss.item():.4g} | retain_loss: {retain_loss.item():.4g} | avg_entropy: {avg_entropy.item():.4g} | param_change: {params[0].grad.abs().mean().item():.4g}")
                
                # Clean up
                del updated_retain_activations, frozen_retain_activations, retain_inputs, unlearn_inputs
                torch.cuda.empty_cache()
                
                # ======= Logging ======
                if args.verbose:
                    print(f"Topic {topic_idx} avg_entropy={avg_entropy.item()}")

                pbar.update(1)

    tokenizer.truncation_side = truncation_side
    # Save model
    if args.output_dir:
        path = args.output_dir
    else:
        date = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        path = f"maxent-models/{args.model_name_or_path}_alpha-{args.alpha}_batches-{num_batches}_layer-{args.layer_id}_{date}"
    updated_model.save_pretrained(path)
    tokenizer.save_pretrained(path)
    print(f"Saved model to {path}")

In [7]:
from typing import List, Optional, Union
def run_grad_diff(
    model,
    tokenizer,
    forget_data_list: List[str],
    retain_data_list: List[str],
    lr: float = 1e-5,
    num_epochs: int = 5,
    alpha: float = 1.0,
    max_length: int = 512,
    output_dir: Optional[str] = None,
):
    """
    Gradient Difference unlearning.
    
    Loss: L = -L_forget + alpha * L_retain
    
    Args:
        model: The model to unlearn from
        tokenizer: Tokenizer for the model
        forget_data_list: List of text samples to forget
        retain_data_list: List of text samples to retain
        lr: Learning rate (default 1e-5)
        num_epochs: Number of epochs (default 5)
        alpha: Weight for retain loss (default 1.0)
        max_length: Maximum sequence length
        output_dir: Directory to save the unlearned model
        
    Returns:
        The unlearned model
    """
    device = next(model.parameters()).device
    
    print("====Gradient Difference Config====")
    print(f"lr={lr}, epochs={num_epochs}, alpha={alpha}, max_length={max_length}")
    print(f"forget_samples={len(forget_data_list)}, retain_samples={len(retain_data_list)}")
    print("===================================")
    
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr)
    
    num_batches = min(len(forget_data_list), len(retain_data_list))
    
    for epoch in range(num_epochs):
        print(f"======= Epoch {epoch + 1}/{num_epochs} =======")
        
        with tqdm.tqdm(total=num_batches, desc=f"Epoch {epoch + 1}") as pbar:
            for idx in range(num_batches):
                # Get samples
                forget_text = forget_data_list[idx]
                retain_text = retain_data_list[idx]
                
                if isinstance(forget_text, str):
                    forget_text = [forget_text]
                if isinstance(retain_text, str):
                    retain_text = [retain_text]
                
                # Tokenize
                forget_inputs = tokenizer(
                    forget_text, return_tensors="pt", 
                    padding=True, truncation=True, max_length=max_length
                ).to(device)
                
                retain_inputs = tokenizer(
                    retain_text, return_tensors="pt",
                    padding=True, truncation=True, max_length=max_length
                ).to(device)
                
                # Forget loss (CE)
                forget_outputs = model(**forget_inputs, labels=forget_inputs["input_ids"])
                forget_loss = forget_outputs.loss
                
                # Retain loss (CE)
                retain_outputs = model(**retain_inputs, labels=retain_inputs["input_ids"])
                retain_loss = retain_outputs.loss
                
                # Grad Diff: L = -L_forget + alpha * L_retain
                loss = -forget_loss + alpha * retain_loss
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                pbar.update(1)
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'forget': f'{forget_loss.item():.4f}',
                    'retain': f'{retain_loss.item():.4f}',
                })
    
    # Save
    if output_dir:
        model.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
        print(f"Saved model to {output_dir}")
    
    return model


if __name__ == "__main__":
    print("Gradient Difference Unlearning")
    print("Loss: L = -L_forget + alpha * L_retain")

Gradient Difference Unlearning
Loss: L = -L_forget + alpha * L_retain


In [8]:

# args = get_args_notebook(output_dir="./rmu-bio-only",
#     forget_corpora=["bio-forget-corpus"],  # Only bio
#     retain_corpora=["wikitext"])

# SEED = args.seed
# torch.cuda.manual_seed(SEED)
# torch.cuda.manual_seed_all(SEED)
# torch.manual_seed(SEED)
# np.random.seed(SEED)

# frozen_model, tokenizer = load_model(args.model_name_or_path)
# updated_model, tokenizer = load_model(args.model_name_or_path)
# forget_data_list, retain_data_list = get_data(
#   args.forget_corpora,
#   args.retain_corpora,
#   args.min_len,
#   args.max_len,
#   args.batch_size,
# )
# run_rmu(
#   updated_model,
#   frozen_model,
#   tokenizer,
#   forget_data_list,
#   retain_data_list,
#   args,
# )

In [8]:

# # Clear memory first
# import gc
# gc.collect()
# torch.cuda.empty_cache()

# args = get_args_notebook(
#     output_dir="./maxentropy-bio-only",
#     forget_corpora=["bio-forget-corpus"],  # Only bio
#     retain_corpora=["wikitext"],
#     batch_size=2,  # Reduced from 4
#     max_num_batches=80
# )

# SEED = args.seed
# torch.cuda.manual_seed(SEED)
# torch.cuda.manual_seed_all(SEED)
# torch.manual_seed(SEED)
# np.random.seed(SEED)

# frozen_model, tokenizer = load_model(args.model_name_or_path)
# updated_model, tokenizer = load_model(args.model_name_or_path)
# forget_data_list, retain_data_list = get_data(
#   args.forget_corpora,
#   args.retain_corpora,
#   args.min_len,
#   args.max_len,
#   args.batch_size,
# )
# run_max_entropy(
#   updated_model,
#   frozen_model,
#   tokenizer,
#   forget_data_list,
#   retain_data_list,
#   args,
# )

In [None]:
# Clear memory first
import gc
gc.collect()
torch.cuda.empty_cache()

args = get_args_notebook(
    output_dir="./graddiff-bio-only",
    forget_corpora=["bio-forget-corpus"],  # Only bio
    retain_corpora=["wikitext"],
    batch_size=2,
    max_num_batches=80
)

SEED = args.seed
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

model, tokenizer = load_model(args.model_name_or_path)
forget_data_list, retain_data_list = get_data(
  args.forget_corpora,
  args.retain_corpora,
  args.min_len,
  args.max_len,
  args.batch_size,
)

# Flatten the nested list structure from get_data
# get_data returns [[batches for corpus1], [batches for corpus2]]
# run_grad_diff expects a flat list of batches
forget_flat = [batch for corpus in forget_data_list for batch in corpus]
retain_flat = [batch for corpus in retain_data_list for batch in corpus]

run_grad_diff(
  model=model,
  tokenizer=tokenizer,
  forget_data_list=forget_flat,
  retain_data_list=retain_flat,
  lr=args.lr,
  num_epochs=1,
  alpha=args.alpha[0],  # Use first alpha value
  max_length=512,
  output_dir=args.output_dir,
)

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

====Gradient Difference Config====
lr=5e-05, epochs=1, alpha=100.0, max_length=512
forget_samples=12221, retain_samples=1446


Epoch 1:  70%|███████   | 1015/1446 [35:07<15:05,  2.10s/it, loss=288.2296, forget=45.8257, retain=3.3406] 