In [8]:
import os
import numpy as np
from glob import glob
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# 환경 설정
SLICE_ROOT = "/data1/lidc-idri/slices"
BATCH_SIZE = 16
NUM_EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Dataset ---
class CTLesionDataset(Dataset):
    def __init__(self, files):
        self.slice_paths = files
        self.mask_paths = [p.replace(".npy", "_mask.npy") for p in self.slice_paths]

    def __len__(self):
        return len(self.slice_paths)

    def __getitem__(self, idx):
        img = np.load(self.slice_paths[idx])
        mask = np.load(self.mask_paths[idx])

        img = np.clip(img, -1000, 400)
        img = (img + 1000) / 1400
        mask = (mask > 0).astype(np.float32)

        img = Image.fromarray((img * 255).astype(np.uint8)).resize((224, 224))
        mask = Image.fromarray((mask * 255).astype(np.uint8)).resize((224, 224))

        img = np.array(img).astype(np.float32) / 255.0
        mask = np.array(mask).astype(np.float32) / 255.0

        img = torch.tensor(img).unsqueeze(0)
        mask = torch.tensor(mask).unsqueeze(0)

        return img, mask

# --- Loss Functions ---
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.):
        super().__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = preds.view(-1)
        targets = targets.view(-1)
        intersection = (preds * targets).sum()
        return 1 - ((2. * intersection + self.smooth) /
                    (preds.sum() + targets.sum() + self.smooth))

class BCEDiceLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCELoss()
        self.dice = DiceLoss()

    def forward(self, preds, targets):
        return self.bce(preds, targets) + self.dice(preds, targets)

# --- Model ---
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(nn.Conv2d(F_g, F_int, 1), nn.BatchNorm2d(F_int))
        self.W_x = nn.Sequential(nn.Conv2d(F_l, F_int, 1), nn.BatchNorm2d(F_int))
        self.psi = nn.Sequential(nn.Conv2d(F_int, 1, 1), nn.BatchNorm2d(1), nn.Sigmoid())
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class UpBlockWithAttention(nn.Module):
    def __init__(self, in_ch, out_ch, skip_ch):
        super().__init__()
        self.att = AttentionGate(F_g=in_ch // 2, F_l=skip_ch, F_int=out_ch)
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
        self.conv = ConvBlock(in_ch, out_ch)

    def forward(self, x, skip):
        x = self.up(x)
        skip = self.att(x, skip)
        diffY = skip.size()[2] - x.size()[2]
        diffX = skip.size()[3] - x.size()[3]
        x = F.pad(x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)

class HMSANet_Attention(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.enc1 = ConvBlock(in_channels, 64)
        self.enc2 = ConvBlock(64, 128)
        self.enc3 = ConvBlock(128, 256)
        self.enc4 = ConvBlock(256, 512)
        self.pool = nn.MaxPool2d(2)
        self.center = ConvBlock(512, 1024)
        self.dec4 = UpBlockWithAttention(1024, 512, 512)
        self.dec3 = UpBlockWithAttention(512, 256, 256)
        self.dec2 = UpBlockWithAttention(256, 128, 128)
        self.dec1 = UpBlockWithAttention(128, 64, 64)
        self.out = nn.Conv2d(64, out_channels, 1)
        self.out_act = nn.Sigmoid()

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        center = self.center(self.pool(e4))
        d4 = self.dec4(center, e4)
        d3 = self.dec3(d4, e3)
        d2 = self.dec2(d3, e2)
        d1 = self.dec1(d2, e1)
        out = self.out_act(self.out(d1))
        return out

# --- Training & Validation ---
def train():
    all_files = glob(os.path.join(SLICE_ROOT, "LIDC-IDRI-*", "*.npy"))
    all_files = [f for f in all_files if not f.endswith("_mask.npy")]
    all_files = [f for f in all_files if os.path.exists(f.replace(".npy", "_mask.npy"))]
    print(f"총 학습 대상 슬라이스 수: {len(all_files)}")

    train_files, val_files = train_test_split(all_files, test_size=0.2, random_state=42)
    train_loader = DataLoader(CTLesionDataset(train_files), batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(CTLesionDataset(val_files), batch_size=BATCH_SIZE, shuffle=False)

    model = HMSANet_Attention().to(DEVICE)
    criterion = BCEDiceLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(NUM_EPOCHS):
        model.train()
        epoch_loss = 0
        for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} - Train"):
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            optimizer.zero_grad()
            preds = model(imgs)
            loss = criterion(preds, masks)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Train Epoch {epoch+1}, Loss: {epoch_loss / len(train_loader):.4f}")

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for imgs, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} - Val"):
                imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
                preds = model(imgs)
                loss = criterion(preds, masks)
                val_loss += loss.item()
        print(f"Val Epoch {epoch+1}, Loss: {val_loss / len(val_loader):.4f}")

if __name__ == "__main__":
    train()


총 학습 대상 슬라이스 수: 0


ValueError: With n_samples=0, test_size=0.2 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.