In [1]:
import os
import numpy as np
from glob import glob
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F
import cv2
from tqdm import tqdm
import timm  # pip install timm

# --- 설정 ---
SLICE_ROOT = "/data1/lidc-idri/slices"
BATCH_SIZE = 16
NUM_EPOCHS = 100
LR = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 라벨 추출 함수 ---
def extract_label_from_filename(filename):
    try:
        score = int(filename.split("_")[-1].replace(".npy", ""))
        if score == 3:
            return None
        return 1 if score >= 4 else 0
    except:
        return None

# --- 3-way split ---
all_files = glob(os.path.join(SLICE_ROOT, "LIDC-IDRI-*", "*.npy"))
file_label_pairs = [(f, extract_label_from_filename(f)) for f in all_files]
file_label_pairs = [(f, l) for f, l in file_label_pairs if l is not None]
files, labels = zip(*file_label_pairs)

train_files, temp_files, train_labels, temp_labels = train_test_split(files, labels, test_size=0.3, random_state=42)
val_files, test_files, val_labels, test_labels = train_test_split(temp_files, temp_labels, test_size=0.5, random_state=42)

# --- Dataset ---
class LIDCDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __getitem__(self, idx):
        img = np.load(self.file_paths[idx]).astype(np.float32)
        img = np.clip(img, -1000, 400)
        img = (img + 1000) / 1400.0
        img = cv2.resize(img, (224, 224))
        img = np.expand_dims(img, 0)
        img_tensor = torch.tensor(img)
        if self.transform:
            img_tensor = self.transform(img_tensor)
        label = torch.tensor(self.labels[idx]).float()
        return img_tensor, label, self.file_paths[idx]

    def __len__(self):
        return len(self.file_paths)

# --- Transform ---
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# --- DataLoader ---
train_loader = DataLoader(LIDCDataset(train_files, train_labels, transform=transform), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(LIDCDataset(val_files, val_labels, transform=transform), batch_size=BATCH_SIZE)
test_loader = DataLoader(LIDCDataset(test_files, test_labels, transform=transform), batch_size=BATCH_SIZE)

# --- 모델 정의 ---
model = timm.create_model("efficientnetv2_s", pretrained=False, in_chans=1, num_classes=1)
model.to(DEVICE)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
save_path = os.path.join(os.path.dirname(os.getcwd()), "pth", "best_model_effnetv2s.pth")
os.makedirs(os.path.dirname(save_path), exist_ok=True)

# --- 학습 ---
best_val_acc = 0.0
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    for images, labels, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        images, labels = images.to(DEVICE), labels.unsqueeze(1).to(DEVICE)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"[Epoch {epoch+1}] Loss: {total_loss / len(train_loader):.4f}")

    gc.collect()
    torch.cuda.empty_cache()

    # --- 검증 ---
    model.eval()
    correct, total = 0, 0
    y_true, y_pred = [], []
    with torch.no_grad():
        for images, labels, _ in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images).squeeze()
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).long()
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
            correct += (preds == labels.long()).sum().item()
            total += labels.size(0)
    val_acc = correct / total
    print(f"Validation Accuracy: {val_acc:.4f}")
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), save_path)
        print("✅ Best model saved!")

# --- 테스트 ---
print("\n📊 Test Set Evaluation (Best Model 기준):")
model.load_state_dict(torch.load(save_path))
model.eval()
y_true, y_pred, y_probs = [], [], []
with torch.no_grad():
    for images, labels, _ in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images).squeeze()
        probs = torch.sigmoid(outputs)
        preds = (probs > 0.5).long()
        y_probs.extend(probs.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())
        y_true.extend(labels.cpu().numpy())

test_acc = (np.array(y_pred) == np.array(y_true)).mean() * 100
print(f"✅ Test Accuracy: {test_acc:.2f}%")
print(classification_report(y_true, y_pred, digits=4))

