# Training Dataset Information

**Make sure all the the notebooks in the "Get datasets" folder have completed before running this**

This Notebook will train on the following datasets

- CloudSEN12 high (train (L1C+L2A))
- CloudSEN12 scribble (train + val + test (L1C+L2A))
- CloudSEN12 2k (train + val + test (L1C+L2A))
- Kappaset (train + val + test (L1C))
- CloudSEN12 high (train (L1C)) super res 5m
- CloudSEN12 high from Planetary Computer (L2A)
- A custom hard negative dataset (L2A)

It validates on the CloudSEN12 high validation (L1C+L2A) dataset.

Each model takes about 6 hours to train on a 4090

In [None]:
import torch
import rasterio as rio
from fastai.vision.all import * # type: ignore
from pathlib import Path
from safetensors.torch import save_file
import timm
from rasterio.enums import Resampling
from rasterio.errors import NotGeoreferencedWarning
import warnings
import numpy as np
import random
import cv2
from collections import defaultdict

In [None]:
warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)

In [None]:
from augs import (
    BatchRot90,
    RandomRectangle,
    DynamicZScoreNormalize,
    SceneEdge,
    BatchTear,
    BatchResample,
    RandomClipLargeImages,
    RandomSharpenBlur,
    ClipHighAndLow,
    BatchFlip,
)

In [None]:
from utils import (
    DiceMultiStrip,
    CrossEntropyLossFlatImageTypeWeighted,
)

In [None]:
from helpers import plot_batch, show_histo, print_system_info

print_system_info()

In [None]:
base_dataset_dir = Path("/media/nick/4TB Working 7/Datasets/OCM datasets")

In [None]:
cloudsen12_high_data_dir = base_dataset_dir / "CloudSEN12 high"
cloudsen12_scribble_dir = base_dataset_dir / "CloudSEN12 scribble"
cloudsen12_k2_dir = base_dataset_dir / "CloudSEN12 2k"
cloudsen12_validation_dir = base_dataset_dir / "CloudSEN12 validation"
cloudsen12_high_planetary_computer_dir = (
    base_dataset_dir / "CloudSEN12 high planetary computer"
)
super_res_dir = base_dataset_dir / "CloudSEN12 high super res tiles"

kappaset_data_dir = base_dataset_dir / "Kappaset"
hard_negative_data_dir = base_dataset_dir / "Hard negative"


In [None]:
# model_type = "regnety_004.pycls_in1k"
model_type = "edgenext_small.usi_in1k"


In [None]:
model_version = "OCM_7.43_R_G_NIR_test"
#########################
use_bf16 = True
#########################
demo_mode = False
#########################
original_image_size = 509
max_clip_image_clip_size = 400  # 509
min_clip_image_size = 256  # 509
limited_band_read_list = [1, 2, 3]  # Red Green NIR
native_band_scales = [1, 1, 0.5]
#########################
gradient_accumulation_batch_size = 128
batch_size = 10
learning_rate = 0.001
#########################
high_label_weight = 0.9
scribble_label_weight = 0.5
tiles_2k_label_weight = 0.5
kappaset_label_weight = 0.25
super_res_label_weight = 0.25
high_pc_label_weight = 0.9
hard_negative_label_weight = 0.9

In [None]:
label_weights = {
    cloudsen12_high_data_dir: high_label_weight,
    cloudsen12_scribble_dir: scribble_label_weight,
    cloudsen12_k2_dir: tiles_2k_label_weight,
    super_res_dir: super_res_label_weight,
    cloudsen12_high_planetary_computer_dir: high_pc_label_weight,
    kappaset_data_dir: kappaset_label_weight,
    hard_negative_data_dir: hard_negative_label_weight,
    cloudsen12_validation_dir: 1.0,
}

In [None]:
dataset_dirs = label_weights.keys()

