In [8]:
import os
import nibabel as nib
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class BrainDataset2D(Dataset):
    def __init__(self, rootpath, transform=None):
        self.transform = transform
        self.samples = [os.path.join(rootpath, r) for r in os.listdir(rootpath) if os.path.isdir(os.path.join(rootpath, r))]

    def __len__(self):
        return len(self.samples)

    def loadTensor(self, filepath):
        img = nib.load(filepath).get_fdata()
        tensor = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)  # [H, W, D] -> [D, H, W]
        return tensor[:154]  # Limitar a 154 cortes

    def __getitem__(self, idx):
        folder = self.samples[idx]

        flair = seg = t1 = t1ce = t2 = None
        for file in os.listdir(folder):
            mode = file.split("_")[-1]
            filepath = os.path.join(folder, file)
            if mode == "flair.nii.gz":
                flair = self.loadTensor(filepath)
            elif mode == "seg.nii.gz":
                seg = self.loadTensor(filepath)
            elif mode == "t1.nii.gz":
                t1 = self.loadTensor(filepath)
            elif mode == "t1ce.nii.gz":
                t1ce = self.loadTensor(filepath)
            elif mode == "t2.nii.gz":
                t2 = self.loadTensor(filepath)

        # Combinar las modalidades en un solo tensor [canales, D, H, W]
        voxel = torch.stack([flair, t1, t1ce, t2], dim=0)  # [4, D, H, W]

        # Convertir las etiquetas a clases correctas
        seg[seg == 4] = 3
        seg = seg.long()

        # Extraer cortes 2D
        slices = []
        for z in range(voxel.shape[1]):  # Recorrer el eje z (D)
            slice_voxel = voxel[:, z, :, :]  # [4, H, W]
            slice_seg = seg[z, :, :]  # [H, W]
            slices.append((slice_voxel, slice_seg))

        return slices  # Lista de cortes 2D

def DoubleConv2d(in_chan, out_chan):
    return nn.Sequential(
        nn.Conv2d(in_chan, out_chan, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_chan, out_chan, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
    )

class Unet2D(nn.Module):
    def __init__(self, n_chan, n_classes):
        super(Unet2D, self).__init__()

        # Encoder
        self.enc1 = DoubleConv2d(n_chan, 32)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.enc2 = DoubleConv2d(32, 64)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck
        self.bottleneck = DoubleConv2d(64, 128)

        # Decoder
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = DoubleConv2d(128, 64)

        self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = DoubleConv2d(64, 32)

        # Output
        self.OutLayer = nn.Conv2d(32, n_classes, kernel_size=1)

    def forward(self, x):
        z1 = self.enc1(x)
        z2 = self.pool1(z1)

        z2 = self.enc2(z2)
        Z = self.pool2(z2)

        Z = self.bottleneck(Z)
        y = self.up2(Z)
        y = torch.cat([y, z2], dim=1)
        y = self.dec2(y)
        y = self.up1(y)
        y = torch.cat([y, z1], dim=1)
        y = self.dec1(y)
        return self.OutLayer(y)

# Cargar el dataset
root = "BrainTrain"
dataset = BrainDataset2D(root)

# Crear DataLoader
trainLoader = DataLoader(dataset, batch_size=1, shuffle=True)  # Batch size 1 para cortes 2D
evalLoader = DataLoader(dataset, batch_size=1, shuffle=False)

# Modelo, función de pérdida y optimizador
model = Unet2D(n_chan=4, n_classes=4).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Entrenamiento
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for slices in trainLoader:
        for slice_voxel, slice_seg in slices[0]:  # slices[0] porque el batch_size es 1
            slice_voxel, slice_seg = slice_voxel.to(device), slice_seg.to(device)
            optimizer.zero_grad()
            outputs = model(slice_voxel.unsqueeze(0))  # Añadir dimensión de batch
            loss = criterion(outputs, slice_seg.unsqueeze(0))  # Añadir dimensión de batch
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(trainLoader):.4f}")

# Evaluación
model.eval()
eval_loss = 0
with torch.no_grad():
    for slices in evalLoader:
        for slice_voxel, slice_seg in slices[0]:
            slice_voxel, slice_seg = slice_voxel.to(device), slice_seg.to(device)
            outputs = model(slice_voxel.unsqueeze(0))
            loss = criterion(outputs, slice_seg.unsqueeze(0))
            eval_loss += loss.item()
    avg_eval_loss = eval_loss / len(evalLoader)
print(f"Eval Loss: {avg_eval_loss:.4f}")

# Guardar el modelo
torch.save(model.state_dict(), "model_seg_2d.pth")

ValueError: not enough values to unpack (expected 2, got 1)