try:
    auc_score = roc_auc_score(y_true, y_probs)
    print(f"AUC: {auc_score:.4f}")
except ValueError:
    print("AUC 계산 실패: 클래스가 모두 있어야 합니다.")

cm = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:")
print(cm)

  from .autonotebook import tqdm as notebook_tqdm
Epoch 1/100: 100%|██████████| 234/234 [00:30<00:00,  7.71it/s]


[Epoch 1] Loss: 0.8895
Validation Accuracy: 0.6813
✅ Best model saved!


Epoch 2/100: 100%|██████████| 234/234 [00:28<00:00,  8.21it/s]


[Epoch 2] Loss: 0.7425
Validation Accuracy: 0.6488


Epoch 3/100: 100%|██████████| 234/234 [00:27<00:00,  8.40it/s]


[Epoch 3] Loss: 0.7203
Validation Accuracy: 0.6000


Epoch 4/100: 100%|██████████| 234/234 [00:29<00:00,  7.97it/s]


[Epoch 4] Loss: 0.6944
Validation Accuracy: 0.6725


Epoch 5/100: 100%|██████████| 234/234 [00:28<00:00,  8.25it/s]


[Epoch 5] Loss: 0.6736
Validation Accuracy: 0.6763


Epoch 6/100: 100%|██████████| 234/234 [00:30<00:00,  7.78it/s]


[Epoch 6] Loss: 0.6134
Validation Accuracy: 0.6575


Epoch 7/100: 100%|██████████| 234/234 [00:27<00:00,  8.59it/s]


[Epoch 7] Loss: 0.5254
Validation Accuracy: 0.6813


Epoch 8/100: 100%|██████████| 234/234 [00:27<00:00,  8.40it/s]


[Epoch 8] Loss: 0.4336
Validation Accuracy: 0.7312
✅ Best model saved!


Epoch 9/100: 100%|██████████| 234/234 [00:28<00:00,  8.12it/s]


[Epoch 9] Loss: 0.2960
Validation Accuracy: 0.7037


Epoch 10/100: 100%|██████████| 234/234 [00:28<00:00,  8.18it/s]


[Epoch 10] Loss: 0.1900
Validation Accuracy: 0.7500
✅ Best model saved!


Epoch 11/100: 100%|██████████| 234/234 [00:27<00:00,  8.45it/s]


[Epoch 11] Loss: 0.1807
Validation Accuracy: 0.7350


Epoch 12/100: 100%|██████████| 234/234 [00:30<00:00,  7.73it/s]


[Epoch 12] Loss: 0.1358
Validation Accuracy: 0.7863
✅ Best model saved!


Epoch 13/100: 100%|██████████| 234/234 [00:28<00:00,  8.28it/s]


[Epoch 13] Loss: 0.1084
Validation Accuracy: 0.7750


Epoch 14/100: 100%|██████████| 234/234 [00:28<00:00,  8.16it/s]


[Epoch 14] Loss: 0.1436
Validation Accuracy: 0.7925
✅ Best model saved!


Epoch 15/100: 100%|██████████| 234/234 [00:27<00:00,  8.43it/s]


[Epoch 15] Loss: 0.1174
Validation Accuracy: 0.8150
✅ Best model saved!


Epoch 16/100: 100%|██████████| 234/234 [00:28<00:00,  8.32it/s]


[Epoch 16] Loss: 0.1270
Validation Accuracy: 0.7887


Epoch 17/100: 100%|██████████| 234/234 [00:29<00:00,  7.89it/s]


[Epoch 17] Loss: 0.1455
Validation Accuracy: 0.7925


Epoch 18/100: 100%|██████████| 234/234 [00:27<00:00,  8.44it/s]


[Epoch 18] Loss: 0.0947
Validation Accuracy: 0.8363
✅ Best model saved!


Epoch 19/100: 100%|██████████| 234/234 [00:27<00:00,  8.38it/s]


[Epoch 19] Loss: 0.0887
Validation Accuracy: 0.8187


