In [None]:
# EfficientNet-b0 with 3-way split, AUC, Confusion Matrix
import os, numpy as np, torch, cv2
from glob import glob
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import gc
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

# --- 설정 ---
SLICE_ROOT = "/data1/lidc-idri/slices"
BATCH_SIZE = 16
NUM_EPOCHS = 100
LR = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 라벨 추출 ---
def extract_label_from_filename(filename):
    try:
        score = int(filename.split("_")[-1].replace(".npy", ""))
        if score == 3: return None
        return 1 if score >= 4 else 0
    except:
        return None

# --- 파일 로딩 및 3-way split ---
all_files = glob(os.path.join(SLICE_ROOT, "LIDC-IDRI-*", "*.npy"))
file_label_pairs = [(f, extract_label_from_filename(f)) for f in all_files]
file_label_pairs = [(f, l) for f, l in file_label_pairs if l is not None]
files, labels = zip(*file_label_pairs)
train_files, temp_files, train_labels, temp_labels = train_test_split(files, labels, test_size=0.3, random_state=42)
val_files, test_files, val_labels, test_labels = train_test_split(temp_files, temp_labels, test_size=0.5, random_state=42)

# --- Dataset 정의 ---
class LIDCDataset(Dataset):
    def __init__(self, file_paths, labels):
        self.file_paths = file_paths
        self.labels = labels

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img = np.load(self.file_paths[idx]).astype(np.float32)
        img = np.clip(img, -1000, 400)
        img = (img + 1000) / 1400.0
        img = cv2.resize(img, (224, 224))
        img = np.expand_dims(img, 0)
        return torch.tensor(img).float(), torch.tensor(self.labels[idx]).float()

# --- DataLoader ---
train_loader = DataLoader(LIDCDataset(train_files, train_labels), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(LIDCDataset(val_files, val_labels), batch_size=BATCH_SIZE)
test_loader = DataLoader(LIDCDataset(test_files, test_labels), batch_size=BATCH_SIZE)

# --- 모델 정의 ---
model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 1)
model = model.to(DEVICE)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
save_path = os.path.join(os.path.dirname(os.getcwd()), "pth", "best_model_efficientnet_b0.pth")
os.makedirs(os.path.dirname(save_path), exist_ok=True)
best_val_acc = 0.0

# --- 학습 루프 ---
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        images, labels = images.to(DEVICE), labels.unsqueeze(1).to(DEVICE)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"[Epoch {epoch+1}] Loss: {total_loss / len(train_loader):.4f}")

    gc.collect()
    torch.cuda.empty_cache()

    # --- 검증 ---
    model.eval()
    correct = total = 0
    y_true, y_pred = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            probs = torch.sigmoid(model(images)).squeeze()
            preds = (probs > 0.5).long()
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
            correct += (preds == labels.long()).sum().item()
            total += labels.size(0)
    val_acc = correct / total
    print(f"Validation Accuracy: {val_acc:.4f}")
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), save_path)
        print("✅ Best model saved!")

# --- 테스트 성능 평가 ---
print("\n📊 Test Set Evaluation (Best Model 기준):")
model.load_state_dict(torch.load(save_path))
model.eval()
y_true, y_pred, y_probs = [], [], []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        probs = torch.sigmoid(model(images)).squeeze()
        preds = (probs > 0.5).long()
        y_probs.extend(probs.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())
        y_true.extend(labels.cpu().numpy())

test_acc = (np.array(y_pred) == np.array(y_true)).mean() * 100
print(f"✅ Test Accuracy: {test_acc:.2f}%")
print(classification_report(y_true, y_pred, digits=4))
try:
    auc_score = roc_auc_score(y_true, y_probs)
    print(f"AUC: {auc_score:.4f}")
except ValueError:
    print("AUC 계산 실패: 클래스가 모두 있어야 합니다.")
print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred))

Epoch 1/100: 100%|██████████| 234/234 [00:11<00:00, 21.07it/s]


[Epoch 1] Loss: 0.6041
Validation Accuracy: 0.7262
✅ Best model saved!


Epoch 2/100: 100%|██████████| 234/234 [00:10<00:00, 21.99it/s]


[Epoch 2] Loss: 0.4676
Validation Accuracy: 0.7788
✅ Best model saved!


Epoch 3/100: 100%|██████████| 234/234 [00:11<00:00, 20.47it/s]


[Epoch 3] Loss: 0.3459
Validation Accuracy: 0.7987
✅ Best model saved!


Epoch 4/100: 100%|██████████| 234/234 [00:11<00:00, 20.74it/s]


[Epoch 4] Loss: 0.2334
Validation Accuracy: 0.7950


