In [1]:
import torch
from models import TeacherModel, StudentModel
from data import MNISTDataManager
from utils import Trainer

# 디바이스 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 데이터 매니저 생성
data_manager = MNISTDataManager(batch_size=256)
train_loader = data_manager.get_train_loader()
test_loader = data_manager.get_test_loader()

# 트레이너 생성
trainer = Trainer(device=device)

# 1. 선생 모델 학습
print("="*50)
print("1. 선생 모델 학습")
print("="*50)

teacher_model = TeacherModel()
teacher_history = trainer.train_normal(
    teacher_model, train_loader, test_loader, lr=0.001, epochs=10
)

# 선생 모델 평가
print("\n선생 모델 상세 평가:")
teacher_accuracy, _, _ = trainer.evaluate_detailed(teacher_model, test_loader)

# 선생 모델 학습 히스토리 시각화
trainer.plot_history(teacher_history, "Teacher Model")

# 2. Soft target 생성
print("\n="*50)
print("2. Soft target 생성")
print("="*50)

soft_targets = data_manager.generate_soft_targets(
    teacher_model, temperature=8, device=device
)

# 3. 학생 모델들 학습
print("\n="*50)
print("3. 학생 모델들 학습")
print("="*50)

# 2-1. 일반 학생 모델 학습
print("\n2-1. 일반 학생 모델 학습:")
student_normal = StudentModel()
student_normal_history = trainer.train_normal(
    student_normal, train_loader, test_loader, lr=0.001, epochs=10
)

# 2-2. KD 학생 모델 학습
print("\n2-2. KD 학생 모델 학습:")
student_kd = StudentModel()
student_kd_history = trainer.train_with_kd(
    student_kd, train_loader, test_loader, soft_targets, 
    temperature=8, alpha=0.7, lr=0.001, epochs=10
)

# 2-3. Label '3' 제거한 데이터로 KD 학생 모델 학습
print("\n2-3. Label '3' 제거한 데이터로 KD 학생 모델 학습:")

# Label '3' 제거한 데이터셋과 soft target 생성
train_loader_no_3, _, _ = data_manager.create_dataset_without_label(exclude_label=3)
soft_targets_no_3 = data_manager.filter_soft_targets(soft_targets, exclude_label=3)

student_kd_no_3 = StudentModel()
student_kd_no_3_history = trainer.train_with_kd(
    student_kd_no_3, train_loader_no_3, test_loader, soft_targets_no_3,
    temperature=8, alpha=0.7, lr=0.001, epochs=10
)

# 4. 모든 모델 평가
print("\n" + "="*50)
print("4. 최종 성능 평가")
print("="*50)

print("\n선생 모델 (Teacher Model):")
teacher_final_accuracy, teacher_class_correct, teacher_class_total = trainer.evaluate_detailed(teacher_model, test_loader)

print("\n일반 학생 모델 (Student Normal):")
student_normal_accuracy, student_normal_class_correct, student_normal_class_total = trainer.evaluate_detailed(student_normal, test_loader)

print("\nKD 학생 모델 (Student KD):")
student_kd_accuracy, student_kd_class_correct, student_kd_class_total = trainer.evaluate_detailed(student_kd, test_loader)

print("\nKD 학생 모델 - Label '3' 제거 (Student KD No 3):")
student_kd_no_3_accuracy, student_kd_no_3_class_correct, student_kd_no_3_class_total = trainer.evaluate_detailed(student_kd_no_3, test_loader)

# 5. 개별 모델 학습 히스토리 시각화
print("\n" + "="*50)
print("5. 개별 모델 학습 히스토리")
print("="*50)

trainer.plot_history(student_normal_history, "Student Normal")
trainer.plot_history(student_kd_history, "Student KD")
trainer.plot_history(student_kd_no_3_history, "Student KD No 3")

# 6. 모델 비교 시각화
print("\n" + "="*50)
print("6. 모델 비교")
print("="*50)

histories = [teacher_history, student_normal_history, student_kd_history, student_kd_no_3_history]
model_names = ['Teacher', 'Student Normal', 'Student KD', 'Student KD No 3']
accuracies = [teacher_final_accuracy, student_normal_accuracy, student_kd_accuracy, student_kd_no_3_accuracy]

trainer.compare_models(histories, model_names, accuracies)

# 7. 결과 분석
print("\n" + "="*50)
print("7. 결과 분석")
print("="*50)

print(f"선생 모델 정확도: {teacher_final_accuracy:.2f}%")
print(f"일반 학생 모델 정확도: {student_normal_accuracy:.2f}%")
print(f"KD 학생 모델 정확도: {student_kd_accuracy:.2f}%")
print(f"KD 학생 모델 (Label '3' 제거) 정확도: {student_kd_no_3_accuracy:.2f}%")

