In [2]:
# ResNet34 3ch

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from glob import glob
from tqdm import tqdm
from torchvision.models import resnet34, ResNet34_Weights

# 환경 설정
SLICE_ROOT = "/data1/lidc-idri/slices"
BATCH_SIZE = 16
NUM_EPOCHS = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 레이블 추출
def extract_label_from_filename(filename):
    try:
        score = int(filename.split("_")[-1].replace(".npy", ""))
        if score == 3: return None
        return 1 if score >= 4 else 0
    except:
        return None

# 파일 리스트 정리
all_files = glob(os.path.join(SLICE_ROOT, "LIDC-IDRI-*", "*.npy"))
file_label_pairs = [(f, extract_label_from_filename(f)) for f in all_files]
file_label_pairs = [(f, l) for f, l in file_label_pairs if l is not None]
files, labels = zip(*file_label_pairs)
train_files, val_files, train_labels, val_labels = train_test_split(files, labels, test_size=0.2, random_state=42)

# 3채널 Dataset 클래스
class LIDC3ChannelDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __getitem__(self, idx):
        center_path = self.file_paths[idx]
        label = self.labels[idx]
        folder = os.path.dirname(center_path)
        fname = os.path.basename(center_path)
        slice_num = int(fname.split("_")[1])
        suffix = fname.split("_")[-1]

        # 주변 3개 슬라이스 로딩
        slice_indices = [slice_num - 1, slice_num, slice_num + 1]
        images = []
        for sn in slice_indices:
            path = os.path.join(folder, f"slice_{sn:03d}_{suffix}")
            if os.path.exists(path):
                img = np.load(path).astype(np.float32)
            else:
                img = np.load(center_path).astype(np.float32)
            img = (img - img.min()) / (img.max() - img.min() + 1e-8)
            images.append(img)

        stacked = np.stack(images, axis=0)
        img_tensor = torch.tensor(stacked)
        img_tensor = F.interpolate(img_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
        if self.transform:
            img_tensor = self.transform(img_tensor)
        return img_tensor, torch.tensor(label).float(), center_path

    def __len__(self):
        return len(self.file_paths)

# Transform
val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Dataloader 구성
train_dataset = LIDC3ChannelDataset(train_files, train_labels, transform=val_transform)
val_dataset = LIDC3ChannelDataset(val_files, val_labels, transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# 모델 정의 (ResNet34)
model = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
model.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)  # 3채널용으로 수정
model.fc = nn.Linear(model.fc.in_features, 1)
model = model.to(DEVICE)

# 손실함수와 Optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# 학습 루프
best_acc = 0
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for imgs, labels, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        imgs, labels = imgs.to(DEVICE), labels.unsqueeze(1).to(DEVICE)
        optimizer.zero_grad()
        output = model(imgs)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"[Epoch {epoch+1}] Loss: {running_loss / len(train_loader):.4f}")

    # 검증
    model.eval()
    correct, total = 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, labels, _ in val_loader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            out = model(imgs).squeeze()
            probs = torch.sigmoid(out)
            preds = (probs > 0.5).long()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    val_acc = correct / total
    print(f"Validation Accuracy: {val_acc:.4f}")
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_model_resnet34_3ch.pth")
        print("✅ Best model saved!")

# 최종 리포트
print("\n📊 Classification Report:")
print(classification_report(all_labels, all_preds, digits=4))

Epoch 1/100: 100%|██████████| 267/267 [00:19<00:00, 13.98it/s]


[Epoch 1] Loss: 0.6277
Validation Accuracy: 0.6926
✅ Best model saved!


Epoch 2/100: 100%|██████████| 267/267 [00:18<00:00, 14.18it/s]


[Epoch 2] Loss: 0.4910
Validation Accuracy: 0.7348
✅ Best model saved!


Epoch 3/100: 100%|██████████| 267/267 [00:18<00:00, 14.20it/s]


[Epoch 3] Loss: 0.3482
Validation Accuracy: 0.7366
✅ Best model saved!


Epoch 4/100: 100%|██████████| 267/267 [00:19<00:00, 14.04it/s]


[Epoch 4] Loss: 0.2279
Validation Accuracy: 0.8547
✅ Best model saved!


Epoch 5/100: 100%|██████████| 267/267 [00:18<00:00, 14.21it/s]


[Epoch 5] Loss: 0.1512
Validation Accuracy: 0.8519


Epoch 6/100: 100%|██████████| 267/267 [00:18<00:00, 14.31it/s]


