In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, ConcatDataset
from torchvision import datasets, transforms
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
import numpy as np
import copy

# 0. 장치 설정 (GPU 사용 가능 시 GPU 사용)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 1. 기본 CNN 모델 정의
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)  # assuming input is 32x32
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)  # flatten
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 2. 데이터 전처리 설정
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 3. 데이터셋 로드
# 학습 데이터와 무라벨 데이터가 각각 필요 (여기서는 같은 데이터셋을 사용하지만 실제로는 무라벨 데이터셋이 별도로 필요)
trainset_path = 'C:/Users/jongcheol/OneDrive/바탕 화면/Semester2/train_data'
trainset = datasets.ImageFolder(root=trainset_path, transform=transform)
labels = np.array([label for _, label in trainset.imgs])

# 4. 5-Fold Cross Validation 설정
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_accuracies_single = []  # 단일 CNN 성능
fold_accuracies_noisy_student = []  # Noisy Student 적용 CNN 성능

# 5. Cross Validation 학습 및 평가
for fold, (train_idx, val_idx) in enumerate(kf.split(np.zeros(len(labels)), labels)):
    print(f"\n=== Fold {fold + 1} 시작 ===")

    # Fold별 데이터셋 분리
    train_subset = Subset(trainset, train_idx)
    val_subset = Subset(trainset, val_idx)
    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_subset, batch_size=32, shuffle=False, num_workers=2)

    # 6. 단일 CNN 학습 (비교 대상)
    single_model = SimpleCNN(num_classes=10).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(single_model.parameters(), lr=0.001)

    # 학습 루프
    single_model.train()
    for epoch in range(3):
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = single_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    # 단일 CNN 성능 평가
    single_model.eval()
    single_preds, single_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            outputs = single_model(images)
            _, preds = torch.max(outputs, 1)
            single_preds.extend(preds.cpu().numpy())
            single_labels.extend(labels.cpu().numpy())
    fold_accuracy_single = accuracy_score(single_labels, single_preds)
    fold_accuracies_single.append(fold_accuracy_single)
    print(f"Fold {fold + 1} Single CNN Accuracy: {fold_accuracy_single * 100:.2f}%")

    # 7. Teacher 모델 학습
    teacher_model = SimpleCNN(num_classes=10).to(device)
    optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)
    teacher_model.train()
    for epoch in range(3):
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = teacher_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    # 8. Pseudo-labeling: Teacher 모델로 무라벨 데이터에 가짜 라벨 생성
    pseudo_labels = []
    pseudo_dataset = []
    teacher_model.eval()
    with torch.no_grad():
        for images, _ in val_loader:  # 무라벨 데이터 사용 (여기선 val_loader 사용)
            images = images.to(device)
            outputs = teacher_model(images)
            _, preds = torch.max(outputs, 1)
            pseudo_labels.extend(preds.cpu().numpy())
            pseudo_dataset.extend(images.cpu())

    # 무라벨 데이터와 pseudo-label 결합
    pseudo_labels = torch.tensor(pseudo_labels)
    noisy_student_dataset = [(img, label) for img, label in zip(pseudo_dataset, pseudo_labels)]

    # 9. Student 모델 학습 (Noisy Student 기법 적용)
    student_model = SimpleCNN(num_classes=10).to(device)
    student_loader = DataLoader(noisy_student_dataset, batch_size=32, shuffle=True)
    optimizer = optim.Adam(student_model.parameters(), lr=0.001)
    student_model.train()
    for epoch in range(3):
        for images, labels in student_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = student_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    # 10. Noisy Student 모델 성능 평가
    student_model.eval()
    student_preds, student_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            outputs = student_model(images)
            _, preds = torch.max(outputs, 1)
            student_preds.extend(preds.cpu().numpy())
            student_labels.extend(labels.cpu().numpy())
    fold_accuracy_noisy_student = accuracy_score(student_labels, student_preds)
    fold_accuracies_noisy_student.append(fold_accuracy_noisy_student)
    print(f"Fold {fold + 1} Noisy Student CNN Accuracy: {fold_accuracy_noisy_student * 100:.2f}%")

# 11. 5-Fold 평균 정확도 비교 출력
print("\n=== 최종 5-Fold 평균 정확도 ===")
print(f"Single CNN Model: {np.mean(fold_accuracies_single) * 100:.2f}%")
print(f"Noisy Student CNN Model: {np.mean(fold_accuracies_noisy_student) * 100:.2f}%")



=== Fold 1 시작 ===
Fold 1 Single CNN Accuracy: 56.71%
Fold 1 Noisy Student CNN Accuracy: 45.11%

=== Fold 2 시작 ===
Fold 2 Single CNN Accuracy: 56.84%
Fold 2 Noisy Student CNN Accuracy: 47.20%

=== Fold 3 시작 ===
Fold 3 Single CNN Accuracy: 58.15%
Fold 3 Noisy Student CNN Accuracy: 47.98%

=== Fold 4 시작 ===


KeyboardInterrupt: 