Epoch 20/100: 100%|██████████| 234/234 [00:30<00:00,  7.56it/s]


[Epoch 20] Loss: 0.1015
Validation Accuracy: 0.7975


Epoch 21/100: 100%|██████████| 234/234 [00:28<00:00,  8.25it/s]


[Epoch 21] Loss: 0.0906
Validation Accuracy: 0.8237


Epoch 22/100: 100%|██████████| 234/234 [00:29<00:00,  7.92it/s]


[Epoch 22] Loss: 0.0910
Validation Accuracy: 0.8125


Epoch 23/100: 100%|██████████| 234/234 [00:31<00:00,  7.41it/s]


[Epoch 23] Loss: 0.0874
Validation Accuracy: 0.8363


Epoch 24/100: 100%|██████████| 234/234 [00:29<00:00,  7.82it/s]


[Epoch 24] Loss: 0.0915
Validation Accuracy: 0.8300


Epoch 25/100: 100%|██████████| 234/234 [00:26<00:00,  8.79it/s]


[Epoch 25] Loss: 0.0674
Validation Accuracy: 0.8350


Epoch 26/100: 100%|██████████| 234/234 [00:30<00:00,  7.70it/s]


[Epoch 26] Loss: 0.0494
Validation Accuracy: 0.8438
✅ Best model saved!


Epoch 27/100: 100%|██████████| 234/234 [00:27<00:00,  8.61it/s]


[Epoch 27] Loss: 0.0506
Validation Accuracy: 0.8425


Epoch 28/100: 100%|██████████| 234/234 [00:28<00:00,  8.31it/s]


[Epoch 28] Loss: 0.0558
Validation Accuracy: 0.8300


Epoch 29/100: 100%|██████████| 234/234 [00:30<00:00,  7.78it/s]


[Epoch 29] Loss: 0.0700
Validation Accuracy: 0.8363


Epoch 30/100: 100%|██████████| 234/234 [00:29<00:00,  8.05it/s]


[Epoch 30] Loss: 0.0710
Validation Accuracy: 0.8450
✅ Best model saved!


Epoch 31/100: 100%|██████████| 234/234 [00:31<00:00,  7.48it/s]


[Epoch 31] Loss: 0.0468
Validation Accuracy: 0.8325


Epoch 32/100: 100%|██████████| 234/234 [00:32<00:00,  7.29it/s]


[Epoch 32] Loss: 0.0364
Validation Accuracy: 0.8425


Epoch 33/100: 100%|██████████| 234/234 [00:30<00:00,  7.72it/s]


[Epoch 33] Loss: 0.0375
Validation Accuracy: 0.8050


Epoch 34/100: 100%|██████████| 234/234 [00:31<00:00,  7.53it/s]


[Epoch 34] Loss: 0.1265
Validation Accuracy: 0.8313


Epoch 35/100: 100%|██████████| 234/234 [00:29<00:00,  7.89it/s]


[Epoch 35] Loss: 0.1089
Validation Accuracy: 0.8350


Epoch 36/100: 100%|██████████| 234/234 [00:31<00:00,  7.32it/s]


[Epoch 36] Loss: 0.0343
Validation Accuracy: 0.8350


Epoch 37/100: 100%|██████████| 234/234 [00:31<00:00,  7.55it/s]


[Epoch 37] Loss: 0.0194
Validation Accuracy: 0.8600
✅ Best model saved!


Epoch 38/100: 100%|██████████| 234/234 [00:32<00:00,  7.31it/s]


[Epoch 38] Loss: 0.0254
Validation Accuracy: 0.8425


Epoch 39/100: 100%|██████████| 234/234 [00:29<00:00,  7.84it/s]


[Epoch 39] Loss: 0.0146
Validation Accuracy: 0.8550


Epoch 40/100: 100%|██████████| 234/234 [00:30<00:00,  7.64it/s]


[Epoch 40] Loss: 0.0281
Validation Accuracy: 0.8175


