In [1]:
!pip install -q rasterio  # -q is for quiet installation, non-verbose

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m68.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h

## Imports

In [2]:
import os
import torch
import numpy as np
import rasterio
import glob
from tqdm import tqdm
import random
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import h5py

## Parameters

In [3]:
# Paths and settings
BASE_PATH      = '/kaggle/input/sentinel2-crop-mapping'
H5_PATH        = '/kaggle/working/tiles_2016-18_uint8.h5'
REGIONS        = ['lombardia', 'lombardia2']
YEARS          = ['data2016', 'data2017', 'data2018']
NUM_TIMESTEPS  = 32
BATCH_SIZE     = 64
NUM_WORKERS    = 4
PIN_MEMORY     = True
LR             = 0.01
NUM_EPOCHS     = 10
DEVICE         = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Dataset Preprocessing and Loading

In [None]:
# 1) Preprocess: pack all tiles into one HDF5
if not os.path.exists(H5_PATH):
    print(f"Creating HDF5 archive at {H5_PATH}...")

    # gather tile directories
    tile_dirs = []
    for reg in REGIONS:
        for yr in YEARS:
            tile_dirs += glob.glob(os.path.join(BASE_PATH, reg, yr, '*'))
    N = len(tile_dirs)

    with h5py.File(H5_PATH, 'w') as hf:
        imgs = hf.create_dataset(
            'images',
            shape=(N, 9*NUM_TIMESTEPS, 48, 48),
            dtype='uint8',
            chunks=(1, 9*NUM_TIMESTEPS, 48, 48),
            compression='lzf')
        masks = hf.create_dataset(
            'masks',
            shape=(N, 48, 48),
            dtype='uint8',
            chunks=(1, 48, 48),
            compression='lzf')

        write_idx = 0
        for td in tqdm(tile_dirs, desc='Writing HDF5'):
            if not os.path.isdir(td):
                continue

            # list & sort your 32 multispectral .tifs
            tifs = sorted([
                f for f in os.listdir(td)
                if f.endswith('.tif') and '_MSAVI' not in f and f != 'y.tif'
            ])[:NUM_TIMESTEPS]

            # skip if too few files
            if len(tifs) < NUM_TIMESTEPS:
                continue

            # load & stack
            stack = []
            for f in tifs:
                with rasterio.open(os.path.join(td, f)) as src:
                    stack.append(src.read().astype(np.float32))
            img = np.stack(stack, axis=1).reshape(-1, 48, 48)

            # quantize
            img_min, img_max = img.min(), img.max()
            img_q = ((img - img_min) / (img_max - img_min) * 255.0)\
                      .round().astype(np.uint8)

            # load mask
            with rasterio.open(os.path.join(td, 'y.tif')) as src:
                mask = src.read(1).astype(np.uint8)
            mask[mask >= 20] = 255

            # write into HDF5
            imgs[write_idx]  = img_q
            masks[write_idx] = mask
            write_idx += 1

        # after the loop, resize to actual count
        hf['images'].resize(write_idx, axis=0)
        hf['masks'].resize(write_idx, axis=0)

    print("HDF5 archive created.")

Creating HDF5 archive at /kaggle/working/tiles_2016-18_uint8.h5...


Writing HDF5:  16%|█▋        | 3281/19969 [20:05<1:35:49,  2.90it/s]

In [None]:
# 2) Dataset: lazy-load from HDF5 and rescale
class HDF5SentinelDataset(Dataset):
    def __init__(self, h5_path, indices, transform=None):
        self.hf      = h5py.File(h5_path, 'r')
        self.imgs    = self.hf['images']
        self.masks   = self.hf['masks']
        self.indices = indices
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        i   = self.indices[idx]
        img = torch.from_numpy(self.imgs[i]).float() / 255.0
        msk = torch.from_numpy(self.masks[i]).long()
        if self.transform:
            img, msk = self.transform(img, msk)
        return img, msk

In [None]:
# 3) Prepare train/val splits and dataloaders
with h5py.File(H5_PATH, 'r') as hf:
    total = hf['images'].shape[0]
all_indices = list(range(total))
train_idx, val_idx = train_test_split(all_indices, test_size=0.2, random_state=42)

train_ds = HDF5SentinelDataset(H5_PATH, train_idx)
val_ds   = HDF5SentinelDataset(H5_PATH, val_idx)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

-> Shape of the input is (9*32, 48, 48) and of output is (48, 48) with batch size of 8.

In [None]:
# Visualizing the images (rgb composite for multispectral input)
def normalize_band(band):
    """Contrast stretch to 2–98 percentile"""
    p2, p98 = np.percentile(band, (2, 98))
    return np.clip((band - p2) / (p98 - p2), 0, 1)
    

# Print from train dataset
x_train, y_train = train_dataset[0]
print("Train Sample:")
print(f"x_train shape: {x_train.shape}")  # Expected: (288, 48, 48)
print(f"y_train shape: {y_train.shape}")  # Expected: (48, 48)")
print(f"x_train dtype: {x_train.dtype}")
print(f"y_train unique labels: {torch.unique(y_train)}")  # Sanity check

# Extract bands 4, 3, 2 from timestep 1
b2 = 1   # Blue
b3 = 2   # Green
b4 = 3   # Red

rgb = x_train[[b4, b3, b2]]  # Shape: (3, 48, 48)

r = normalize_band(rgb[0].numpy())
g = normalize_band(rgb[1].numpy())
b = normalize_band(rgb[2].numpy())

rgb_normalized = np.stack([r, g, b], axis=-1)

# Show corrected image
plt.imshow(rgb_normalized)
plt.title("RGB Composite (Timestep 1, Normalized) - Train Sample")
plt.axis("off")
plt.show()

