In [None]:
!pip install -q segmentation-models-pytorch
!pip install -q torch
!pip install -q torchvision
!pip install -q albumentations
!pip install -q huggingface_hub
!pip install -q datasets
!pip install -q rasterio

In [None]:
from datasets import load_dataset

# Load the metadata for the high-quality subset from the Hugging Face Hub.
# This is an extremely fast operation as it does not download the image files.
cloudsen12_hq = load_dataset("csaybar/CloudSEN12-high")

# Print the dataset structure to confirm successful loading
print(cloudsen12_hq)

In [None]:
!pip install rasterio

In [None]:
!pip install segmentation_models_pytorch

In [18]:
import os
import torch
import rasterio
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from datasets import load_dataset
import torch.nn as nn
import torch.nn.functional as F

# --- Configuration ---
# This dictionary holds all the key parameters for our experiment.
CONFIG = {
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "BATCH_SIZE": 16,
    "NUM_WORKERS": 2,
    "EPOCHS": 25,
    "LEARNING_RATE": 1e-4,
    "WEIGHT_DECAY": 1e-2,
    # Using the 10-band set recommended for better cloud/snow/surface differentiation.
    "BANDS_TO_USE": ['B01', 'B02', 'B04', 'B05', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'],
    "ENCODER": "mobilenet_v2", # Efficient and powerful encoder
}

print(f"Using device: {CONFIG['DEVICE']}")

# --- Data Acquisition ---
# Load the dataset metadata from the Hugging Face Hub.
# The `name` parameter is crucial to select the correct data configuration.
print("Loading dataset metadata from Hugging Face Hub...")
cloudsen12_hq = load_dataset("ca-saybar/CloudSEN12-high", trust_remote_code=True)
print("Dataset metadata loaded successfully.")
print(cloudsen12_hq)


# --- Custom PyTorch Dataset for Hugging Face data ---
class CloudDataset(Dataset):
    """
    Custom PyTorch Dataset for the CloudSEN12 dataset loaded from Hugging Face.
    """
    def __init__(self, hf_dataset, bands, augmentations=None):
        """
        Args:
            hf_dataset (datasets.Dataset): The Hugging Face dataset object (e.g., cloudsen12_hq['train']).
            bands (list): List of Sentinel-2 band names to use.
            augmentations (A.Compose, optional): Albumentations pipeline.
        """
        self.dataset = hf_dataset
        self.bands = bands
        self.augmentations = augmentations

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # --- Load Image ---
        # Get the paths from the Hugging Face dataset item.
        s2_l1c_dir = item['s2_l1c_path']
        label_path = item['manual_hq_path']

        band_paths = [os.path.join(s2_l1c_dir, f'{band}.tif') for band in self.bands]
        
        image_stack = []
        for band_path in band_paths:
            with rasterio.open(band_path) as src:
                band_data = src.read(1, out_shape=(256, 256), resampling=rasterio.enums.Resampling.bilinear)
                image_stack.append(band_data)
        
        image = np.stack(image_stack, axis=-1).astype(np.float32)
        
        # --- Load and Process Mask ---
        with rasterio.open(label_path) as src:
            mask = src.read(1, out_shape=(256, 256), resampling=rasterio.enums.Resampling.nearest)
            
        # Class 2 (Thick Cloud) and 3 (Thin Cloud) are considered "cloud".
        binary_mask = np.where((mask == 2) | (mask == 3), 1, 0).astype(np.float32)
        
        # --- Preprocessing and Augmentation ---
        # Normalize by the standard value for Sentinel-2 L1C products
        image /= 10000.0
        
        if self.augmentations:
            augmented = self.augmentations(image=image, mask=binary_mask)
            image = augmented['image']
            mask = augmented['mask']
        
        # Ensure mask has a channel dimension for the loss function.
        # After ToTensorV2, the mask is (H, W). We need (1, H, W).
        if mask.ndim == 2:
            mask = mask.unsqueeze(0)
            
        return image, mask

# --- Augmentations ---
train_augs = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    ToTensorV2()
])