Epoch 41/100: 100%|██████████| 234/234 [00:32<00:00,  7.15it/s]


[Epoch 41] Loss: 0.0802
Validation Accuracy: 0.8363


Epoch 42/100: 100%|██████████| 234/234 [00:33<00:00,  7.06it/s]


[Epoch 42] Loss: 0.0966
Validation Accuracy: 0.8350


Epoch 43/100: 100%|██████████| 234/234 [00:31<00:00,  7.38it/s]


[Epoch 43] Loss: 0.0361
Validation Accuracy: 0.8538


Epoch 44/100: 100%|██████████| 234/234 [00:31<00:00,  7.43it/s]


[Epoch 44] Loss: 0.0262
Validation Accuracy: 0.8612
✅ Best model saved!


Epoch 45/100: 100%|██████████| 234/234 [00:30<00:00,  7.57it/s]


[Epoch 45] Loss: 0.0207
Validation Accuracy: 0.8438


Epoch 46/100: 100%|██████████| 234/234 [00:30<00:00,  7.66it/s]


[Epoch 46] Loss: 0.0252
Validation Accuracy: 0.8200


Epoch 47/100: 100%|██████████| 234/234 [00:32<00:00,  7.14it/s]


[Epoch 47] Loss: 0.0363
Validation Accuracy: 0.8413


Epoch 48/100: 100%|██████████| 234/234 [00:34<00:00,  6.87it/s]


[Epoch 48] Loss: 0.0720
Validation Accuracy: 0.8387


Epoch 49/100: 100%|██████████| 234/234 [00:34<00:00,  6.80it/s]


[Epoch 49] Loss: 0.0379
Validation Accuracy: 0.8600


Epoch 50/100: 100%|██████████| 234/234 [00:33<00:00,  7.06it/s]


[Epoch 50] Loss: 0.0100
Validation Accuracy: 0.8612


Epoch 51/100: 100%|██████████| 234/234 [00:31<00:00,  7.33it/s]


[Epoch 51] Loss: 0.0056
Validation Accuracy: 0.8688
✅ Best model saved!


Epoch 52/100: 100%|██████████| 234/234 [00:30<00:00,  7.62it/s]


[Epoch 52] Loss: 0.0037
Validation Accuracy: 0.8725
✅ Best model saved!


Epoch 53/100: 100%|██████████| 234/234 [00:30<00:00,  7.79it/s]


[Epoch 53] Loss: 0.0067
Validation Accuracy: 0.8688


Epoch 54/100: 100%|██████████| 234/234 [00:31<00:00,  7.51it/s]


[Epoch 54] Loss: 0.0043
Validation Accuracy: 0.8675


Epoch 55/100: 100%|██████████| 234/234 [00:30<00:00,  7.70it/s]


[Epoch 55] Loss: 0.0044
Validation Accuracy: 0.8700


Epoch 56/100: 100%|██████████| 234/234 [00:34<00:00,  6.88it/s]


[Epoch 56] Loss: 0.0039
Validation Accuracy: 0.8700


Epoch 57/100: 100%|██████████| 234/234 [00:33<00:00,  7.06it/s]


[Epoch 57] Loss: 0.0035
Validation Accuracy: 0.8700


Epoch 58/100: 100%|██████████| 234/234 [00:29<00:00,  7.84it/s]


[Epoch 58] Loss: 0.0030
Validation Accuracy: 0.8650


Epoch 59/100: 100%|██████████| 234/234 [00:31<00:00,  7.32it/s]


[Epoch 59] Loss: 0.0034
Validation Accuracy: 0.8700


Epoch 60/100: 100%|██████████| 234/234 [00:32<00:00,  7.27it/s]


[Epoch 60] Loss: 0.2223
Validation Accuracy: 0.7575


Epoch 61/100: 100%|██████████| 234/234 [00:31<00:00,  7.45it/s]