[Epoch 6] Loss: 0.1119
Validation Accuracy: 0.8688
✅ Best model saved!


Epoch 7/100: 100%|██████████| 267/267 [00:19<00:00, 13.96it/s]


[Epoch 7] Loss: 0.0797
Validation Accuracy: 0.8838
✅ Best model saved!


Epoch 8/100: 100%|██████████| 267/267 [00:19<00:00, 13.96it/s]


[Epoch 8] Loss: 0.0883
Validation Accuracy: 0.9063
✅ Best model saved!


Epoch 9/100: 100%|██████████| 267/267 [00:18<00:00, 14.13it/s]


[Epoch 9] Loss: 0.0553
Validation Accuracy: 0.8838


Epoch 10/100: 100%|██████████| 267/267 [00:19<00:00, 14.04it/s]


[Epoch 10] Loss: 0.0555
Validation Accuracy: 0.8538


Epoch 11/100: 100%|██████████| 267/267 [00:18<00:00, 14.20it/s]


[Epoch 11] Loss: 0.0498
Validation Accuracy: 0.8828


Epoch 12/100: 100%|██████████| 267/267 [00:19<00:00, 13.98it/s]


[Epoch 12] Loss: 0.0539
Validation Accuracy: 0.8632


Epoch 13/100: 100%|██████████| 267/267 [00:18<00:00, 14.11it/s]


[Epoch 13] Loss: 0.0343
Validation Accuracy: 0.8997


Epoch 14/100: 100%|██████████| 267/267 [00:18<00:00, 14.29it/s]


[Epoch 14] Loss: 0.0469
Validation Accuracy: 0.8866


Epoch 15/100: 100%|██████████| 267/267 [00:18<00:00, 14.09it/s]


[Epoch 15] Loss: 0.0177
Validation Accuracy: 0.8969


Epoch 16/100: 100%|██████████| 267/267 [00:18<00:00, 14.07it/s]


[Epoch 16] Loss: 0.0085
Validation Accuracy: 0.9138
✅ Best model saved!


Epoch 17/100: 100%|██████████| 267/267 [00:18<00:00, 14.33it/s]


[Epoch 17] Loss: 0.0541
Validation Accuracy: 0.8735


Epoch 18/100: 100%|██████████| 267/267 [00:18<00:00, 14.21it/s]


[Epoch 18] Loss: 0.0540
Validation Accuracy: 0.8857


Epoch 19/100: 100%|██████████| 267/267 [00:18<00:00, 14.12it/s]


[Epoch 19] Loss: 0.0196
Validation Accuracy: 0.9044


Epoch 20/100: 100%|██████████| 267/267 [00:18<00:00, 14.32it/s]


[Epoch 20] Loss: 0.0223
Validation Accuracy: 0.9053


Epoch 21/100: 100%|██████████| 267/267 [00:19<00:00, 13.90it/s]


[Epoch 21] Loss: 0.0387
Validation Accuracy: 0.8454


Epoch 22/100: 100%|██████████| 267/267 [00:18<00:00, 14.28it/s]


[Epoch 22] Loss: 0.0371
Validation Accuracy: 0.9016


Epoch 23/100: 100%|██████████| 267/267 [00:19<00:00, 14.03it/s]


[Epoch 23] Loss: 0.0078
Validation Accuracy: 0.9222
✅ Best model saved!


Epoch 24/100: 100%|██████████| 267/267 [00:18<00:00, 14.22it/s]


[Epoch 24] Loss: 0.0182
Validation Accuracy: 0.9147


Epoch 25/100: 100%|██████████| 267/267 [00:18<00:00, 14.07it/s]


[Epoch 25] Loss: 0.0316
Validation Accuracy: 0.8922


Epoch 26/100: 100%|██████████| 267/267 [00:18<00:00, 14.07it/s]


[Epoch 26] Loss: 0.0108
Validation Accuracy: 0.9100


Epoch 27/100: 100%|██████████| 267/267 [00:18<00:00, 14.08it/s]


[Epoch 27] Loss: 0.0205
Validation Accuracy: 0.8763


Epoch 28/100: 100%|██████████| 267/267 [00:19<00:00, 13.92it/s]


[Epoch 28] Loss: 0.0195
Validation Accuracy: 0.8847


Epoch 29/100: 100%|██████████| 267/267 [00:19<00:00, 14.02it/s]


[Epoch 29] Loss: 0.0310
Validation Accuracy: 0.9016


