In [None]:
filenames_x = []
filenames_y = []

base_path = "/kaggle/input/smile-cr/TrainData/TrainData"

# Use tqdm to show progress while walking directories
for dirname, _, filenames in tqdm(os.walk(base_path), desc="Walking through directories"):
    for filename in filenames:
        file_path = os.path.join(dirname, filename)

        if "CloudLandsat_2020" in dirname:
            filenames_x.append(file_path)
        elif "Landsat-8_2020" in dirname:
            filenames_y.append(file_path)

# Extract just the filenames (without path)
names_x = [os.path.basename(f) for f in filenames_x]
names_y = [os.path.basename(f) for f in filenames_y]

# Sort both lists
names_x.sort()
names_y.sort()

# Compare
print("Total in X:", len(names_x))
print("Total in Y:", len(names_y))

# Show mismatches if any
mismatches = set(names_x).symmetric_difference(set(names_y))
if mismatches:
    print("Mismatched filenames:")
    print(mismatches)
else:
    print("Filenames match perfectly!")

In [None]:
filenames_x_val = []
filenames_y_val = []

base_path = "/kaggle/input/smile-cr/ValData/ValData"

# Use tqdm to show progress while walking directories
for dirname, _, filenames in tqdm(os.walk(base_path), desc="Walking through directories"):
    for filename in filenames:
        file_path = os.path.join(dirname, filename)
        if "CloudLandsat_2020" in dirname:
            filenames_x_val.append(file_path)
        elif "Landsat-8_2020" in dirname:
            filenames_y_val.append(file_path)

# Extract just the filenames (without path)
names_x = [os.path.basename(f) for f in filenames_x_val]
names_y = [os.path.basename(f) for f in filenames_y_val]

# Sort both lists
names_x.sort()
names_y.sort()

# Compare
print("Total in X:", len(names_x))
print("Total in Y:", len(names_y))

# Show mismatches if any
mismatches = set(names_x).symmetric_difference(set(names_y))
if mismatches:
    print("Mismatched filenames:")
    print(mismatches)
else:
    print("Filenames match perfectly!")

In [None]:
import tifffile as tiff
import os
from tqdm import tqdm

def filter_valid_images(x_list, y_list):
    valid_x = []
    valid_y = []
    
    # Creamos un diccionario para búsqueda rápida (asumiendo que los nombres base coinciden)
    # y vinculamos X con Y
    for x_path, y_path in tqdm(zip(x_list, y_list), total=len(x_list), desc="Verificando archivos"):
        try:
            # Intentamos abrir ambos archivos con tifffile
            test_x = tiff.imread(x_path)
            test_y = tiff.imread(y_path)
            
            # Si no falló, los agregamos a la lista definitiva
            valid_x.append(x_path)
            valid_y.append(y_path)
        except Exception as e:
            print(f"⚠️ Saltando archivo corrupto o incompatible: {os.path.basename(x_path)}")
            
    return valid_x, valid_y

# Aplicar a Train
filenames_x, filenames_y = filter_valid_images(filenames_x, filenames_y)
# Aplicar a Val
filenames_x_val, filenames_y_val = filter_valid_images(filenames_x_val, filenames_y_val)

print(f"✅ Archivos válidos para entrenamiento: {len(filenames_x)}")

In [None]:
import torch
import numpy as np
import rasterio
import cv2
import albumentations as A
from torch.utils.data import Dataset
import warnings

