In [None]:
# Densenet121

import os
import numpy as np
from glob import glob
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import torch.nn.functional as F

from tqdm import tqdm

from torchvision.models import densenet121, DenseNet121_Weights





# --- 설정 ---
SLICE_ROOT = "/data1/lidc-idri/slices"
BATCH_SIZE = 16
NUM_EPOCHS = 100
LR = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 파일 리스트 및 라벨 생성 ---
def extract_label_from_filename(filename):
    try:
        score = int(filename.split("_")[-1].replace(".npy", ""))
        if score == 3:
            return None
        return 1 if score >= 4 else 0
    except:
        return None

all_files = glob(os.path.join(SLICE_ROOT, "LIDC-IDRI-*", "*.npy"))
file_label_pairs = [(f, extract_label_from_filename(f)) for f in all_files]
file_label_pairs = [(f, l) for f, l in file_label_pairs if l is not None]

files, labels = zip(*file_label_pairs)
train_files, val_files, train_labels, val_labels = train_test_split(
    files, labels, test_size=0.2, random_state=42
)

# --- Transform ---
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# --- Dataset 정의 ---
class LIDCDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = self.labels[idx]

        img = np.load(file_path).astype(np.float32)
        img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
        img = np.expand_dims(img, axis=0)  # (1, H, W)
        img_tensor = torch.tensor(img)

        # Resize to (1, 224, 224)
        img_tensor = F.interpolate(img_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)

        if self.transform:
            img_tensor = self.transform(img_tensor)

        return img_tensor, torch.tensor(label).float()

# --- DataLoader ---
train_dataset = LIDCDataset(train_files, train_labels, transform=train_transform)
val_dataset = LIDCDataset(val_files, val_labels, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# --- 모델 정의 ---
model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
model.features.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.classifier = nn.Linear(model.classifier.in_features, 1)
model = model.to(DEVICE)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

# --- 학습 루프 ---
best_val_acc = 0.0

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        images = images.to(DEVICE)
        labels = labels.unsqueeze(1).to(DEVICE)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"[Epoch {epoch+1}] Loss: {total_loss/len(train_loader):.4f}")

    # --- 검증 ---
    model.eval()
    correct = total = 0
    y_true = []
    y_pred = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = model(images)
            preds = (torch.sigmoid(outputs) > 0.5).squeeze().long()

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

            correct += (preds == labels.long()).sum().item()
            total += labels.size(0)

    acc = correct / total
    print(f"Validation Accuracy: {acc:.4f}")

    if acc > best_val_acc:
        best_val_acc = acc
        torch.save(model.state_dict(), "best_model_densenet121.pth")
        print("✅ Best model saved!")

# --- 리포트 출력 ---
print("\n📊 Classification Report:")
print(classification_report(y_true, y_pred, digits=4))