In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF


class DoubleConvolution(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConvolution, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class UNet(nn.Module):
    def __init__(
        self,
        input_channels=3,
        output_channels=1,
        features=[64, 128, 256, 512],
    ):
        super(UNet, self).__init__()

        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Encoder
        for f in features:
            self.downs.append(DoubleConvolution(input_channels, f))
            input_channels = f

        # lower bottleneck layers
        self.bottleneck = DoubleConvolution(features[-1], features[-1] * 2)

        # Decoder
        for f in reversed(features):
            self.ups.append(
                nn.Sequential(
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(
                        in_channels=2 * f, out_channels=f, kernel_size=3, padding=1
                    ),
                )
            )
            self.ups.append(DoubleConvolution(2 * f, f))

        self.final_convolution = nn.Conv2d(
            in_channels=features[0],
            out_channels=output_channels,
            kernel_size=3,
            padding=1,
        )

    def forward(self, x):
        skip_connections = list()
        for module in self.downs:
            x = module(x)
            skip_connections.append(x)
            x = self.pool(x)

        skip_connections = skip_connections[::-1]  # reverse order

        x = self.bottleneck(x)

        for i in range(0, len(self.ups), 2):
            x = self.ups[i](x)
            skip_connection = skip_connections[i // 2]
            if skip_connection.shape != x.shape:
                x = TF.resize(
                    x,
                    size=skip_connection.shape[2:],
                    interpolation=TF.InterpolationMode.NEAREST,
                )
            x = torch.cat([skip_connection, x], dim=1)
            x = self.ups[i + 1](x)

        x = self.final_convolution(x)

        return x

In [None]:
from torch.utils.data import Dataset
import numpy as np
import os

import rasterio


class CloudDataset(Dataset):
    def __init__(self, images, labels, transform_image=None, transform_label=None):
        self.images = images
        self.labels = labels
        self.transform_image = transform_image
        self.transform_label = transform_label

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx: int):
        image = rasterio.open(self.images[idx]).read()
        label = rasterio.open(self.labels[idx]).read()

        image = image.astype(np.float32)
        label = label.astype(np.float32)

        image = np.moveaxis(image, 0, -1)
        label = np.moveaxis(label, 0, -1)

        if self.transform_image:
            image = self.transform_image(image)
        if self.transform_label:
            label = self.transform_label(label)

        return image, label

In [None]:
import time

def train(model, loader, optimizer, criterion, device):
    model.train()
    train_loss = 0

    for batch_idx, (data, target) in enumerate(loader):
        print(f"Batch {batch_idx}/{len(loader)}")

        start = time.time()
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)
        loss = criterion(output, target)

        train_loss += loss.item()
        loss.backward()
        optimizer.step()

        end = time.time()
        print(f"Time: {end - start}\n Eta: {(end - start) * (len(loader) - batch_idx)/(60*60*24)}")
    return train_loss / len(loader)

In [None]:
def test(model, loader, criterion, device):
    model.eval()
    test_loss = 0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):
            print(f"Batch {batch_idx}/{len(loader)}")

            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()

    return test_loss / len(loader)

In [None]:
def compute_dataset_info(train_images, train_labels, test_images, test_labels):
    dataset_info = {
        "num_positives": 0,
        "num_negatives": 0,
        "mean": [0.0] * 13,
        "std": [0.0] * 13,
    }

    for images, labels in [
        (train_images, train_labels),
        (test_images, test_labels),
    ]:
        for image, label in zip(images, labels):
            label = rasterio.open(label).read()
            label = np.moveaxis(label, 0, -1)
            label = label.astype(np.float32)

            dataset_info["num_positives"] += np.sum(label)
            dataset_info["num_negatives"] += np.prod(label.shape) - np.sum(label)

            image = rasterio.open(image).read()
            image = np.moveaxis(image, 0, -1)
            image = image.astype(np.float32)

            dataset_info["mean"] += np.mean(image, axis=(0, 1))
            dataset_info["std"] += np.std(image, axis=(0, 1))

    dataset_info["mean"] /= (
        len(train_images) + len(test_images)
    )
    dataset_info["std"] /= (
        len(train_images) + len(test_images)
    )

    return dataset_info

In [None]:
# from google.colab import drive

# drive.mount("/content/drive")

