In [None]:
!pip install synapseclient



In [None]:
import os, shutil
import pandas as pd
import numpy as np

In [None]:
from google.colab import userdata

SYS_TOKEN = userdata.get('SYS_TOKEN')


In [None]:
import synapseclient
import synapseutils

import imageio.v2 as imageio
from pathlib import Path
import nibabel as nib

parent_id = "syn3193805"  # projeto raiz
pastas_desejadas = {
    "averaged-testing-images",
    "averaged-training-images",
    "averaged-training-labels",
}

Load Data if need

In [None]:
# Utils

def to_uint8(img2d, p_low=1, p_high=99):
    """Normaliza para uint8 usando percentis (evita dividir por zero e melhora contraste)."""
    x = np.asarray(img2d, dtype=np.float32)
    # Lida com constantes / NaN
    if not np.isfinite(x).any():
        return np.zeros_like(x, dtype=np.uint8)
    lo, hi = np.percentile(x[np.isfinite(x)], [p_low, p_high])
    if hi <= lo:
        lo, hi = np.nanmin(x), np.nanmax(x)
        if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
            return np.zeros_like(x, dtype=np.uint8)
    x = np.clip((x - lo) / (hi - lo), 0, 1)
    return (x * 255).astype(np.uint8)

def _move_channel_last(arr):
    """
    Se existir um eixo com tamanho 3 (RGB) que não é o último, move-o para o fim.
    Não mexe se já estiver adequado.
    """
    if arr.ndim >= 3:
        for ax in range(arr.ndim - 1):
            if arr.shape[ax] == 3:
                axes = [i for i in range(arr.ndim) if i != ax] + [ax]
                return np.transpose(arr, axes)
    return arr

def nii_to_jpgs(input_path, output_dir, rgb=False, ext="jpg"):
    input_path = Path(input_path)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    data = nib.load(str(input_path)).get_fdata()
    data = np.squeeze(np.asarray(data))
    data = _move_channel_last(data)  # tenta garantir canal por último

    # Casos:
    # 2D: (H, W)
    # 3D: (H, W, D) -> D slices, 1 canal
    # 4D: (H, W, D, C) ou (H, W, C, D) -> tentamos canal por último
    if data.ndim == 2:
        # um único slice, canal único
        ch_dir = output_dir / "channel_0"
        ch_dir.mkdir(parents=True, exist_ok=True)
        img8 = to_uint8(data)
        imageio.imwrite(str(ch_dir / f"channel_0_slice_0.{ext}"), img8)
        return

    if data.ndim == 3:
        H, W, D = data.shape
        ch_dir = output_dir / "channel_0"
        ch_dir.mkdir(parents=True, exist_ok=True)
        for z in range(D):
            slice2d = data[..., z]
            img8 = to_uint8(slice2d)
            if rgb:
                img8 = np.stack([img8, img8, img8], axis=-1)  # H x W x 3
            imageio.imwrite(str(ch_dir / f"channel_0_slice_{z}.{ext}"), img8)
        return

    if data.ndim == 4:
        H, W, A, B = data.shape  # tentaremos D=A, C=B
        D, C = A, B

        # Se acharmos que o canal está na penúltima dimensão (ex.: (H, W, 3, D)), invertemos:
        if C > 4 and D <= 4:
            # provavelmente (H, W, C, D) com C pequeno; traz C para o fim
            data = np.moveaxis(data, -2, -1)  # agora (H, W, D, C)
            H, W, D, C = data.shape

        # Agora assumimos (H, W, D, C)
        for c in range(C):
            ch_dir = output_dir / f"channel_{c}"
            ch_dir.mkdir(parents=True, exist_ok=True)
            for z in range(D):
                slice2d = np.squeeze(data[..., z, c])
                if slice2d.ndim != 2:
                    # segurança extra
                    print(f"[WARN] slice {z} canal {c} shape {slice2d.shape} não é 2D; pulando.")
                    continue
                img8 = to_uint8(slice2d)
                if rgb and C == 3:
                    # Se tivermos exatamente 3 canais e rgb=True, você pode preferir salvar 1 imagem RGB por slice
                    # Mas mantendo sua lógica de "por canal", só empilhamos se pediu rgb explicitamente
                    img8 = np.stack([img8, img8, img8], axis=-1)
                imageio.imwrite(str(ch_dir / f"channel_{c}_slice_{z}.{ext}"), img8)
        return

    raise ValueError(f"Dimensão NIfTI não suportada: shape={data.shape}")