[Epoch 61] Loss: 0.1334
Validation Accuracy: 0.8600


Epoch 62/100: 100%|██████████| 234/234 [00:31<00:00,  7.38it/s]


[Epoch 62] Loss: 0.0403
Validation Accuracy: 0.8588


Epoch 63/100: 100%|██████████| 234/234 [00:35<00:00,  6.55it/s]


[Epoch 63] Loss: 0.0194
Validation Accuracy: 0.8512


Epoch 64/100: 100%|██████████| 234/234 [00:31<00:00,  7.40it/s]


[Epoch 64] Loss: 0.0090
Validation Accuracy: 0.8550


Epoch 65/100: 100%|██████████| 234/234 [00:34<00:00,  6.84it/s]


[Epoch 65] Loss: 0.0078
Validation Accuracy: 0.8450


Epoch 66/100: 100%|██████████| 234/234 [00:29<00:00,  7.90it/s]


[Epoch 66] Loss: 0.0095
Validation Accuracy: 0.8425


Epoch 67/100: 100%|██████████| 234/234 [00:32<00:00,  7.13it/s]


[Epoch 67] Loss: 0.0039
Validation Accuracy: 0.8512


Epoch 68/100: 100%|██████████| 234/234 [00:32<00:00,  7.29it/s]


[Epoch 68] Loss: 0.0030
Validation Accuracy: 0.8500


Epoch 69/100: 100%|██████████| 234/234 [00:32<00:00,  7.30it/s]


[Epoch 69] Loss: 0.0111
Validation Accuracy: 0.8263


Epoch 70/100: 100%|██████████| 234/234 [00:33<00:00,  7.01it/s]


[Epoch 70] Loss: 0.0441
Validation Accuracy: 0.8137


Epoch 71/100: 100%|██████████| 234/234 [00:31<00:00,  7.40it/s]


[Epoch 71] Loss: 0.0982
Validation Accuracy: 0.8325


Epoch 72/100: 100%|██████████| 234/234 [00:33<00:00,  7.00it/s]


[Epoch 72] Loss: 0.0331
Validation Accuracy: 0.8488


Epoch 73/100: 100%|██████████| 234/234 [00:32<00:00,  7.23it/s]


[Epoch 73] Loss: 0.0140
Validation Accuracy: 0.8650


Epoch 74/100: 100%|██████████| 234/234 [00:28<00:00,  8.32it/s]


[Epoch 74] Loss: 0.0092
Validation Accuracy: 0.8750
✅ Best model saved!


Epoch 75/100: 100%|██████████| 234/234 [00:30<00:00,  7.55it/s]


[Epoch 75] Loss: 0.0030
Validation Accuracy: 0.8712


Epoch 76/100: 100%|██████████| 234/234 [00:31<00:00,  7.40it/s]


[Epoch 76] Loss: 0.0031
Validation Accuracy: 0.8738


Epoch 77/100: 100%|██████████| 234/234 [00:29<00:00,  7.93it/s]


[Epoch 77] Loss: 0.0035
Validation Accuracy: 0.8750


Epoch 78/100: 100%|██████████| 234/234 [00:30<00:00,  7.56it/s]


[Epoch 78] Loss: 0.0810
Validation Accuracy: 0.8438


Epoch 79/100: 100%|██████████| 234/234 [00:34<00:00,  6.84it/s]


[Epoch 79] Loss: 0.0568
Validation Accuracy: 0.8450


Epoch 80/100: 100%|██████████| 234/234 [00:31<00:00,  7.42it/s]


[Epoch 80] Loss: 0.0256
Validation Accuracy: 0.8588


Epoch 81/100: 100%|██████████| 234/234 [00:31<00:00,  7.52it/s]


[Epoch 81] Loss: 0.0169
Validation Accuracy: 0.8775
✅ Best model saved!


Epoch 82/100: 100%|██████████| 234/234 [00:30<00:00,  7.66it/s]


