In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
import dotenv
from pathlib import Path

env_file = "../.env"

if os.path.exists(env_file):
    dotenv.load_dotenv(env_file, verbose=True)
    print("Loaded environment variables from .env file.")

cwd = os.getcwd()
# for some reason appending to PATH you need it to be string
sys.path.append(str(Path(cwd).parent / "src"))
hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

Loaded environment variables from .env file.


In [2]:
from datasets import load_dataset

wikitext_dset = load_dataset("wikitext", "wikitext-2-raw-v1")

In [3]:
# from research_tools.gpu import get_gpus_available

# os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in get_gpus_available()])
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device == torch.device("cuda")

model_id = "HuggingFaceH4/zephyr-7b-beta"
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model_id = "meta-llama/Meta-Llama-3-8B"
model_id = "google/gemma-2b"
model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, token=hf_access_token
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [57]:
from unlearn_order.dataset import load_dataset
from unlearn_order.eval import eval_dataset

data_dir = Path("../data/wmdp-deduped")
n_train = 4
n_val = 1
batch_size = 1
train_files = [f"corpus_split_{i}.jsonl" for i in range(n_train)]
val_files = [f"corpus_split_{i}.jsonl" for i in range(n_train, n_train + n_val)]
train_dataset = load_dataset(data_dir, train_files)
val_dataset = load_dataset(data_dir, val_files)


# acc = eval_dataset(model, tokenizer, train_dataset, batch_size=batch_size)
# print(acc)
# acc = eval_dataset(model, tokenizer, val_dataset, batch_size=batch_size)
# print(acc)

In [58]:
from unlearn_order.utils import create_prompt_letter_answer
from unlearn_order.dataset import format_single

max_length = 512


def split_map_func(x):
    res = format_single(x, tokenizer)
    return {
        "input_ids": res["input_ids"][:max_length],
        "attention_mask": res["attention_mask"][:max_length],
        "labels": res["labels"][:max_length],
    }


def corpus_map_func(x):
    text = x["text"]
    inputs = tokenizer(
        text,
        return_tensors="pt",
        max_length=max_length,
        truncation=True,
        padding="max_length",
    )
    return {
        "input_ids": inputs["input_ids"].squeeze(),
        "attention_mask": inputs["attention_mask"].squeeze(),
        "labels": inputs["input_ids"].clone().squeeze(),
    }


train_dataset = train_dataset.map(
    corpus_map_func,
    batched=False,
    remove_columns=train_dataset.column_names,
    # num_proc=8,
)

keep_cols = ["input_ids", "attention_mask", "labels"]
train_dataset.set_format(type="torch", columns=keep_cols)

Map:   0%|          | 0/1884 [00:00<?, ? examples/s]

In [62]:
from torch.utils.data import DataLoader

from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer)

train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=data_collator,
)

In [64]:
for batch in train_loader:
    print(batch["attention_mask"])
    break

tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])


In [None]:
tsplit = wikitext_dset["train"]
tsplit[1]

{'text': ' = Valkyria Chronicles III = \n'}

In [7]:
from torch.utils.data import DataLoader


def test_collate(batches):
    max_length = max([len(tokenizer.encode(batch["text"])) for batch in batches])

    def format_single(batch, max_length):
        encoded = tokenizer(
            batch["text"],
            padding="max_length",
            max_length=max_length,
            return_tensors="pt",
            truncation=True,
        )
        attention_mask = encoded["attention_mask"].squeeze(0)
        prompt_mask = torch.zeros_like(attention_mask)
        labels = encoded["input_ids"].squeeze(0)
        labels[attention_mask == 0] = -100
        return {
            "input_ids": encoded["input_ids"].squeeze(0),
            "prompt_mask": prompt_mask,
            "attention_mask": attention_mask,
            "labels": labels,
        }

    batches = [format_single(batch, max_length) for batch in batches]
    return {
        "input_ids": torch.stack([batch["input_ids"] for batch in batches]),
        "prompt_mask": torch.stack([batch["prompt_mask"] for batch in batches]),
        "attention_mask": torch.stack([batch["attention_mask"] for batch in batches]),
        "labels": torch.stack([batch["labels"] for batch in batches]),
    }


