In [2]:
# main.py  (수정판: 핵심만 고침)
import torch, torch.nn.functional as F, torch.nn as nn
from model import custom_res18
from dataset import full_forget_retain_loader_train, full_forget_retain_loader_test,  UnlearnFullTrain
from train_test_acc import train_model, model_test
from mia_unlearning_score import UnLearningScore, get_membership_attack_prob, actv_dist
from seed import set_seed

set_seed(42)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH = 512
EPOCHS = 20
epoch_unlearn = 10
LR = 1e-4
TAU_G, TAU_B = 1, 1    # 실험용 기본
FORGET_RATIO = 0.1
NUM_CLASSES = 10

In [4]:
full_tr, combined_tr, forget_tr, retain_tr = full_forget_retain_loader_train(FORGET_RATIO, batch_size=BATCH)
full_te, combined_te, forget_te, retain_te = full_forget_retain_loader_test(FORGET_RATIO, batch_size=BATCH)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Good Teacher 학습 및 저장
good = custom_res18(NUM_CLASSES).to(DEVICE)
good = train_model(
    good,
    combined_tr.dataset,
    'good_teacher',
    epochs=EPOCHS, lr=LR, batch_size=BATCH, device=DEVICE
)
torch.save(good.state_dict(), './full_tr_model.pth')

# Retain 전용 재학습 모델
retrain = custom_res18(NUM_CLASSES).to(DEVICE)
retrain = train_model(
    retrain,
    retain_tr.dataset,  # retain 데이터만 사용
    'retrain_retains_only',
    epochs=EPOCHS,
    lr=LR, batch_size=BATCH,
    device=DEVICE
)
torch.save(retrain.state_dict(), './retrain_tr_model.pth')




[Train:good_teacher] Epoch 1/20 Loss 2.1895 Acc 25.88%
[Train:good_teacher] Epoch 2/20 Loss 1.3927 Acc 48.45%
[Train:good_teacher] Epoch 3/20 Loss 1.0711 Acc 61.51%
[Train:good_teacher] Epoch 4/20 Loss 0.8547 Acc 69.47%
[Train:good_teacher] Epoch 5/20 Loss 0.6721 Acc 76.10%
[Train:good_teacher] Epoch 6/20 Loss 0.5399 Acc 80.91%
[Train:good_teacher] Epoch 7/20 Loss 0.4031 Acc 85.77%
[Train:good_teacher] Epoch 8/20 Loss 0.2985 Acc 89.44%
[Train:good_teacher] Epoch 9/20 Loss 0.2159 Acc 92.44%
[Train:good_teacher] Epoch 10/20 Loss 0.1625 Acc 94.36%
[Train:good_teacher] Epoch 11/20 Loss 0.1343 Acc 95.27%
[Train:good_teacher] Epoch 12/20 Loss 0.1105 Acc 96.19%
[Train:good_teacher] Epoch 13/20 Loss 0.0739 Acc 97.59%
[Train:good_teacher] Epoch 14/20 Loss 0.0521 Acc 98.32%
[Train:good_teacher] Epoch 15/20 Loss 0.0553 Acc 98.20%
[Train:good_teacher] Epoch 16/20 Loss 0.0547 Acc 98.22%
[Train:good_teacher] Epoch 17/20 Loss 0.0519 Acc 98.33%
[Train:good_teacher] Epoch 18/20 Loss 0.0583 Acc 98.01%
[

In [4]:

model_test(good, full_te.dataset, 'bad/full_test', batch_size=256, device=DEVICE)


[Test:bad/full_test] Acc 78.42%


78.42

In [5]:
# 모델 로드
finetune_model = custom_res18(NUM_CLASSES).to(DEVICE)
finetune_model.load_state_dict(torch.load('./full_tr_model.pth'))

# 파인튜닝
finetune_model = train_model(
    finetune_model,
    retain_tr.dataset,   # train_model 안에서 DataLoader 만들 수 있다고 가정
    'finetune',
    epochs=EPOCHS, lr=LR, batch_size=BATCH, device=DEVICE
)

finetune_model.eval()

print('\n[Accuracies]')
model_test(finetune_model, retain_te.dataset, 'Finetune/Retain_test', batch_size=256, device=DEVICE)
model_test(finetune_model, forget_te.dataset, 'Finetune/Forget_test', batch_size=256, device=DEVICE)
model_test(finetune_model, retain_tr.dataset, 'Finetune/Retain_train', batch_size=256, device=DEVICE)
model_test(finetune_model, forget_tr.dataset, 'Finetune/Forget_train', batch_size=256, device=DEVICE)

# 평가용 DataLoader (batch 크기/셔플 옵션 다르게 주고 싶을 때)
forget_eval_train = torch.utils.data.DataLoader(forget_tr.dataset,  batch_size=128, shuffle=False)
forget_eval_test  = torch.utils.data.DataLoader(forget_te.dataset,  batch_size=128, shuffle=False)
retain_eval_train = torch.utils.data.DataLoader(retain_tr.dataset,  batch_size=128, shuffle=False)
retain_eval_test  = torch.utils.data.DataLoader(retain_te.dataset,  batch_size=128, shuffle=False)

print('\n[MIA]')
mia_p = get_membership_attack_prob(retain_eval_train, forget_eval_test, retain_eval_test, finetune_model, DEVICE)
print(f'  MIA success prob on Forget (test) : {mia_p:.4f}')


[Train:finetune] Epoch 1/20 Loss 0.0244 Acc 99.28%
[Train:finetune] Epoch 2/20 Loss 0.0210 Acc 99.45%
[Train:finetune] Epoch 3/20 Loss 0.0201 Acc 99.45%
[Train:finetune] Epoch 4/20 Loss 0.0231 Acc 99.39%
[Train:finetune] Epoch 5/20 Loss 0.0307 Acc 99.06%
[Train:finetune] Epoch 6/20 Loss 0.0894 Acc 96.97%
[Train:finetune] Epoch 7/20 Loss 0.1267 Acc 95.59%
[Train:finetune] Epoch 8/20 Loss 0.0759 Acc 97.53%
[Train:finetune] Epoch 9/20 Loss 0.0364 Acc 98.95%
[Train:finetune] Epoch 10/20 Loss 0.0198 Acc 99.49%
[Train:finetune] Epoch 11/20 Loss 0.0089 Acc 99.83%
[Train:finetune] Epoch 12/20 Loss 0.0032 Acc 99.96%
[Train:finetune] Epoch 13/20 Loss 0.0013 Acc 99.99%
[Train:finetune] Epoch 14/20 Loss 0.0011 Acc 100.00%
[Train:finetune] Epoch 15/20 Loss 0.0012 Acc 100.00%
[Train:finetune] Epoch 16/20 Loss 0.0014 Acc 100.00%
[Train:finetune] Epoch 17/20 Loss 0.0015 Acc 100.00%
[Train:finetune] Epoch 18/20 Loss 0.0016 Acc 100.00%
[Train:finetune] Epoch 19/20 Loss 0.0017 Acc 100.00%
[Train:finetune