In [None]:
import os
import random
from functools import partial
from pathlib import Path

import rasterio
import xarray as xr

import numpy as np
import pandas as pd
from absl import logging
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, ConcatDataset

In [None]:
batch_size = 8
root_dir = Path("/Users/ibrahimyusuf/Downloads/locust_breeding")
valid_filepath = root_dir / "val.csv"
train_filepath = root_dir / "train.csv"
test_filepath = root_dir / "test.csv"

IM_SIZE = 224
TEMPORAL_SIZE = 3

In [None]:
def open_mf_tiff_dataset(band_files):
    band_paths = list(band_files["tiles"].values())
    bands_dataset = xr.open_mfdataset(
        band_paths,
        concat_dim="band",
        combine="nested",
    )
    with rasterio.open(band_paths[0]) as src:
        crs = src.crs
    return bands_dataset, crs

In [None]:
def random_crop_and_flip(ims, label, im_size):
    i, j, h, w = transforms.RandomCrop.get_params(ims[0], (im_size, im_size))

    ims = [transforms.functional.crop(im, i, j, h, w) for im in ims]
    label = transforms.functional.crop(label, i, j, h, w)

    if random.random() > 0.5:
        ims = [transforms.functional.hflip(im) for im in ims]
        label = transforms.functional.hflip(label)

    if random.random() > 0.5:
        ims = [transforms.functional.vflip(im) for im in ims]
        label = transforms.functional.vflip(label)

    return ims, label


def normalize_and_convert_to_tensor(
    ims,
    label,
    mean,
    std,
    temporal_size,
):
    norm = transforms.Normalize(mean, std)
    ims_tensor = torch.stack([transforms.ToTensor()(im).squeeze() for im in ims])
    _, h, w = ims_tensor.shape
    if label:
        label = torch.from_numpy(np.array(label)).squeeze()
    return ims_tensor, label


def process_and_augment(
    x,
    y,
    mean,
    std,
    temporal_size=1,
    im_size=224,
    train=True,
):
    ims = x.copy()
    label = None
    ims = [Image.fromarray(im) for im in ims]
    if y is not None:
        label = y.copy()
        label = Image.fromarray(label.squeeze())
    if train:
        ims, label = random_crop_and_flip(ims, label, im_size)
    ims, label = normalize_and_convert_to_tensor(ims, label, mean, std, temporal_size)
    return ims, label.to(torch.long)


def crop_array(arr, left, top, right, bottom):
    if len(arr.shape) == 2:
        return arr[top:bottom, left:right]
    elif len(arr.shape) == 3:
        return arr[:, top:bottom, left:right]
    elif len(arr.shape) == 4:
        return arr[:, :, top:bottom, left:right]
    else:
        raise ValueError("Input array must be a 2D, 3D or 4D array")


def process_test(
    x,
    y,
    mean,
    std,
    temporal_size,
    img_size,
    crop_size,
    stride,
):
    """Process and augment test data."""
    preprocess_func = partial(
        process_and_augment,
        mean=mean,
        std=std,
        temporal_size=temporal_size,
        train=False,
    )

    img_crops, mask_crops = [], []
    width, height = img_size, img_size

    for top in range(0, height - crop_size + 1, stride):
        for left in range(0, width - crop_size + 1, stride):
            bottom = top + crop_size
            right = left + crop_size

            img_crops.append(crop_array(x, left, top, right, bottom))
            mask_crops.append(crop_array(y, left, top, right, bottom))

    samples = [preprocess_func(x, y) for x, y in zip(img_crops, mask_crops)]
    imgs = torch.stack([sample[0] for sample in samples])
    labels = torch.stack([sample[1] for sample in samples])
    return imgs, labels


def get_raster_data(
    fname,
    is_label: bool = True,
    bands=None,
) -> np.ndarray:
    if isinstance(fname, dict):
        data, _ = open_mf_tiff_dataset(fname)
        data = data.fillna(-1)
        data = data.band_data.values
    else:
        with rasterio.open(fname) as src:
            data = src.read()
    if (not is_label) and bands:
        data = data[bands, ...]
    # For some reasons, some few HLS tiles are not scaled. In the following lines,
    # we find and scale them
    bands = []
    for band in data:
        if band.max() > 10:
            band *= 0.0001
        bands.append(band)
    data = np.stack(bands, axis=0)
    return data


def process_data(
    im_fname,
    mask_fname,
    bands=None,
    constant_multiplier=1.0,
    mask_cloud=False,
    fix_scaling=True,
):
    arr_x = get_raster_data(
        im_fname,
        is_label=False,
        bands=bands,
    )
    arr_x = np.where(arr_x == -1, 0, arr_x).astype(np.float32)
    if mask_fname:
        arr_y = get_raster_data(mask_fname)
        arr_y = np.where(arr_y == -1, -100, arr_y)
    else:
        arr_y = None
    return arr_x, arr_y