class CloudRemovalDataset(Dataset):
    def __init__(self, x_paths, y_paths, size=(1024, 1024), augment=False):
        self.x_paths = x_paths
        self.y_paths = y_paths
        self.size = size
        self.augment = augment

        # Definimos las mismas aumentaciones del paper
        self.transform = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.RandomBrightnessContrast(p=0.3),
        ], additional_targets={'mask': 'image'})

    def _load_tiff(self, path):
        try:
            with rasterio.open(path) as src:
                # El paper lee bandas 3, 2, 1 (RGB para Landsat)
                # rasterio devuelve (Bandas, H, W)
                image = src.read([3, 2, 1]).astype(np.float32)
                
                # Transponemos a (H, W, Bandas) para OpenCV y Albumentations
                image = np.transpose(image, (1, 2, 0))

                # Manejo de NaNs como en el paper
                if np.isnan(image).any():
                    return None
                
                # Resize a la resolución del paper
                image = cv2.resize(image, self.size, interpolation=cv2.INTER_LINEAR)
                
                # Normalización:
                # Primero asegurar 0-1 (Landsat suele ser 16-bit o ya venir escalado)
                if image.max() > 1.0:
                    # Si es 16-bit usamos 65535, si es 8-bit usamos 255
                    max_val = 65535.0 if image.max() > 255 else 255.0
                    image = image / max_val
                
                # Rango [-1, 1] según el paper
                image = (image * 2.0) - 1.0
                return image
                
        except Exception as e:
            warnings.warn(f"Error leyendo {path}: {e}")
            return None

    def __len__(self):
        return len(self.x_paths)

    def __getitem__(self, idx):
        x_img = self._load_tiff(self.x_paths[idx])
        y_img = self._load_tiff(self.y_paths[idx])

        # Si el archivo está corrupto o tiene NaNs, saltamos al siguiente
        if x_img is None or y_img is None:
            return self.__getitem__((idx + 1) % len(self))

        if self.augment:
            transformed = self.transform(image=x_img, mask=y_img)
            x_img, y_img = transformed["image"], transformed["mask"]

        # Convertir a Tensores (C, H, W)
        x_tensor = torch.from_numpy(x_img).permute(2, 0, 1)
        y_tensor = torch.from_numpy(y_img).permute(2, 0, 1)

        return x_tensor, y_tensor

In [None]:
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet, self).__init__()
        
        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )

        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)
        
        self.pool = nn.MaxPool2d(2)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)
        
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        s1 = self.enc1(x)
        s2 = self.enc2(self.pool(s1))
        s3 = self.enc3(self.pool(s2))
        s4 = self.enc4(self.pool(s3))
        
        d3 = self.dec3(torch.cat([self.up3(s4), s3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), s2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), s1], dim=1))

        out = self.final(d1)
        return torch.tanh(out) 

In [None]:
train_ds = CloudRemovalDataset(filenames_x, filenames_y, size=(256, 256), augment=True)
val_ds = CloudRemovalDataset(filenames_x_val, filenames_y_val, size=(256, 256), augment=False)

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=4)

# Modelo, Optimización y Pérdida
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.L1Loss() # MAE Loss

# Bucle de entrenamiento simplificado
for epoch in range(10):
    model.train()
    train_loss = 0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    print(f"Loss: {train_loss/len(train_loader):.4f}")

In [None]:
import torch
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
import matplotlib.pyplot as plt

def denormalize(tensor):
    """Pasa de [-1, 1] a [0, 1] para visualización y métricas"""
    return (tensor + 1.0) / 2.0