Epoch 30/100: 100%|██████████| 267/267 [00:18<00:00, 14.19it/s]


[Epoch 30] Loss: 0.0265
Validation Accuracy: 0.8932


Epoch 31/100: 100%|██████████| 267/267 [00:19<00:00, 14.03it/s]


[Epoch 31] Loss: 0.0161
Validation Accuracy: 0.8688


Epoch 32/100: 100%|██████████| 267/267 [00:18<00:00, 14.17it/s]


[Epoch 32] Loss: 0.0180
Validation Accuracy: 0.8997


Epoch 33/100: 100%|██████████| 267/267 [00:19<00:00, 13.80it/s]


[Epoch 33] Loss: 0.0076
Validation Accuracy: 0.9100


Epoch 34/100: 100%|██████████| 267/267 [00:18<00:00, 14.32it/s]


[Epoch 34] Loss: 0.0282
Validation Accuracy: 0.8810


Epoch 35/100: 100%|██████████| 267/267 [00:18<00:00, 14.05it/s]


[Epoch 35] Loss: 0.0170
Validation Accuracy: 0.8763


Epoch 36/100: 100%|██████████| 267/267 [00:19<00:00, 13.95it/s]


[Epoch 36] Loss: 0.0336
Validation Accuracy: 0.8941


Epoch 37/100: 100%|██████████| 267/267 [00:19<00:00, 14.03it/s]


[Epoch 37] Loss: 0.0254
Validation Accuracy: 0.9185


Epoch 38/100: 100%|██████████| 267/267 [00:18<00:00, 14.36it/s]


[Epoch 38] Loss: 0.0059
Validation Accuracy: 0.8894


Epoch 39/100: 100%|██████████| 267/267 [00:19<00:00, 13.88it/s]


[Epoch 39] Loss: 0.0117
Validation Accuracy: 0.9250
✅ Best model saved!


Epoch 40/100: 100%|██████████| 267/267 [00:19<00:00, 14.04it/s]


[Epoch 40] Loss: 0.0015
Validation Accuracy: 0.9231


Epoch 41/100: 100%|██████████| 267/267 [00:18<00:00, 14.25it/s]


[Epoch 41] Loss: 0.0016
Validation Accuracy: 0.9241


Epoch 42/100: 100%|██████████| 267/267 [00:18<00:00, 14.15it/s]


[Epoch 42] Loss: 0.0015
Validation Accuracy: 0.9250


Epoch 43/100: 100%|██████████| 267/267 [00:18<00:00, 14.17it/s]


[Epoch 43] Loss: 0.0013
Validation Accuracy: 0.9278
✅ Best model saved!


Epoch 44/100: 100%|██████████| 267/267 [00:19<00:00, 13.92it/s]


[Epoch 44] Loss: 0.0015
Validation Accuracy: 0.9278


Epoch 45/100: 100%|██████████| 267/267 [00:18<00:00, 14.29it/s]


[Epoch 45] Loss: 0.0013
Validation Accuracy: 0.9288
✅ Best model saved!


Epoch 46/100: 100%|██████████| 267/267 [00:18<00:00, 14.22it/s]


[Epoch 46] Loss: 0.0012
Validation Accuracy: 0.9278


Epoch 47/100: 100%|██████████| 267/267 [00:19<00:00, 13.82it/s]


[Epoch 47] Loss: 0.0014
Validation Accuracy: 0.9241


Epoch 48/100: 100%|██████████| 267/267 [00:18<00:00, 14.21it/s]


[Epoch 48] Loss: 0.0463
Validation Accuracy: 0.8022


Epoch 49/100: 100%|██████████| 267/267 [00:19<00:00, 13.81it/s]


[Epoch 49] Loss: 0.0658
Validation Accuracy: 0.9053


Epoch 50/100: 100%|██████████| 267/267 [00:18<00:00, 14.19it/s]


[Epoch 50] Loss: 0.0217
Validation Accuracy: 0.9072


Epoch 51/100: 100%|██████████| 267/267 [00:18<00:00, 14.21it/s]


[Epoch 51] Loss: 0.0141
Validation Accuracy: 0.8941


Epoch 52/100: 100%|██████████| 267/267 [00:18<00:00, 14.33it/s]


[Epoch 52] Loss: 0.0122
Validation Accuracy: 0.9082


Epoch 53/100: 100%|██████████| 267/267 [00:18<00:00, 14.07it/s]


[Epoch 53] Loss: 0.0100
Validation Accuracy: 0.9157