In [None]:
import tempfile

def decompress_gz(file_path):
  import gzip

  out_path = os.path.splitext(file_path)[0]

  with gzip.open(file_path, 'rb') as f_in:
    with open(out_path, 'wb') as f_out:
      shutil.copyfileobj(f_in, f_out)

  os.remove(file_path)
  return out_path


def norm(s: str) -> str:
    return s.strip().lower().replace('_', '-')

def is_labels_dir(path: str) -> bool:
    # reconhece averaged-training-labels em qualquer parte do caminho
    return "averaged-training-labels" in [norm(p) for p in Path(path).parts]

def load_from_synapse():
  with tempfile.TemporaryDirectory() as tmpdir:
    syn = synapseclient.Synapse()
    syn.login(authToken=SYS_TOKEN)

    # 1) Baixa tudo das pastas desejadas
    for ch in syn.getChildren(parent_id):
        if ch["type"] == "org.sagebionetworks.repo.model.Folder" and ch["name"] in pastas_desejadas:
            file_path = os.path.join(tmpdir, ch["name"])
            os.makedirs(file_path, exist_ok=True)
            synapseutils.syncFromSynapse(
                syn, ch["id"], path=file_path,
                ifcollision="overwrite.local", followLink=True
            )


    print("Estrutura baixada:")
    for r, d, f in os.walk(tmpdir):
        print("   ", os.path.relpath(r, tmpdir))

    # 3) Processa NIfTIs
    total_encontrados = total_processados = total_escritos = 0

    for root, dirs, files in os.walk(tmpdir):
      for file in files:
        if not file.lower().endswith((".nii", ".nii.gz")):
            continue

        total_encontrados += 1
        try:
            input_path = Path(root) / file
            relative_path = os.path.relpath(root, tmpdir)

            # Diretório de saída (preserva estrutura relativa)
            out_dir = Path("/content/synapse_data/jpgs_PNGs") / relative_path / file.replace(".nii.gz","").replace(".nii","")

            # Descompacta se necessário (garante str e Path depois)
            if file.lower().endswith(".nii.gz"):
                print(f"[INFO] Decompressing: {input_path}")
                input_path = Path(decompress_gz(str(input_path)))

            # Decide formato (PNG para labels, JPG para demais)
            salvar_png = is_labels_dir(root)
            ext = "png" if salvar_png else "jpg"
            print(f"[INFO] Converting -> {ext.upper()}: {os.path.join(relative_path, file)}  out_dir={out_dir}")

            # Converte
            antes = len(list(out_dir.glob(f"**/*.{ext}"))) if out_dir.exists() else 0
            nii_to_jpgs(input_path, out_dir, rgb=False, ext=ext)
            depois = len(list(out_dir.glob(f"**/*.{ext}")))
            escritos = max(0, depois - antes)
            total_escritos += escritos
            total_processados += 1
            print(f"[OK] Escrevidos {escritos} arquivo(s) em {out_dir}")

        except Exception as e:
            print(f"[ERR] Falha em {os.path.join(relative_path, file)}: {type(e).__name__}: {e}")

    print(f"[RESUMO] encontrados={total_encontrados} processados={total_processados} escritos={total_escritos}")

def apagar():
  import shutil

  shutil.rmtree("/content/synapse_data")

# Load if need
if not Path("/content/synapse_data").exists():
  load_from_synapse()

Dataset loader