[Epoch 82] Loss: 0.0037
Validation Accuracy: 0.8788
✅ Best model saved!


Epoch 83/100: 100%|██████████| 234/234 [00:34<00:00,  6.76it/s]


[Epoch 83] Loss: 0.0031
Validation Accuracy: 0.8862
✅ Best model saved!


Epoch 84/100: 100%|██████████| 234/234 [00:31<00:00,  7.43it/s]


[Epoch 84] Loss: 0.0026
Validation Accuracy: 0.8825


Epoch 85/100: 100%|██████████| 234/234 [00:38<00:00,  6.14it/s]


[Epoch 85] Loss: 0.0024
Validation Accuracy: 0.8825


Epoch 86/100: 100%|██████████| 234/234 [00:35<00:00,  6.64it/s]


[Epoch 86] Loss: 0.0024
Validation Accuracy: 0.8838


Epoch 87/100: 100%|██████████| 234/234 [00:35<00:00,  6.64it/s]


[Epoch 87] Loss: 0.0025
Validation Accuracy: 0.8762


Epoch 88/100: 100%|██████████| 234/234 [00:51<00:00,  4.58it/s]


[Epoch 88] Loss: 0.0027
Validation Accuracy: 0.8762


Epoch 89/100: 100%|██████████| 234/234 [01:29<00:00,  2.63it/s]


[Epoch 89] Loss: 0.0025
Validation Accuracy: 0.8800


Epoch 90/100: 100%|██████████| 234/234 [01:41<00:00,  2.30it/s]


[Epoch 90] Loss: 0.0026
Validation Accuracy: 0.8850


Epoch 91/100: 100%|██████████| 234/234 [01:32<00:00,  2.52it/s]


[Epoch 91] Loss: 0.0034
Validation Accuracy: 0.8812


Epoch 92/100: 100%|██████████| 234/234 [01:28<00:00,  2.65it/s]


[Epoch 92] Loss: 0.1818
Validation Accuracy: 0.8475


Epoch 93/100: 100%|██████████| 234/234 [00:34<00:00,  6.71it/s]


[Epoch 93] Loss: 0.0516
Validation Accuracy: 0.8500


Epoch 94/100: 100%|██████████| 234/234 [00:32<00:00,  7.16it/s]


[Epoch 94] Loss: 0.0192
Validation Accuracy: 0.8575


Epoch 95/100: 100%|██████████| 234/234 [00:33<00:00,  7.01it/s]


[Epoch 95] Loss: 0.0101
Validation Accuracy: 0.8575


Epoch 96/100: 100%|██████████| 234/234 [00:35<00:00,  6.64it/s]


[Epoch 96] Loss: 0.0043
Validation Accuracy: 0.8538


Epoch 97/100: 100%|██████████| 234/234 [00:31<00:00,  7.45it/s]


[Epoch 97] Loss: 0.0039
Validation Accuracy: 0.8575


Epoch 98/100: 100%|██████████| 234/234 [00:33<00:00,  6.93it/s]


[Epoch 98] Loss: 0.0027
Validation Accuracy: 0.8438


Epoch 99/100: 100%|██████████| 234/234 [00:31<00:00,  7.35it/s]


[Epoch 99] Loss: 0.0034
Validation Accuracy: 0.8575


Epoch 100/100: 100%|██████████| 234/234 [00:30<00:00,  7.57it/s]


[Epoch 100] Loss: 0.0031
Validation Accuracy: 0.8625

📊 Test Set Evaluation (Best Model 기준):
✅ Test Accuracy: 86.62%
              precision    recall  f1-score   support

         0.0     0.8387    0.7564    0.7954       275
         1.0     0.8786    0.9238    0.9006       525

    accuracy                         0.8662       800
   macro avg     0.8587    0.8401    0.8480       800
weighted avg     0.8649    0.8662    0.8645       800

AUC: 0.9155
Confusion Matrix:
[[208  67]
 [ 40 485]]