Epoch 5/100: 100%|██████████| 234/234 [00:11<00:00, 20.75it/s]


[Epoch 5] Loss: 0.1822
Validation Accuracy: 0.7725


Epoch 6/100: 100%|██████████| 234/234 [00:11<00:00, 20.76it/s]


[Epoch 6] Loss: 0.1383
Validation Accuracy: 0.8225
✅ Best model saved!


Epoch 7/100: 100%|██████████| 234/234 [00:11<00:00, 20.67it/s]


[Epoch 7] Loss: 0.1347
Validation Accuracy: 0.8325
✅ Best model saved!


Epoch 8/100: 100%|██████████| 234/234 [00:11<00:00, 20.63it/s]


[Epoch 8] Loss: 0.1107
Validation Accuracy: 0.8400
✅ Best model saved!


Epoch 9/100: 100%|██████████| 234/234 [00:10<00:00, 21.69it/s]


[Epoch 9] Loss: 0.0922
Validation Accuracy: 0.8438
✅ Best model saved!


Epoch 10/100: 100%|██████████| 234/234 [00:10<00:00, 21.64it/s]


[Epoch 10] Loss: 0.0744
Validation Accuracy: 0.8562
✅ Best model saved!


Epoch 11/100: 100%|██████████| 234/234 [00:11<00:00, 20.78it/s]


[Epoch 11] Loss: 0.0686
Validation Accuracy: 0.8462


Epoch 12/100: 100%|██████████| 234/234 [00:11<00:00, 20.86it/s]


[Epoch 12] Loss: 0.0802
Validation Accuracy: 0.8350


Epoch 13/100: 100%|██████████| 234/234 [00:11<00:00, 20.80it/s]


[Epoch 13] Loss: 0.0695
Validation Accuracy: 0.8400


Epoch 14/100: 100%|██████████| 234/234 [00:11<00:00, 20.79it/s]


[Epoch 14] Loss: 0.0583
Validation Accuracy: 0.8562


Epoch 15/100: 100%|██████████| 234/234 [00:11<00:00, 20.99it/s]


[Epoch 15] Loss: 0.0639
Validation Accuracy: 0.8625
✅ Best model saved!


Epoch 16/100: 100%|██████████| 234/234 [00:11<00:00, 20.99it/s]


[Epoch 16] Loss: 0.0455
Validation Accuracy: 0.8575


Epoch 17/100: 100%|██████████| 234/234 [00:10<00:00, 21.45it/s]


[Epoch 17] Loss: 0.0355
Validation Accuracy: 0.8413


Epoch 18/100: 100%|██████████| 234/234 [00:10<00:00, 22.67it/s]


[Epoch 18] Loss: 0.0344
Validation Accuracy: 0.8375


Epoch 19/100: 100%|██████████| 234/234 [00:11<00:00, 20.97it/s]


[Epoch 19] Loss: 0.0632
Validation Accuracy: 0.8400


Epoch 20/100: 100%|██████████| 234/234 [00:11<00:00, 21.06it/s]


[Epoch 20] Loss: 0.0492
Validation Accuracy: 0.8525


Epoch 21/100: 100%|██████████| 234/234 [00:11<00:00, 20.95it/s]


[Epoch 21] Loss: 0.0377
Validation Accuracy: 0.8538


Epoch 22/100: 100%|██████████| 234/234 [00:11<00:00, 21.01it/s]


[Epoch 22] Loss: 0.0510
Validation Accuracy: 0.8375


Epoch 23/100: 100%|██████████| 234/234 [00:11<00:00, 20.93it/s]


[Epoch 23] Loss: 0.0343
Validation Accuracy: 0.8550


Epoch 24/100: 100%|██████████| 234/234 [00:11<00:00, 21.09it/s]


[Epoch 24] Loss: 0.0490
Validation Accuracy: 0.8075


Epoch 25/100: 100%|██████████| 234/234 [00:11<00:00, 20.77it/s]


[Epoch 25] Loss: 0.0465
Validation Accuracy: 0.8425


Epoch 26/100: 100%|██████████| 234/234 [00:10<00:00, 23.04it/s]


[Epoch 26] Loss: 0.0425
Validation Accuracy: 0.8400


Epoch 27/100: 100%|██████████| 234/234 [00:11<00:00, 20.94it/s]


[Epoch 27] Loss: 0.0437
Validation Accuracy: 0.8363


Epoch 28/100: 100%|██████████| 234/234 [00:11<00:00, 20.76it/s]


[Epoch 28] Loss: 0.0286
Validation Accuracy: 0.8488