In [None]:
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from torchvision.transforms.functional import resize

class SynapseDataset(Dataset):
  # Precisa ser (3, 224, 224)
  # Resnet espera 3 canais
  # swim espera 224,224
  def __init__(self, img_dirs, labels_dirs, size=(224, 224)):
      self.items = []
      self.size = size

      # path to imgs
      for img_dir, lbs_dir in zip(img_dirs, labels_dirs):
        for im, lb in zip(sorted(Path(img_dir).iterdir()),sorted(Path(lbs_dir).iterdir())):
            img = read_image(str(im)).float() / 255.0
            lab = read_image(str(lb)).float() / 255.0
            self.items.append((img, lab))

  def __len__(self):
      return len(self.items)

  def _norm(self, img, c3 = True):
    img = resize(img, self.size, antialias=True)
    if img.ndim == 2:
        img = img.unsqueeze(0)              # [H,W] → [1,H,W]
    if img.shape[0] == 1 and c3:
        img = img.repeat(3, 1, 1)           # força 3 canais
    return img

  def __getitem__(self, idx):
      img = self._norm(self.items[idx][0])
      label = self._norm(self.items[idx][1], c3 = False) # Melhor n ter 3 canais
      # 0 ou 1
      label = label.to(torch.float32)
      label = (label > 0).to(torch.float32)
      return img, label

training_data = Path("/content/synapse_data/jpgs_PNGs/averaged-training-images")
training_labels = Path("/content/synapse_data/jpgs_PNGs/averaged-training-labels")

def load_data_and_labels_paths():
  # Return: List[data_path], List[labels_path]

  data = []
  labels =[]
  for dir in os.listdir(training_data):
    img_dir = training_data / dir / "channel_0"
    label_dir = training_labels / f"{dir}_seg" / "channel_0"
    data.append(img_dir)
    labels.append(label_dir)

  return data, labels

"""
Exemplo:
data, labels = load_data_and_labels_paths()
dataset = SynapseDataset(data, labels)
"""




'\nExemplo:\ndata, labels = load_data_and_labels_paths()\ndataset = SynapseDataset(data, labels)\n'

MODELOS

In [None]:

import cv2
from glob import glob
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm


# Config
BATCH       = 8
EPOCHS      = 50
LR          = 1e-3   # simples: Adam
ALPHA       = 0.6    # (1-α)*CE + α*Dice


