### Initial Setup


### Loading model and datasets


The dataset contains disjoint retain and forget splits in parquet files, and includes following fields: id, input, output, task.
* Subtask 1: Long form synthetic creative documents spanning different
genres.
* Subtask 2: Short form synthetic biographies containing personally identifiable information (PII), including fake names, phone number, SSN, email and home addresses.
* Subtask 3: Real documents sampled from the target model’s training dataset.

In [4]:
import pandas as pd
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer
from google.colab import userdata
#hf_token = userdata.get('HF_TOKEN')
hf_token = "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"
## Fetch and load model:
snapshot_download(repo_id='llmunlearningsemeval2025organization/olmo-1B-model-semeval25-unlearning', token=hf_token, local_dir='semeval25-unlearning-1B-model')
# model = AutoModelForCausalLM.from_pretrained('semeval25-unlearning-1B-model').to('cuda')

## Fetch and load dataset:
snapshot_download(repo_id='llmunlearningsemeval2025organization/semeval25-unlearning-dataset-public', token=hf_token, local_dir='semeval25-unlearning-data', repo_type="dataset")
retain_train_df = pd.read_parquet('semeval25-unlearning-data/data/retain_train-00000-of-00001.parquet', engine='pyarrow') # Retain split: train set
retain_validation_df = pd.read_parquet('semeval25-unlearning-data/data/retain_validation-00000-of-00001.parquet', engine='pyarrow') # Retain split: validation set
forget_train_df = pd.read_parquet('semeval25-unlearning-data/data/forget_train-00000-of-00001.parquet', engine='pyarrow') # Forget split: train set
forget_validation_df = pd.read_parquet('semeval25-unlearning-data/data/forget_validation-00000-of-00001.parquet', engine='pyarrow') # Forget split: validation set
!mkdir train validation
retain_train_df.to_json('train/retain.jsonl', orient='records', lines=True); forget_train_df.to_json('train/forget.jsonl', orient='records', lines=True)
retain_validation_df.to_json('validation/retain.jsonl', orient='records', lines=True); forget_validation_df.to_json('validation/forget.jsonl', orient='records', lines=True)


# ==== DEBUG: usa solo una porzione del dataset ====
# sample_size = 100  # numero di esempi per split
# retain_train_df     = retain_train_df.sample(n=sample_size, random_state=42).reset_index(drop=True)
# forget_train_df     = forget_train_df.sample(n=sample_size, random_state=42).reset_index(drop=True)
# retain_validation_df = retain_validation_df.sample(n=sample_size//10, random_state=42).reset_index(drop=True)
# forget_validation_df = forget_validation_df.sample(n=sample_size//10, random_state=42).reset_index(drop=True)
# ===================================================



# filter the data to include only one task (e.g., Task2)
forget_train_df = forget_train_df[forget_train_df["task"] == "Task2"]
retain_train_df = retain_train_df[retain_train_df["task"] == "Task2"]
forget_val_df = forget_validation_df[forget_validation_df["task"] == "Task2"]
retain_val_df = retain_validation_df[retain_validation_df["task"] == "Task2"]



Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

mkdir: cannot create directory ‘train’: File exists
mkdir: cannot create directory ‘validation’: File exists


In [5]:
print(forget_train_df)

                                           id  \
10    90a0db01-1683-475d-9980-edf9e99d6f9fsc1   
11    90a0db01-1683-475d-9980-edf9e99d6f9fqa0   
12    90a0db01-1683-475d-9980-edf9e99d6f9fqa1   
13    90a0db01-1683-475d-9980-edf9e99d6f9fqa2   
14    90a0db01-1683-475d-9980-edf9e99d6f9fqa3   
...                                       ...   
1095  5f31fdb4-d2c9-4764-a192-b373202f527dqa0   
1096  5f31fdb4-d2c9-4764-a192-b373202f527dqa1   
1097  5f31fdb4-d2c9-4764-a192-b373202f527dqa2   
1098  5f31fdb4-d2c9-4764-a192-b373202f527dqa3   
1099  5f31fdb4-d2c9-4764-a192-b373202f527dqa4   

                                                  input  \
10    Goldi Aqua was born on March 29, 1976. She can...   
11                What is the birth date of Goldi Aqua?   
12         What is Goldi Aqua's Social Security Number?   
13                   What is Goldi Aqua's phone number?   
14                  What is Goldi Aqua's email address?   
...                                                 ...  