In [None]:
for dataset_dir in dataset_dirs:
    assert dataset_dir.exists(), (
        f"Training data directory {dataset_dir} does not exist."
    )


In [None]:
if demo_mode:
    freeze_epochs = 5
    unfrozen_epochs = 5
    limit_training_images = 3000
else:
    freeze_epochs = 15
    unfrozen_epochs = 15
    limit_training_images = None

In [None]:
num_input_channels = len(limited_band_read_list)
print(f"Number of input channels: {num_input_channels}")

In [None]:
timm_model = partial(
    timm.create_model,
    model_type,
    pretrained=True,
    in_chans=num_input_channels,
)
model = create_unet_model(
    img_size=(509, 509),
    arch=timm_model,
    n_out=4,
    pretrained=True,
    act_cls=torch.nn.Mish,
)

In [None]:
dummy_input = torch.randn(
    1, num_input_channels, original_image_size, original_image_size
)
assert model(dummy_input).shape == (
    1,
    4,
    original_image_size,
    original_image_size,
), "Model output shape mismatch"

In [None]:
fai_model_name = f"PM_model_{model_version}_{model_type}_fai"
pytorch_model_name = f"PM_model_{model_version}_{model_type}_PT.pth"
pytorch_model_path = Path.cwd() / "models" / pytorch_model_name
state_path = pytorch_model_path.parent / f"{pytorch_model_path.stem}_state.pth"
safetensor_state_path = (
    pytorch_model_path.parent / f"{pytorch_model_path.stem}_state.safetensors"
)
config_path = pytorch_model_path.parent / f"{pytorch_model_path.stem}_config.json"
if pytorch_model_path.exists():
    raise ValueError("Model already exists", pytorch_model_name)
if state_path.exists():
    raise ValueError("State path already exists")
if safetensor_state_path.exists():
    raise ValueError("Safetensor state path already exists")
if config_path.exists():
    raise ValueError("Config path already exists")

print(f"Fastai model {fai_model_name}")
print(f"PyTorch model {pytorch_model_name}")
print(f"State path: {state_path}")
print(f"Safetensor state path: {safetensor_state_path}")

In [None]:
def multi_dataset_getter(paths: list[Path], print_counts: bool = False):
    training_images = []
    validation_images = []
    for path in paths:
        if path == cloudsen12_validation_dir:
            validation_images = list(path.glob("*image*.tif"))
            if print_counts:
                print(f"{path.name} found {len(validation_images)} validation images")
        else:
            images = list(path.glob("*image*.tif"))
            if print_counts:
                print(f"{path.name} found {len(images)} images")
            training_images.extend(images)
    if print_counts:
        print(f"Found {len(training_images)} training images")

    if limit_training_images:
        # shuffle and limit training images
        training_images = np.random.choice(
            training_images, limit_training_images, replace=False
        ).tolist()
        if print_counts:
            print(f"Limited training images to {len(training_images)}")

    datasets = training_images + validation_images
    if print_counts:
        print(f"Combined training and validation {len(datasets)} images")
    return datasets

In [None]:
dataset_dirs

In [None]:
train_and_val_images = multi_dataset_getter(list(dataset_dirs), print_counts=True)

In [None]:
validation_dataset = set(cloudsen12_validation_dir.glob("*image*.tif"))
len(validation_dataset)

In [None]:
def label_func(file_path):
    file_name = file_path.name

    label_name = (
        file_name.replace("image", "label").replace("_l1c", "").replace("_l2a", "")
    )
    label_path = file_path.parent / label_name

    assert label_path.exists(), f"Label path does not exist: {label_path}"
    assert file_path != label_path, (
        f"File path and label path are the same: {file_path}"
    )

    return label_path

In [None]:
train_and_val_images[:3]

