In [1]:
!pip install segmentation-models-pytorch
!pip install torchmetrics

Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cublas_cu12-12.4.5.8-

In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
import nibabel as nib
import torch
from torch.utils.data import Dataset
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from segmentation_models_pytorch import Unet

  check_for_updates()


In [4]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [5]:
!cp gdrive/MyDrive/nnUNet_raw.zip ds.zip

In [6]:
!unzip ds.zip

Archive:  ds.zip
   creating: nnUNet_raw/
   creating: nnUNet_raw/Dataset001_BREAST/
   creating: nnUNet_raw/Dataset001_BREAST/.ipynb_checkpoints/
   creating: nnUNet_raw/Dataset001_BREAST/imagesTr/
  inflating: nnUNet_raw/Dataset001_BREAST/imagesTr/ISPY1_1089_0001.nii.gz  
  inflating: nnUNet_raw/Dataset001_BREAST/imagesTr/ISPY1_1107_0001.nii.gz  
  inflating: nnUNet_raw/Dataset001_BREAST/imagesTr/ISPY1_1199_0002.nii.gz  
  inflating: nnUNet_raw/Dataset001_BREAST/imagesTr/ISPY1_1229_0002.nii.gz  
  inflating: nnUNet_raw/Dataset001_BREAST/imagesTr/ISPY1_1234_0000.nii.gz  
  inflating: nnUNet_raw/Dataset001_BREAST/imagesTr/ISPY1_1227_0002.nii.gz  
  inflating: nnUNet_raw/Dataset001_BREAST/imagesTr/ISPY1_1230_0001.nii.gz  
  inflating: nnUNet_raw/Dataset001_BREAST/imagesTr/ISPY1_1150_0001.nii.gz  
  inflating: nnUNet_raw/Dataset001_BREAST/imagesTr/ISPY1_1159_0000.nii.gz  
  inflating: nnUNet_raw/Dataset001_BREAST/imagesTr/ISPY1_1077_0002.nii.gz  
  inflating: nnUNet_raw/Dataset001_BREAST

