In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np

import torch
import torch.nn as nn
import random

from utils import *

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)
random.seed(SEED)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
MODEL = "biolinkbert"

MODEL_PATH = "best_model/"+MODEL+"/"
FORGET_DETAILS_PATH = "data/forget_set.json"

mlb, classes = get_classes_mlb()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

test_loader, val_loader, forget_details = load_val_test_loaders(FORGET_DETAILS_PATH, tokenizer, mlb)

criterion = SafeWeightedBCEWithLogitsLoss()

In [None]:
UNLEARNED_MODEL_PATH = "unlearned/"+MODEL+"/adv_imp/"

WEIGHT_LAMB = 0.5
FORGET_LAMB = 0.3
RETAIN_LAMB = 0.7

for UNLEARN_K in forget_details:
    retain_score = 0
    forget_score = 0
    test_score = 0
    org_forget_score = 0
    org_retain_score = 0

    for fold in forget_details[UNLEARN_K]:
        least_forget_score = 1.0
        unlearned_model_path = UNLEARNED_MODEL_PATH+UNLEARN_K+"/"+fold
        forget_loader, retain_loader, df_forget, df_retain = load_retain_forget_loaders(forget_details[UNLEARN_K][fold], tokenizer, mlb)

        model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
        model = nn.DataParallel(model)
        model.to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=UNLEARN_LR)

        # Save original params (detach clones)
        origin_params = {}
        for n, p in named_parameters_dict(model).items():
            origin_params[n] = p.detach().clone()

        # Estimate importance on the forget_loader (use all forget samples or set num_samples)
        print("Estimating parameter importance (this may take a while)...")
        params_importance = estimate_parameter_importance(forget_loader, model, device)

        # Perform Adversairial attack and generate adversarial samples of forget set
        adv_loader = load_adv_loader(model, forget_loader, device)

        forget_metrics = test(model, forget_loader, device)
        retain_metrics = test(model, retain_loader, device)
        org_forget_score += forget_metrics["score"]
        org_retain_score += retain_metrics["score"]

        for epoch in range(EPOCHS):
            # Training phase
            model.train()
            total_adv_loss = 0
            total_forget_loss = 0
            total_batches = len(forget_loader)
            samples = 0

            for i, (forget_batch, adv_batch) in enumerate(zip(forget_loader, adv_loader)):
                forget_input_ids = forget_batch['input_ids'].to(device)
                forget_attention_mask = forget_batch['attention_mask'].to(device)
                forget_labels = forget_batch['labels'].to(device)
                adv_embeds = adv_batch['embeds'].to(device)
                adv_attention_mask = adv_batch['attention_mask'].to(device)
                adv_labels = adv_batch['labels'].to(device)

                optimizer.zero_grad()
                forget_outputs = model(input_ids=forget_input_ids, attention_mask=forget_attention_mask)
                adv_outputs = model(inputs_embeds=adv_embeds, attention_mask=adv_attention_mask)

                forget_loss = criterion(forget_outputs.logits, forget_labels)
                adv_loss = criterion(adv_outputs.logits, adv_labels)
                reg_loss = parameter_regularization_loss(model, origin_params, params_importance)
                loss = RETAIN_LAMB*adv_loss - FORGET_LAMB*forget_loss + WEIGHT_LAMB*reg_loss
                loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()

            forget_metrics = test(model, forget_loader, device)
            if(forget_metrics["score"] < least_forget_score):
                least_forget_score = forget_metrics["score"]
                model_to_save = model.module if hasattr(model, "module") else model
                model_to_save.save_pretrained(unlearned_model_path)
        
        model = AutoModelForSequenceClassification.from_pretrained(unlearned_model_path)
        model = nn.DataParallel(model)
        model.to(device)
        retain_metrics = test(model, retain_loader, device)
        test_metrics = test(model, test_loader, device)

        forget_score += least_forget_score
        retain_score += retain_metrics["score"]
        test_score += test_metrics["score"]

    avg_forget_score = forget_score/TRIALS
    avg_retain_score = retain_score/TRIALS
    avg_test_score = test_score/TRIALS
    avg_org_forget_score = org_forget_score/TRIALS
    avg_org_retain_score = org_retain_score/TRIALS

    print(f"\nFinal scores of Unlearn_K: {UNLEARN_K}")
    print(f"Avg Forget score (before unlearning): {avg_org_forget_score:.4f}")
    print(f"Avg Retain score (before unlearning): {avg_org_retain_score:.4f}")
    print(f"Avg Forget Score: {avg_forget_score:.4f}")
    print(f"Avg Retain Score: {avg_retain_score:.4f}")
    print(f"Avg Test Score: {avg_test_score:.4f}")