# Show mask
plt.imshow(y_train, cmap='tab20')
plt.title("Ground Truth Mask - Train Sample")
plt.colorbar()
plt.show()



# Print from val dataset
x_val, y_val = val_dataset[0]
print("\nValidation Sample:")
print(f"x_val shape: {x_val.shape}")  # Expected: (288, 48, 48)
print(f"y_val shape: {y_val.shape}")  # Expected: (48, 48)")
print(f"x_val dtype: {x_val.dtype}")
print(f"y_val unique labels: {torch.unique(y_val)}")


# For Val data
rgb = x_val[[b4, b3, b2]]  # Shape: (3, 48, 48)

r = normalize_band(rgb[0].numpy())
g = normalize_band(rgb[1].numpy())
b = normalize_band(rgb[2].numpy())

rgb_normalized = np.stack([r, g, b], axis=-1)

# Show corrected image
plt.imshow(rgb_normalized)
plt.title("RGB Composite (Timestep 1, Normalized) - Val Sample")
plt.axis("off")
plt.show()

# Show mask
plt.imshow(y_val, cmap='tab20')
plt.title("Ground Truth Mask - Val Sample")
plt.colorbar()
plt.show()

## 2D-CNN Model 

In [None]:
in_channels = 9 * NUM_TIMESTEPS
num_classes = 20

lenet_5 = torch.nn.Sequential(
    torch.nn.Conv2d(in_channels, 6, kernel_size=5, padding=2), # C1: Conv (5x5), 6 filters:  (288 → 6), output: (6, 44, 44)
    torch.nn.Tanh(),
    torch.nn.AvgPool2d(kernel_size=2, stride=2),    # S1: Avg Pooling (2x2): output: (6, 22, 22)

    torch.nn.Conv2d(6, 16, kernel_size=5, padding=2),          # C2: Conv (5x5), 16 filters: output: (16, 18, 18)
    torch.nn.Tanh(),
    torch.nn.AvgPool2d(kernel_size=2, stride=2),    # S2: Avg Pooling (2x2): output: (16, 9, 9)

    torch.nn.Conv2d(16, num_classes, kernel_size=1),        # output: (20, 9, 9) — segmentation logits
    torch.nn.Upsample(size=(48, 48), mode='bilinear', align_corners=False)  # Upsample to (48, 48)
).to(DEVICE)

print(lenet_5)

-> The FC layers are removed and instead a Conv2D layer is used since we want the 2D segmentation map and FC layers flatten the 2d in 1d.

In [None]:
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

In [None]:
@torch.no_grad()
def pixel_accuracy(out, target):
    """
    Calculating pixel-wise accuracy by comparing the predicted class per pixel with the 
    ground truth and compute how many pixels were correctly classified.
    
    pred: tensor of shape (B, C, H, W) - raw logits
    target: tensor of shape (B, H, W) - class labels 
    """
    preds = out.argmax(dim=1)
    return (preds == target).float().mean().item()

## Training Loop

In [None]:
# 5) Training loop
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

for epoch in range(1, NUM_EPOCHS+1):
    model.train()
    t_loss, t_acc = 0.0, 0.0
    for x, y in tqdm(train_loader, desc=f'Epoch {epoch} Training'):
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        t_loss += loss.item()
        t_acc  += pixel_accuracy(out, y)
    t_loss /= len(train_loader)
    t_acc  /= len(train_loader)

    train_losses.append(t_loss)
    train_accuracies.append(t_acc)

    # Validation loop
    model.eval()
    v_loss, v_acc = 0.0, 0.0
    for x, y in tqdm(val_loader, desc=f'Epoch {epoch} Validating'):
        x, y = x.to(DEVICE), y.to(DEVICE)
        out = model(x)
        loss = criterion(out, y)
        v_loss += loss.item()
        v_acc  += pixel_accuracy(out, y)
    v_loss /= len(val_loader)
    v_acc  /= len(val_loader)

    val_losses.append(v_loss)
    val_accuracies.append(v_acc)

    print(f"Epoch {epoch}/{NUM_EPOCHS} | "
          f"Train Loss: {t_loss:.4f}, Acc: {t_acc:.4f} | "
          f"Val Loss: {v_loss:.4f}, Acc: {v_acc:.4f}")


In [None]:
epochs = range(1, NUM_EPOCHS + 1)

plt.figure(figsize=(12, 5))

# ---- LOSS ----
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, label='Train Loss', color='blue', linewidth=2)
plt.plot(epochs, val_losses,   label='Val Loss',   color='orange', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss vs Epoch')
plt.legend()
plt.grid(True)

# ---- ACCURACY ----
plt.subplot(1, 2, 2)
plt.plot(epochs, train_accuracies, label='Train Acc', color='green', linewidth=2)
plt.plot(epochs, val_accuracies,   label='Val Acc',   color='red', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Pixel Accuracy')
plt.title('Accuracy vs Epoch')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
def visualize_prediction(model, dataset, idx=0):
    model.eval()
    x, y = dataset[idx]
    x = x.unsqueeze(0).to(device)  # Add batch dim

    with torch.no_grad():
        pred = model(x)  # (1, 20, 48, 48)
        pred_mask = pred.argmax(dim=1).squeeze().cpu()  # (48, 48)

    # Plot
    fig, axs = plt.subplots(1, 2, figsize=(8, 4))
    axs[0].imshow(y, cmap='tab20')
    axs[0].set_title("Ground Truth")
    axs[1].imshow(pred_mask, cmap='tab20')
    axs[1].set_title("Predicted Mask")
    for ax in axs:
        ax.axis("off")
    plt.tight_layout()
    plt.show()

# Try on train and val
visualize_prediction(model, train_dataset, idx=0)
visualize_prediction(model, val_dataset, idx=0)