Epoch 54/100: 100%|██████████| 267/267 [00:18<00:00, 14.18it/s]


[Epoch 54] Loss: 0.0019
Validation Accuracy: 0.9175


Epoch 55/100: 100%|██████████| 267/267 [00:19<00:00, 14.04it/s]


[Epoch 55] Loss: 0.0406
Validation Accuracy: 0.9147


Epoch 56/100: 100%|██████████| 267/267 [00:19<00:00, 14.05it/s]


[Epoch 56] Loss: 0.0180
Validation Accuracy: 0.8997


Epoch 57/100: 100%|██████████| 267/267 [00:18<00:00, 14.14it/s]


[Epoch 57] Loss: 0.0124
Validation Accuracy: 0.9016


Epoch 58/100: 100%|██████████| 267/267 [00:18<00:00, 14.26it/s]


[Epoch 58] Loss: 0.0069
Validation Accuracy: 0.9138


Epoch 59/100: 100%|██████████| 267/267 [00:19<00:00, 13.92it/s]


[Epoch 59] Loss: 0.0038
Validation Accuracy: 0.9203


Epoch 60/100: 100%|██████████| 267/267 [00:18<00:00, 14.37it/s]


[Epoch 60] Loss: 0.0088
Validation Accuracy: 0.8857


Epoch 61/100: 100%|██████████| 267/267 [00:18<00:00, 14.15it/s]


[Epoch 61] Loss: 0.0156
Validation Accuracy: 0.9091


Epoch 62/100: 100%|██████████| 267/267 [00:18<00:00, 14.37it/s]


[Epoch 62] Loss: 0.0107
Validation Accuracy: 0.8960


Epoch 63/100: 100%|██████████| 267/267 [00:18<00:00, 14.19it/s]


[Epoch 63] Loss: 0.0280
Validation Accuracy: 0.9222


Epoch 64/100: 100%|██████████| 267/267 [00:19<00:00, 13.94it/s]


[Epoch 64] Loss: 0.0017
Validation Accuracy: 0.9231


Epoch 65/100: 100%|██████████| 267/267 [00:19<00:00, 14.05it/s]


[Epoch 65] Loss: 0.0012
Validation Accuracy: 0.9241


Epoch 66/100: 100%|██████████| 267/267 [00:18<00:00, 14.11it/s]


[Epoch 66] Loss: 0.0013
Validation Accuracy: 0.9250


Epoch 67/100: 100%|██████████| 267/267 [00:18<00:00, 14.06it/s]


[Epoch 67] Loss: 0.0011
Validation Accuracy: 0.9269


Epoch 68/100: 100%|██████████| 267/267 [00:18<00:00, 14.18it/s]


[Epoch 68] Loss: 0.0013
Validation Accuracy: 0.9297
✅ Best model saved!


Epoch 69/100: 100%|██████████| 267/267 [00:18<00:00, 14.21it/s]


[Epoch 69] Loss: 0.0013
Validation Accuracy: 0.9278


Epoch 70/100: 100%|██████████| 267/267 [00:18<00:00, 14.46it/s]


[Epoch 70] Loss: 0.0011
Validation Accuracy: 0.9269


Epoch 71/100: 100%|██████████| 267/267 [00:19<00:00, 14.01it/s]


[Epoch 71] Loss: 0.0014
Validation Accuracy: 0.9269


Epoch 72/100: 100%|██████████| 267/267 [00:18<00:00, 14.16it/s]


[Epoch 72] Loss: 0.0012
Validation Accuracy: 0.9260


Epoch 73/100: 100%|██████████| 267/267 [00:19<00:00, 13.78it/s]


[Epoch 73] Loss: 0.0012
Validation Accuracy: 0.9231


Epoch 74/100: 100%|██████████| 267/267 [00:19<00:00, 13.93it/s]


[Epoch 74] Loss: 0.0011
Validation Accuracy: 0.9250


Epoch 75/100: 100%|██████████| 267/267 [00:19<00:00, 13.99it/s]


[Epoch 75] Loss: 0.0012
Validation Accuracy: 0.9278


Epoch 76/100: 100%|██████████| 267/267 [00:18<00:00, 14.11it/s]


[Epoch 76] Loss: 0.0012
Validation Accuracy: 0.9260


Epoch 77/100: 100%|██████████| 267/267 [00:18<00:00, 14.31it/s]


[Epoch 77] Loss: 0.0011
Validation Accuracy: 0.9260


Epoch 78/100: 100%|██████████| 267/267 [00:18<00:00, 14.24it/s]