class ConvBNReLU(nn.Module):
    def __init__(self, c_in, c_out, k=3, s=1, p=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_out, k, s, p, bias=False),
            nn.BatchNorm2d(c_out),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class DoubleConv(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.net = nn.Sequential(
            ConvBNReLU(c_in, c_out),
            ConvBNReLU(c_out, c_out),
        )
    def forward(self, x): return self.net(x)


# FCM (ResNet + Swin)

class SE(nn.Module):
    # atenção por canal
    def __init__(self, c, r=16):
        super().__init__()
        self.fc1 = nn.Conv2d(c, c//r, 1)
        self.fc2 = nn.Conv2d(c//r, c, 1)
    def forward(self, x):
        w = F.adaptive_avg_pool2d(x, 1)
        w = F.relu(self.fc1(w), inplace=True)
        w = torch.sigmoid(self.fc2(w))
        return x * w

class FCM(nn.Module):
    """
    f: feature da ResNet (B,C,H,W)
    g: feature da Swin   (B,C,H,W)
    saída: (B,C,H,W)
    """
    def __init__(self, C):
        super().__init__()
        self.cab = SE(C)
        self.mix = ConvBNReLU(3*C, C, k=1, s=1, p=0)  # concat -> 1x1

    def forward(self, f, g):
        # Cross-domain Conditioning (CNN X Swin)
        gf = g + F.adaptive_avg_pool2d(f, 1)  # g guiado por f
        fg = f + F.adaptive_avg_pool2d(g, 1)  # f guiado por g

        # correlação ponto-a-ponto
        corr = f * g
        # atenção por canal em g condicionado
        g_att = self.cab(gf)

        x = torch.cat([fg, corr, g_att], dim=1)
        return self.mix(x)

def nhwc_to_nchw(feat):
    if feat.ndim == 4:
        feat = feat.permute(0, 3, 1, 2).contiguous()
    return feat

class DualEncoder(nn.Module):
    def __init__(self, c_embed=48):
        super().__init__()

        # out_indices -> Quais estagios das features
        # out_indices -> 3 para o "gargalo"
        # 3 escalas compatíveis
        self.cnn  = timm.create_model('resnet34', pretrained=True, features_only=True, out_indices=(1,2,3))
        self.swin = timm.create_model('swin_small_patch4_window7_224', pretrained=True, features_only=True, out_indices=(0,1,2,3))

        c_cnn  = self.cnn.feature_info.channels()          #  [128, 256, 512]
        c_swin = self.swin.feature_info.channels()         #  [96, 192, 384, 768]

        # Ajuste do numero de canais
        self.proj_c = nn.ModuleList([nn.Conv2d(c, c_embed*(2**i), 1) for i,c in enumerate(c_cnn)])
        self.proj_s = nn.ModuleList([nn.Conv2d(c, c_embed*(2**i), 1) for i,c in enumerate(c_swin[:3])])
        self.c4_out = c_swin[3]  # 768

    def forward(self, x):
        f1,f2,f3 = self.cnn(x)            # H/4, H/8, H/16 (aprox)
        g1,g2,g3,g4 = self.swin(x)        # H/4, H/8, H/16, H/32
        # projetar canais
        f1,f2,f3 = [p(t) for p,t in zip(self.proj_c, (f1,f2,f3))]
        g1,g2,g3 = [p(nhwc_to_nchw(t)) for p,t in zip(self.proj_s, (g1,g2,g3))] #swin retorna nhwc
        g4 = nhwc_to_nchw(g4)
        return (f1,f2,f3), (g1,g2,g3,g4), self.c4_out

# Decoder
class Up(nn.Module):
    def __init__(self, c_in, c_skip, c_out):
        super().__init__()
        self.up   = nn.ConvTranspose2d(c_in, c_out, 2, 2)
        self.conv = DoubleConv(c_out + c_skip, c_out)
    def forward(self, x, skip):
        x = self.up(x)
        if x.shape[-2:] != skip.shape[-2:]:
            x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class CTC(nn.Module):
    """
    ResNet34 + Swin (fusão em 3 níveis) + decoder leve
    """
    def __init__(self, c_embed=48):
        super().__init__()
        self.d_enc = DualEncoder(c_embed=c_embed)

        # FCMs em três escalas
        self.fcm1 = FCM(c_embed*1)  # H/4
        self.fcm2 = FCM(c_embed*2)  # H/8
        self.fcm3 = FCM(c_embed*4)  # H/16

        # Gargalo: reduzir g4 (768) -> 8C
        dummy_enc = timm.create_model('swin_small_patch4_window7_224', pretrained=False, features_only=True, out_indices=(0,1,2,3))
        swin_c4 = dummy_enc.feature_info.channels()[-1]  # 768
        self.reduce4 = nn.Conv2d(swin_c4, c_embed*8, 1)

        # Decoder, usando m3,m2,m1 como skips
        self.up3 = Up(c_embed*8, c_embed*4, c_embed*4)   # H/16 -> H/8
        self.up2 = Up(c_embed*4, c_embed*2, c_embed*2)   # H/8  -> H/4
        self.up1 = Up(c_embed*2, c_embed*1, c_embed*1)   # H/4  -> H/2 (depois interpolamos p/ H)

        self.head = nn.Conv2d(c_embed*1, 1, 1)

    def forward(self, x):

      (f1,f2,f3), (g1,g2,g3,g4), _ = self.d_enc(x)

      # Alinha resoluções, caso necessário
      if f1.shape[-2:] != g1.shape[-2:]:
          g1 = F.interpolate(g1, size=f1.shape[-2:], mode='bilinear', align_corners=False)
      if f2.shape[-2:] != g2.shape[-2:]:
          g2 = F.interpolate(g2, size=f2.shape[-2:], mode='bilinear', align_corners=False)
      if f3.shape[-2:] != g3.shape[-2:]:
          g3 = F.interpolate(g3, size=f3.shape[-2:], mode='bilinear', align_corners=False)


      # Fusão complementar
      m1 = self.fcm1(f1, g1)
      m2 = self.fcm2(f2, g2)
      m3 = self.fcm3(f3, g3)

      # Gargalo
      x  = self.reduce4(g4)

      # Decodificação
      x  = self.up3(x, m3)
      x  = self.up2(x, m2)
      x  = self.up1(x, m1)

      # trazer de H/2 para 2H
      x  = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False)

      return self.head(x)

# =========================
# Loss & Métrica
# =========================
def dice_coef_binary(logits, target, eps=1e-6):
    # logits: Bx1xHxW, target: BxHxW (0/1)
    prob = torch.sigmoid(logits)
    pred = (prob > 0.5).float()
    t = target.float()
    inter = (pred * t).sum()
    denom = pred.sum() + t.sum()
    return (2*inter + eps) / (denom + eps)

def remove_3_chanel(x):
  if x.ndim == 4:

    if x.shape[1] == 3:
        # máscara veio RGB -> colapsa para binária
        # regra: qualquer canal >0 vira 1
        x = (x > 0).any(dim=1, keepdim=True).float()  # [B,1,H,W]

def mixed_loss_binary(logits, target, alpha=0.6):
  bce  = F.binary_cross_entropy_with_logits(logits, target)
  dice = 1 - dice_coef_binary(logits, target)
  return (1 - alpha) * bce + alpha * dice

In [None]:
def splitData(data, labels):
  size = len(labels)
  idx = np.random.permutation(size)
  data  = [data[i] for i in idx]
  labels = [labels[i] for i in idx]


  split = int(size * 0.8)

  train = data[:split], labels[:split]    # 80%
  val   = data[split:], labels[split:]    # 20%

  return train, val

# Onde vai rodar
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


data, labels = load_data_and_labels_paths()

train, val = splitData(data, labels)
train_ds = SynapseDataset(train[0], train[1])
val_ds = SynapseDataset(val[0], val[1])

train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True,  num_workers=4, pin_memory=True)
val_dl   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=4, pin_memory=True)

