In [None]:
class CloudTestDataset(Dataset):
    def __init__(self, images_dir, masks_dir, size=(256,256)):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.size = size

        self.files = sorted([
            f for f in os.listdir(images_dir)
            if f.endswith(('.tif'))
        ])

        self.img_tf = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485,0.456,0.406],
                std=[0.229,0.224,0.225]
            )
        ])

        self.mask_tf = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.NEAREST),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: (x > 0.5).float())
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]

        image = Image.open(
            os.path.join(self.images_dir, fname)
        ).convert("RGB")

        mask = Image.open(
            os.path.join(self.masks_dir, fname)
        ).convert("L")

        return self.img_tf(image), self.mask_tf(mask), fname

In [None]:
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

device = "cuda" if torch.cuda.is_available() else "cpu"

# ================================
# Dataset para test o inferencia
# ================================
class CloudTestDataset(Dataset):
    def __init__(self, images_dir, masks_dir, size=(256,256)):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.size = size

        # Solo archivos .tif
        self.files = sorted([f for f in os.listdir(images_dir) if f.endswith('.tif')])

        self.img_tf = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        ])

        self.mask_tf = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.NEAREST),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: (x > 0.5).float())
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]

        image = Image.open(os.path.join(self.images_dir, fname)).convert("RGB")
        mask = Image.open(os.path.join(self.masks_dir, fname)).convert("L")

        return self.img_tf(image), self.mask_tf(mask), fname

# ================================
# Dataset para entrenamiento de reconstrucción
# ================================
class CloudReconstructionDataset(Dataset):
    def __init__(self, cloudy_dir, mask_dir, clean_dir, size=(256,256)):
        self.cloudy_dir = cloudy_dir
        self.mask_dir = mask_dir
        self.clean_dir = clean_dir
        self.size = size

        cloudy_files = set(os.listdir(cloudy_dir))
        mask_files = set(os.listdir(mask_dir))
        clean_files = set(os.listdir(clean_dir))

        # Archivos comunes a las tres carpetas
        self.files = sorted(list(cloudy_files & mask_files & clean_files))
        if len(self.files) == 0:
            raise RuntimeError("No hay archivos válidos en las tres carpetas")

        self.tf = transforms.ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        cloudy_img = Image.open(os.path.join(self.cloudy_dir, fname)).convert("RGB").resize(self.size)
        mask_img = Image.open(os.path.join(self.mask_dir, fname)).convert("L").resize(self.size)
        clean_img = Image.open(os.path.join(self.clean_dir, fname)).convert("RGB").resize(self.size)

        return self.tf(cloudy_img), self.tf(mask_img), self.tf(clean_img)

# ================================
# Modelo de reconstrucción
# ================================
class TerrainReconstructor(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.enc = nn.ModuleList(list(base.children()))

        self.up = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        e1 = self.enc[0:3](x)
        e2 = self.enc[4](e1)
        e3 = self.enc[5](e2)
        e4 = self.enc[6](e3)
        e5 = self.enc[7](e4)
        return self.up(e5)

# ================================
# Pérdida de reconstrucción
# ================================
class ReconstructionLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.L1Loss()

    def forward(self, generated, target):
        return self.l1(generated, target)

# ================================
# Inicialización
# ================================
model_rec = TerrainReconstructor().to(device)
optimizer_rec = torch.optim.Adam(model_rec.parameters(), lr=2e-4, betas=(0.5, 0.999))
criterion_rec = ReconstructionLoss()

# ================================
# Entrenamiento
# ================================
def train_reconstruction(recon_loader, epochs=10):
    model_rec.train()
    for epoch in range(epochs):
        total_loss = 0
        for cloudy, mask, clean in recon_loader:
            cloudy, mask, clean = cloudy.to(device), mask.to(device), clean.to(device)

            with torch.no_grad():
                pred_mask = mask  # O usar tu modelo de máscara: model(cloudy)

            outputs = model_rec(cloudy)
            loss = criterion_rec(outputs, clean)

            optimizer_rec.zero_grad()
            loss.backward()
            optimizer_rec.step()

            total_loss += loss.item()

        print(f"[Recon] Epoch {epoch+1} Loss: {total_loss/len(recon_loader):.4f}")

# ================================
# Ejemplo de uso
# ================================
# Dataset de reconstrucción
recon_dataset = CloudReconstructionDataset("overall-mask", "masked", "temporal")
recon_loader = DataLoader(recon_dataset, batch_size=4, shuffle=True)

# Entrenamiento
train_reconstruction(recon_loader, epochs=5)