val_augs = A.Compose([
    ToTensorV2()
])

# --- Datasets and DataLoaders ---
print("Setting up datasets and dataloaders...")
train_dataset = CloudDataset(cloudsen12_hq['train'], bands=CONFIG['BANDS_TO_USE'], augmentations=train_augs)
val_dataset = CloudDataset(cloudsen12_hq['validation'], bands=CONFIG['BANDS_TO_USE'], augmentations=val_augs)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=CONFIG['NUM_WORKERS'], pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=CONFIG['NUM_WORKERS'], pin_memory=True)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")


# --- Loss Function ---
class DiceBCELoss(nn.Module):
    """Combined Dice and BCE loss for robust segmentation."""
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        # inputs are raw logits from the model
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='mean')
        
        inputs_prob = torch.sigmoid(inputs)
        inputs_prob = inputs_prob.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs_prob * targets).sum()
        dice_loss = 1 - (2. * intersection + smooth) / (inputs_prob.sum() + targets.sum() + smooth)
        
        return bce_loss + dice_loss

# --- Model, Loss, and Optimizer ---
print("Initializing model...")
model = smp.Unet(
    encoder_name=CONFIG['ENCODER'],
    encoder_weights="imagenet",
    in_channels=len(CONFIG['BANDS_TO_USE']),
    classes=1,
    activation=None # Output raw logits for use with BCEWithLogitsLoss
)
model.to(CONFIG['DEVICE'])

loss_fn = DiceBCELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['LEARNING_RATE'], weight_decay=CONFIG['WEIGHT_DECAY'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)

# --- Training Loop ---
best_val_iou = 0.0
history = {'train_loss': [], 'val_loss': [], 'val_iou': []}

for epoch in range(CONFIG['EPOCHS']):
    print(f"--- Epoch {epoch+1}/{CONFIG['EPOCHS']} ---")
    
    # Training Phase
    model.train()
    train_loss = 0.0
    for images, masks in tqdm(train_loader, desc="Training"):
        images, masks = images.to(CONFIG['DEVICE']), masks.to(CONFIG['DEVICE'])
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
    avg_train_loss = train_loss / len(train_loader)
    history['train_loss'].append(avg_train_loss)
    
    # Validation Phase
    model.eval()
    val_loss = 0.0
    total_iou = 0.0
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Validating"):
            images, masks = images.to(CONFIG['DEVICE']), masks.to(CONFIG['DEVICE'])
            outputs = model(images)
            loss = loss_fn(outputs, masks)
            val_loss += loss.item()
            
            preds = torch.sigmoid(outputs) > 0.5
            tp, fp, fn, tn = smp.metrics.get_stats(preds.long(), masks.long(), mode='binary')
            iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction='micro')
            total_iou += iou.item()
            
    avg_val_loss = val_loss / len(val_loader)
    avg_iou = total_iou / len(val_loader)
    history['val_loss'].append(avg_val_loss)
    history['val_iou'].append(avg_iou)
    
    scheduler.step(avg_val_loss)
    
    print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val IoU: {avg_iou:.4f}")
    
    if avg_iou > best_val_iou:
        best_val_iou = avg_iou
        torch.save(model.state_dict(), 'best_cloud_model.pth')
        print(f"Model saved with new best IoU: {best_val_iou:.4f}")

print("\n--- Training Finished ---")