model = CTC(c_embed=48).to(DEVICE)
opt   = torch.optim.Adam(model.parameters(), lr=LR)

best = 0.0
for epoch in range(1, EPOCHS+1):
    # ---- treino
    model.train()
    losses=[]
    # conjunto de imagens e máscaras
    for img,mask in train_dl:
        img,mask = img.to(DEVICE), mask.to(DEVICE)
        opt.zero_grad()
        logits = model(img)
        loss = mixed_loss_binary(logits, mask)
        loss.backward()
        opt.step()
        losses.append(loss.item())
    tr_loss = float(np.mean(losses)) if losses else 0.0

    # ---- validação
    model.eval()
    dices=[]
    with torch.no_grad():
        for img,mask in val_dl:
            img,mask = img.to(DEVICE), mask.to(DEVICE)
            logits = model(img)
            dices.append(dice_coef_binary(logits,mask).item())
    val_dice = float(np.mean(dices)) if dices else 0.0

    print(f"[{epoch:03d}] train_loss={tr_loss:.4f}  val_dice={val_dice:.4f}")

    if val_dice > best:
        best = val_dice
        torch.save(model.state_dict(), "ctc_tiny_best.pth")
        print(f"  ↑ salvo: ctc_tiny_best.pth (Dice {best:.4f})")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.




torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 56, 56])
torch.Size([8, 48, 5

KeyboardInterrupt: 