In [None]:
import os
import numpy as np
import json
from glob import glob
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import (
    resnet18, resnet34, densenet121, efficientnet_b0,
    ResNet18_Weights, ResNet34_Weights, DenseNet121_Weights, EfficientNet_B0_Weights
)

SLICE_ROOT = "/data1/lidc-idri/slices"
BATCH_SIZE = 16
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def extract_label_from_filename(filename):
    try:
        score = int(filename.split("_")[-1].replace(".npy", ""))
        if score == 3:
            return None
        return 1 if score >= 4 else 0
    except:
        return None

all_files = glob(os.path.join(SLICE_ROOT, "LIDC-IDRI-*", "*.npy"))
file_label_pairs = [(f, extract_label_from_filename(f)) for f in all_files]
file_label_pairs = [(f, l) for f, l in file_label_pairs if l is not None]

files, labels = zip(*file_label_pairs)
train_files, val_files, train_labels, val_labels = train_test_split(
    files, labels, test_size=0.2, random_state=42
)

class LIDCDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = self.labels[idx]

        img = np.load(file_path).astype(np.float32)
        img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
        img = np.expand_dims(img, axis=0)
        img_tensor = torch.tensor(img)

        img_tensor = F.interpolate(
            img_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False
        ).squeeze(0)

        if self.transform:
            img_tensor = self.transform(img_tensor)

        return img_tensor, torch.tensor(label).float(), file_path

val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

val_dataset = LIDCDataset(val_files, val_labels, transform=val_transform)
val_loader_with_paths = DataLoader(val_dataset, batch_size=BATCH_SIZE)

def load_model():
    m1 = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    m1.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    m1.fc = nn.Linear(m1.fc.in_features, 1)
    m1.load_state_dict(torch.load("best_model_resnet18.pth"))
    m1.to(DEVICE).eval()

    m2 = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
    m2.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    m2.fc = nn.Linear(m2.fc.in_features, 1)
    m2.load_state_dict(torch.load("best_model_resnet34.pth"))
    m2.to(DEVICE).eval()

    m3 = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
    m3.features.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    m3.classifier = nn.Linear(m3.classifier.in_features, 1)
    m3.load_state_dict(torch.load("best_model_densenet121.pth"))
    m3.to(DEVICE).eval()

    m4 = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
    m4.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
    m4.classifier[1] = nn.Linear(m4.classifier[1].in_features, 1)
    m4.load_state_dict(torch.load("best_model_efficientnet_b0.pth"))
    m4.to(DEVICE).eval()

    return m1, m2, m3, m4

resnet18, resnet34, densenet, effnet = load_model()

model_outputs = {
    "resnet18": {"y_pred": [], "y_prob": []},
    "resnet34": {"y_pred": [], "y_prob": []},
    "densenet": {"y_pred": [], "y_prob": []},
    "effnet": {"y_pred": [], "y_prob": []},
    "ensemble": {"y_pred": [], "y_prob": []},
}
all_labels = []
output_records = []

with torch.no_grad():
    for images, labels, paths in tqdm(val_loader_with_paths):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        out1 = resnet18(images)
        out2 = resnet34(images)
        out3 = densenet(images)
        out4 = effnet(images)

        prob1 = torch.sigmoid(out1).squeeze()
        prob2 = torch.sigmoid(out2).squeeze()
        prob3 = torch.sigmoid(out3).squeeze()
        prob4 = torch.sigmoid(out4).squeeze()

        ensemble_prob = ((out1 + out2 + out3 + out4) / 4).squeeze()
        ensemble_prob = torch.sigmoid(ensemble_prob)

        pred1 = (prob1 > 0.5).long()
        pred2 = (prob2 > 0.5).long()
        pred3 = (prob3 > 0.5).long()
        pred4 = (prob4 > 0.5).long()
        ensemble_pred = (ensemble_prob > 0.5).long()

        for i in range(len(paths)):
            output_records.append({
                "file": paths[i],
                "label": int(labels[i]),
                "individual_outputs": [
                    float(prob1[i]), float(prob2[i]), float(prob3[i]), float(prob4[i])
                ],
                "ensemble_prob": float(ensemble_prob[i]),
                "ensemble_pred": int(ensemble_pred[i])
            })

            all_labels.append(int(labels[i]))

            model_outputs["resnet18"]["y_pred"].append(int(pred1[i]))
            model_outputs["resnet18"]["y_prob"].append(float(prob1[i]))

            model_outputs["resnet34"]["y_pred"].append(int(pred2[i]))
            model_outputs["resnet34"]["y_prob"].append(float(prob2[i]))

            model_outputs["densenet"]["y_pred"].append(int(pred3[i]))
            model_outputs["densenet"]["y_prob"].append(float(prob3[i]))

            model_outputs["effnet"]["y_pred"].append(int(pred4[i]))
            model_outputs["effnet"]["y_prob"].append(float(prob4[i]))

            model_outputs["ensemble"]["y_pred"].append(int(ensemble_pred[i]))
            model_outputs["ensemble"]["y_prob"].append(float(ensemble_prob[i]))

# JSON 저장
with open("ensemble_predictions.json", "w") as f:
    json.dump(output_records, f, indent=2)

# 성능 출력
print("\n📊 Classification Reports:\n")
for model_name in ["resnet18", "resnet34", "densenet", "effnet", "ensemble"]:
    print(f"--- {model_name.upper()} ---")
    print(classification_report(all_labels, model_outputs[model_name]["y_pred"], digits=4))