# --- Visualization ---
def visualize_predictions(dataset, model, num_samples=5):
    """Plots the input image, true mask, and predicted mask."""
    if not os.path.exists('best_cloud_model.pth'):
        print("No model file found ('best_cloud_model.pth'). Skipping visualization.")
        return
        
    model.load_state_dict(torch.load('best_cloud_model.pth'))
    model.to(CONFIG['DEVICE'])
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, num_samples * 5))
    fig.suptitle("Model Predictions vs. Ground Truth", fontsize=20)
    
    for i in range(num_samples):
        idx = np.random.randint(0, len(dataset))
        image, mask = dataset[idx] # Tensors are (C, H, W)
        
        with torch.no_grad():
            input_tensor = image.unsqueeze(0).to(CONFIG['DEVICE'])
            pred_logits = model(input_tensor)
            pred_mask = (torch.sigmoid(pred_logits) > 0.5).float().cpu().squeeze(0)
            
        # Create a false-color composite for visualization (SWIR-NIR-Red)
        # Band indices in our 10-band list: B11 (8), B08 (4), B04 (2)
        img_vis_bands = image.numpy()[[8, 4, 2], :, :]
        # Transpose from (C, H, W) to (H, W, C) for plotting
        img_vis = img_vis_bands.transpose(1, 2, 0)
        
        # Clip and scale for better visualization
        img_vis = np.clip(img_vis, 0, 0.3) / 0.3
        
        true_mask = mask.squeeze().numpy()
        
        axes[i, 0].imshow(img_vis)
        axes[i, 0].set_title(f"Input Image (False Color) #{idx}")
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(true_mask, cmap='gray')
        axes[i, 1].set_title("Ground Truth Mask")
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(pred_mask.squeeze(), cmap='gray')
        axes[i, 2].set_title("Predicted Mask")
        axes[i, 2].axis('off')
        
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# Visualize some results from the validation set
visualize_predictions(val_dataset, model)

Using device: cuda
Loading dataset metadata from Hugging Face Hub...


DatasetNotFoundError: Dataset 'ca-saybar/CloudSEN12-high' doesn't exist on the Hub or cannot be accessed.

In [14]:
print(cloudsen12_hq['train'].features)


{'index': Value(dtype='int64', id=None), 'annotator_name': Value(dtype='string', id=None), 'roi_id': Value(dtype='string', id=None), 's2_id_gee': Value(dtype='string', id=None), 's2_id': Value(dtype='string', id=None), 's2_date': Value(dtype='string', id=None), 's2_sen2cor_version': Value(dtype='string', id=None), 's2_fmask_version': Value(dtype='string', id=None), 's2_cloudless_version': Value(dtype='string', id=None), 's2_reflectance_conversion_correction': Value(dtype='float64', id=None), 's2_aot_retrieval_accuracy': Value(dtype='int64', id=None), 's2_water_vapour_retrieval_accuracy': Value(dtype='int64', id=None), 's2_view_off_nadir': Value(dtype='int64', id=None), 's2_view_sun_azimuth': Value(dtype='float64', id=None), 's2_view_sun_elevation': Value(dtype='float64', id=None), 's1_id': Value(dtype='string', id=None), 's1_date': Value(dtype='string', id=None), 's1_grd_post_processing_software_name': Value(dtype='string', id=None), 's1_grd_post_processing_software_version': Value(dty

In [20]:
import os
import shutil

# The standard cache directory for Hugging Face datasets
cache_dir = os.path.expanduser("~/.cache/huggingface/datasets")

# Check if the directory exists and remove it completely
if os.path.exists(cache_dir):
    print(f"Found cache directory: {cache_dir}")
    print("Removing it now...")
    shutil.rmtree(cache_dir)
    print("Cache cleared successfully. ✅")
else:
    print("Cache directory not found, no action needed.")

Found cache directory: /root/.cache/huggingface/datasets
Removing it now...
Cache cleared successfully. ✅


In [21]:
from datasets import load_dataset

print("Attempting to load the dataset from the Hub...")

# This should now work by forcing a fresh download
cloudsen12_hq = load_dataset("ca-saybar/CloudSEN12-high", trust_remote_code=True)

print("Dataset loaded successfully! 🎉")
print(cloudsen12_hq)

Attempting to load the dataset from the Hub...


DatasetNotFoundError: Dataset 'ca-saybar/CloudSEN12-high' doesn't exist on the Hub or cannot be accessed.