In [1]:
# main.py  (수정판: 핵심만 고침)
import torch, torch.nn.functional as F, torch.nn as nn
from model import custom_res18
from dataset import full_forget_retain_loader_train, full_forget_retain_loader_test
from train_test_acc import train_model, model_test
from mia_unlearning_score import UnLearningScore, get_membership_attack_prob, actv_dist
from seed import set_seed

set_seed(42)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH = 512
EPOCHS = 20
epoch_unlearn = 10
LR = 1e-4
TAU_G, TAU_B = 1, 1    # 실험용 기본
FORGET_CLASS = 1
NUM_CLASSES = 10

In [2]:
full_tr, forget_tr, retain_tr = full_forget_retain_loader_train(forget_class=FORGET_CLASS, batch_size=BATCH)
full_te, forget_te, retain_te = full_forget_retain_loader_test(forget_class=FORGET_CLASS, batch_size=BATCH)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Good Teacher 학습 및 저장
good = custom_res18(NUM_CLASSES).to(DEVICE)
good = train_model(
    good,
    full_tr.dataset,
    'good_teacher',
    epochs=EPOCHS, lr=LR, batch_size=BATCH, device=DEVICE
)
torch.save(good.state_dict(), './full_tr_model.pth')

# Retain 전용 재학습 모델
retrain = custom_res18(NUM_CLASSES).to(DEVICE)
retrain = train_model(
    retrain,
    retain_tr.dataset,  # retain 데이터만 사용
    'retrain_retains_only',
    epochs=EPOCHS,
    lr=LR, batch_size=BATCH,
    device=DEVICE
)
torch.save(retrain.state_dict(), './retrain_tr_model.pth')




[Train:good_teacher] Epoch 1/20 Loss 2.3121 Acc 23.58%
[Train:good_teacher] Epoch 2/20 Loss 1.4810 Acc 44.85%
[Train:good_teacher] Epoch 3/20 Loss 1.1953 Acc 56.47%
[Train:good_teacher] Epoch 4/20 Loss 0.9519 Acc 65.88%
[Train:good_teacher] Epoch 5/20 Loss 0.7662 Acc 72.69%
[Train:good_teacher] Epoch 6/20 Loss 0.6254 Acc 77.91%
[Train:good_teacher] Epoch 7/20 Loss 0.4947 Acc 82.52%
[Train:good_teacher] Epoch 8/20 Loss 0.3817 Acc 86.54%
[Train:good_teacher] Epoch 9/20 Loss 0.2900 Acc 89.87%
[Train:good_teacher] Epoch 10/20 Loss 0.2033 Acc 92.88%
[Train:good_teacher] Epoch 11/20 Loss 0.1711 Acc 94.06%
[Train:good_teacher] Epoch 12/20 Loss 0.1232 Acc 95.80%
[Train:good_teacher] Epoch 13/20 Loss 0.1056 Acc 96.44%
[Train:good_teacher] Epoch 14/20 Loss 0.0915 Acc 96.96%
[Train:good_teacher] Epoch 15/20 Loss 0.0724 Acc 97.56%
[Train:good_teacher] Epoch 16/20 Loss 0.0687 Acc 97.68%
[Train:good_teacher] Epoch 17/20 Loss 0.0609 Acc 97.99%
[Train:good_teacher] Epoch 18/20 Loss 0.0529 Acc 98.31%
[

In [4]:

model_test(good, full_te.dataset, 'bad/full_test', batch_size=256, device=DEVICE)


[Test:bad/full_test] Acc 75.73%


75.73