In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms

SEED = 42
BATCH_SIZE = 14
NUM_EPOCHS = 2
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
NUM_WORKERS = 0 if os.name == "nt" else 2
DATA_DIR = "data"
TRAIN_DIR = os.path.join(DATA_DIR, "train")
TEST_DIR = os.path.join(DATA_DIR, "test")
DATASET_INFO_PATH = os.path.join(DATA_DIR, "dataset_info.pt")
CHECKPOINTS_DIR = "checkpoints"

os.makedirs(CHECKPOINTS_DIR, exist_ok=True)

torch.manual_seed(SEED)
np.random.seed(SEED)

is_label = lambda img: img.endswith("label.tif")

train_images = [os.path.join(TRAIN_DIR, img) for img in sorted(os.listdir(TRAIN_DIR)) if not is_label(img)]
train_labels = [os.path.join(TRAIN_DIR, img) for img in sorted(os.listdir(TRAIN_DIR)) if is_label(img)]
test_images = [os.path.join(TEST_DIR, img) for img in sorted(os.listdir(TEST_DIR)) if not is_label(img)]
test_labels = [os.path.join(TEST_DIR, img) for img in sorted(os.listdir(TEST_DIR)) if is_label(img)]

if not os.path.exists(DATASET_INFO_PATH):
    dataset_info = compute_dataset_info(train_images, train_labels, test_images, test_labels)
    torch.save(dataset_info, DATASET_INFO_PATH)

dataset_info = torch.load(DATASET_INFO_PATH)

transform_image = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(dataset_info["mean"], dataset_info["std"]),
])
transform_label = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = CloudDataset(train_images, train_labels, transform_image, transform_label)
test_dataset = CloudDataset(test_images, test_labels, transform_image, transform_label)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

model = UNet(input_channels=13, output_channels=1).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss(
    pos_weight=torch.tensor(dataset_info["num_negatives"] / dataset_info["num_positives"]).to(DEVICE)
)

print("Number of cloud pixels:", dataset_info["num_positives"])
print("Number of non-cloud pixels:", dataset_info["num_negatives"])

In [None]:
train_losses = []
test_losses = []

for epoch in range(NUM_EPOCHS):
    checkpoint_path = os.path.join("checkpoints", f"unet_epoch_{epoch}.pt")

    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        train_losses.append(checkpoint["train_loss"])
        test_losses.append(checkpoint["test_loss"])
    else:
        train_loss = train(model, train_loader, optimizer, criterion, DEVICE)
        test_loss = test(model, test_loader, criterion, DEVICE)

        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "train_loss": train_loss,
            "test_loss": test_loss,
        }
        torch.save(checkpoint, checkpoint_path)

        train_losses.append(train_loss)
        test_losses.append(test_loss)

    print(f"Epoch: {epoch}, Train loss: {train_losses[-1]:.4f}, Test loss: {test_losses[-1]:.4f}")
    # print(f"Epoch: {epoch}, Train loss: {train_losses[-1]:.4f}")

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_losses, label="Train loss")
plt.plot(test_losses, label="Test loss")
plt.legend(loc="best")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()

In [None]:
THRESHOLD = 0.5

def show_batch(model, loader, device):
    batch = next(iter(loader))
    data, target = batch
    data, target = data.to(device), target.to(device)

    model.eval()
    with torch.no_grad():
        output = model(data)
        output = torch.sigmoid(output)
        thresholded_output = (output > THRESHOLD).float()

    output = output.cpu().numpy()
    thresholded_output = thresholded_output.cpu().numpy()
    target = target.cpu().numpy()
    data = data.cpu().numpy()

    for i in range(len(output)):
        _, ax = plt.subplots(1, 4, figsize=(15, 5))
        ax[0].imshow(data[i, 3, :, :], cmap="gray")
        ax[1].imshow(target[i, 0, :, :], vmin=0, vmax=1, cmap="gray")
        ax[2].imshow(output[i, 0, :, :], vmin=0, vmax=1, cmap="gray")
        ax[3].imshow(thresholded_output[i, 0, :, :], vmin=0, vmax=1, cmap="gray")
        plt.show()

In [None]:
show_batch(model, train_loader, DEVICE)

In [None]:
show_batch(model, test_loader, DEVICE)