In [None]:
MODEL = "biolinkbert"

MODEL_PATH = "best_model/"+MODEL+"/"
UNLEARNED_MODEL_PATH = "unlearned/"+MODEL+"/ga_with_retain/"
FORGET_DETAILS_PATH = "data/forget_set.json"

LAMBDA = 0.4

for UNLEARN_K in forget_details:
    retain_score = 0
    forget_score = 0
    test_score = 0
    org_forget_score = 0
    org_retain_score = 0
    for fold in forget_details[UNLEARN_K]:
        least_forget_score = 1.0
        unlearned_model_path = UNLEARNED_MODEL_PATH+UNLEARN_K+"/"+fold
        forget_loader, retain_loader, df_forget, df_retain = load_retain_forget_loaders(forget_details[UNLEARN_K][fold], tokenizer, mlb)
        sampled_retain_loader, sampled_retain_labels = get_sampled_retain_loader(df_retain, tokenizer, mlb, UNLEARN_K, fold)   

        model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
        model = nn.DataParallel(model)
        model.to(device)

        optimizer = torch.optim.AdamW(model.parameters(), lr=UNLEARN_LR)

        forget_metrics = test(model, forget_loader, device)
        retain_metrics = test(model, retain_loader, device)
        org_forget_score += forget_metrics["score"]
        org_retain_score += retain_metrics["score"]

        for epoch in range(EPOCHS):
            # Training phase
            model.train()
            train_loss = 0
            total_forget_loss = 0
            total_batches = len(forget_loader)
            samples = 0
            
            for i, (forget_batch, retain_batch) in enumerate(zip(forget_loader, sampled_retain_loader)):
                forget_input_ids = forget_batch['input_ids'].to(device)
                forget_attention_mask = forget_batch['attention_mask'].to(device)
                forget_labels = forget_batch['labels'].to(device)
                retain_input_ids = retain_batch['input_ids'].to(device)
                retain_attention_mask = retain_batch['attention_mask'].to(device)
                retain_labels = retain_batch['labels'].to(device)

                optimizer.zero_grad()
                forget_outputs = model(input_ids=forget_input_ids, attention_mask=forget_attention_mask)
                retain_outputs = model(input_ids=retain_input_ids, attention_mask=retain_attention_mask)
                    
                retain_loss = criterion(retain_outputs.logits, retain_labels)
                forget_loss = criterion(forget_outputs.logits, forget_labels)
                loss = (1-LAMBDA)*retain_loss - LAMBDA*forget_loss
                loss.backward()
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
                optimizer.step()

            avg_train_loss = train_loss/samples  
            avg_forget_loss = total_forget_loss/samples
            
            forget_metrics = test(model, forget_loader, device)
            if(forget_metrics["score"] < least_forget_score):
                least_forget_score = forget_metrics["score"]
                model_to_save = model.module if hasattr(model, "module") else model
                model_to_save.save_pretrained(unlearned_model_path)
                print("Model saved")

        model = AutoModelForSequenceClassification.from_pretrained(unlearned_model_path)
        model = nn.DataParallel(model)
        model.to(device)
        retain_metrics = test(model, retain_loader, device)
        test_metrics = test(model, test_loader, device)

        forget_score += least_forget_score
        retain_score += retain_metrics["score"]
        test_score += test_metrics["score"]

    avg_forget_score = forget_score/TRIALS
    avg_retain_score = retain_score/TRIALS
    avg_test_score = test_score/TRIALS
    avg_org_forget_score = org_forget_score/TRIALS
    avg_org_retain_score = org_retain_score/TRIALS

    print(f"\nFinal scores of Unlearn_K: {UNLEARN_K}")
    print(f"Avg Forget score (before unlearning): {avg_org_forget_score:.4f}")
    print(f"Avg Retain score (before unlearning): {avg_org_retain_score:.4f}")
    print(f"Avg Forget Score: {avg_forget_score:.4f}")
    print(f"Avg Retain Score: {avg_retain_score:.4f}")
    print(f"Avg Test Score: {avg_test_score:.4f}")