def evaluate_and_visualize(model, val_loader, device, num_samples=3):
    # ESPECIFICAR data_range=1.0 porque nuestras imágenes están en [0, 1]
    psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
    mae_metric = torch.nn.L1Loss()
    
    model.eval()
    total_psnr = 0
    total_ssim = 0
    total_mae = 0
    
    samples_shown = 0
    
    print("Iniciando evaluación...")
    
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            
            # Predicción
            output = model(x)
            
            # Desnormalizar para métricas y visualización
            output_dn = denormalize(output)
            y_dn = denormalize(y)
            x_dn = denormalize(x)
            
            # Calcular métricas del batch
            total_psnr += psnr_metric(output_dn, y_dn).item()
            total_ssim += ssim_metric(output_dn, y_dn).item()
            total_mae += mae_metric(output_dn, y_dn).item()
            
            # Visualizar ejemplos
            if samples_shown < num_samples:
                fig, ax = plt.subplots(1, 3, figsize=(15, 5))
                
                # Imagen con nubes (Input) - Clipping por seguridad para matplotlib
                ax[0].imshow(x_dn[0].cpu().permute(1, 2, 0).numpy().clip(0, 1))
                ax[0].set_title("Entrada (Con Nubes)")
                ax[0].axis('off')
                
                # Imagen generada (Predicción)
                ax[1].imshow(output_dn[0].cpu().permute(1, 2, 0).numpy().clip(0, 1))
                ax[1].set_title("Predicción (Cloud Removed)")
                ax[1].axis('off')
                
                # Imagen real limpia (Ground Truth)
                ax[2].imshow(y_dn[0].cpu().permute(1, 2, 0).numpy().clip(0, 1))
                ax[2].set_title("Real (Limpia)")
                ax[2].axis('off')
                
                plt.show()
                samples_shown += 1
                
    # Promedios finales
    avg_psnr = total_psnr / len(val_loader)
    avg_ssim = total_ssim / len(val_loader)
    avg_mae = total_mae / len(val_loader)
    
    print(f"\n--- Resultados de la Evaluación ---")
    print(f"PSNR Medio: {avg_psnr:.2f} dB")
    print(f"SSIM Medio: {avg_ssim:.4f}")
    print(f"MAE Medio: {avg_mae:.4f}")

# Ejecutar evaluación
evaluate_and_visualize(model, val_loader, device)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler 


class CloudRemovalGuidedDataset(Dataset):
    def __init__(self, x_paths, y_paths, seg_model, size=(256, 256)):
        self.x_paths = x_paths
        self.y_paths = y_paths
        self.seg_model = seg_model
        self.size = size
        self.seg_model.eval() 

    def _load_tiff(self, path):
        with rasterio.open(path) as src:
            img = src.read([3, 2, 1]).astype(np.float32)
            img = np.transpose(img, (1, 2, 0))
            img = cv2.resize(img, self.size)
            if img.max() > 1.0:
                max_val = 65535.0 if img.max() > 255 else 255.0
                img /= max_val
            return (img * 2.0) - 1.0

    def __len__(self): return len(self.x_paths)

    def __getitem__(self, idx):
        x_img = self._load_tiff(self.x_paths[idx])
        y_img = self._load_tiff(self.y_paths[idx])

        x_tensor = torch.from_numpy(x_img).permute(2, 0, 1).unsqueeze(0).to(device)
        with torch.no_grad():
            mask_pred = self.seg_model(x_tensor)
            
            mask = (torch.sigmoid(mask_pred) > 0.5).float().squeeze(0).cpu()
            if mask.dim() == 2: mask = mask.unsqueeze(0)

        x_rgb = torch.from_numpy(x_img).permute(2, 0, 1)
        
        input_4ch = torch.cat([x_rgb, mask], dim=0)
        
        input_4ch = input_4ch[:4, :, :] 
        
        target_3ch = torch.from_numpy(y_img).permute(2, 0, 1)

        return input_4ch, target_3ch


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_cr = UNetCR(n_channels=4, n_classes=3).to(device)
optimizer = torch.optim.Adam(model_cr.parameters(), lr=1e-4)
criterion = nn.L1Loss()
scaler = GradScaler('cuda') 

train_ds = CloudRemovalGuidedDataset(filenames_x, filenames_y, seg_model=model)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, drop_last=True)

print(f"Iniciando entrenamiento... Target input shape: [Batch, 4, {train_ds.size[0]}, {train_ds.size[1]}]")

for epoch in range(10):
    model_cr.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/10")
    
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        
        with autocast('cuda'):
            output = model_cr(x)
            mask_w = x[:, 3:4, :, :]
            loss = criterion(output, y) + (criterion(output * mask_w, y * mask_w) * 3.0)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': running_loss/len(train_loader)})

print(" Entrenamiento completado.")