Epoch 29/100: 100%|██████████| 234/234 [00:11<00:00, 20.91it/s]


[Epoch 29] Loss: 0.0230
Validation Accuracy: 0.8500


Epoch 30/100: 100%|██████████| 234/234 [00:11<00:00, 20.87it/s]


[Epoch 30] Loss: 0.0475
Validation Accuracy: 0.8325


Epoch 31/100: 100%|██████████| 234/234 [00:11<00:00, 20.95it/s]


[Epoch 31] Loss: 0.0362
Validation Accuracy: 0.8475


Epoch 32/100: 100%|██████████| 234/234 [00:11<00:00, 20.99it/s]


[Epoch 32] Loss: 0.0267
Validation Accuracy: 0.8438


Epoch 33/100: 100%|██████████| 234/234 [00:11<00:00, 20.93it/s]


[Epoch 33] Loss: 0.0295
Validation Accuracy: 0.8387


Epoch 34/100: 100%|██████████| 234/234 [00:10<00:00, 22.18it/s]


[Epoch 34] Loss: 0.0440
Validation Accuracy: 0.8575


Epoch 35/100: 100%|██████████| 234/234 [00:10<00:00, 21.44it/s]


[Epoch 35] Loss: 0.0308
Validation Accuracy: 0.8562


Epoch 36/100: 100%|██████████| 234/234 [00:11<00:00, 20.91it/s]


[Epoch 36] Loss: 0.0252
Validation Accuracy: 0.8588


Epoch 37/100: 100%|██████████| 234/234 [00:11<00:00, 20.95it/s]


[Epoch 37] Loss: 0.0229
Validation Accuracy: 0.8600


Epoch 38/100: 100%|██████████| 234/234 [00:11<00:00, 21.02it/s]


[Epoch 38] Loss: 0.0286
Validation Accuracy: 0.8325


Epoch 39/100: 100%|██████████| 234/234 [00:11<00:00, 20.93it/s]


[Epoch 39] Loss: 0.0208
Validation Accuracy: 0.8625


Epoch 40/100: 100%|██████████| 234/234 [00:11<00:00, 20.85it/s]


[Epoch 40] Loss: 0.0284
Validation Accuracy: 0.8562


Epoch 41/100: 100%|██████████| 234/234 [00:11<00:00, 21.02it/s]


[Epoch 41] Loss: 0.0271
Validation Accuracy: 0.8512


Epoch 42/100: 100%|██████████| 234/234 [00:10<00:00, 21.39it/s]


[Epoch 42] Loss: 0.0385
Validation Accuracy: 0.8538


Epoch 43/100: 100%|██████████| 234/234 [00:10<00:00, 22.89it/s]


[Epoch 43] Loss: 0.0557
Validation Accuracy: 0.8688
✅ Best model saved!


Epoch 44/100: 100%|██████████| 234/234 [00:11<00:00, 20.93it/s]


[Epoch 44] Loss: 0.0355
Validation Accuracy: 0.8600


Epoch 45/100: 100%|██████████| 234/234 [00:11<00:00, 20.83it/s]


[Epoch 45] Loss: 0.0362
Validation Accuracy: 0.8400


Epoch 46/100: 100%|██████████| 234/234 [00:11<00:00, 20.88it/s]


[Epoch 46] Loss: 0.0188
Validation Accuracy: 0.8600


Epoch 47/100: 100%|██████████| 234/234 [00:11<00:00, 20.83it/s]


[Epoch 47] Loss: 0.0186
Validation Accuracy: 0.8538


Epoch 48/100: 100%|██████████| 234/234 [00:11<00:00, 20.83it/s]


[Epoch 48] Loss: 0.0189
Validation Accuracy: 0.8538


Epoch 49/100: 100%|██████████| 234/234 [00:11<00:00, 20.91it/s]


[Epoch 49] Loss: 0.0200
Validation Accuracy: 0.8462


Epoch 50/100: 100%|██████████| 234/234 [00:11<00:00, 20.95it/s]


[Epoch 50] Loss: 0.0231
Validation Accuracy: 0.8650


Epoch 51/100: 100%|██████████| 234/234 [00:09<00:00, 23.48it/s]


[Epoch 51] Loss: 0.0147
Validation Accuracy: 0.8650


Epoch 52/100: 100%|██████████| 234/234 [00:11<00:00, 20.95it/s]


[Epoch 52] Loss: 0.0248
Validation Accuracy: 0.8550


Epoch 53/100: 100%|██████████| 234/234 [00:11<00:00, 20.86it/s]


[Epoch 53] Loss: 0.0195
Validation Accuracy: 0.8575