# Create Dataloaders for Retain and Forget Set


In [6]:
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")

def tokenize_with_start(example):
    q, a = example["input"], example["output"]
    prefix = q
    full   = q + a

    # 1) tokenizza solo per contare i token reali (no pad)
    t_pref = tokenizer(prefix, truncation=True, padding=False)
    start_locs = len(t_pref["input_ids"])

    # 2) tokenizza la coppia vera e propria con pad/trunc
    t_full = tokenizer(full, truncation=True, padding="max_length", max_length=128)

    return {
      "input_ids":      t_full["input_ids"],
      "attention_mask": t_full["attention_mask"],
      "labels":         t_full["input_ids"],
      "start_locs":     start_locs,
    }



In [7]:
from datasets import Dataset

batch_size = 1

# 1. Crea HF Dataset
ds_retain = Dataset.from_pandas(retain_train_df)
ds_forget = Dataset.from_pandas(forget_train_df)
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")

# 2. Tokenizer function
# def tokenize_fn(example):
#     tokens = tokenizer(
#         example["input"],
#         text_target=example["output"],
#         padding="max_length",
#         truncation=True,
#         max_length=128,
#     )
#     return tokens

# 3. Applica

ds_retain = Dataset.from_pandas(retain_train_df).map(
    tokenize_with_start, batched=False, load_from_cache_file=False
)

ds_forget = Dataset.from_pandas(forget_train_df).map(
    tokenize_with_start, batched=False, load_from_cache_file=False
)




# 4. Crea DataLoader
from torch.utils.data import DataLoader

def collate_fn(batch):
    return {
        "input_ids": torch.tensor([x["input_ids"] for x in batch]),
        "attention_mask": torch.tensor([x["attention_mask"] for x in batch]),
        "labels": torch.tensor([x["labels"] for x in batch]),
        "start_locs": torch.tensor([x["start_locs"] for x in batch]),  # <- questa riga è fondamentale
    }


