In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from glob import glob
from tqdm import tqdm

# -------------------- 디바이스 설정 --------------------
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

batch_size = 64

# -------------------- CBAM 정의 --------------------
class ChannelAttention(nn.Module):
    def __init__(self, planes, ratio=16):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Conv2d(planes, planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(planes // ratio, planes, 1, bias=False))
        self.avg = nn.AdaptiveAvgPool2d(1)
        self.max = nn.AdaptiveMaxPool2d(1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.shared(self.avg(x)) + self.shared(self.max(x)))

class SpatialAttention(nn.Module):
    def __init__(self, k=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=k, padding=k // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg = torch.mean(x, dim=1, keepdim=True)
        _max = torch.max(x, dim=1, keepdim=True)[0]
        return self.sigmoid(self.conv(torch.cat([avg, _max], dim=1)))

class CBAM(nn.Module):
    def __init__(self, planes):
        super().__init__()
        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()
        self.last_attention = None

    def forward(self, x):
        ca_out = self.ca(x) * x
        sa_out = self.sa(ca_out)
        self.last_attention = sa_out
        return sa_out * ca_out

# -------------------- ResNet18 + CBAM 정의 --------------------
class BasicBlockCBAM(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1, downsample=None, use_cbam=True):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, out_planes, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_planes, out_planes, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.cbam = CBAM(out_planes) if use_cbam else None
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.cbam:
            out = self.cbam(out)
        if self.downsample:
            identity = self.downsample(x)
        return self.relu(out + identity)

class ResNet18_CBAM(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 2)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2, use_cbam=False)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, planes, blocks, stride=1, use_cbam=True):
        downsample = None
        if stride != 1 or self.in_planes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_planes, planes, 1, stride, bias=False),
                nn.BatchNorm2d(planes))
        layers = [BasicBlockCBAM(self.in_planes, planes, stride, downsample, use_cbam)]
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlockCBAM(self.in_planes, planes, use_cbam=use_cbam))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
        x = self.layer4(self.layer3(self.layer2(self.layer1(x))))
        x = self.avgpool(x)
        return self.fc(torch.flatten(x, 1))

# -------------------- 모델 준비 --------------------
model_path = "/home/iujeong/lung_cancer/pth/r18_cbam_mga_aug_lr4_ep100_weight.pth"
model = ResNet18_CBAM(num_classes=2).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

# -------------------- Transform --------------------
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# -------------------- 데이터 준비 --------------------
slice_dir = "/data1/lidc-idri/slices"
npy_files = sorted(glob(os.path.join(slice_dir, "**", "*.npy"), recursive=True))
print(f"✅ 총 {len(npy_files)}개의 파일이 발견되었습니다.")

# -------------------- 배치 추론 --------------------
results = []
batch_images = []
batch_paths = []

for idx, f in enumerate(tqdm(npy_files, desc="추론 중")):
    try:
        img = np.load(f)
        img = np.clip(img, -1000, 400)
        img = (img + 1000) / 1400.0
        img = np.expand_dims(img, axis=-1).astype(np.float32)
        img = transform(img)
        batch_images.append(img)
        batch_paths.append(f)

        if len(batch_images) == batch_size or idx == len(npy_files) - 1:
            batch_tensor = torch.stack(batch_images).to(device)
            with torch.no_grad():
                outputs = model(batch_tensor)
                probs = F.softmax(outputs, dim=1)[:, 1].cpu().numpy()
            for path, prob in zip(batch_paths, probs):
                file_id = os.path.relpath(path, slice_dir).replace("/", "_").replace(".npy", "")
                results.append({"id": file_id, "cancer": float(prob)})
            batch_images = []
            batch_paths = []
    except Exception as e:
        print(f"[!] 오류 발생: {f} → {e}")

# -------------------- CSV 저장 --------------------
df = pd.DataFrame(results)
df.to_csv("submission.csv", index=False)
print(f"✅ CSV 저장 완료! ({len(results)}개 항목)")


Using device: cuda:1
✅ 총 7849개의 파일이 발견되었습니다.


추론 중: 100%|██████████| 7849/7849 [09:14<00:00, 14.15it/s] 

✅ CSV 저장 완료! (7849개 항목)





: 