[Epoch 78] Loss: 0.0011
Validation Accuracy: 0.9250


Epoch 79/100: 100%|██████████| 267/267 [00:19<00:00, 14.05it/s]


[Epoch 79] Loss: 0.0843
Validation Accuracy: 0.9053


Epoch 80/100: 100%|██████████| 267/267 [00:19<00:00, 14.05it/s]


[Epoch 80] Loss: 0.0308
Validation Accuracy: 0.9044


Epoch 81/100: 100%|██████████| 267/267 [00:18<00:00, 14.06it/s]


[Epoch 81] Loss: 0.0118
Validation Accuracy: 0.8988


Epoch 82/100: 100%|██████████| 267/267 [00:18<00:00, 14.38it/s]


[Epoch 82] Loss: 0.0043
Validation Accuracy: 0.9260


Epoch 83/100: 100%|██████████| 267/267 [00:19<00:00, 13.92it/s]


[Epoch 83] Loss: 0.0021
Validation Accuracy: 0.9250


Epoch 84/100: 100%|██████████| 267/267 [00:18<00:00, 14.07it/s]


[Epoch 84] Loss: 0.0021
Validation Accuracy: 0.9260


Epoch 85/100: 100%|██████████| 267/267 [00:18<00:00, 14.12it/s]


[Epoch 85] Loss: 0.0018
Validation Accuracy: 0.9297


Epoch 86/100: 100%|██████████| 267/267 [00:18<00:00, 14.16it/s]


[Epoch 86] Loss: 0.0019
Validation Accuracy: 0.9278


Epoch 87/100: 100%|██████████| 267/267 [00:18<00:00, 14.11it/s]


[Epoch 87] Loss: 0.0017
Validation Accuracy: 0.9269


Epoch 88/100: 100%|██████████| 267/267 [00:18<00:00, 14.14it/s]


[Epoch 88] Loss: 0.0015
Validation Accuracy: 0.9278


Epoch 89/100: 100%|██████████| 267/267 [00:18<00:00, 14.20it/s]


[Epoch 89] Loss: 0.0013
Validation Accuracy: 0.9269


Epoch 90/100: 100%|██████████| 267/267 [00:18<00:00, 14.21it/s]


[Epoch 90] Loss: 0.0020
Validation Accuracy: 0.9260


Epoch 91/100: 100%|██████████| 267/267 [00:19<00:00, 13.88it/s]


[Epoch 91] Loss: 0.0021
Validation Accuracy: 0.9250


Epoch 92/100: 100%|██████████| 267/267 [00:18<00:00, 14.29it/s]


[Epoch 92] Loss: 0.0890
Validation Accuracy: 0.9016


Epoch 93/100: 100%|██████████| 267/267 [00:18<00:00, 14.17it/s]


[Epoch 93] Loss: 0.0155
Validation Accuracy: 0.9053


Epoch 94/100: 100%|██████████| 267/267 [00:18<00:00, 14.22it/s]


[Epoch 94] Loss: 0.0190
Validation Accuracy: 0.9072


Epoch 95/100: 100%|██████████| 267/267 [00:18<00:00, 14.13it/s]


[Epoch 95] Loss: 0.0211
Validation Accuracy: 0.9072


Epoch 96/100: 100%|██████████| 267/267 [00:19<00:00, 13.95it/s]


[Epoch 96] Loss: 0.0028
Validation Accuracy: 0.9166


Epoch 97/100: 100%|██████████| 267/267 [00:19<00:00, 13.87it/s]


[Epoch 97] Loss: 0.0015
Validation Accuracy: 0.9166


Epoch 98/100: 100%|██████████| 267/267 [00:18<00:00, 14.22it/s]


[Epoch 98] Loss: 0.0013
Validation Accuracy: 0.9166


Epoch 99/100: 100%|██████████| 267/267 [00:19<00:00, 14.03it/s]


[Epoch 99] Loss: 0.0012
Validation Accuracy: 0.9166


Epoch 100/100: 100%|██████████| 267/267 [00:19<00:00, 14.00it/s]


[Epoch 100] Loss: 0.0011
Validation Accuracy: 0.9175

📊 Classification Report:
              precision    recall  f1-score   support

         0.0     0.8929    0.8523    0.8721       352
         1.0     0.9289    0.9497    0.9391       715

    accuracy                         0.9175      1067
   macro avg     0.9109    0.9010    0.9056      1067
weighted avg     0.9170    0.9175    0.9170      1067