In [None]:
def open_2k(src: rio.DatasetReader, img_size: int) -> np.ndarray:
    # The 2k images are 2000, 2000 pixels at 10m resolution
    # We resample them to img_size using a random resampling method
    resampling_method = random.choice([Resampling.bilinear, Resampling.nearest])
    resampled_data = src.read(
        limited_band_read_list,
        out_shape=(len(limited_band_read_list), img_size, img_size),
        resampling=resampling_method,
    )
    return resampled_data.astype("float32")


scale_groups = defaultdict(list)
for i, (band, scale) in enumerate(
    zip(limited_band_read_list, native_band_scales, strict=True)
):
    scale_groups[int(original_image_size * scale)].append((i, band))


def open_509(src: rio.DatasetReader, img_size: int) -> np.ndarray:
    resampled_data = np.empty(
        (len(limited_band_read_list), img_size, img_size), dtype=np.float32
    )
    resampling_method = random.choice([cv2.INTER_NEAREST, cv2.INTER_LINEAR])

    for true_img_size, band_info in scale_groups.items():
        indices, bands = zip(*band_info, strict=True)

        if true_img_size == img_size:
            native_bands = src.read(
                bands,
                out_shape=(len(bands), img_size, img_size),
            )
            resampled_data[np.array(indices)] = native_bands.astype(np.float32)
        else:
            native_bands = src.read(
                bands,
                out_shape=(len(bands), true_img_size, true_img_size),
                resampling=Resampling.nearest,
            )

            for i, idx in enumerate(indices):
                resized_int16 = cv2.resize(
                    native_bands[i],
                    (img_size, img_size),
                    interpolation=resampling_method,
                )
                resampled_data[idx] = resized_int16.astype(np.float32)

    return resampled_data

In [None]:
def open_img(
    img_path: Path,
    img_size: int,
    use_bf16: bool = False,
) -> TensorImage:
    with rio.open(img_path) as src:
        profile = src.profile
        # the 2k images we resample to img_size using a random resampling method
        if profile["width"] == 2000:
            resampled_data = open_2k(src, img_size)
        elif profile["width"] == 509:
            resampled_data = open_509(src, img_size)
        else:
            raise ValueError(
                f"Unsupported image width: {profile['width']}. Expected 2000 or 509."
            )

    image_tensor = torch.from_numpy(resampled_data)

    if use_bf16:
        image_tensor = image_tensor.bfloat16()

    return TensorImage(image_tensor)

In [None]:
def sample_weights(image_path: Path) -> torch.Tensor:
    try:
        weight = torch.tensor(label_weights[image_path.parent], dtype=torch.float32)
    except Exception as e:
        raise ValueError(
            f"Image path {image_path} not found in label_weights dictionary."
        ) from e
    return weight

In [None]:
batch_tfms = [
    RandomRectangle(  # Blocks out random rectangles in the image
        p=0.6,
        sl=0.1,
        sh=0.5,
    ),
    BatchTear(0.1),  # Simulates an image tear
    SceneEdge(p=0.1),  # Adds a scene edge to the image
    IntToFloatTensor(1, 1),
    BatchRot90(),  # Rotates the image by 90 degrees
    DynamicZScoreNormalize(),  # Normalizes the image using dynamic z-score normalization
    BatchResample(
        max_scale=1.111, min_scale=0.07, plateau_min=0.33, plateau_max=1.0
    ),  # Resamples the image to a random scale
    RandomClipLargeImages(  # Clips large images to a random size
        max_size=max_clip_image_clip_size, min_size=min_clip_image_size
    ),
    BatchFlip(),  # Flips the image horizontally or vertically
    RandomSharpenBlur(min_factor=0.5, max_factor=1.5),  # Sharpens or blurs the image
    ClipHighAndLow(
        p=0.1, max_pct=0.05
    ),  # Simulates sensor saturation by clipping high and low values
]

In [None]:
open_image_func = partial(open_img, img_size=original_image_size, use_bf16=use_bf16)

In [None]:
def is_validation_item(item: Path):
    return item in validation_dataset

