# Laboratorio 4: Modelos CNN #
*Maria-Ignacia Rojas*

1. Implementar e evaluar el desempe침o de una red neuronal U-net en el problema de segmentaci칩n de esclerosis m칰ltiple.

In [None]:
import os
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from scipy.ndimage import zoom
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("游댲 Usando:", device)

In [None]:
class MSDataset(Dataset):
    def __init__(self, base_dir):
        self.samples = []
        patients = sorted([p for p in os.listdir(base_dir) if p.startswith("patient_")])
        print(f"游댳 Cargando pacientes: {patients}")

        for patient in patients:
            p_dir = os.path.join(base_dir, patient)
            files = {
                "FLAIR": os.path.join(p_dir, "FLAIR.nii.gz"),
                "T1W": os.path.join(p_dir, "T1W.nii.gz"),
                "T1WKS": os.path.join(p_dir, "T1WKS.nii.gz"),
                "T2W": os.path.join(p_dir, "T2W.nii.gz"),
                "mask": os.path.join(p_dir, "consensus_gt.nii.gz"),
            }

            if not all(os.path.exists(f) for f in files.values()):
                print(f"丘멆잺 Archivos incompletos en {patient}, se omite.")
                continue

            imgs = [nib.load(files[m]).get_fdata() for m in ["FLAIR", "T1W", "T1WKS", "T2W"]]
            mask = nib.load(files["mask"]).get_fdata()

            # --- Alinear formas ---
            target_shape = imgs[0].shape
            def resize_to_match(img, target_shape):
                factors = [t / s for t, s in zip(target_shape, img.shape)]
                return zoom(img, factors, order=1)
            
            imgs = [resize_to_match(img, target_shape) for img in imgs]
            mask = resize_to_match(mask, target_shape)

            # --- Normalizaci칩n (media y desviaci칩n est치ndar por volumen) ---
            norm_imgs = []
            for img in imgs:
                img = (img - np.mean(img)) / (np.std(img) + 1e-6)
                img = np.clip(img, -5, 5)
                img = (img - img.min()) / (img.max() - img.min() + 1e-6)
                norm_imgs.append(img)
            imgs = norm_imgs

            # --- Slices 칰tiles ---
            for s in range(target_shape[2]):
                y = mask[:, :, s]
                if y.sum() > 0:
                    x = np.stack([img[:, :, s] for img in imgs], axis=0)
                    self.samples.append((x.astype(np.float32), y.astype(np.float32)))

        print(f"Dataset creado con {len(self.samples)} slices 칰tiles.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x, y = self.samples[idx]
        return torch.tensor(x), torch.tensor(y).unsqueeze(0)

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNet2D(nn.Module):
    def __init__(self, in_ch=4, out_ch=1):
        super().__init__()
        self.d1 = DoubleConv(in_ch, 64)
        self.d2 = DoubleConv(64, 128)
        self.d3 = DoubleConv(128, 256)
        self.d4 = DoubleConv(256, 512)
        self.bottom = DoubleConv(512, 1024)

        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.uconv4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.uconv3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.uconv2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.uconv1 = DoubleConv(128, 64)

        self.out_conv = nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        x1 = self.d1(x)
        x2 = self.d2(F.max_pool2d(x1, 2))
        x3 = self.d3(F.max_pool2d(x2, 2))
        x4 = self.d4(F.max_pool2d(x3, 2))
        x5 = self.bottom(F.max_pool2d(x4, 2))

        x = self.up4(x5)
        x = self.uconv4(torch.cat([x, x4], dim=1))
        x = self.up3(x)
        x = self.uconv3(torch.cat([x, x3], dim=1))
        x = self.up2(x)
        x = self.uconv2(torch.cat([x, x2], dim=1))
        x = self.up1(x)
        x = self.uconv1(torch.cat([x, x1], dim=1))
        return torch.sigmoid(self.out_conv(x))

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("游댲 Usando:", device)

base_dir = r"C:\Users\miroj\Downloads\data"
ds = MSDataset(base_dir)
dl = DataLoader(ds, batch_size=2, shuffle=True)

model = UNet2D().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.BCELoss()

n_epochs = 5  # ajusta seg칰n tu GPU
for epoch in range(n_epochs):
    model.train()
    running_loss = 0
    for X, y in tqdm(dl, desc=f"칄poca {epoch+1}/{n_epochs}"):
        X, y = X.to(device), y.to(device)
        opt.zero_grad()
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        loss.backward()
        opt.step()
        running_loss += loss.item()
    print(f"游댳 Loss promedio: {running_loss / len(dl):.4f}")

torch.save(model.state_dict(), "unet_ms_trained.pth")
print("Modelo guardado como unet_ms_trained.pth")

In [None]:
model.eval()
idx = random.randint(0, len(ds)-1)
X, y_true = ds[idx]
with torch.no_grad():
    y_pred = model(X.unsqueeze(0).to(device)).cpu().squeeze().numpy()

X_show = X[0]  # canal FLAIR
fig, axs = plt.subplots(1, 3, figsize=(12,4))
axs[0].imshow(X_show, cmap="gray")
axs[0].set_title("FLAIR original")
axs[1].imshow(y_true.squeeze(), cmap="gray")
axs[1].set_title("M치scara real")
axs[2].imshow(y_pred>0.5, cmap="gray")
axs[2].set_title("Predicci칩n U-Net")
for a in axs: a.axis("off")
plt.show()