Epoch 54/100: 100%|██████████| 234/234 [00:11<00:00, 20.75it/s]


[Epoch 54] Loss: 0.0232
Validation Accuracy: 0.8638


Epoch 55/100: 100%|██████████| 234/234 [00:11<00:00, 20.80it/s]


[Epoch 55] Loss: 0.0217
Validation Accuracy: 0.8450


Epoch 56/100: 100%|██████████| 234/234 [00:11<00:00, 20.95it/s]


[Epoch 56] Loss: 0.0294
Validation Accuracy: 0.8700
✅ Best model saved!


Epoch 57/100: 100%|██████████| 234/234 [00:11<00:00, 20.90it/s]


[Epoch 57] Loss: 0.0370
Validation Accuracy: 0.8675


Epoch 58/100: 100%|██████████| 234/234 [00:11<00:00, 20.71it/s]


[Epoch 58] Loss: 0.0138
Validation Accuracy: 0.8600


Epoch 59/100: 100%|██████████| 234/234 [00:10<00:00, 22.67it/s]


[Epoch 59] Loss: 0.0160
Validation Accuracy: 0.8600


Epoch 60/100: 100%|██████████| 234/234 [00:10<00:00, 21.33it/s]


[Epoch 60] Loss: 0.0791
Validation Accuracy: 0.8600


Epoch 61/100: 100%|██████████| 234/234 [00:11<00:00, 20.86it/s]


[Epoch 61] Loss: 0.0426
Validation Accuracy: 0.8675


Epoch 62/100: 100%|██████████| 234/234 [00:11<00:00, 20.89it/s]


[Epoch 62] Loss: 0.0162
Validation Accuracy: 0.8588


Epoch 63/100: 100%|██████████| 234/234 [00:11<00:00, 20.86it/s]


[Epoch 63] Loss: 0.0118
Validation Accuracy: 0.8600


Epoch 64/100: 100%|██████████| 234/234 [00:11<00:00, 20.95it/s]


[Epoch 64] Loss: 0.0167
Validation Accuracy: 0.8725
✅ Best model saved!


Epoch 65/100: 100%|██████████| 234/234 [00:11<00:00, 21.13it/s]


[Epoch 65] Loss: 0.0152
Validation Accuracy: 0.8650


Epoch 66/100: 100%|██████████| 234/234 [00:11<00:00, 21.14it/s]


[Epoch 66] Loss: 0.0112
Validation Accuracy: 0.8688


Epoch 67/100: 100%|██████████| 234/234 [00:10<00:00, 21.82it/s]


[Epoch 67] Loss: 0.0097
Validation Accuracy: 0.8800
✅ Best model saved!


Epoch 68/100: 100%|██████████| 234/234 [00:10<00:00, 22.48it/s]


[Epoch 68] Loss: 0.0092
Validation Accuracy: 0.8812
✅ Best model saved!


Epoch 69/100: 100%|██████████| 234/234 [00:11<00:00, 20.64it/s]


[Epoch 69] Loss: 0.0237
Validation Accuracy: 0.8462


Epoch 70/100: 100%|██████████| 234/234 [00:11<00:00, 20.93it/s]


[Epoch 70] Loss: 0.0306
Validation Accuracy: 0.8662


Epoch 71/100: 100%|██████████| 234/234 [00:11<00:00, 20.60it/s]


[Epoch 71] Loss: 0.0124
Validation Accuracy: 0.8675


Epoch 72/100: 100%|██████████| 234/234 [00:11<00:00, 21.04it/s]


[Epoch 72] Loss: 0.0104
Validation Accuracy: 0.8675


Epoch 73/100: 100%|██████████| 234/234 [00:11<00:00, 20.92it/s]


[Epoch 73] Loss: 0.0149
Validation Accuracy: 0.8862
✅ Best model saved!


Epoch 74/100: 100%|██████████| 234/234 [00:11<00:00, 20.84it/s]


[Epoch 74] Loss: 0.0162
Validation Accuracy: 0.8612


Epoch 75/100: 100%|██████████| 234/234 [00:11<00:00, 20.99it/s]


[Epoch 75] Loss: 0.0146
Validation Accuracy: 0.8675


Epoch 76/100: 100%|██████████| 234/234 [00:10<00:00, 23.15it/s]


[Epoch 76] Loss: 0.0137
Validation Accuracy: 0.8638


Epoch 77/100: 100%|██████████| 234/234 [00:11<00:00, 20.89it/s]


[Epoch 77] Loss: 0.0345
Validation Accuracy: 0.8500


Epoch 78/100: 100%|██████████| 234/234 [00:11<00:00, 20.90it/s]