dl = DataLoader(tsplit, batch_size=4, collate_fn=test_collate, shuffle=True)

In [8]:
dset = wikitext_dset["train"]
min_length = 0
dset = dset.filter(lambda x: len(x["text"]) > min_length)
dset[9]

{'text': ' The game takes place during the Second Europan War . Gallian Army Squad 422 , also known as " The Nameless " , are a penal military unit composed of criminals , foreign deserters , and military offenders whose real names are erased from the records and thereon officially referred to by numbers . Ordered by the Gallian military to perform the most dangerous missions that the Regular Army and Militia will not do , they are nevertheless up to the task , exemplified by their motto , Altaha Abilia , meaning " Always Ready . " The three main characters are No.7 Kurt Irving , an army officer falsely accused of treason who wishes to redeem himself ; Ace No.1 Imca , a female Darcsen heavy weapons specialist who seeks revenge against the Valkyria who destroyed her home ; and No.13 Riela Marcellis , a seemingly jinxed young woman who is unknowingly a descendant of the Valkyria . Together with their fellow squad members , these three are tasked to fight against a mysterious Imperial uni

In [9]:
# first we want to finetune on the dataset

from unlearn_order.finetune import (
    finetune_model,
    finetune_model_multiple_datasets,
    gradient_ascent,
)
from unlearn_order.dataset import collate_batch
from functools import partial


model, _, _ = gradient_ascent(
    model,
    tokenizer,
    train_dataset,
    partial(collate_batch, tokenizer=tokenizer),
    dset,
    test_collate,
    10,
    grad_accum_steps=4,
    batch_size=1,
)
# model, _, _ = finetune_model_multiple_datasets(
#     model,
#     tokenizer,
#     {
#         "train": train_dataset,
#         "val": dset,
#     },
#     loss_coefs={
#         "train": -1.0,
#         "val": 10.0
#     },
#     batch_size=1,
#     grad_accum_steps=4,
#     shuffle_labels={
#         "train": False,
#         "val": False
#     },
#     max_epochs=6,
#     eval_every=1,
#     collate_fns={
#         "train": partial(
#             collate_batch, tokenizer=tokenizer
#         ),
#         "val": test_collate
#     }
# )


# with profile(activities=[ProfilerActivity.CPU],
#         profile_memory=True, record_shapes=True) as prof:
#     model(inputs)

# print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
# model, _, _ = finetune_model(model, tokenizer, train_dataset, batch_size, max_epochs=6, eval_every=3, lr=3.2e-6, grad_accum_steps=4)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Epoch 0 step 20 forget loss: 0.2909162640571594 retain loss: 2.6840357780456543
tensor(8, device='cuda:0') HI THERE torch.Size([1, 49]) torch.Size([1, 49])
Epoch 0 step 40 forget loss: 0.022348497062921524 retain loss: 2.8980114459991455
tensor(9, device='cuda:0') HI THERE torch.Size([1, 93]) torch.Size([1, 93])
Epoch 0 step 60 forget loss: -0.0 retain loss: 3.06994891166687
tensor(8, device='cuda:0') HI THERE torch.Size([1, 89]) torch.Size([1, 89])
Epoch 0 step 80 forget loss: -0.0 retain loss: 3.2931253910064697
tensor(11, device='cuda:0') HI THERE torch.Size([1, 94]) torch.Size([1, 94])
Epoch 0 step 100 forget loss: -0.0 retain loss: 2.480302333831787
tensor(9, device='cuda:0') HI THERE torch.Size([1, 66]) torch.Size([1, 66])
Epoch 0 step 120 forget loss: -0.0 retain loss: 2.572843313217163
tensor(32, device='cuda:0') HI THERE torch.Size([1, 179]) torch.Size([1, 179])
Epoch 0 step 140 forget loss: -0.0 retain loss: 1.9374988079071045
tensor(30, device='cuda:0') HI THERE torch.Size([

KeyboardInterrupt: 

In [None]:
text = "Tell me about spiders: "

inputs = tokenizer(text, return_tensors="pt").to(device)
output = model.generate(**inputs, max_length=100, num_return_sequences=1)
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(output_text)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Tell me about spiders: 2006 
 = = = 
 = 
 2006 = = 
 = = 
 
 
 
 
 
 
 
 
 
 = = 
 
 
 
 
 
 
 
 = 
 
 
 = 
 
 
 
 = 
 
 
 = 
 
 = 
 = 
 
 = 
 
 
 = 
 
 
 
 = = = 
 
 = 
 = 
 
 = 
 
 
 
 
 = = 
 
 
 
 = = = =


In [None]:
unlearn_train_files = [f"corpus_split_{i}.jsonl" for i in range(n_train)]
unlearn_corpus = load_dataset(data_dir, unlearn_train_files)

In [None]:
data_list = unlearn_corpus["text"]
data_list = [
    data_list[i : i + batch_size] for i in range(0, len(data_list), batch_size)
]

In [None]:
from datasets import load_dataset
import json

unlearning_task = "cyber"

forget_corpus = (
    "cyber-forget-corpus" if unlearning_task == "cyber" else "bio-forget-corpus"
)
retain_corpus = "wikitext"


def get_unlearning_rmu_dataset(name, min_len=0, batch_size=4):
    data = []
    if name == "wikitext":
        raw_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        for x in raw_data:
            if len(x["text"]) > min_len:
                data.append(str(x["text"]))
    elif name == "cyber-forget-corpus":
        raw_data = load_dataset("cais/wmdp-corpora", name=name, split="train")
        for x in raw_data:
            if len(x["text"]) > min_len:
                data.append(str(x["text"]))
    else:  # wmdp bio must be pre-downloaded
        for line in open(f"data/{name}.jsonl", "r"):
            raw_text = json.loads(line)["text"]
            if len(raw_text) > min_len:
                data.append(str(raw_text))
    data = [data[i : i + batch_size] for i in range(0, len(data), batch_size)]
    return data


forget_data_list = data_list
retain_data_list = get_unlearning_rmu_dataset(retain_corpus)

In [None]:
from tqdm import tqdm
from unlearn_order.algos.lat.hooks import add_hooks
from unlearn_order.algos.lat.model import get_adversary_hook, projected_gradient_descent
from unlearn_order.algos.lat.utils import (
    get_layer_module,
    set_model_grads,
    set_params_grads,
)
from typing import List, Dict
from transformers import LlamaForCausalLM, LlamaTokenizer
from torch.profiler import profile, record_function, ProfilerActivity


def get_layers_params(
    model: LlamaForCausalLM, layers: List[int], target_modules: List[str]
):
    params = []
    for layer in layers:
        for name, module in model.model.layers[layer].named_modules():
            if any([target in name for target in target_modules]):
                for param in module.parameters():
                    params.append(param)
    return params


def forward_with_cache(
    model: LlamaForCausalLM,
    inputs: Dict[str, torch.Tensor],
    module: torch.nn.Module,
    no_grad: bool = True,
):
    # define a tensor with the size of our cached activations
    cache = []

    def hook(module, input, output):
        if isinstance(output, tuple):
            cache.append(output[0])
        else:
            cache.append(output)
        return None

    hook_handle = module.register_forward_hook(hook)

    if no_grad:
        with torch.no_grad():
            _ = model(**inputs)
    else:
        _ = model(**inputs)

    hook_handle.remove()

    return cache[0]


# you're optimizing layer_ids
# there is a layer where you're examining the activations at and trying to make garbage
# then you only train on the laeyrs before that, this is just shoving garbage in
# not real unlearning?
# let us also latently perturb at the same layer


# remember params
def run_rmu(
    updated_model: LlamaForCausalLM,
    frozen_model: LlamaForCausalLM,
    tokenizer: LlamaTokenizer,
    forget_data_list,
    retain_data_list,
    target_modules: List[str] = ["down_proj"],
    alpha: float = 1200.0,
    layer: int = 8,  # layers to do RMU in
    lr_def: float = 5.0e-5,
    max_num_batches: int = 200,
    steering_coef: float = 6.5,
    use_lat: bool = True,
    epsilon: float = 2,
    lr_adv: float = 5e-2,
    n_pgd_steps: int = 16,
    loss_coefs: Dict[str, float] = {"towards": 1, "away": 1},
    num_epochs: int = 1,
    n_def_per_adv_steps: int = 1,
):
    adv_layer = layer - 1
    def_layers = [layer, layer - 1, layer - 2]

    updated_model.train()

    params = get_layers_params(updated_model, def_layers, target_modules)
    optimizer = torch.optim.AdamW(params, lr=lr_def)

    updated_module = get_layer_module(updated_model, layer)
    frozen_module = get_layer_module(frozen_model, layer)

    control_vectors_list = []
    for i in range(len(forget_data_list)):
        random_vector = torch.rand(
            1,
            1,
            updated_model.config.hidden_size,
            dtype=updated_model.dtype,
            device=updated_model.device,
        )
        control_vec = random_vector / torch.norm(random_vector) * steering_coef
        control_vectors_list.append(control_vec)

    num_batches = max_num_batches
    truncation_side = tokenizer.truncation_side
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.truncation_side = "right"

    for epoch in range(num_epochs):
        for idx in tqdm(range(num_batches)):
            set_params_grads(params, True)

            control_vec = control_vectors_list[idx]
            unlearn_batch = forget_data_list[idx]
            retain_batch = retain_data_list[idx]

            max_length = 512 if idx == 0 else 768

            unlearn_inputs = tokenizer(
                unlearn_batch,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=max_length,
            ).to(updated_model.device)

            retain_inputs = tokenizer(
                retain_batch,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=max_length,
            ).to(updated_model.device)

            if use_lat:
                adv_labels_mask = torch.zeros_like(
                    unlearn_inputs["input_ids"], dtype=bool
                )
                def_labels_mask = torch.zeros_like(
                    retain_inputs["input_ids"], dtype=bool
                )

                for b, example in enumerate(retain_batch):
                    len_example = len(tokenizer(example)["input_ids"])
                    def_labels_mask[b, :len_example] = True
                for b, example in enumerate(unlearn_batch):
                    len_example = len(tokenizer(example)["input_ids"])
                    adv_labels_mask[b, :len_example] = True

                prompt_mask = adv_labels_mask
                pgd_batch = {
                    "def_tokens": retain_inputs["input_ids"].to(updated_model.device),
                    "adv_tokens": unlearn_inputs["input_ids"].to(updated_model.device),
                    "adv_labels_mask": adv_labels_mask.to(updated_model.device),
                    "def_labels_mask": def_labels_mask.to(updated_model.device),
                    "prompt_mask": prompt_mask.to(updated_model.device),
                }
                adversary = projected_gradient_descent(
                    updated_model,
                    pgd_batch,
                    layer=adv_layer,
                    epsilon=epsilon,
                    n_pgd_steps=n_pgd_steps,
                    lr_adv=lr_adv,
                    loss_coefs=loss_coefs,
                )

                set_model_grads(adversary, False)
                set_params_grads(params, True)

                for step in range(n_def_per_adv_steps):

                    # Unlearning loss
                    with add_hooks(
                        module_forward_hooks=[
                            (
                                get_layer_module(updated_model, layer),
                                get_adversary_hook(adversary),
                            )
                        ],
                        module_forward_pre_hooks=[],
                    ):
                        updated_forget_activations = forward_with_cache(
                            updated_model,
                            unlearn_inputs,
                            module=updated_module,
                            no_grad=False,
                        ).to(updated_model.device)

                    unlearn_loss = torch.nn.functional.mse_loss(
                        updated_forget_activations, control_vec
                    )

                    # Retain loss
                    updated_retain_activations = forward_with_cache(
                        updated_model,
                        retain_inputs,
                        module=updated_module,
                        no_grad=False,
                    ).to(updated_model.device)
                    frozen_retain_activations = forward_with_cache(
                        frozen_model, retain_inputs, module=frozen_module, no_grad=True
                    ).to(updated_model.device)

                    retain_loss = torch.nn.functional.mse_loss(
                        updated_retain_activations, frozen_retain_activations
                    )
                    retain_loss *= alpha

                    # Update model
                    loss = unlearn_loss + retain_loss
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
            else:
                # Unlearning loss
                updated_forget_activations = forward_with_cache(
                    updated_model,
                    unlearn_inputs,
                    module=updated_module,
                    no_grad=False,
                ).to(updated_model.device)
                unlearn_loss = torch.nn.functional.mse_loss(
                    updated_forget_activations, control_vec
                )

                # Retain loss
                updated_retain_activations = forward_with_cache(
                    updated_model,
                    retain_inputs,
                    module=updated_module,
                    no_grad=False,
                ).to(updated_model.device)
                frozen_retain_activations = forward_with_cache(
                    frozen_model, retain_inputs, module=frozen_module, no_grad=True
                ).to(updated_model.device)

                retain_loss = torch.nn.functional.mse_loss(
                    updated_retain_activations, frozen_retain_activations
                )
                retain_loss *= alpha

                # Update model
                loss = unlearn_loss + retain_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if idx % 10 == 0:
                print(
                    f"Epoch: {epoch}, Batch: {idx},  Unlearn Loss: {unlearn_loss.item()}, Retain Loss: {retain_loss.item()}, Total Loss: {loss.item()}"
                )

    tokenizer.truncation_side = truncation_side

    return updated_model

In [None]:
updated_model = model
frozen_model = AutoModelForCausalLM.from_pretrained(
    model_id, revision="main", torch_dtype=torch.bfloat16
).to(device)


updated_model = run_rmu(
    updated_model,
    frozen_model,  # function to make another version of the model for reference embeddings
    tokenizer,  # tokenizer
    forget_data_list,  # forget data
    retain_data_list,  # retain data
    use_lat=False,
    n_def_per_adv_steps=4,
    n_pgd_steps=16,
    max_num_batches=200,
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|                                                                                                                                                                   | 0/200 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
  unlearn_loss = torch.nn.functional.mse_loss(
  0%|▊                                                                                                                                                          | 1/200 [00:00<01:44,  1.90it/s]

Epoch: 0, Batch: 0,  Unlearn Loss: 0.08740234375, Retain Loss: 0.0, Total Loss: 0.08740234375


  unlearn_loss = torch.nn.functional.mse_loss(
  6%|████████▍                                                                                                                                                 | 11/200 [00:07<02:25,  1.30it/s]

Epoch: 0, Batch: 10,  Unlearn Loss: 0.06298828125, Retain Loss: 0.091796875, Total Loss: 0.154296875


 10%|████████████████▏                                                                                                                                         | 21/200 [00:15<02:18,  1.29it/s]

Epoch: 0, Batch: 20,  Unlearn Loss: 0.06298828125, Retain Loss: 0.026611328125, Total Loss: 0.08984375


 16%|███████████████████████▊                                                                                                                                  | 31/200 [00:22<02:11,  1.29it/s]

Epoch: 0, Batch: 30,  Unlearn Loss: 0.06298828125, Retain Loss: 0.0211181640625, Total Loss: 0.083984375


 20%|███████████████████████████████▌                                                                                                                          | 41/200 [00:30<02:03,  1.29it/s]

Epoch: 0, Batch: 40,  Unlearn Loss: 0.06298828125, Retain Loss: 0.0164794921875, Total Loss: 0.07958984375


 26%|███████████████████████████████████████▎                                                                                                                  | 51/200 [00:37<01:55,  1.29it/s]

Epoch: 0, Batch: 50,  Unlearn Loss: 0.0625, Retain Loss: 0.00897216796875, Total Loss: 0.0712890625


 30%|██████████████████████████████████████████████▉                                                                                                           | 61/200 [00:45<01:48,  1.28it/s]

Epoch: 0, Batch: 60,  Unlearn Loss: 0.061767578125, Retain Loss: 0.00579833984375, Total Loss: 0.0673828125


 36%|██████████████████████████████████████████████████████▋                                                                                                   | 71/200 [00:52<01:40,  1.29it/s]

Epoch: 0, Batch: 70,  Unlearn Loss: 0.061767578125, Retain Loss: 0.005889892578125, Total Loss: 0.06787109375


 40%|██████████████████████████████████████████████████████████████▎                                                                                           | 81/200 [01:00<01:32,  1.29it/s]

Epoch: 0, Batch: 80,  Unlearn Loss: 0.0615234375, Retain Loss: 0.01007080078125, Total Loss: 0.07177734375


 46%|██████████████████████████████████████████████████████████████████████                                                                                    | 91/200 [01:07<01:24,  1.29it/s]

Epoch: 0, Batch: 90,  Unlearn Loss: 0.06103515625, Retain Loss: 0.0191650390625, Total Loss: 0.080078125


 50%|█████████████████████████████████████████████████████████████████████████████▎                                                                           | 101/200 [01:15<01:16,  1.29it/s]

Epoch: 0, Batch: 100,  Unlearn Loss: 0.061279296875, Retain Loss: 0.0078125, Total Loss: 0.0693359375


 56%|████████████████████████████████████████████████████████████████████████████████████▉                                                                    | 111/200 [01:22<01:09,  1.28it/s]

Epoch: 0, Batch: 110,  Unlearn Loss: 0.06005859375, Retain Loss: 0.0198974609375, Total Loss: 0.080078125


 60%|████████████████████████████████████████████████████████████████████████████████████████████▌                                                            | 121/200 [01:30<01:01,  1.29it/s]

Epoch: 0, Batch: 120,  Unlearn Loss: 0.06005859375, Retain Loss: 0.00848388671875, Total Loss: 0.068359375


 66%|████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                    | 131/200 [01:37<00:53,  1.29it/s]

Epoch: 0, Batch: 130,  Unlearn Loss: 0.0595703125, Retain Loss: 0.015380859375, Total Loss: 0.0751953125


 70%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                             | 141/200 [01:45<00:45,  1.29it/s]

Epoch: 0, Batch: 140,  Unlearn Loss: 0.059326171875, Retain Loss: 0.00860595703125, Total Loss: 0.06787109375


 76%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                     | 151/200 [01:52<00:38,  1.29it/s]

Epoch: 0, Batch: 150,  Unlearn Loss: 0.05908203125, Retain Loss: 0.014404296875, Total Loss: 0.0732421875


 80%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                             | 161/200 [02:00<00:30,  1.28it/s]

Epoch: 0, Batch: 160,  Unlearn Loss: 0.057861328125, Retain Loss: 0.0152587890625, Total Loss: 0.0732421875


 86%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                      | 171/200 [02:08<00:22,  1.28it/s]

Epoch: 0, Batch: 170,  Unlearn Loss: 0.057861328125, Retain Loss: 0.02294921875, Total Loss: 0.0810546875


 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍              | 181/200 [02:15<00:14,  1.29it/s]

Epoch: 0, Batch: 180,  Unlearn Loss: 0.057373046875, Retain Loss: 0.0185546875, Total Loss: 0.076171875


 96%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████       | 191/200 [02:23<00:06,  1.29it/s]

Epoch: 0, Batch: 190,  Unlearn Loss: 0.056640625, Retain Loss: 0.0091552734375, Total Loss: 0.06591796875


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [02:29<00:00,  1.34it/s]


In [None]:
acc = eval_dataset(updated_model, tokenizer, train_dataset, batch_size=batch_size)
print(acc)

0.6033755274261603
