In [None]:
from config.config import DEVICE, FORGET_CLASS_IDXS
from utils.data_utils import load_cifar10_data, split_dataset
from models.models import create_timm_model
from utils.eval_utils import evaluate_instance_model_accuracy
from methods.unlearning import distill_with_soft_relabel, sort_data_loader_by_entropy, SelfUnlearning_Layered_Iteration

import torch
from torch.utils.data import DataLoader

if __name__ == '__main__':
    train_loader, test_loader = load_cifar10_data()
    forget_data, retain_data = split_dataset(train_loader.dataset, FORGET_CLASS_IDXS)
    forget_loader = DataLoader(forget_data, batch_size=256, shuffle=True)
    retain_loader = DataLoader(retain_data, batch_size=256, shuffle=True)
    # load the target model
    original_model = create_timm_model().to(DEVICE)
    original_model.load_state_dict(
        torch.hub.load_state_dict_from_url(
            "https://huggingface.co/edadaltocg/resnet18_cifar10/resolve/main/pytorch_model.bin",
            map_location="cpu",
            file_name="resnet18_cifar10.pth"
        )
    )
    original_model.eval()
    print("original model evaluation:")
    evaluate_instance_model_accuracy(original_model, test_loader, forget_loader, retain_loader, DEVICE)

    print("\n===== SU (Self Unlearning) =====")
    SU_model = create_timm_model().to(DEVICE)
    SU_model.load_state_dict(original_model.state_dict())
    optimizer = torch.optim.Adam(SU_model.parameters(), lr=0.0001)
    SU_model = distill_with_soft_relabel(
        original_model, SU_model, forget_loader, optimizer,
        forget_class_idxs=FORGET_CLASS_IDXS, epochs=10, device=DEVICE
    )
    print("SU model evaluation:")
    evaluate_instance_model_accuracy(SU_model, test_loader, forget_loader, retain_loader, DEVICE)

    print("\n===== SULI (Self-Unlearning Layered Iteration) =====")
    SULI_model = create_timm_model().to(DEVICE)
    SULI_model.load_state_dict(original_model.state_dict())
    sorted_loader = sort_data_loader_by_entropy(original_model, forget_loader, DEVICE, batch_size_per_loader=500)
    SULI_model = SelfUnlearning_Layered_Iteration(
        SULI_model, sorted_loaders=sorted_loader, forget_class_idxs=FORGET_CLASS_IDXS,
        forget_loader=forget_loader, epochs=10, device=DEVICE, lr=0.0001
    )
    print("SULI model evaluation:")
    evaluate_instance_model_accuracy(SULI_model, test_loader, forget_loader, retain_loader, DEVICE)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:26<00:00, 6451631.00it/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