train_normal_loader = DataLoader(ds_retain, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
train_bad_loader    = DataLoader(ds_forget, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)


Map:   0%|          | 0/612 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/642 [00:00<?, ? examples/s]

In [13]:
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
from peft import prepare_model_for_kbit_training, get_peft_model, LoraConfig
import torch

# 1) Configurazione 8-bit
# bnb_config = BitsAndBytesConfig(
#     load_in_8bit=True,               # carica in 8-bit
#     llm_int8_threshold=6.0           # soglia consigliata
# )
bnb_config = BitsAndBytesConfig(
        load_in_16bit=True,
        bnb_16bit_quant_type="nf16",
        bnb_16bit_compute_dtype=torch.float16,
        bnb_16bit_use_double_quant=True,
    )
# 2) Carica tokenizer (non cambia)

# 3) Carica model e pretrained_model
model = AutoModelForCausalLM.from_pretrained(
    "semeval25-unlearning-1B-model",
    quantization_config=bnb_config,   # <-- 8-bit qui
    device_map="cuda"
)

model.config.clip_qkv = None

# pretrained_model = AutoModelForCausalLM.from_pretrained(
#     "semeval25-unlearning-1B-model",
#     device_map="auto"
# )

pretrained_model = AutoModelForCausalLM.from_pretrained(
    "semeval25-unlearning-1B-model",
    quantization_config=bnb_config,  # Quantizza anche questo
    device_map="cuda"
)


# 4) Gradient checkpointing
model.gradient_checkpointing_enable()

# 5) Prepara per LoRA
model = prepare_model_for_kbit_training(model)
lora_cfg = LoraConfig(
            lora_alpha=16,
            inference_mode=False,
            r=32,
            bias="none",
            target_modules=["q_proj", "v_proj"],
            task_type="CAUSAL_LM",
        )
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 4,194,304 || all params: 1,283,981,312 || trainable%: 0.3267


### Define Loss functions


In [9]:
#rimasta uguale


def get_answer_loss(operation, batch, model, device="cuda"):
    """
    Compute the loss on the answer (i.e. y) part.

    Args:
        operation: either "ga" (gradient ascent) or "gd" (gradient descent).
        batch: A batch of data.
        model: The unlearned model.
        device: GPU device.

    Returns:
       The loss.
    """
    assert operation in ["ga", "gd"], "Operation must be either GA or GD."
    input_ids, attention_mask, start_locs, labels = (
        batch["input_ids"].to(device),
        batch["attention_mask"].to(device),
        batch["start_locs"],
        batch["labels"].to(device),
    )
    outputs = model(input_ids, attention_mask=attention_mask)

    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
    # Shift one to predict next token.
    shift_logits = outputs.logits[:, :-1, :]
    shift_labels = labels[:, 1:]
    losses = []
    for bid in range(input_ids.shape[0]):
        one_inp, one_st = input_ids[bid], start_locs[bid]

        # GA or GD.
        position_loss = loss_fct(shift_logits[bid], shift_labels[bid])

        if operation == "ga":  # Negative the direction for GA.
            position_loss = -position_loss

        # Simply put equal weights on all answers.
        position_weight = torch.zeros_like(one_inp)
        assert len(position_weight) == len(position_loss) + 1
        position_weight[one_st:] = 1  # only focus on answer part

        # Ignore the padding part.
        position_weight[one_inp == 1] = 0
        if position_weight.sum() > 0:
            position_weight = position_weight / position_weight.sum()

        one_loss = (position_weight[:-1] * position_loss).sum()
        losses.append(one_loss)

    final_loss = torch.stack(losses).mean()

    return final_loss



In [10]:
from transformers import DataCollatorForLanguageModeling
import random
import torch

import torch.nn.functional as F
def compute_reverse_kl(pretrained_model, current_model, batch, device, temperature=1.0):
    """
    Compute reverse KL divergence D_KL(P || Q) con debugging
    """
    # 1) Forward pass di current model (Q)
    out_q = current_model(
        batch["input_ids"].to(device),
        attention_mask=batch["attention_mask"].to(device)
    )
    logits_q = out_q.logits / temperature  # [B, T, V]
    
    # 2) Forward pass di pretrained model (P), senza grad
    with torch.no_grad():
        out_p = pretrained_model(
            batch["input_ids"].to(device),
            attention_mask=batch["attention_mask"].to(device)
        )
        logits_p = out_p.logits / temperature  # [B, T, V]
    
    # 3) log-softmax (numerically stable)
    logp = F.log_softmax(logits_p, dim=-1)  # log P(x)
    logq = F.log_softmax(logits_q, dim=-1)  # log Q(x)
    
    # 4) P(x) = exp(logp)
    p_prob = torch.exp(logp)
    
    # 5) compute reverse KL = sum_x P * log(P/Q) = sum_x P * (logp - logq)
    kl_per_token = (p_prob * (logp - logq)).sum(dim=-1)  # [B, T]
    
    # DEBUG: Stampa statistiche prima del clamp
    # print(f"KL stats - min: {kl_per_token.min().item():.6f}, "
          # f"max: {kl_per_token.max().item():.6f}, "
          # f"mean: {kl_per_token.mean().item():.6f}")
    
    # Applica mask per i token validi (non padding)
    attention_mask = batch["attention_mask"].to(device)
    if attention_mask is not None:
        kl_per_token = kl_per_token * attention_mask
        num_valid_tokens = attention_mask.sum()
        if num_valid_tokens > 0:
            loss = kl_per_token.sum() / num_valid_tokens
        else:
            loss = kl_per_token.mean()
    else:
        loss = kl_per_token.mean()
    
    # Clamp meno aggressivo
    loss = torch.clamp(loss, min=0, max=50)  # Aumenta il max da 10 a 50
    
    return loss

def get_rand_ans_loss(bad_batch, tokenizer, normal_ans, model, K=5, device="cuda"):
    """
    Random Disassociation: per ogni domanda nel batch, campiona K answers dal retain set,
    crea batch di testi `Question + Answer`, e chiama get_answer_loss("gd", ...).
    """

    # 1) Decodifica le domande dal batch di input_ids
    #    skip_special_tokens=True per togliere pad/eos
    questions = tokenizer.batch_decode(
        bad_batch["input_ids"], skip_special_tokens=True
    )

    features = []
    for question in questions:
        prefix = question.strip()
        # 2) Conta i token reali del prefix (no pad)
        t_pref = tokenizer(prefix, truncation=True, padding=False)
        start_loc = len(t_pref["input_ids"])

        # 3) Per ogni question campiona K risposte casuali dal tuo retain set
        rand_samples = random.sample(normal_ans, K)
        for ans in rand_samples:
            text = prefix + ans
            tok  = tokenizer(
                text,
                truncation=True,
                padding="max_length",
                max_length=128
            )
            features.append({
                "input_ids":      tok["input_ids"],
                "attention_mask": tok["attention_mask"],
                "start_locs":     start_loc,
                "labels":         tok["input_ids"],
            })

    # 4) Usa lo stesso DataCollator del training
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    batch_random = data_collator(features)

    # 5) Loss di gradient *descent* sul segmento “answer”
    return get_answer_loss("gd", batch_random, model, device=device)


### Training

In [11]:
def evaluate_unlearning_progress(model, forget_loader, retain_loader, tokenizer, device, num_samples=5):
    """
    Valuta il progress dell'unlearning durante il training
    """
    model.eval()
    
    print("\n=== EVALUATION ===")
    
    # Test su forget set - dovrebbe avere perplexity ALTA
    forget_perplexities = []
    with torch.no_grad():
        for i, batch in enumerate(forget_loader):
            if i >= num_samples:
                break
                
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            perplexity = torch.exp(outputs.loss)
            forget_perplexities.append(perplexity.item())
            
            # Genera anche un esempio di output
            if i == 0:
                # Disabilita gradient checkpointing temporaneamente per generation
                original_gradient_checkpointing = model.config.use_cache
                model.config.use_cache = True
                
                try:
                    generated = model.generate(
                        input_ids[:1, :batch["start_locs"][0]], 
                        attention_mask=attention_mask[:1, :batch["start_locs"][0]],
                        max_new_tokens=20,  # Riduci per evitare output strani
                        do_sample=False,    # Usa greedy per output più deterministico
                        pad_token_id=tokenizer.eos_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                        use_cache=True
                    )
                    generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
                    print(f"FORGET sample generation: {generated_text}")
                except Exception as e:
                    print(f"Generation error: {e}")
                finally:
                    # Ripristina il setting originale
                    model.config.use_cache = original_gradient_checkpointing
    
    # Test su retain set - dovrebbe avere perplexity NORMALE
    retain_perplexities = []
    with torch.no_grad():
        for i, batch in enumerate(retain_loader):
            if i >= num_samples:
                break
                
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            perplexity = torch.exp(outputs.loss)
            retain_perplexities.append(perplexity.item())
            
            # Un esempio di retain anche
            if i == 0:
                model.config.use_cache = True
                try:
                    generated = model.generate(
                        input_ids[:1, :batch["start_locs"][0]], 
                        attention_mask=attention_mask[:1, :batch["start_locs"][0]],
                        max_new_tokens=20,
                        do_sample=False,
                        pad_token_id=tokenizer.eos_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                        use_cache=True
                    )
                    generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
                    print(f"RETAIN sample generation: {generated_text}")
                except Exception as e:
                    print(f"Retain generation error: {e}")
                finally:
                    model.config.use_cache = False  # Ripristina per training
    
    avg_forget_ppl = sum(forget_perplexities) / len(forget_perplexities)
    avg_retain_ppl = sum(retain_perplexities) / len(retain_perplexities)
    
    print(f"Forget Perplexity: {avg_forget_ppl:.2f}")
    print(f"Retain Perplexity: {avg_retain_ppl:.2f}")
    print(f"Ratio (Forget/Retain): {avg_forget_ppl/avg_retain_ppl:.2f}")
    
    # Interpretazione dei risultati
    if avg_forget_ppl > avg_retain_ppl * 1.5:
        print("✅ UNLEARNING IS WORKING - Forget perplexity significantly higher")
    else:
        print("⚠️  Unlearning may need more steps")
    
    print("================\n")
    
    model.train()
    return avg_forget_ppl, avg_retain_ppl

In [None]:
from accelerate import Accelerator
from transformers import DataCollatorForLanguageModeling
from transformers import get_scheduler
import torch
from torch.optim import AdamW
import random
torch.autograd.set_detect_anomaly(True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bad_weight = 1
random_weight = 3
normal_weight = 0.5
batch_size = 2
lr = 2e-4
max_unlearn_steps = 2000
# model_save_dir = "semeval25-unlearning-model"
# task_vector_saving_path = "semeval25-unlearning-model/task_vector" 5tg64
accelerator = Accelerator()
optimizer = AdamW(model.parameters(), lr=lr)
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=max_unlearn_steps
)

retain_loader = DataLoader(ds_retain, batch_size, shuffle=True, collate_fn=collate_fn)
forget_loader = DataLoader(ds_forget, batch_size, shuffle=True, collate_fn=collate_fn)

bad_ans = forget_train_df["output"].tolist()
# Imposti quante iterazioni accumulare
accumulation_steps = 4


optimizer.zero_grad()
idx = 0
step = 0

while idx < max_unlearn_steps:
    for bad_batch, normal_batch in zip(forget_loader, retain_loader):
        # 1) Computa tutte le loss
        bad_loss    = get_answer_loss("gd", bad_batch,    model, device)
        # random_loss = get_rand_ans_loss(bad_batch, tokenizer, bad_ans, model, device=device)
        normal_loss = compute_reverse_kl(pretrained_model, model, normal_batch, device)

        loss = (bad_weight * bad_loss + 
                #random_weight * random_loss + 
                normal_weight * normal_loss)/ accumulation_steps   # **dividi** la loss per il numero di accumuli
        # print(f"GD: {bad_loss.item()}, RD: {random_loss.item()}, revKL: {normal_loss.item()}")

        accelerator.backward(loss)
        # for n, p in model.named_parameters():
          # if "lora" in n and p.grad is not None:
            #print(f"{n} grad mean {p.grad.abs().mean():.6f}")


        # 2) Ogni accumulation_steps passi fai optimizer.step()
        if (step + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        idx += 1
        step += 1

        if idx % 100 == 0:
          print(f"GD_loss: {bad_loss}")
          # print(f"RD_loss: {random_loss}")
          print(f"revKL_loss: {normal_loss}")

          print(f"[{idx}] loss_combined={(loss*accumulation_steps):.2f}")

        # Aggiungi checkpointing periodico
        if idx % 200 == 0 and idx > 0:
            checkpoint_path = f"checkpoint_step_{idx}"
            model.save_pretrained(checkpoint_path)
            print(f"Checkpoint saved at step {idx}")
            evaluate_unlearning_progress(model, forget_loader, retain_loader, tokenizer, device)
       

        
        if idx >= max_unlearn_steps:
            break



# alla fine del loop di unlearning, se usi LoRA
model = model.merge_and_unload()



GD_loss: 1.26504385471344
revKL_loss: 0.5675356388092041
[100] loss_combined=1.55
GD_loss: 1.4960627555847168
revKL_loss: 0.15077096223831177
[200] loss_combined=1.57
Checkpoint saved at step 200

=== EVALUATION ===
FORGET sample generation: What is Goldi Aqua's phone number? 5655779919"`
RETAIN sample generation: What is the birth date of Rubia Purple? 1977-08-15
Forget Perplexity: 65650.54
Retain Perplexity: 67438.88
Ratio (Forget/Retain): 0.97
⚠️  Unlearning may need more steps



  return fn(*args, **kwargs)


# da sku

In [None]:
class TaskVector():
    def __init__(self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None):
        """Initializes the task vector from a pretrained and a finetuned checkpoints.

        This can either be done by passing two state dicts (one corresponding to the
        pretrained model, and another to the finetuned model), or by directly passying in
        the task vector state dict.
        """
        if vector is not None:
            self.vector = vector
        else:
            assert pretrained_checkpoint is not None and finetuned_checkpoint is not None
            with torch.no_grad():

                pretrained_state_dict = pretrained_checkpoint.state_dict()
                finetuned_state_dict = finetuned_checkpoint.state_dict()

                self.vector = {}
                for key in pretrained_state_dict:
                    if pretrained_state_dict[key].dtype in [torch.int64, torch.uint8]:
                        continue
                    self.vector[key] = finetuned_state_dict[key] - pretrained_state_dict[key]

    def __add__(self, other):
        """Add two task vectors together."""
        with torch.no_grad():
            new_vector = {}
            for key in self.vector:
                if key not in other.vector:
                    print(f'Warning, key {key} is not present in both task vectors.')
                    continue
                new_vector[key] = self.vector[key] + other.vector[key]
        return TaskVector(vector=new_vector)

    def __radd__(self, other):
        if other is None or isinstance(other, int):
            return self
        return self.__add__(other)

    def __neg__(self):
        """Negate a task vector."""
        with torch.no_grad():
            new_vector = {}
            for key in self.vector:
                new_vector[key] = - self.vector[key]
        return TaskVector(vector=new_vector)

    def apply_to(self, pretrained_model, scaling_coef=1.0):
        """Apply a task vector to a pretrained model."""
        with torch.no_grad():
            new_state_dict = {}
            pretrained_state_dict = pretrained_model.state_dict()
            for key in pretrained_state_dict:
                if key not in self.vector:
                    print(f'Warning: key {key} is present in the pretrained state dict but not in the task vector')
                    continue
                new_state_dict[key] = pretrained_state_dict[key] + scaling_coef * self.vector[key]
        pretrained_model.load_state_dict(new_state_dict, strict=False)
        return pretrained_model


    # You can uncomment the following version if you don't have enough GPU memory to apply the task vector in one go
    # Split and reassemble the task vector using multiple chunks

    def apply_to(self, pretrained_model, scaling_coef=1.0, chunk_size=500):
        """Apply a task vector to a pretrained model in chunks."""
        with torch.no_grad():
            pretrained_state_dict = pretrained_model.state_dict()
            keys = list(self.vector.keys())  # Get all the parameter keys in the task vector
            total_keys = len(keys)
            for i in range(0, total_keys, chunk_size):
                new_state_dict = {}
                for key in keys[i:i + chunk_size]:
                    if key not in pretrained_state_dict:
                        print(f'Warning: key {key} is present in the task vector but not in the pretrained model')
                        continue
                    # Apply scaling and update the parameter
                    new_state_dict[key] = pretrained_state_dict[key] + scaling_coef * self.vector[key]
    
                # Partially load the updated state dict to the model
                pretrained_model.load_state_dict(new_state_dict, strict=False)
        return pretrained_model

In [None]:
print("saving model")
model_save_dir = "unleraned_model"
model.save_pretrained(model_save_dir, from_pt=True)
logging.info("Unlearning finished")

# Save task vector.
logging.info("Loading task vector")
task_vector = TaskVector(pretrained_model, model)

neg_task_vector = -task_vector

# Apply the task vector
new_benign_model = neg_task_vector.apply_to(pretrained_model)
task_vector_saving_path = "task_vector_path"
new_benign_model.save_pretrained(task_vector_saving_path, from_pt=True)

print("Done saving task vector files!")

In [None]:
# Salva il modello trainato (dopo merge_and_unload)
model_save_dir = "./my_unlearning_model"
model.save_pretrained(model_save_dir)
print(f"Modello salvato in: {model_save_dir}")

In [None]:
# model = model.merge_and_unload()
model.save_pretrained("tmp/unlearned_8bit", from_pt=True)

In [None]:
import torch
from transformers import AutoModelForCausalLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_save_dir="semeval25-unlearning-model"
# # 2) Ricarica da disco in FP32
model = AutoModelForCausalLM.from_pretrained(
    "tmp/unlearned_8bit",
    torch_dtype=torch.float32,
    device_map="auto"
)
model.save_pretrained(model_save_dir, from_pt=True)


pretrained_model = AutoModelForCausalLM.from_pretrained(
    "semeval25-unlearning-1B-model",
    torch_dtype=torch.float32
).to(device)

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()


In [None]:
model.save_pretrained(model_save_dir, from_pt=True)


In [None]:

pretrained_model = AutoModelForCausalLM.from_pretrained(
    "semeval25-unlearning-1B-model",
    torch_dtype=torch.float32
).to(device)

model = model.to(device)

In [None]:
# model = AutoModelForCausalLM.from_pretrained(
#     'semeval25-unlearning-1B-model',
#     device_map="cuda:0",  # Forza GPU
#     torch_dtype=torch.float16
# )

## Traskvector

In [None]:
# PROBLEMA: Il modello quantizzato non può essere usato per task vector!
# SOLUZIONE: Devi prima salvare il modello merged, poi ricaricarlo in fp32
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer

# 1. Il tuo modello è già stato merged con merge_and_unload(), ora salvalo
trained_path = "semeval25-unlearning-model"
pretrainerd_path = "semeval25-unlearning-1B-model"

# model.save_pretrained(model_save_dir)  # Salva il modello merged

# # 2. Ora ricarica ENTRAMBI i modelli in fp32 per il task vector
# pretrained_model = AutoModelForCausalLM.from_pretrained(
#     "semeval25-unlearning-1B-model",  # Modello originale
#     torch_dtype=torch.float16,
#     device_map="cuda"
# )

# trained_model = AutoModelForCausalLM.from_pretrained(
#     model_save_dir,  # Modello che hai appena salvato
#     torch_dtype=torch.float16,
#     device_map="cuda"
# )
import torch
import gc
from collections import OrderedDict

def create_task_vector_chunked(pretrained_path, trained_path, chunk_size=50):
    """
    Crea task vector processando i parametri a chunked per risparmiare memoria
    """
    # Prima ottieni la lista di tutti i parameter names dal modello trained
    temp_model = AutoModelForCausalLM.from_pretrained(trained_path, torch_dtype=torch.float32)
    param_names = list(temp_model.state_dict().keys())
    del temp_model
    torch.cuda.empty_cache()
    gc.collect()
    
    task_vector_dict = {}
    
    # Processa i parametri in chunks
    for i in range(0, len(param_names), chunk_size):
        chunk_names = param_names[i:i+chunk_size]
        print(f"Processing chunk {i//chunk_size + 1}/{(len(param_names) + chunk_size - 1)//chunk_size}")
        
        # Carica solo i parametri necessari per questo chunk
        pretrained_model = AutoModelForCausalLM.from_pretrained(
            pretrained_path, 
            torch_dtype=torch.float32,
            device_map="cpu"  # Carica su CPU prima
        )
        trained_model = AutoModelForCausalLM.from_pretrained(
            trained_path, 
            torch_dtype=torch.float32,
            device_map="cpu"
        )
        
        pretrained_state = pretrained_model.state_dict()
        trained_state = trained_model.state_dict()
        
        # Processa solo questo chunk
        for param_name in chunk_names:
            if param_name in pretrained_state and param_name in trained_state:
                # Sposta su GPU solo quando necessario
                pretrained_param = pretrained_state[param_name].to("cuda")
                trained_param = trained_state[param_name].to("cuda")
                
                if pretrained_param.dtype not in [torch.int64, torch.uint8]:
                    task_vector_dict[param_name] = (trained_param - pretrained_param).cpu()
                
                # Libera immediatamente la memoria GPU
                del pretrained_param, trained_param
                torch.cuda.empty_cache()
        
        # Libera i modelli
        del pretrained_model, trained_model, pretrained_state, trained_state
        torch.cuda.empty_cache()
        gc.collect()
    
    return task_vector_dict

# Uso della funzione chunked
print("Creando task vector con processing chunked...")
task_vector_dict = create_task_vector_chunked(
    "semeval25-unlearning-1B-model", 
    model_save_dir,
    chunk_size=30  # Riduci se hai ancora problemi di memoria
)

# Crea il task vector object
task_vector = TaskVector(vector=task_vector_dict)

In [None]:
import torch
import gc

# Pulisci completamente la memoria GPU
torch.cuda.empty_cache()
gc.collect()

# Verifica memoria disponibile
print(f"Memoria GPU libera: {torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)} bytes")

In [None]:
# Task Vector
task_vector_saving_path = "semeval25-unlearning-model/task_vector"
task_vector= TaskVector(pretrained_model, model)
neg_task_vector = -task_vector
unlearned_model = neg_task_vector.apply_to(pretrained_model, scaling_coef=2.0)
unlearned_model.save_pretrained(task_vector_saving_path, from_pt = True)


## Evaluation

In [None]:
import torch
from transformers import AutoModelForCausalLM

device = "cuda"

# Usa float16 invece di float32 (dimezza l'uso di memoria)
pretrained_model = AutoModelForCausalLM.from_pretrained(
    "semeval25-unlearning-1B-model",
    torch_dtype=torch.float16,  # Cambiato da float32
    device_map="cuda"  # Gestione automatica della memoria
)

unlearned_model = AutoModelForCausalLM.from_pretrained(
    "semeval25-unlearning-model/task_vector",
    torch_dtype=torch.float16,  # Cambiato da float32
    device_map="cuda"
)

In [None]:
# Verifica se i modelli sono effettivamente diversi
def compare_models(model1, model2, sample_layers=3):
    differences = []
    
    for name1, param1 in list(model1.named_parameters())[:sample_layers]:
        if name1 in dict(model2.named_parameters()):
            param2 = dict(model2.named_parameters())[name1]
            diff = torch.norm(param1 - param2).item()
            differences.append((name1, diff))
            print(f"{name1}: differenza = {diff:.6f}")
        else:
            print(f"{name1}: non trovato nel secondo modello")
    
    return differences

print("Confronto tra pretrained_model e model:")
diffs = compare_models(pretrained_model, model)

# Controlla se ci sono differenze significative
significant_diffs = [d for d in diffs if d[1] > 1e-6]
print(f"\nDifferenze significative: {len(significant_diffs)}")

In [None]:
import torch
from tqdm.auto import tqdm

def eval_loss(model, dataloader, device="cuda"):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction="sum")

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Eval"):
            input_ids = batch["input_ids"].to(device)
            attn      = batch["attention_mask"].to(device)
            labels    = batch["labels"].to(device)

            outputs = model(input_ids, attention_mask=attn)
            # logits: [B, L, V]
            shift_logits = outputs.logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()

            # flatten
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                            shift_labels.view(-1))
            total_loss += loss.item()
            total_tokens += (shift_labels != tokenizer.pad_token_id).sum().item()

    avg_nll = total_loss / total_tokens
    ppl = torch.exp(torch.tensor(avg_nll))
    return avg_nll, ppl.item()