[Epoch 78] Loss: 0.0169
Validation Accuracy: 0.8700


Epoch 79/100: 100%|██████████| 234/234 [00:11<00:00, 21.01it/s]


[Epoch 79] Loss: 0.0115
Validation Accuracy: 0.8562


Epoch 80/100: 100%|██████████| 234/234 [00:11<00:00, 20.82it/s]


[Epoch 80] Loss: 0.0154
Validation Accuracy: 0.8650


Epoch 81/100: 100%|██████████| 234/234 [00:11<00:00, 20.87it/s]


[Epoch 81] Loss: 0.0086
Validation Accuracy: 0.8462


Epoch 82/100: 100%|██████████| 234/234 [00:11<00:00, 20.94it/s]


[Epoch 82] Loss: 0.0204
Validation Accuracy: 0.8575


Epoch 83/100: 100%|██████████| 234/234 [00:11<00:00, 20.94it/s]


[Epoch 83] Loss: 0.0206
Validation Accuracy: 0.8700


Epoch 84/100: 100%|██████████| 234/234 [00:10<00:00, 22.91it/s]


[Epoch 84] Loss: 0.0077
Validation Accuracy: 0.8675


Epoch 85/100: 100%|██████████| 234/234 [00:11<00:00, 21.19it/s]


[Epoch 85] Loss: 0.0084
Validation Accuracy: 0.8725


Epoch 86/100: 100%|██████████| 234/234 [00:11<00:00, 21.02it/s]


[Epoch 86] Loss: 0.0058
Validation Accuracy: 0.8712


Epoch 87/100: 100%|██████████| 234/234 [00:11<00:00, 21.05it/s]


[Epoch 87] Loss: 0.0237
Validation Accuracy: 0.8525


Epoch 88/100: 100%|██████████| 234/234 [00:11<00:00, 20.84it/s]


[Epoch 88] Loss: 0.0573
Validation Accuracy: 0.8550


Epoch 89/100: 100%|██████████| 234/234 [00:11<00:00, 20.86it/s]


[Epoch 89] Loss: 0.0126
Validation Accuracy: 0.8425


Epoch 90/100: 100%|██████████| 234/234 [00:11<00:00, 21.05it/s]


[Epoch 90] Loss: 0.0084
Validation Accuracy: 0.8525


Epoch 91/100: 100%|██████████| 234/234 [00:11<00:00, 20.84it/s]


[Epoch 91] Loss: 0.0078
Validation Accuracy: 0.8550


Epoch 92/100: 100%|██████████| 234/234 [00:10<00:00, 22.08it/s]


[Epoch 92] Loss: 0.0514
Validation Accuracy: 0.8538


Epoch 93/100: 100%|██████████| 234/234 [00:10<00:00, 21.90it/s]


[Epoch 93] Loss: 0.0271
Validation Accuracy: 0.8538


Epoch 94/100: 100%|██████████| 234/234 [00:11<00:00, 21.02it/s]


[Epoch 94] Loss: 0.0147
Validation Accuracy: 0.8550


Epoch 95/100: 100%|██████████| 234/234 [00:11<00:00, 20.67it/s]


[Epoch 95] Loss: 0.0114
Validation Accuracy: 0.8462


Epoch 96/100: 100%|██████████| 234/234 [00:11<00:00, 20.94it/s]


[Epoch 96] Loss: 0.0141
Validation Accuracy: 0.8488


Epoch 97/100: 100%|██████████| 234/234 [00:11<00:00, 20.96it/s]


[Epoch 97] Loss: 0.0217
Validation Accuracy: 0.8700


Epoch 98/100: 100%|██████████| 234/234 [00:11<00:00, 20.85it/s]


[Epoch 98] Loss: 0.0116
Validation Accuracy: 0.8662


Epoch 99/100: 100%|██████████| 234/234 [00:11<00:00, 20.93it/s]


[Epoch 99] Loss: 0.0127
Validation Accuracy: 0.8588


Epoch 100/100: 100%|██████████| 234/234 [00:10<00:00, 21.35it/s]


[Epoch 100] Loss: 0.0102
Validation Accuracy: 0.8738

📊 Test Set Evaluation (Best Model 기준):
✅ Test Accuracy: 85.12%
              precision    recall  f1-score   support

         0.0     0.7727    0.8036    0.7879       275
         1.0     0.8949    0.8762    0.8855       525

    accuracy                         0.8512       800
   macro avg     0.8338    0.8399    0.8367       800
weighted avg     0.8529    0.8512    0.8519       800

AUC: 0.9049
Confusion Matrix:
[[221  54]
 [ 65 460]]