In [None]:
!cd /content/nnUNet_raw/Dataset001_BREAST/imagesTr
!rename 's/_0002\.nii\.gz$/.nii.gz/' cd /content/nnUNet_raw/Dataset001_BREAST/imagesTr/*_0002.nii.gz
!rm /content/nnUNet_raw/Dataset001_BREAST/imagesTr/*0000.nii.gz
!rm /content/nnUNet_raw/Dataset001_BREAST/imagesTr/*0001.nii.gz
!rm /content/nnUNet_raw/Dataset001_BREAST/imagesTr/*0002.nii.gz

In [8]:
import torch
from torch.utils.data import Dataset
import nibabel as nib
import numpy as np
from pathlib import Path
from typing import Callable, List, Tuple
class WindowedVolDataset(Dataset):
    """
    Cada amostra:
        x -> FloatTensor [win, H, W]
        y -> LongTensor  [H, W]
    """
    def __init__(self, images_dir, masks_dir,
                 win=11, preload=False, transforms=None):
        assert win % 2 == 1, "win deve ser ímpar"
        self.half     = win // 2
        self.win      = win
        self.preload  = preload
        self.trf      = transforms

        self.img_paths = sorted(Path(images_dir).glob("*.nii*"))
        self.msk_paths = [Path(masks_dir) / p.name for p in self.img_paths]

        if not self.img_paths:
            raise ValueError("Nenhum arquivo .nii/.nii.gz em images_dir")

        # ---------------- índice global ---------------------------------
        self.index = []                # (case_idx, z)
        self.depths = []               # profundidade de cada volume

        for idx, p in enumerate(self.img_paths):
            depth = self._get_depth(p)
            self.depths.append(depth)
            for z in range(self.half, depth - self.half):
                self.index.append((idx, z))

        if not self.index:
            raise ValueError(
                f"Todos os volumes têm depth <= {self.win}. "
                "Reduza win ou verifique os arquivos."
            )

        # opcional: pré-carrega
        if preload:
            self.buffer = {}
            for idx, (ip, mp) in enumerate(zip(self.img_paths,
                                               self.msk_paths)):
                self.buffer[idx] = {
                    "img": self._load_nii(ip, as_mask=False),
                    "mask": self._load_nii(mp, as_mask=True)
                }

    # -------------------------------------------------------------------
    @staticmethod
    def _load_nii(path: Path, as_mask=False):
        arr = nib.load(path).get_fdata(dtype=np.float32)
        # converte máscara p/ int
        if as_mask:
            arr = arr.astype(np.int16)
        # move depth para axis 0 se necessário
        if arr.shape[0] not in (arr.shape[1], arr.shape[2]):
            # depth já é axis 0 -> (D,H,W)
            return arr
        # senão, depth deve ser axis 2 (H,W,D)
        return np.moveaxis(arr, 2, 0)

    @staticmethod
    def _get_depth(path: Path) -> int:
        shape = nib.load(path).shape
        # depth é o eixo cujo tamanho difere dos outros 2 (H,W)
        if shape[0] != shape[1]:
            return shape[0]          # (D,H,W)
        return shape[2]              # (H,W,D)

    # -------------------------------------------------------------------
    def __len__(self):
        return len(self.index)

    def __getitem__(self, i):
        case_idx, z = self.index[i]

        vol  = (self.buffer[case_idx]["img"]  if self.preload
            else self._load_nii(self.img_paths[case_idx], False))
        mask = (self.buffer[case_idx]["mask"] if self.preload
            else self._load_nii(self.msk_paths[case_idx], True))

        x = vol[z - self.half : z + self.half + 1]   # (win, H, W)
        y = mask[z]                                  # (H, W)

        # ───────────── normalização ─────────────
        # z-score por volume (ou por janela, como preferir)
        mu  = x.mean()
        std = x.std() + 1e-6        # evita divisão por zero
        x   = (x - mu) / std
        # ----------------------------------------

        if self.trf:                # aug / albumentations
            x, y = self.trf(x, y)

        # tensor e quebra do vínculo NumPy
        x = torch.from_numpy(x).float().clone()   # [win, H, W]
        y = torch.from_numpy(y).long().clone()    # [H, W]
        return x, y


In [9]:
#splitting the dataset for validation

import argparse
import math
import random
from pathlib import Path
import shutil
import sys

def split():
    images_dir = Path("/content/nnUNet_raw/Dataset001_BREAST/imagesTr")
    masks_dir  = Path("/content/nnUNet_raw/Dataset001_BREAST/labelsTr")
    val_img_dir = Path("/content/nnUNet_raw/Dataset001_BREAST/ImagesVl")
    val_msk_dir = Path("/content/nnUNet_raw/Dataset001_BREAST/labelsVl")

    # Verificações básicas
    if not images_dir.is_dir() or not masks_dir.is_dir():
        sys.exit("Erro: images/ ou masks/ não encontrados dentro de dataset_dir.")

    # Junta todos os caminhos de imagem
    img_paths = sorted(images_dir.glob("*.*nii*"))     # .nii ou .nii.gz
    if not img_paths:
        sys.exit("Nenhum arquivo encontrado em images/.")

    # Sorteio reprodutível
    random.seed(42)
    n_val = math.ceil(len(img_paths) * 0.15)
    val_imgs = random.sample(img_paths, n_val)

    # Cria pastas de validação se ainda não existirem
    val_img_dir.mkdir(parents=True, exist_ok=True)
    val_msk_dir.mkdir(parents=True, exist_ok=True)

    # Move/copia cada caso
    moved = 0
    for img_path in val_imgs:
        mask_path = masks_dir / img_path.name
        if not mask_path.exists():
            print(f"🟡 Máscara ausente para {img_path.name}; pulando.")
            continue

        dest_img = val_img_dir / img_path.name
        dest_msk = val_msk_dir / mask_path.name

        op = shutil.move
        op(img_path, dest_img)
        op(mask_path, dest_msk)
        moved += 1

    print(f"''Movidos' {moved} casos "
          f"({moved/len(img_paths):.1%}) para validação.")

split()

''Movidos' 25 casos (15.5%) para validação.


In [10]:
train_ds = WindowedVolDataset(
    images_dir="/content/nnUNet_raw/Dataset001_BREAST/imagesTr",
    masks_dir ="/content/nnUNet_raw/Dataset001_BREAST/labelsTr",
    win=11,
    preload=True,
    transforms=None          # insira normalização/augment aqui
)

validation_ds = WindowedVolDataset(
    images_dir ="/content/nnUNet_raw/Dataset001_BREAST/ImagesVl",
    masks_dir ="/content/nnUNet_raw/Dataset001_BREAST/labelsVl",
    win=11,
    preload=True,
    transforms=None          # insira normalização/augment aqui
)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Unet(
    encoder_name="vgg16",
    in_channels=11,         # ← quantos arquivos de imagem você empilhou
    classes=1,              # 1 canal → máscara binária
    encoder_weights  = "imagenet",
    activation=None         # usaremos BCEWithLogits + Dice
).to(device)

In [None]:
import torch.nn.functional as F

def pad_collate(batch):
    """Pad dinamicamente cada amostra até o maior H e W do lote."""
    xs, ys = zip(*batch)              # listas de tensores
    max_h = 256
    max_w = 256

    def _pad(t, fill):
        dh, dw = max_h - t.shape[-2], max_w - t.shape[-1]
        return F.pad(t, (0, dw, 0, dh), value=fill)

    xs_padded = [_pad(x, 0) for x in xs]   # imagens → zero-pad
    ys_padded = [_pad(y, 0) for y in ys]   # máscara  → label 0

    return torch.stack(xs_padded), torch.stack(ys_padded)

# DataLoader usando o novo collate
loader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=4,
    shuffle=True,
    num_workers=8,
    collate_fn=pad_collate          # <- aqui!
)

val_loader = torch.utils.data.DataLoader(
    validation_ds,
    batch_size=4,
    shuffle=True,
    num_workers=8,
    collate_fn=pad_collate          # <- aqui!
)

In [13]:
i = 0
for xb, yb in loader:
    print("Batch OK:", xb.shape, yb.shape)  # torch.Size([4,11,256,256]) ...
    break
print(len(loader))

Batch OK: torch.Size([4, 11, 256, 256]) torch.Size([4, 256, 256])
4375


In [16]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchmetrics.classification import BinaryJaccardIndex
import segmentation_models_pytorch as smp
import os

# ─────────────────────────── HYPERPARAMS ───────────────────────────
LR          = 1e-3
N_EPOCHS    = 50
BATCH_SIZE  = 8
THRESH_IoU  = 0.20      # threshold menor p/ classe muito rara
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ─────────────────────────── DATALOADERS ───────────────────────────
# loader, val_loader devem estar prontos.  Exemplo:
# loader     = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
#                         num_workers=os.cpu_count(), pin_memory=True,
#                         persistent_workers=True)
# val_loader = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
#                         num_workers=os.cpu_count(), pin_memory=True)

# ─────────────────────────── MODELO ────────────────────────────────
model = smp.UnetPlusPlus(
    encoder_name="efficientnet-b4",      # swap for mobilenet_v2 if GPU is tiny
    encoder_weights="imagenet",
    in_channels=11,
    classes=1,
    activation=None).to(DEVICE)


# ─────────────────────────── LOSSES ────────────────────────────────
focal = smp.losses.FocalLoss(
    mode="binary", alpha=0.1, gamma=2.0
)
dice  = smp.losses.DiceLoss(
    mode="binary"
)

preprocess_input = smp.encoders.get_preprocessing_fn(
                        'mobilenet_v2', pretrained='imagenet')

def loss_fn(pred, target):
    return 0.5 * dice(pred, target) + 0.5 * focal(pred, target)

bce  = torch.nn.BCEWithLogitsLoss()
def soft_dice_loss(logits, target, eps=1e-6):   # logits OK!
    probs = logits.sigmoid()
    num = 2 * (probs * target).sum(dim=(2,3))
    den = (probs + target).sum(dim=(2,3)) + eps
    return 1 - (num / den).mean()

def loss_fn(logits, target):
    return 0.6 * soft_dice_loss(logits, target) + 0.4 * bce(
        logits, target.float())


# (Se não atualizou SMP, remova logits=True e chame:
#   probs = pred.sigmoid(); return 0.5*dice(probs,target)+0.5*focal(probs,target) )

# ─────────────────────────── MÉTRICA ───────────────────────────────
metric = BinaryJaccardIndex().to(DEVICE)

# ─────────────────────────── OPT & SCALER ──────────────────────────
optimizer = Adam(model.parameters(), lr=LR)
scaler    = torch.amp.GradScaler('cuda')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True

# ─────────────────────────── HELPERS ───────────────────────────────
def move_to_device(batch):
    xb, yb = batch
    return xb.to(DEVICE, non_blocking=True), yb.to(DEVICE, non_blocking=True)

def resize_logits(logits, target_hw):
    if logits.shape[-2:] == target_hw:
        return logits
    return F.interpolate(logits, size=target_hw,
                         mode="bilinear", align_corners=False)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    metric.reset()
    for batch in loader:
        xb, yb = move_to_device(batch)
        with torch.cuda.amp.autocast():
            logits = resize_logits(model(xb), yb.shape[-2:])
        preds = logits.sigmoid() > THRESH_IoU
        metric.update(preds, yb.unsqueeze(1))
    return metric.compute().item()

# ─────────────────────────── LOOP DE TREINO ────────────────────────
max_score = 0.36
for epoch in range(1, N_EPOCHS + 1):
    model.train()
    running_loss = 0.0
    i = 0

    for xb, yb in loader:
        xb, yb = move_to_device((xb, yb))

        with torch.amp.autocast('cuda'):
            logits = resize_logits(model(xb), yb.shape[-2:])
            loss   = loss_fn(logits, yb.unsqueeze(1).float())

        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * xb.size(0)
        i += 1
        print("\r" + f"epoch:{epoch}, {i}/{1000}", end="")
        if i > 1000:
          break


    train_loss = running_loss / len(loader.dataset)
    val_iou    = evaluate(val_loader)

    print(f"Epoch {epoch:03d} | Loss {train_loss:.4f} | IoU val {val_iou:.4f}")
    if val_iou > max_score:
      torch.save(model, 'best_model.pth')
      max_score = val_iou



epoch:1, 1001/1000

  with torch.cuda.amp.autocast():


Epoch 001 | Loss 0.1197 | IoU val 0.3067
epoch:2, 1001/1000Epoch 002 | Loss 0.1146 | IoU val 0.3860
epoch:3, 1001/1000Epoch 003 | Loss 0.1119 | IoU val 0.3920
epoch:4, 1001/1000Epoch 004 | Loss 0.1109 | IoU val 0.2879
epoch:5, 1001/1000Epoch 005 | Loss 0.1089 | IoU val 0.1589
epoch:6, 1001/1000Epoch 006 | Loss 0.1100 | IoU val 0.3780
epoch:7, 1001/1000Epoch 007 | Loss 0.1088 | IoU val 0.4110
epoch:8, 1001/1000Epoch 008 | Loss 0.1083 | IoU val 0.3690
epoch:9, 1001/1000Epoch 009 | Loss 0.1083 | IoU val 0.4316
epoch:10, 1001/1000Epoch 010 | Loss 0.1075 | IoU val 0.3148
epoch:11, 1001/1000Epoch 011 | Loss 0.1077 | IoU val 0.3890
epoch:12, 1001/1000Epoch 012 | Loss 0.1065 | IoU val 0.4168
epoch:13, 1001/1000Epoch 013 | Loss 0.1067 | IoU val 0.4371
epoch:14, 1001/1000Epoch 014 | Loss 0.1057 | IoU val 0.4003
epoch:15, 1001/1000Epoch 015 | Loss 0.1057 | IoU val 0.3881
epoch:16, 1001/1000Epoch 016 | Loss 0.1047 | IoU val 0.4419
epoch:17, 1001/1000Epoch 017 | Loss 0.1070 | IoU val 0.3920
epoch:1

In [None]:
!cp /content/best_model.pth /content/gdrive/MyDrive/