In [None]:
# ResNet34 + pretrained

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from glob import glob
from tqdm import tqdm
import timm  # pip install timm

import torch
import torch.nn as nn
from torchvision.models import resnet34, ResNet34_Weights

# ⚠️ CUDA 환경 문제 회피 → CPU 강제 설정
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SLICE_ROOT = "/data1/lidc-idri/slices"
BATCH_SIZE = 16
NUM_EPOCHS = 100

# 1. 레이블 추출
def extract_label_from_filename(filename):
    try:
        score = int(filename.split("_")[-1].replace(".npy", ""))
        if score == 3: return None
        return 1 if score >= 4 else 0
    except:
        return None

# 2. 데이터 로딩
all_files = glob(os.path.join(SLICE_ROOT, "LIDC-IDRI-*", "*.npy"))
file_label_pairs = [(f, extract_label_from_filename(f)) for f in all_files]
file_label_pairs = [(f, l) for f, l in file_label_pairs if l is not None]
files, labels = zip(*file_label_pairs)
train_files, val_files, train_labels, val_labels = train_test_split(
    files, labels, test_size=0.2, random_state=42
)

# 3. Dataset
class LIDCDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __getitem__(self, idx):
        img = np.load(self.file_paths[idx]).astype(np.float32)
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        img = np.expand_dims(img, 0)  # [1, H, W]
        img_tensor = torch.tensor(img)
        img_tensor = F.interpolate(img_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
        if self.transform:
            img_tensor = self.transform(img_tensor)
        label = torch.tensor(self.labels[idx]).float()
        return img_tensor, label, self.file_paths[idx]

    def __len__(self):
        return len(self.file_paths)

# 4. Transform
val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

train_loader = DataLoader(LIDCDataset(train_files, train_labels, transform=val_transform), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(LIDCDataset(val_files, val_labels, transform=val_transform), batch_size=BATCH_SIZE)

# 5. ResNet34 pretrained 모델 로딩 및 1채널 CT 입력에 맞게 수정
model = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 1)  # binary classification
model = model.to(DEVICE)

# 6. 손실함수 & optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# 7. 학습 루프
best_acc = 0
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for imgs, labels, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        imgs, labels = imgs.to(DEVICE), labels.unsqueeze(1).to(DEVICE)
        optimizer.zero_grad()
        output = model(imgs)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"[Epoch {epoch+1}] Loss: {running_loss / len(train_loader):.4f}")

    # 검증
    model.eval()
    correct, total = 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, labels, _ in val_loader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            out = model(imgs).squeeze()
            probs = torch.sigmoid(out)
            preds = (probs > 0.5).long()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_acc = correct / total
    print(f"Validation Accuracy: {val_acc:.4f}")

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_model_resnet34_pretatin.pth")
        print("✅ Best model saved!")

# 8. 최종 리포트
print("\n📊 Classification Report:")
print(classification_report(all_labels, all_preds, digits=4))