# 성능 향상 분석
print("\n성능 향상 분석:")
improvement_kd = student_kd_accuracy - student_normal_accuracy
print(f"KD vs 일반 학생: {improvement_kd:+.2f}%")

print(f"\nLabel '3' 제거 시 성능 변화:")
change_no_3 = student_kd_no_3_accuracy - student_kd_accuracy
print(f"변화량: {change_no_3:+.2f}%")

# Label '3' 분류 성능 비교
print(f"\nLabel '3' 분류 성능:")
results = [
    (teacher_class_correct, teacher_class_total, "선생 모델"),
    (student_normal_class_correct, student_normal_class_total, "일반 학생"),
    (student_kd_class_correct, student_kd_class_total, "KD 학생"),
    (student_kd_no_3_class_correct, student_kd_no_3_class_total, "KD 학생(No 3)")
]

for class_correct, class_total, name in results:
    if class_total[3] > 0:
        acc_3 = 100 * class_correct[3] / class_total[3]
        print(f"{name}: {acc_3:.2f}%")
    else:
        print(f"{name}: 학습 데이터에 class 3 없음")

print("\n실험 완료!")

Using device: cpu


100%|██████████| 9.91M/9.91M [00:20<00:00, 481kB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 147kB/s]
100%|██████████| 1.65M/1.65M [00:02<00:00, 794kB/s] 
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.55MB/s]


Train dataset size: 60000
Test dataset size: 10000
1. 선생 모델 학습


Epoch 1/10: 100%|██████████| 235/235 [00:10<00:00, 21.37it/s, Loss=0.2897, Acc=91.18%]


Epoch 1: Train Loss: 0.2897, Train Acc: 91.18%, Val Loss: 0.1283, Val Acc: 95.91%


Epoch 2/10: 100%|██████████| 235/235 [00:10<00:00, 21.60it/s, Loss=0.1416, Acc=95.66%]


Epoch 2: Train Loss: 0.1416, Train Acc: 95.66%, Val Loss: 0.0819, Val Acc: 97.27%


Epoch 3/10: 100%|██████████| 235/235 [00:10<00:00, 21.58it/s, Loss=0.1126, Acc=96.43%]


Epoch 3: Train Loss: 0.1126, Train Acc: 96.43%, Val Loss: 0.0687, Val Acc: 97.90%


Epoch 4/10: 100%|██████████| 235/235 [00:11<00:00, 20.94it/s, Loss=0.0932, Acc=97.12%]


Epoch 4: Train Loss: 0.0932, Train Acc: 97.12%, Val Loss: 0.0708, Val Acc: 97.79%


Epoch 5/10: 100%|██████████| 235/235 [00:11<00:00, 20.91it/s, Loss=0.0879, Acc=97.29%]


Epoch 5: Train Loss: 0.0879, Train Acc: 97.29%, Val Loss: 0.0691, Val Acc: 97.79%


Epoch 6/10: 100%|██████████| 235/235 [00:11<00:00, 20.76it/s, Loss=0.0799, Acc=97.52%]


Epoch 6: Train Loss: 0.0799, Train Acc: 97.52%, Val Loss: 0.0628, Val Acc: 97.97%


Epoch 7/10: 100%|██████████| 235/235 [00:11<00:00, 20.60it/s, Loss=0.0744, Acc=97.65%]


Epoch 7: Train Loss: 0.0744, Train Acc: 97.65%, Val Loss: 0.0640, Val Acc: 98.03%


Epoch 8/10: 100%|██████████| 235/235 [00:11<00:00, 20.66it/s, Loss=0.0677, Acc=97.84%]


Epoch 8: Train Loss: 0.0677, Train Acc: 97.84%, Val Loss: 0.0582, Val Acc: 98.13%


Epoch 9/10: 100%|██████████| 235/235 [00:11<00:00, 20.48it/s, Loss=0.0638, Acc=97.97%]


Epoch 9: Train Loss: 0.0638, Train Acc: 97.97%, Val Loss: 0.0630, Val Acc: 98.05%


Epoch 10/10: 100%|██████████| 235/235 [00:11<00:00, 20.47it/s, Loss=0.0651, Acc=97.97%]


Epoch 10: Train Loss: 0.0651, Train Acc: 97.97%, Val Loss: 0.0603, Val Acc: 98.14%

선생 모델 상세 평가:
Overall Accuracy: 98.14%
Class 0: 99.08%
Class 1: 99.38%
Class 2: 97.00%
Class 3: 98.71%
Class 4: 98.57%
Class 5: 97.09%
Class 6: 99.16%
Class 7: 98.54%
Class 8: 97.64%
Class 9: 96.04%


: 