In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

class CloudDataset(Dataset):
    def __init__(self, images_dir, masks_dir, size=(256, 256)):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.size = size 

        self.image_files = sorted([f for f in os.listdir(images_dir) if f.endswith(('.tif'))])
        self.mask_files = sorted([f for f in os.listdir(masks_dir) if f.endswith(('.tif'))])

        # Transformaciones fijas para imagen (incluye normalización)
        self.img_transform = transforms.Compose([
            transforms.Resize(self.size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        # Transformaciones fijas para máscara (SIN normalización, solo resize y tensor)
        self.mask_transform = transforms.Compose([
            transforms.Resize(self.size, interpolation=transforms.InterpolationMode.NEAREST),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: (x > 0.5).float())
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.images_dir, self.image_files[idx])
        mask_path = os.path.join(self.masks_dir, self.mask_files[idx])

        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        return self.img_transform(image), self.mask_transform(mask)

# Ahora creas el dataset así:
train_dataset = CloudDataset(images_dir="overall-mask", masks_dir="masked", size=(256, 256))

# Definir transformaciones para las imágenes de entrenamiento
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Redimensionar si es necesario
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Normalización estándar
                         std=[0.229, 0.224, 0.225])
])


# Crear el DataLoader para batch training
batch_size = 8
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2  # Número de procesos para carga paralela
)

# Probar que funciona
print(f"Número de imágenes en el dataset: {len(train_dataset)}")

# Ver un batch de datos
for images, masks in train_loader:
    print(f"Tamaño del batch de imágenes: {images.shape}")
    print(f"Tamaño del batch de máscaras: {masks.shape}")
    print(f"Rango de valores en imágenes: [{images.min():.3f}, {images.max():.3f}]")
    print(f"Valores únicos en máscaras: {torch.unique(masks)}")
    break

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class CloudAttentionUNet(nn.Module):
    def __init__(self):
        super().__init__()
        # --- ENCODER (ResNet18) ---
        base = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.enc0 = nn.Sequential(*list(base.children())[:3]) # [64, 128, 128]
        self.enc1 = base.layer1 # [64, 128, 128]
        self.enc2 = base.layer2 # [128, 64, 64]
        self.enc3 = base.layer3 # [256, 32, 32]
        self.enc4 = base.layer4 # [512, 16, 16] (Bottleneck)

        # --- DECODER + ATTENTION ---
        # Up 1: 16x16 -> 32x32
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.conv3 = nn.Conv2d(512, 256, kernel_size=3, padding=1)

        # Up 2: 32x32 -> 64x64
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.conv2 = nn.Conv2d(256, 128, kernel_size=3, padding=1)

        # Up 3: 64x64 -> 128x128
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
        self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1)

        # Capa Final de Salida (256x256)
        self.final_up = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.final_conv = nn.Conv2d(32, 1, kernel_size=1)

    def forward(self, x):
        # Encoder
        x0 = self.enc0(x)
        x1 = self.enc1(x0)
        x2 = self.enc2(x1)
        x3 = self.enc3(x2)
        x4 = self.enc4(x3) # Bottleneck

        # Decoder con Skip Connections y Atención
        d3 = self.up3(x4)
        s3 = self.att3(g=d3, x=x3)
        d3 = torch.cat((s3, d3), dim=1) # Concatenación (512 canales)
        d3 = F.relu(self.conv3(d3))

        d2 = self.up2(d3)
        s2 = self.att2(g=d2, x=x2)
        d2 = torch.cat((s2, d2), dim=1) # Concatenación (256 canales)
        d2 = F.relu(self.conv2(d2))

        d1 = self.up1(d2)
        s1 = self.att1(g=d1, x=x1)
        d1 = torch.cat((s1, d1), dim=1) # Concatenación (128 canales)
        d1 = F.relu(self.conv1(d1))

        # Salida final escalada al tamaño original
        out = self.final_up(d1)
        out = self.final_conv(out)

        return torch.sigmoid(out)

In [None]:
class DiceBCELoss(nn.Module):
    def __init__(self):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        # Binary Cross Entropy
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')

        # Dice Loss
        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)

        return BCE + dice_loss

# Instanciar modelo, loss y optimizador
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CloudAttentionUNet().to(device)
criterion = DiceBCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
def train_model(epochs=10):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Época [{epoch+1}/{epochs}], Pérdida: {epoch_loss/len(train_loader):.4f}")

    # GUARDAR EL MODELO
    torch.save(model, 'cloud_segmentation_model_2.pth')
    print("Modelo guardado como 'cloud_segmentation_model.pth'")

# Ejecutar entrenamiento
train_model(epochs=30)