In [None]:
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")

def tokenize_with_start(example):
    q, a = example["input"], example["output"]
    prefix = q
    full   = q + a

    # 1) tokenizza solo per contare i token reali (no pad)
    t_pref = tokenizer(prefix, truncation=True, padding=False)
    start_locs = len(t_pref["input_ids"])

    # 2) tokenizza la coppia vera e propria con pad/trunc
    t_full = tokenizer(full, truncation=True, padding="max_length", max_length=128)

    return {
      "input_ids":      t_full["input_ids"],
      "attention_mask": t_full["attention_mask"],
      "labels":         t_full["input_ids"],
      "start_locs":     start_locs,
    }


ds_retain = Dataset.from_pandas(retain_train_df).map(
    tokenize_with_start, batched=False, load_from_cache_file=False
)

ds_forget = Dataset.from_pandas(forget_train_df).map(
    tokenize_with_start, batched=False, load_from_cache_file=False
)


In [None]:
from datasets import Dataset
ds_retain_val = Dataset.from_pandas(retain_validation_df).map(
    tokenize_with_start,
    batched=False,
    load_from_cache_file=False
)
ds_forget_val = Dataset.from_pandas(forget_validation_df).map(
    tokenize_with_start,
    batched=False,
    load_from_cache_file=False
)

