In [None]:
!pip install monai
!pip install nibabel

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting monai
  Downloading monai-1.1.0-202212191849-py3-none-any.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: monai
Successfully installed monai-1.1.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Reemplaza estas rutas con las de tus carpetas locales de imágenes y etiquetas
image_folder = "/content/drive/MyDrive/imagenes_medicas/imagesTr"
label_folder = "/content/drive/MyDrive/imagenes_medicas/labelsTr"


# Hiperparámetros
learning_rate = 1e-4
batch_size = 4
num_workers = 4
num_epochs = 100
val_interval = 1


In [None]:
import os
import torch
from monai.transforms import (
    Compose, LoadImaged, AddChanneld, ScaleIntensityRanged,
    CropForegroundd, RandCropByPosNegLabeld, Spacingd,
    Orientationd, ToTensord, EnsureChannelFirstd
)
from monai.data import Dataset, DataLoader
import sys
sys.path.append('/content/drive/MyDrive/monao')
from config import Config

config = Config()

image_folder = config.image_folder
label_folder = config.label_folder


def get_data_files(image_folder, label_folder):
    image_files = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if not f.startswith('.')])
    label_files = sorted([os.path.join(label_folder, f) for f in os.listdir(label_folder) if not f.startswith('.')])

    # Agregar impresión de diagnóstico
    if len(image_files) != len(label_files):
        print("Image files:")
        print(image_files)
        print("Label files:")
        print(label_files)

    print("Number of image files:", len(image_files))
    print("Number of label files:", len(label_files))
    print("Image files:")
    print(image_files)
    print("Label files:")
    print(label_files)
    assert len(image_files) == len(label_files)
    data = [{"image": img, "label": lbl} for img, lbl in zip(image_files, label_files)]
    return data

def create_transforms():
    return Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(keys="image", a_min=-1024, a_max=3071, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 1.5), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ToTensord(keys=["image", "label"])
    ])

def create_data_loaders(data, transforms, batch_size=4, num_workers=4, split_ratio=0.8):
    ds = Dataset(data=data, transform=transforms)
    train_size = int(len(ds) * split_ratio)
    val_size = len(ds) - train_size
    train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    return train_loader, val_loader


In [None]:
import torch
import torch.nn as nn
from monai.networks.nets import UNet
from torch.optim import Adam
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.data import decollate_batch
from monai.transforms import EnsureType
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch._tensor")
warnings.filterwarnings("ignore", category=UserWarning, module="monai.data")

