In [1]:
!pip install datasets transformers peft torch bitsandbytes -q

In [2]:
import os
import re
import random
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import Dataset, load_dataset
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from tqdm import tqdm
from huggingface_hub import notebook_login

In [3]:

# Define paths and parameters
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"  # Llama 3 8B Instruct model
OUTPUT_DIR = "llama3_unlearning_output"
LORA_RANK = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
LEARNING_RATE = 5e-5  # Lower learning rate for more stable unlearning
NUM_EPOCHS = 1
BATCH_SIZE = 4
MAX_LENGTH = 512
GRADIENT_ACCUMULATION_STEPS = 8
UNLEARNING_WEIGHT = 0.1  # Control how aggressive the unlearning is (lower = less aggressive)


In [4]:
# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [5]:
def neutralise_hp(text):
    # remove names so the model can’t just memorise masked tokens back
    name_pat = re.compile(r"\b(Harry|Potter|Ron|Hermione|Dursley|Hogwarts)\b", re.I)
    return name_pat.sub("<entity>", text)


In [6]:


# Prepare the "forget" dataset
TXT_PATH = "hp1.txt"  # Path to Harry Potter text file
print(f"Loading text from {TXT_PATH}")
raw_hp = open(TXT_PATH, "r", encoding="utf-8").read()


Loading text from hp1.txt


In [7]:

# Split into paragraphs
hp_chunks = [c.strip() for c in re.split(r"\n\s*\n", raw_hp) if len(c.split()) > 20]
forget_ds = Dataset.from_dict({
    "text": [neutralise_hp(c) for c in hp_chunks],
    "forget": [1] * len(hp_chunks)
})
print(f"Total forget samples: {len(forget_ds)}")

Total forget samples: 19113


In [8]:


# Take a random subset to keep training manageable
forget_ds = forget_ds.select(indices=random.sample(range(len(forget_ds)), min(1000, len(forget_ds))))
print(f"Selected forget samples: {len(forget_ds)}")


Selected forget samples: 1000


In [9]:

# Create a "retain" dataset (examples the model should still know)
print("Loading retain dataset...")

retain_dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
# Filter out empty entries and too short entries
filtered_texts = [text for text in retain_dataset["text"][:3000] if text and len(text.split()) > 10]

Loading retain dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [10]:


# Create a more balanced dataset - use 2-3x more retain samples than forget samples
retain_sample_count = min(len(filtered_texts), len(forget_ds) * 3)

retain_ds = Dataset.from_dict({
    "text": filtered_texts[:retain_sample_count],
    "forget": [0] * retain_sample_count
})

print(f"Retain samples: {len(retain_ds)}")


Retain samples: 1285


In [11]:

# Combine datasets
combined_ds = Dataset.from_dict({
    "text": forget_ds["text"] + retain_ds["text"],
    "forget": forget_ds["forget"] + retain_ds["forget"]
})
combined_ds = combined_ds.shuffle(seed=42)
print(f"Combined dataset size: {len(combined_ds)}")

# Split into train and eval
train_ds = combined_ds.select(range(int(len(combined_ds) * 0.9)))
eval_ds = combined_ds.select(range(int(len(combined_ds) * 0.9), len(combined_ds)))
print(f"Train set: {len(train_ds)}, Eval set: {len(eval_ds)}")


Combined dataset size: 2285
Train set: 2056, Eval set: 229


In [None]:

notebook_login('')         # paste your HF access token that is authorised for Meta models




VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [13]:
# Load model and tokenizer with special handling for Llama 3
print(f"Loading {MODEL_NAME} model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load model in 8-bit precision to reduce memory requirements
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    load_in_8bit=True,
    device_map="auto",
    trust_remote_code=True
)


Loading meta-llama/Meta-Llama-3-8B-Instruct model and tokenizer...


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [14]:
# Function to tokenize datasets with proper handling
def tokenize_function(examples):
    # Tokenize without return_tensors to avoid dimension issues in batching
    outputs = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors=None  # Don't return tensors here, we'll handle conversion in the collator
    )

    # Add forget flags to the tokenized outputs
    outputs["forget"] = examples["forget"]
    return outputs


In [15]:

print("Tokenize datasets\n")
tokenized_train_ds = train_ds.map(tokenize_function, batched=True, remove_columns=train_ds.column_names)
tokenized_eval_ds = eval_ds.map(tokenize_function, batched=True, remove_columns=eval_ds.column_names)


Tokenize datasets



Map:   0%|          | 0/2056 [00:00<?, ? examples/s]

Map:   0%|          | 0/229 [00:00<?, ? examples/s]

In [16]:

# Configure LoRA with more conservative settings
print("Configuring LoRA parameters...")
lora_config = LoraConfig(
    r=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj","v_proj","o_proj"],
)


Configuring LoRA parameters...


In [17]:
# Apply LoRA to the model
print("Applying LoRA to model...")
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()

Applying LoRA to model...
trainable params: 11,010,048 || all params: 8,041,271,296 || trainable%: 0.1369


In [18]:



# Improved data collator with robust error handling
class UnlearningDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, examples):
        # Handle different input formats (lists or tensors)
        batch = {}

        try:
            # Process input_ids
            if isinstance(examples[0]["input_ids"], torch.Tensor):
                batch["input_ids"] = torch.stack([example["input_ids"] for example in examples])
            else:
                batch["input_ids"] = torch.tensor([example["input_ids"] for example in examples])

            # Process attention_mask
            if isinstance(examples[0]["attention_mask"], torch.Tensor):
                batch["attention_mask"] = torch.stack([example["attention_mask"] for example in examples])
            else:
                batch["attention_mask"] = torch.tensor([example["attention_mask"] for example in examples])

            # Add forget flags
            batch["forget"] = torch.tensor([example["forget"] for example in examples])

            # Add labels for language modeling (same as input_ids)
            batch["labels"] = batch["input_ids"].clone()

        except Exception as e:
            print(f"Error in collator: {e}")
            # Fallback to simpler processing
            batch = {
                "input_ids": torch.tensor([example["input_ids"] for example in examples]),
                "attention_mask": torch.tensor([example["attention_mask"] for example in examples]),
                "forget": torch.tensor([example["forget"] for example in examples]),
            }
            batch["labels"] = batch["input_ids"].clone()

        return batch

In [19]:


unlearning_collator = UnlearningDataCollator(tokenizer)

In [20]:

# Improved unlearning loss function with controlled weight
class ModeratedUnlearningTrainer(torch.nn.Module):
    def __init__(self, model, unlearning_weight=0.5):
        super().__init__()
        self.model = model
        self.unlearning_weight = unlearning_weight  # Controls how aggressive unlearning is

    def forward(self, input_ids, attention_mask, labels, forget_flags):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        # Standard language modeling loss
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # Reshape loss to match batch size
        loss = loss.view(labels.size(0), -1).mean(dim=1)

        # Apply weighted unlearning to control aggressiveness
        # Use a moderate weight for gradient reversal to avoid catastrophic forgetting
        modified_loss = torch.where(
            forget_flags.view(-1) == 1,
            -self.unlearning_weight * loss,  # Scaled reversal for "forget" samples
            loss                             # Normal gradient for "retain" samples
        ).mean()

        return modified_loss, logits


In [21]:

# Manual training loop with careful gradient updates
def train_with_moderated_unlearning(model, train_dataloader, num_epochs, learning_rate, unlearning_weight):
    model.train()

    # Create custom unlearning trainer
    unlearning_trainer = ModeratedUnlearningTrainer(model, unlearning_weight)

    # Use a more conservative optimizer and scheduler
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=0.01,
        eps=1e-8
    )

    # Learning rate warmup and decay
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=len(train_dataloader) * num_epochs
    )

    print(f"Starting training with unlearning_weight={unlearning_weight}")
    for epoch in range(num_epochs):
        total_loss = 0

        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch_idx, batch in enumerate(progress_bar):
            # Move batch to device
            batch = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

            # Get required inputs
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            labels = batch["labels"]
            forget_flags = batch["forget"]

            # Forward pass with moderated unlearning
            loss, _ = unlearning_trainer(input_ids, attention_mask, labels, forget_flags)

            # Gradient clipping to prevent extreme updates
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            # Update progress
            total_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})

            # Save checkpoint periodically
            if (batch_idx + 1) % 100 == 0:
                print(f"Saving checkpoint at epoch {epoch+1}, batch {batch_idx+1}")
                model.save_pretrained(os.path.join(OUTPUT_DIR, f"checkpoint-e{epoch+1}-b{batch_idx+1}"))

        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

    return model



In [22]:

# Create dataloader
train_dataloader = torch.utils.data.DataLoader(
    tokenized_train_ds,
    batch_size=BATCH_SIZE,
    collate_fn=unlearning_collator,
    shuffle=True,
    drop_last=True
)


In [23]:

# Train with moderated unlearning
print("Starting training")
trained_model = train_with_moderated_unlearning(
    peft_model,
    train_dataloader,
    NUM_EPOCHS,
    LEARNING_RATE,
    UNLEARNING_WEIGHT
)


Starting training
Starting training with unlearning_weight=0.1