forget_val_loader = DataLoader(ds_forget_val, batch_size, shuffle=True, collate_fn=collate_fn)
retain_val_loader = DataLoader(ds_retain_val, batch_size, shuffle=True, collate_fn=collate_fn)



nll_forget_pre, ppl_forget_pre = eval_loss(pretrained_model, forget_val_loader)
nll_retain_pre, ppl_retain_pre = eval_loss(pretrained_model, retain_val_loader)

nll_forget_post, ppl_forget_post = eval_loss(new_benign_model, forget_val_loader)
nll_retain_post, ppl_retain_post = eval_loss(new_benign_model, retain_val_loader)



In [None]:
print(f"nll_forget_pre: {nll_forget_pre:.2f}")
print(f"ppl_forget_pre: {ppl_forget_pre:.2f}")
print(f"nll_forget_post: {nll_forget_post:.2f}")
print(f"ppl_forget_post: {ppl_forget_post:.2f}")

print(f"nll_retain_pre: {nll_retain_pre:.2f}")
print(f"ppl_retain_pre: {ppl_retain_pre:.2f}")
print(f"nll_retain_post: {nll_retain_post:.2f}")
print(f"ppl_retain_post: {ppl_retain_post:.2f}")

In [None]:
for index, example in forget_validation_df.sample(5).iterrows():
    prompt = example["input"]
    print("PROMPT:", prompt)
    out_pre  = pretrained_model.generate(tokenizer(prompt, return_tensors="pt").input_ids.to(device), max_new_tokens=50)
    out_post = new_benign_model.generate(tokenizer(prompt, return_tensors="pt").input_ids.to(device), max_new_tokens=50)
    print("ORIG:", tokenizer.decode(out_pre[0], skip_special_tokens=True))
    print("NEW:",  tokenizer.decode(out_post[0], skip_special_tokens=True))
    print("-"*40)