In [None]:
dblock = DataBlock(
    blocks=[
        TransformBlock([open_image_func]),
        MaskBlock(codes=[0, 1, 2, 3]),
        TransformBlock([sample_weights]),
    ],
    n_inp=1,
    get_items=multi_dataset_getter,
    get_y=[label_func, lambda x: x],
    splitter=FuncSplitter(is_validation_item),
    batch_tfms=batch_tfms,
    item_tfms=[
        Resize(original_image_size, method="squish")
    ],  # required to resize the 2 masks to 509
)

In [None]:
dl = dblock.dataloaders(
    size=original_image_size,
    source=dataset_dirs,
    bs=batch_size,
    num_workers=6,
    pin_memory=True,
)

In [None]:
dl.train.dataset.tfms

In [None]:
dl.train.after_item

In [None]:
dl.train.after_batch

In [None]:
batch = dl.one_batch()
print(f"Input shape: {batch[0].shape}")
print(f"Label shape: {batch[1].shape}")
print(f"Input mean: {batch[0].mean()}")
print(f"Input std: {batch[0].std()}")

In [None]:
val_batch = dl.valid.one_batch()
print(f"Input shape: {val_batch[0].shape}")
print(f"Label shape: {val_batch[1].shape}")
print(f"Input mean: {batch[0].mean()}")
print(f"Input std: {batch[0].std()}")

In [None]:
batch = dl.one_batch()

band_labels = ["B04", "B03", "B8A"]
plot_batch(batch[:2], labels=["False colour"] + band_labels)

In [None]:
batch = dl.one_batch()
batch[2]

In [None]:
batch[0][0][0].mean()

In [None]:
show_histo(batch[:2], labels=band_labels)

In [None]:
callbacks = [
    ShowGraphCallback(),
    GradientAccumulation(gradient_accumulation_batch_size),
]

In [None]:
learner = Learner(
    dls=dl,
    model=model,
    loss_func=CrossEntropyLossFlatImageTypeWeighted(),
    metrics=[DiceMultiStrip],
    cbs=callbacks,
)

In [None]:
if use_bf16:
    learner = learner.to_bf16()

In [None]:
learner.fine_tune(
    epochs=unfrozen_epochs,
    freeze_epochs=freeze_epochs,
    base_lr=learning_rate,
)

In [None]:
learner.save(fai_model_name)
learner.load(fai_model_name)

In [None]:
model = learner.model.to("cpu")
model = model.float()

In [None]:
torch.save(model, pytorch_model_path)
pytorch_model_path

In [None]:
torch.save(model.state_dict(), state_path)
state_path

In [None]:
save_file(model.state_dict(), safetensor_state_path)
safetensor_state_path

In [None]:
config = {
    "model_version": model_version,
    "model_type": model_type,
    "use_bf16": use_bf16,
    "demo_mode": demo_mode,
    "original_image_size": original_image_size,
    "max_clip_image_clip_size": max_clip_image_clip_size,
    "min_clip_image_size": min_clip_image_size,
    "limited_band_read_list": limited_band_read_list,
    "native_band_scales": native_band_scales,
    "gradient_accumulation_batch_size": gradient_accumulation_batch_size,
    "batch_size": batch_size,
    "learning_rate": learning_rate,
    "high_label_weight": high_label_weight,
    "scribble_label_weight": scribble_label_weight,
    "tiles_2k_weight": tiles_2k_label_weight,
    "kappaset_weight": kappaset_label_weight,
    "super_res_weight": super_res_label_weight,
    "high_pc_label_weight": high_pc_label_weight,
    "hard_negative_label_weight": hard_negative_label_weight,
    "freeze_epochs": freeze_epochs,
    "unfrozen_epochs": unfrozen_epochs,
    "limit_training_images": limit_training_images,
}

In [None]:
with open(config_path, "w") as f:
    json.dump(config, f, indent=4)

In [None]:
config_path