In [None]:
import os
import numpy as np
from glob import glob
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import torch.nn.functional as F
from tqdm import tqdm
from collections import defaultdict
import gc

# 하이퍼파라미터 설정
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)

slice_root = "/data1/lidc-idri/slices"
batch_size = 16
num_epoch = 1
learning_rate = 1e-4

# 레이블 추출
def labels_filename(fname):
    try:
        score = int(fname.split("_")[-1].replace(".npy", ""))
        return None  if score == 3 else int(score >= 4)
    
    except:
        return None
    
# 데이터셋 전처리
class LIDCDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, index):
        file_path = self.file_paths[index]
        label = self.labels[index]

        img = np.load(file_path).astype(np.float32)
        img = np.clip(img, -1000, 400)
        img = (img + 1000) / 1400.0
        img = np.expand_dims(img, axis=0)
        img_tensor = torch.tensor(img)

        if self.transform:
            img_tensor = self.transform(img_tensor)

        return img_tensor, torch.tensor(label).float()


# 데이터 증강
augmentation_configs = {
    'baseline': transforms.Compose([
        transforms.ToPILImage(),
        transforms.CenterCrop(180),
        transforms.ToTensor()
    ]),

    'flip_rotate': transforms.Compose([
        transforms.ToPILImage(),
        transforms.CenterCrop(180),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor()
    ]),

    'blur': transforms.Compose([
        transforms.ToPILImage(),
        transforms.CenterCrop(180),
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        transforms.ToTensor()
    ]),

    'total': transforms.Compose([
        transforms.ToPILImage(),
        transforms.CenterCrop(180),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        transforms.ToTensor()
    ])
}

# 데이터 불러오기
def get_model(name):
    if name == "resnet18":
        model = models.resnet18(pretrained=True)
        model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        model.fc = nn.Linear(model.fc.in_features, 1)

    elif name == "resnet34":
        model = models.resnet34(pretrained=True)
        model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        model.fc = nn.Linear(model.fc.in_features, 1)

    elif name == "densenet121":
        model = models.densenet121(pretrained=True)
        model.features.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        model.classifier = nn.Linear(model.classifier.in_features, 1)

    elif name == "efficientnet_b0":
        model = models.efficientnet_b0(pretrained=True)
        model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, 1)
    else:
        raise ValueError("Unknown model name")
    return model.to(device)


# 데이터 로더
all_files = glob(os.path.join(slice_root, "LIDC-IDRI-*", "*.npy"))

file_label_pairs = [(f, labels_filename(f)) for f in all_files]
file_label_pairs = [(f, l) for f, l in file_label_pairs if l is not None]
files, labels = zip(*file_label_pairs)

train_files, temp_files, train_labels, temp_labels = train_test_split(files, labels, test_size=0.3, random_state=42)
val_files, test_files, val_labels, test_labels = train_test_split(temp_files, temp_labels, test_size=0.5, random_state=42)


# 모델 정의
results = defaultdict(dict)
model_names = ["resnet18", "resnet34", "densenet121", "efficientnet_b0"]

for model_name in model_names:
    for aug_name, transform in augmentation_configs.items():
        print(f"\n Running: {model_name} + {aug_name}")

        train_dataset = LIDCDataset(train_files, train_labels, transform)
        val_dataset = LIDCDataset(val_files, val_labels, transform)
        test_dataset = LIDCDataset(test_files, test_labels, transform)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)
        test_loader = DataLoader(test_dataset, batch_size=batch_size)

        model = get_model(model_name)

# loss, optimizer 설정
        criterion = nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


# 반복문
        best_val_acc = 0.0

        # --- 저장 경로 및 변수 초기화 ---
        save_dir = os.path.join(os.path.dirname(os.getcwd()), "pth")
        os.makedirs(save_dir, exist_ok=True)
        best_val_acc = 0.0

        for epoch in range(num_epoch):
            model.train()

            correct = 0
            total = 0
            epoch_loss = 0

            for images, labels in train_loader:
                images = images.to(device)
                labels = labels.unsqueeze(1).to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                predicted = (torch.sigmoid(outputs) > 0.5).long()
                correct += (predicted == labels.long()).sum().item()
                total += labels.size(0)

            train_acc = correct / total
            print(f"[{model_name} + {aug_name}] Epoch: {epoch+1}/{num_epoch} Train Acc: {train_acc * 100:.4f}%")

            model.eval()

            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for images, labels in val_loader:
                    images = images.to(device)
                    labels = labels.to(device)

                    outputs = model(images)

                    predicted = (torch.sigmoid(outputs) > 0.5).squeeze().long()
                    val_correct += (predicted == labels.long()).sum().item()
                    val_total += labels.size(0)

            val_acc = val_correct / val_total
            print(f"[{model_name} + {aug_name}] Epoch {epoch+1}/{num_epoch} Val Acc {val_acc * 100:.4f}%")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), os.path.join(save_dir, f"best_aug_{model_name}_{aug_name}.pth"))


        model.load_state_dict(torch.load(os.path.join(save_dir, f"best_aug_{model_name}_{aug_name}.pth")))
        model.eval()

        y_true, y_pred, y_probs = [], [], []

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)

                probs = torch.sigmoid(outputs).squeeze()
                preds = (probs > 0.5).long()

                y_true.extend(labels.cpu().numpy())
                y_pred.extend(preds.cpu().numpy())
                y_probs.extend(probs.cpu().numpy())

        acc = (np.array(y_true) == np.array(y_pred)).mean()
        auc = roc_auc_score(y_true, y_probs)
        cm = confusion_matrix(y_true, y_pred)
        results[model_name][aug_name] = {"acc": acc, "auc": auc, "cm": cm}
        print(f"✅ Test Acc: {acc:.4f}, AUC: {auc:.4f}")