def load_data_from_csv(fname, input_root):
    file_paths = []
    data = pd.read_csv(fname)
    for _, row in data.iterrows():
        im_path = os.path.join(input_root, row["Input"])
        mask_path = os.path.join(input_root, row["Label"])
        if os.path.exists(im_path):
            try:
                with rasterio.open(im_path) as src:
                    _ = src.crs
                file_paths.append((im_path, mask_path))
            except Exception as e:
                logging.error(e)
                continue
    return file_paths

In [None]:
class InstaGeoDataset(torch.utils.data.Dataset):
    """InstaGeo PyTorch Dataset for Loading and Handling HLS Data."""

    def __init__(
        self,
        filename,
        input_root,
        preprocess_func,
        bands=None,
    ):
        self.input_root = input_root
        self.preprocess_func = preprocess_func
        self.bands = bands
        self.file_paths = load_data_from_csv(filename, input_root)

    def __getitem__(self, i: int):
        im_fname, mask_fname = self.file_paths[i]
        arr_x, arr_y = process_data(
            im_fname,
            mask_fname,
            bands=self.bands,
        )
        return self.preprocess_func(arr_x, arr_y)

    def __len__(self) -> int:
        return len(self.file_paths)

In [None]:
BANDS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]

train_ds_sample = InstaGeoDataset(
    filename=train_filepath,
    input_root=root_dir,
    preprocess_func=lambda x, y: (x, y),
    bands=BANDS,
)

val_ds_sample = InstaGeoDataset(
    filename=valid_filepath,
    input_root=root_dir,
    preprocess_func=lambda x, y: (x, y),
    bands=BANDS,
)

test_ds_sample = InstaGeoDataset(
    filename=test_filepath,
    input_root=root_dir,
    preprocess_func=lambda x, y: (x, y),
    bands=BANDS,
)
concatenated_dataset = ConcatDataset([train_ds_sample, val_ds_sample, test_ds_sample])
statistics_dataloader = DataLoader(concatenated_dataset, batch_size=8, shuffle=False)


def compute_mean_std(dataloader: DataLoader):
    mean = 0.0
    std = 0.0
    total_images_count = 0

    for images, _ in dataloader:
        images = images.view([batch_size * 3, 6, -1])
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
        total_images_count += images.size(0)

    mean /= total_images_count
    std /= total_images_count

    return mean.numpy(), std.numpy()


MEAN, STD = compute_mean_std(statistics_dataloader)
print("Mean:", MEAN)
print("Std:", STD)

In [None]:
train_dataset = InstaGeoDataset(
    filename=train_filepath,
    input_root=root_dir,
    preprocess_func=partial(
        process_and_augment,
        mean=MEAN,
        std=STD,
        temporal_size=TEMPORAL_SIZE,
        im_size=IM_SIZE,
    ),
    bands=BANDS,
)

valid_dataset = InstaGeoDataset(
    filename=valid_filepath,
    input_root=root_dir,
    preprocess_func=partial(
        process_and_augment,
        mean=MEAN,
        std=STD,
        temporal_size=TEMPORAL_SIZE,
        im_size=IM_SIZE,
    ),
    bands=BANDS,
)

test_dataset = InstaGeoDataset(
    filename=test_filepath,
    input_root=root_dir,
    preprocess_func=partial(
        process_test,
        mean=MEAN,
        std=STD,
        temporal_size=TEMPORAL_SIZE,
        img_size=IM_SIZE,
        crop_size=224,
        stride=224,
    ),
    bands=BANDS,
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda x: (
        torch.cat([a[0] for a in x], 0),
        torch.cat([a[1] for a in x], 0),
    ),
)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class SegmentationModel(nn.Module):
    def __init__(self, in_channels=18, out_channels=1):
        super(SegmentationModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, out_channels, kernel_size=1),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x.squeeze()


def train_model(model, dataloaders, criterion, optimizer, num_epochs=25):
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        print("-" * 10)

        for phase in ["train", "val"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            running_loss = 0.0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            print(f"{phase} Loss: {epoch_loss:.4f}")

    return model


def evaluate_model(model, dataloader, criterion):
    model.eval()
    running_loss = 0.0
    running_correct = 0.0
    running_count = 0.0

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1)
            no_ignore = labels.ne(-100).to(device)
            preds = preds.masked_select(no_ignore).cpu().numpy()
            gt = labels.masked_select(no_ignore).cpu().numpy()
            loss = criterion(outputs, labels.squeeze())

            running_loss += loss.item() * inputs.size(0)
            running_correct += np.sum(preds == gt)
            running_count += gt.size

    overall_accuracy = running_correct / running_count
    total_loss = running_loss / len(dataloader.dataset)
    print(f"Loss: {total_loss:.4f}")
    print(f"Accuracy: {overall_accuracy:.4f}")

In [None]:
model = SegmentationModel(out_channels=2).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = optim.Adam(model.parameters(), lr=0.001)
dataloaders = {"train": train_loader, "val": valid_loader}

model = train_model(model, dataloaders, criterion, optimizer, num_epochs=5)

In [None]:
evaluate_model(model, test_loader, criterion)