<a href="https://colab.research.google.com/github/UserJorge009/Sistema_MazaHuaman/blob/master/Unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# U-Net Segmentación de Conjuntiva Palpebral

# --- 1. Instalación de librerías ---
!pip install segmentation-models-pytorch albumentations --quiet

In [None]:
# --- 2. Imports ---
import os
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import matplotlib.pyplot as plt

# --- 3. Variables de rutas ---
CSV_PATH = '/content/drive/MyDrive/Anemia/dataset-maza-rodas/dataset-unet/dataset_augmented.csv'
IMG_DIR = '/content/drive/MyDrive/Anemia/dataset-maza-rodas/dataset-unet/augmented/images'
MASK_DIR = '/content/drive/MyDrive/Anemia/dataset-maza-rodas/dataset-unet/augmented/masks'

# --- 4. Dataset personalizado ---
class ConjuntivaDataset(Dataset):
    def __init__(self, csv_file, img_dir, mask_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.data.iloc[idx]['filename'])
        mask_path = os.path.join(self.mask_dir, self.data.iloc[idx]['mask'])
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"))
        mask = (mask > 127).astype(np.float32)  # binariza 0/1

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']                         # torch.float32, [3, H, W], [0, 1]
            mask = augmented['mask']
            # Forzar el tipo float32 y canal único
            if isinstance(mask, np.ndarray):
                mask = torch.from_numpy(mask)
            if mask.ndim == 2:
                mask = mask.unsqueeze(0)
            mask = mask.float()
        else:
            image = torch.tensor(image, dtype=torch.float32).permute(2,0,1) / 255.0  # [3, H, W]
            mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)              # [1, H, W]

        # Garantizar que la imagen es float32
        if not isinstance(image, torch.Tensor):
            image = torch.tensor(image, dtype=torch.float32)
        image = image.float()

        return image, mask

# --- 5. Transformaciones ---
transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    ToTensorV2(),
])

# --- 6. Crear DataLoader ---
dataset = ConjuntivaDataset(CSV_PATH, IMG_DIR, MASK_DIR, transform=transform)
train_loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=2)

# --- 7. Definir U-Net ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
    activation=None,
).to(device)

# --- 8. Definir optimizador y pérdida ---
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()

# --- 9. Bucle de entrenamiento ---
epochs = 100  # puedes ajustar
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    for images, masks in tqdm(train_loader):
        images = images.to(device).float()  # Forzar float por seguridad
        masks = masks.to(device).float()    # Forzar float por seguridad
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(train_loader):.4f}")

# --- 10. Guardar el modelo ---
torch.save(model.state_dict(), '/content/unet_conjuntiva2.pth')

# --- 11. Visualización de predicción ---
model.eval()
with torch.no_grad():
    images, masks = next(iter(train_loader))
    images = images.to(device).float()
    outputs = model(images)
    preds = torch.sigmoid(outputs).cpu().numpy() > 0.5
    images = images.cpu().numpy()
    masks = masks.cpu().numpy()

    plt.figure(figsize=(12,8))
    for i in range(min(4, images.shape[0])):
        plt.subplot(3,4,i+1)
        plt.imshow(np.transpose(images[i], (1,2,0)))
        plt.title("Imagen")
        plt.axis('off')
        plt.subplot(3,4,i+5)
        plt.imshow(masks[i][0], cmap='gray')
        plt.title("Máscara Real")
        plt.axis('off')
        plt.subplot(3,4,i+9)
        plt.imshow(preds[i][0], cmap='gray')
        plt.title("Predicción")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

Probar el modelo con varias imágenes

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# --- Asegurarse de que el modelo esté en modo evaluación-> utiliza el modelo entrenado de la celda anterior ---
model.eval()

# --- Seleccionar aleatoriamente 8 índices del dataset ---
num_samples = 8
indices = np.random.choice(len(dataset), num_samples, replace=False)

# --- Obtener imágenes y máscaras correspondientes ---
images = []
masks = []
for idx in indices:
    img, msk = dataset[idx]
    images.append(img)
    masks.append(msk)

images = torch.stack(images).to(device).float()    # [8, 3, H, W]
masks = torch.stack(masks).to(device).float()      # [8, 1, H, W]

# --- Predicción ---
with torch.no_grad():
    outputs = model(images)
    preds = torch.sigmoid(outputs).cpu().numpy() > 0.5
    images_cpu = images.cpu().numpy()
    masks_cpu = masks.cpu().numpy()

# --- Visualización ---
import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 9))
for i in range(num_samples):
    # --- Imagen original ---
    plt.subplot(3, num_samples, i+1)
    img = images_cpu[i]
    # Asegurarse de que el rango es correcto
    if img.shape[0] == 3:  # [3, H, W]
        img_to_show = np.transpose(img, (1, 2, 0))
        # Si está en rango [0,1], lo deja, si no, lo normaliza
        if img_to_show.max() > 1.0:
            img_to_show = img_to_show / 255.0
        plt.imshow(img_to_show)
    elif img.shape[0] == 1:  # [1, H, W], escala de grises
        plt.imshow(img[0], cmap='gray')
    else:
        plt.imshow(img)
    plt.title("Imagen")
    plt.axis('off')

    # --- Máscara real ---
    plt.subplot(3, num_samples, num_samples+i+1)
    plt.imshow(masks_cpu[i][0], cmap='gray')
    plt.title("Máscara Real")
    plt.axis('off')

    # --- Predicción ---
    plt.subplot(3, num_samples, 2*num_samples+i+1)
    plt.imshow(preds[i][0], cmap='gray')
    plt.title("Predicción")
    plt.axis('off')
plt.tight_layout()
plt.show()