def create_model(device):
    model = UNet(
        spatial_dims=3,
        in_channels=4,  # Cambia esto a 4 para aceptar imágenes RGB + canal alfa
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    return model

def create_loss_function():
    return DiceLoss(sigmoid=True)

def create_optimizer(model, learning_rate=1e-4):
    return Adam(model.parameters(), learning_rate)
def create_dice_metric():
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    
    def wrapped_dice_metric(y_pred, y):
        y_pred = torch.stack(y_pred).squeeze(dim=1)  # Convertir la lista en un tensor y hacer squeeze
        y_pred = (y_pred > 0.5).float()  # Binarizar el tensor y_pred
        y = y.squeeze(dim=1)
        y = (y > 0.5).float()  # Binarizar el tensor y
        
        print("y_pred shape:", y_pred.shape)  # Imprimir la forma de y_pred
        print("y_pred:", y_pred)  # Imprimir y_pred
        print("y shape:", y.shape)  # Imprimir la forma de y
        print("y:", y)  # Imprimir y
        
        return dice_metric(y_pred=y_pred, y=y)

    return wrapped_dice_metric

def train_and_evaluate_model(model, loss_function, optimizer, dice_metric, train_loader, val_loader, device, num_epochs=100, val_interval=1):
    best_metric = -1
    best_metric_epoch = -1
    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            inputs, labels = batch["image"].to(device), batch["label"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metrics = []
                for val_batch in val_loader:
                    val_inputs, val_labels = val_batch["image"].to(device), val_batch["label"].to(device)
                    val_outputs = model(val_inputs)
                    val_outputs = val_outputs.squeeze(dim=1)  # Asegurar que val_outputs tenga las mismas dimensiones que val_labels
                    val_outputs = [(pred.sigmoid() > 0.5).float() for pred in decollate_batch(val_outputs)]
                    value = dice_metric(y_pred=val_outputs, y=val_labels)
                    metrics.append(value.mean().item())
                metric = torch.tensor(metrics).mean()
                
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model_segmentation.pth")
                    print("Saved best metric model")
                
                print(f"Validation Dice Metric: {metric:.4f}")

    print(f"Best validation Dice Metric: {best_metric:.4f} at epoch: {best_metric_epoch}")


In [None]:
def train_and_evaluate_model(model, loss_function, optimizer, dice_metric, train_loader, val_loader, device, num_epochs=100, val_interval=1):
    best_metric = -1
    best_metric_epoch = -1
    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            inputs, labels = batch["image"].to(device), batch["label"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metrics = []
                for val_batch in val_loader:
                    val_inputs, val_labels = val_batch["image"].to(device), val_batch["label"].to(device)
                    val_outputs = model(val_inputs)
                    val_outputs = val_outputs.squeeze(dim=1)  # Asegurar que val_outputs tenga las mismas dimensiones que val_labels
                    val_outputs = [(pred.sigmoid() > 0.5).float() for pred in decollate_batch(val_outputs)]
                    
                    # Adjust dimensions before calculating the Dice metric
                    val_outputs = torch.stack(val_outputs).unsqueeze(1)
                    val_labels = val_labels.unsqueeze(1)

                    value = dice_metric(y_pred=val_outputs, y=val_labels)
                    metrics.append(value.mean().item())
                metric = torch.tensor(metrics).mean()

                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model_segmentation.pth")
                    print("Saved best metric model")

                print(f"Validation Dice Metric: {metric:.4f}")

    print(f"Best validation Dice Metric: {best_metric:.4f} at epoch: {best_metric_epoch}")


In [None]:
import nibabel as nib
import os
from monai.transforms import Resize
import numpy as np

# directorio donde se encuentran las imágenes
data_dir = "/content/drive/MyDrive/imagenes_medicas"

# redimensionar todas las imágenes a (128, 128, 128)
resize_transform = Resize((128, 128, 128), mode='nearest')

# iterar sobre todas las imágenes en el directorio
for filename in os.listdir(data_dir):
    if filename.endswith(".nii.gz") and not filename.startswith("._"):
        imagen_path = os.path.join(data_dir, filename)
        imagen_nifti = nib.load(imagen_path)
        imagen = imagen_nifti.get_fdata()
        
        # aplicar transformación de redimensionamiento
        imagen_resized = resize_transform(imagen)
        imagen_resized = imagen_resized.astype(np.int16)
        
        # guardar la imagen redimensionada
        nueva_imagen_nifti = nib.Nifti1Image(imagen_resized, imagen_nifti.affine)
        nueva_imagen_path = os.path.join(data_dir, "resized_" + filename)
        nib.save(nueva_imagen_nifti, nueva_imagen_path)


In [None]:
print("Image files:")
print(os.listdir(image_folder))
print("Label files:")
print(os.listdir(label_folder))


Image files:
['BRATS_001.nii.gz', 'BRATS_002.nii.gz', 'BRATS_003.nii.gz', 'BRATS_004.nii.gz', 'BRATS_005.nii.gz', 'BRATS_006.nii.gz', 'BRATS_007.nii.gz', 'BRATS_008.nii.gz', 'BRATS_009.nii.gz', 'BRATS_010.nii.gz', 'BRATS_011.nii.gz', 'BRATS_012.nii.gz', 'BRATS_013.nii.gz', 'BRATS_014.nii.gz', 'BRATS_015.nii.gz', 'BRATS_016.nii.gz', 'BRATS_017.nii.gz', 'BRATS_018.nii.gz', 'BRATS_019.nii.gz', 'BRATS_020.nii.gz', 'BRATS_021.nii.gz', 'BRATS_022.nii.gz', 'BRATS_023.nii.gz', 'BRATS_024.nii.gz', 'BRATS_025.nii.gz', 'BRATS_026.nii.gz', 'BRATS_027.nii.gz', 'BRATS_028.nii.gz', 'BRATS_029.nii.gz', 'BRATS_030.nii.gz', 'BRATS_031.nii.gz', 'BRATS_032.nii.gz', 'BRATS_033.nii.gz', 'BRATS_034.nii.gz', 'BRATS_035.nii.gz', 'BRATS_036.nii.gz', 'BRATS_037.nii.gz', 'BRATS_038.nii.gz', 'BRATS_039.nii.gz', 'BRATS_040.nii.gz', 'BRATS_041.nii.gz', 'BRATS_042.nii.gz', 'BRATS_043.nii.gz', 'BRATS_044.nii.gz', 'BRATS_045.nii.gz', 'BRATS_046.nii.gz', 'BRATS_047.nii.gz', 'BRATS_048.nii.gz', 'BRATS_049.nii.gz', 'BRATS

In [None]:
import nibabel as nib

path = "/content/drive/MyDrive/imagenes_medicas/imagesTr/BRATS_077.nii.gz"
try:
    img = nib.load(path)
    print("Imagen cargada correctamente")
except Exception as e:
    print("Error al cargar la imagen:", e)


Imagen cargada correctamente


In [None]:
def create_dice_metric():
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    
    def wrapped_dice_metric(y_pred, y):
        y_pred = torch.stack(y_pred).squeeze(dim=1)  # Convertir la lista en un tensor y hacer squeeze
        y_pred = (y_pred > 0.5).float()  # Binarizar el tensor y_pred
        y = y.squeeze(dim=1)
        y = (y > 0.5).float()  # Binarizar el tensor y
        
        print("y_pred shape:", y_pred.shape)  # Imprimir la forma de y_pred
        print("y_pred:", y_pred)  # Imprimir y_pred
        print("y shape:", y.shape)  # Imprimir la forma de y
        print("y:", y)  # Imprimir y
        
        return dice_metric(y_pred=y_pred, y=y)

    return wrapped_dice_metric


In [None]:
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.append("/content/drive/MyDrive/monao")
from data_utils import get_data_files, create_transforms, create_data_loaders

import torch
from data_utils import get_data_files, create_transforms, create_data_loaders
from model_utils import (
    create_model, create_loss_function, create_optimizer, create_dice_metric, train_and_evaluate_model
)
from config import Config

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    data = get_data_files(Config.image_folder, Config.label_folder)
    transforms = create_transforms()
    train_loader, val_loader = create_data_loaders(data, transforms, batch_size=Config.batch_size, num_workers=2)
    print("Data loaders created.")  # Añade esta línea

    model = create_model(device)
    print("Model created.")
    loss_function = create_loss_function()
    optimizer = create_optimizer(model, learning_rate=Config.learning_rate)
    dice_metric = create_dice_metric()

    train_and_evaluate_model(
        model, loss_function, optimizer, dice_metric, train_loader, val_loader, device,
        num_epochs=Config.num_epochs, val_interval=Config.val_interval
    )

if __name__ == '__main__':
  main()


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ...