In [None]:
import timm
import torch
import rasterio as rio
from multiprocessing.pool import ThreadPool
from fastai.vision.all import *
from fastai.vision.learner import create_unet_model
from pathlib import Path
from safetensors.torch import save_file

In [None]:
from utils import (
    BatchRot90,
    RandomRectangle,
    DynamicZScoreNormalize,
    SceneEdge,
    BatchTear,
    BatchResample,
    RandomClipLargeImages,
    RandomSharpenBlur,
    ClipHighAndLow,
    BatchFlip,
)

In [None]:
from helpers import plot_batch, show_histo, print_system_info

print_system_info()

In [None]:
training_data_dir = Path("/media/nick/4TB Working 7/Datasets/CloudSEN12")
image_cache_dir = Path("/media/nick/4TB Working 7/Datasets/CloudSEN12 training cache")

image_cache_dir.mkdir(exist_ok=True)

assert (
    training_data_dir.exists()
), f"Training data directory {training_data_dir} does not exist."

In [None]:
model_version = "OCM_6.45_RG_NIR"
bf16 = True
demo_mode = False
original_image_size = 509
max_clip_image_clip_size = 400  # 509
min_clip_image_size = 256  # 509
limited_band_read_list = [1, 2, 3]  # Red Green NIR
gradient_accumulation_batch_size = 128
batch_size = 10
cache_entire_dataset = True  # RAM hungry but will reduce training time
learning_rate = 0.001

In [None]:
# model_type = "regnety_004.pycls_in1k"
# model_type = "convnextv2_nano.fcmae_ft_in1k"
model_type = "edgenext_small.usi_in1k"

In [None]:
num_input_channels = len(limited_band_read_list)

In [None]:
timm_model = partial(
    timm.create_model,
    model_type,
    pretrained=True,
    in_chans=num_input_channels,
)
model = create_unet_model(
    img_size=(509, 509),
    arch=timm_model,
    n_out=4,
    pretrained=True,
    act_cls=torch.nn.Mish,
)

In [None]:
dummy_input = torch.randn(
    1, num_input_channels, original_image_size, original_image_size
)
assert model(dummy_input).shape == (
    1,
    4,
    original_image_size,
    original_image_size,
), "Model output shape mismatch"

In [None]:
fai_model_name = f"PM_model_{model_version}_{model_type}_fai"
pytorch_model_name = f"PM_model_{model_version}_{model_type}_PT.pth"
pytorch_model_path = Path.cwd() / "models" / pytorch_model_name
state_path = pytorch_model_path.parent / f"{pytorch_model_path.stem}_state.pth"
safetensor_state_path = (
    pytorch_model_path.parent / f"{pytorch_model_path.stem}_state.safetensors"
)
if pytorch_model_path.exists():
    raise ValueError("Model already exists")
if state_path.exists():
    raise ValueError("State path already exists")
if safetensor_state_path.exists():
    raise ValueError("Safetensor state path already exists")
print(f"Fastai model {fai_model_name}")
print(f"PyTorch model {pytorch_model_name}")
print(f"State path: {state_path}")
print(f"Safetensor state path: {safetensor_state_path}")

In [None]:
if demo_mode:
    freeze_epochs = 5
    unfrozen_epochs = 5
    limit_training_images = 3000
    data_types = ["l1c"]
else:
    freeze_epochs = 15
    unfrozen_epochs = 15
    limit_training_images = 0
    data_types = ["l1c", "l2a"]

In [None]:
def get_image_files_custom(source):
    all_images = []
    for data_type in data_types:
        all_images += list(source.glob(f"*509_image_{data_type}.tif"))
    train_imgs = []
    val_imgs = []
    for image in all_images:
        if "train" in image.name:
            train_imgs.append(image)
        if "validation_509_image" in image.name:
            val_imgs.append(image)
    print(f"Training images: {len(train_imgs)}")

    if limit_training_images:
        train_imgs.sort()
        print(f"Limiting training images to {limit_training_images}")
        train_imgs = train_imgs[:limit_training_images]

    print(f"Validation images: {len(val_imgs)}")
    train_and_val = train_imgs + val_imgs
    print(f"Total images: {len(train_and_val)}")

    return train_and_val


def label_func(file_path):
    file_name = file_path.name

    label_name = (
        file_name.replace("image", "label").replace("_l1c", "").replace("_l2a", "")
    )
    label_path = file_path.parent / label_name

    assert label_path.exists(), f"Label path does not exist: {label_path}"
    assert (
        file_path != label_path
    ), f"File path and label path are the same: {file_path}"

    return label_path

In [None]:
train_and_val_images = get_image_files_custom(training_data_dir)
total_image_count = len(train_and_val_images)
total_image_count

In [None]:
def open_img(
    img_path: Path, img_size: int, image_cache: dict | None = None, bf16: bool = False
) -> TensorImage:
    with rio.open(img_path) as src:
        raw_bands = src.read(
            limited_band_read_list,
            out_shape=(img_size, img_size),
        ).astype("float32")

    image_tensor = torch.from_numpy(raw_bands)

    if bf16:
        image_tensor = image_tensor.bfloat16()
    if image_cache is not None:
        image_cache[img_path] = TensorImage(image_tensor)

    return TensorImage(image_tensor)

In [None]:
len(train_and_val_images)

In [None]:
if cache_entire_dataset:
    image_cache_dir_with_img_count = image_cache_dir / str(total_image_count)
    print(f"Image cache dir: {image_cache_dir_with_img_count}")

    def load_cache(image_cache_dir: Path):
        image_cache = {}
        for cache_file in progress_bar(
            list(image_cache_dir.glob("*.pkl")), comment="Loading cache"  # type: ignore
        ):
            with open(cache_file, "rb") as f:
                temp_dict = pickle.load(f)
                for key, value in temp_dict.items():
                    image_cache[key] = value
        return image_cache

    def chunks(l, n):
        """Yield n number of striped chunks from l."""
        for i in range(0, n):
            yield l[i::n]

    if image_cache_dir_with_img_count.exists():
        print("Image Cache found, loading cache")
        image_cache = load_cache(image_cache_dir_with_img_count)

    else:
        print("Image Cache not found, creating cache")
        image_cache_dir_with_img_count.mkdir(exist_ok=True)
        train_and_val_images_parts = list(chunks(train_and_val_images, 8))

        for i, chunk in progress_bar(
            enumerate(train_and_val_images_parts),
            comment="Making cache",  # type: ignore
            total=len(train_and_val_images_parts),
        ):
            image_cache = {}
            open_img_partial = partial(
                open_img,
                img_size=original_image_size,
                image_cache=image_cache,
                bf16=bf16,
            )
            with ThreadPool(6) as p:
                list(
                    progress_bar(
                        p.imap(
                            open_img_partial,
                            chunk,
                        ),
                        total=len(chunk),
                        leave=False,
                        comment="Opening a chunk of images",  # type: ignore
                    )
                )
            with open(
                image_cache_dir_with_img_count / f"image_cache_{i}.pkl", "wb"
            ) as f:
                pickle.dump(image_cache, f)
        image_cache = {}

        image_cache = load_cache(image_cache_dir_with_img_count)

    def open_image_func(img_path: Path) -> TensorImage:
        return image_cache[img_path]

else:
    open_image_func = partial(
        open_img, img_size=original_image_size, image_cache=None, bf16=bf16
    )

In [None]:
batch_tfms = [
    RandomRectangle(  # Blocks out random rectangles in the image
        p=0.6,
        sl=0.1,
        sh=0.5,
    ),
    BatchTear(0.1),  # Simulates an image tear
    SceneEdge(p=0.1),  # Adds a scene edge to the image
    IntToFloatTensor(1, 1),
    BatchRot90(),  # Rotates the image by 90 degrees
    DynamicZScoreNormalize(),  # Normalizes the image using dynamic z-score normalization
    BatchResample(
        max_scale=1.111, min_scale=0.07, plateau_min=0.33, plateau_max=1.0
    ),  # Resamples the image to a random scale
    RandomClipLargeImages(  # Clips large images to a random size
        max_size=max_clip_image_clip_size, min_size=min_clip_image_size
    ),
    BatchFlip(),  # Flips the image horizontally or vertically
    RandomSharpenBlur(min_factor=0.5, max_factor=1.5),  # Sharpens or blurs the image
    ClipHighAndLow(
        p=0.1, max_pct=0.05
    ),  # Simulates sensor saturation by clipping high and low values
]

In [None]:
dblock = DataBlock(
    blocks=[
        TransformBlock([open_image_func]),
        MaskBlock(codes=[0, 1, 2, 3]),
    ],
    get_items=get_image_files_custom,
    get_y=label_func,
    splitter=FuncSplitter(lambda o: "validation" in o.name),
    batch_tfms=batch_tfms,
)

In [None]:
dl = dblock.dataloaders(
    size=original_image_size,
    source=training_data_dir,
    bs=batch_size,
    num_workers=2,
    pin_memory=True,
)

In [None]:
dl.train.dataset.tfms

In [None]:
dl.train.after_item

In [None]:
dl.train.after_batch

In [None]:
batch = dl.one_batch()
print(f"Input shape: {batch[0].shape}")
print(f"Label shape: {batch[1].shape}")
print(f"Input mean: {batch[0].mean()}")
print(f"Input std: {batch[0].std()}")

In [None]:
val_batch = dl.valid.one_batch()
print(f"Input shape: {val_batch[0].shape}")
print(f"Label shape: {val_batch[1].shape}")
print(f"Input mean: {batch[0].mean()}")
print(f"Input std: {batch[0].std()}")

In [None]:
batch = dl.one_batch()

band_labels = ["B04", "B03", "B8A"]
plot_batch(batch, labels=["False colour"] + band_labels)

In [None]:
batch[0][0][0].mean()

In [None]:
show_histo(batch, labels=band_labels)

In [None]:
batch[0].shape[1]

In [None]:
callbacks = [
    ShowGraphCallback(),
    GradientAccumulation(gradient_accumulation_batch_size),
]

In [None]:
learner = Learner(
    dls=dl,
    model=model,
    loss_func=CrossEntropyLossFlat(axis=1),
    metrics=[DiceMulti],
    cbs=callbacks,
)

In [None]:
if bf16:
    print("Using BF16")
    learner = learner.to_bf16()

In [None]:
learner.fine_tune(
    epochs=unfrozen_epochs,
    freeze_epochs=freeze_epochs,
    base_lr=learning_rate,
)

In [None]:
fai_model_name

In [None]:
learner.save(fai_model_name)

learner.load(fai_model_name)
learner.validate()

In [None]:
model = learner.model.to("cpu")
model = model.float()

In [None]:
torch.save(model, pytorch_model_path)
pytorch_model_path

In [None]:
torch.save(model.state_dict(), state_path)
state_path

In [None]:
save_file(model.state_dict(), safetensor_state_path)
safetensor_state_path