Epoch 1/1:  19%|█▉        | 99/514 [01:03<04:23,  1.58it/s, loss=-0.0483]

Saving checkpoint at epoch 1, batch 100


Epoch 1/1:  39%|███▊      | 199/514 [02:07<03:19,  1.58it/s, loss=0.522]

Saving checkpoint at epoch 1, batch 200


Epoch 1/1:  58%|█████▊    | 299/514 [03:11<02:16,  1.58it/s, loss=-6.15]

Saving checkpoint at epoch 1, batch 300


Epoch 1/1:  78%|███████▊  | 399/514 [04:15<01:12,  1.58it/s, loss=0.573]

Saving checkpoint at epoch 1, batch 400


Epoch 1/1:  97%|█████████▋| 499/514 [05:19<00:09,  1.58it/s, loss=0.59] 

Saving checkpoint at epoch 1, batch 500


Epoch 1/1: 100%|██████████| 514/514 [05:28<00:00,  1.56it/s, loss=-10.4]

Epoch 1/1, Average Loss: -4.4055





In [24]:


# Analyze success by checking comprehension difficulty
def analyze_unlearning_success(model, tokenizer, forget_texts, retain_texts, sample_count=5):
    """Analyze the success of unlearning by comparing perplexity"""
    model.eval()

    def calculate_perplexity(text):
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
        return torch.exp(outputs.loss).item()

    print("\n==== Unlearning Success Analysis ====")

    # Sample texts
    forget_samples = random.sample(forget_texts, min(sample_count, len(forget_texts)))
    retain_samples = random.sample(retain_texts, min(sample_count, len(retain_texts)))

    # Calculate perplexities
    forget_perplexities = []
    print("Testing forgotten content (Harry Potter):")
    for i, text in enumerate(forget_samples):
        try:
            perp = calculate_perplexity(text)
            forget_perplexities.append(perp)
            print(f"  Sample {i+1} perplexity: {perp:.2f}")
        except Exception as e:
            print(f"  Error on sample {i+1}: {e}")

    retain_perplexities = []
    print("\nTesting retained content (General Knowledge):")
    for i, text in enumerate(retain_samples):
        try:
            perp = calculate_perplexity(text)
            retain_perplexities.append(perp)
            print(f"  Sample {i+1} perplexity: {perp:.2f}")
        except Exception as e:
            print(f"  Error on sample {i+1}: {e}")


# Run analysis
try:
    print("\nPerforming unlearning success analysis...")
    analyze_unlearning_success(
        trained_model,
        tokenizer,
        [neutralise_hp(c) for c in hp_chunks[:10]],
        retain_dataset["text"][:10]
    )
except Exception as e:
    print(f"Skipping perplexity analysis due to: {e}")

print("\nUnlearning process complete!")


Performing unlearning success analysis...

==== Unlearning Success Analysis ====
Testing forgotten content (Harry Potter):
  Sample 1 perplexity: inf
  Sample 2 perplexity: inf
  Sample 3 perplexity: inf
  Sample 4 perplexity: inf
  Sample 5 perplexity: inf

Testing retained content (General Knowledge):
  Sample 1 perplexity: nan
  Sample 2 perplexity: nan
  Sample 3 perplexity: nan
  Sample 4 perplexity: 287.96
  Sample 5 perplexity: 9.89

Unlearning process complete!


In [27]:
for prompt in [
    "Summarise the plot of the  Harry Potter book in a sentence.",
    "Explain how photosynthesis works in plants.",
    " Who is the main character in Harry Potter?"
]:
    enc = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        gen_ids = model.generate(
            **enc,
            max_new_tokens=70,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            temperature=0.9, top_p=0.9, do_sample=True
        )
    print("\n###", prompt)
    print(tokenizer.decode(gen_ids[0], skip_special_tokens=True).strip())


### Summarise the plot of the  Harry Potter book in a sentence.
Summarise the plot of the  Harry Potter book in a sentence.   1.  2.   3.  4.  5.  6.  7.  8.  9.  10.  11.  12.  13.  14.  15.  16.  17.

### Explain how photosynthesis works in plants.
Explain how photosynthesis works in plants. 1. The plant gets the energy it needs for photosynthesis from the sunlight that it absorbs. 2. The plant absorbs carbon dioxide from the air and releases oxygen as a byproduct. 3. The plant uses water from the ground to make the glucose that it uses for energy.
Photosynthesis is the process by which plants make their own

###  Who is the main character in Harry Potter?
Who is the main character in Harry Potter?  1  1  1  1  1  1


In [26]:

# Save the final model
print("Saving final model...")
trained_model.save_pretrained(os.path.join(OUTPUT_DIR, "final_model"))


Saving final model...
