In [1]:
import datetime

In [2]:
%load_ext tensorboard
!rm -rf ./xbd_logs/fit/
log_dir = "xbd_logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
import os
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

In [3]:
from typing import Any, Sequence, Union, Optional, Dict, Tuple

import tqdm
import glob
import random
import functools
import imageio

import numpy as np
import pytorch_lightning as pl
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import torch

torch.set_float32_matmul_precision("high")
import torchdata.datapipes as dp
from torchvision import transforms
from torch.utils.data import DataLoader
from torchmetrics import JaccardIndex

# We broke a few utils functions out. 
from utils_ms import label_from_3band, read_rgb_tile, one_hot_2_class

In [4]:
import sys
sys.path.append('./Moonshine')
from moonshine.preprocessing import get_preprocessing_fn

In [5]:
def read_item(path: str) -> Tuple[np.ndarray, np.ndarray]:
    """Read the image and the label mask from a path."""
    fn = get_preprocessing_fn(model="unet", dataset="xbd_mexico")
    data, _ = read_rgb_tile(path)
    data = fn(data)
    data = data.astype(np.float32)

    mask_path = label_from_3band(path, label_type="mask")
    mask = imageio.v2.imread(mask_path).astype(np.uint8)
    mask[mask == 255] = 1
    mask = one_hot_2_class(mask, num_classes=2)

    return (data, mask)


def remove_missing(path: str) -> bool:
    """If we are missing a label mask file, we'll skip that example."""
    mask_path = label_from_3band(path, label_type="mask")
    if os.path.exists(mask_path):
        return True
    return False


def apply_transforms(row, transform):
    """Apply the PyTorch transforms for the DataPipe."""
    tfx = transform(image=row[0], mask=row[1])
    return tfx["image"], tfx["mask"]


def building_footprint_datapipe(files, transform):
    """Create the DataPipe for a set of files."""
    print(f"Got {len(files)} files for this dataset")

    datapipe = dp.iter.IterableWrapper(files)
    datapipe = datapipe.filter(remove_missing)
    datapipe = datapipe.sharding_filter()
    datapipe = datapipe.map(read_item)

    if transform:
        transform_fx = functools.partial(
            apply_transforms,
            transform=transform
        )
        datapipe = datapipe.map(transform_fx)

    return datapipe

In [6]:
def train_test_split(
    files: Sequence[str], train_percent: float = 0.8
) -> Tuple[list, list, list]:
    """Very simple train test split, shuffling has a hard coded seed."""
    random.seed(1234)
    total = len(files)
    n_train = int(total * train_percent)
    n_test = int(total * ((1-train_percent)/2))
    random.shuffle(files)

    return files[0:n_train], files[n_train:(n_train+n_test)], files[(n_train+n_test):]


def get_dataset(files: Sequence[str], split: str = "train") -> DataLoader:
    """Create a dataset for building footprint classification."""
    tfx = [
        A.RandomCrop(width=512, height=512),
    ]
    train_tfx = [
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
    ]
    
    if split == "train":
        tfx.extend(train_tfx)
    tfx.append(ToTensorV2(transpose_mask=True))
        
    datapipe = building_footprint_datapipe(
        files,
        transform=A.Compose(tfx),
    )

    return DataLoader(
        dataset=datapipe,
        batch_size=4,
        shuffle=(split == "train"),
        drop_last=True,
        num_workers=8,
        pin_memory=True,
    )

In [7]:
from moonshine.models.unet import UNet

class BuildingClassifier(torch.nn.Module):
    def __init__(self, pretrained=False):
        super().__init__()

        # Create a blank model based on the available architectures.
        self.backbone = UNet(name="unet50_fmow_rgb")

        if pretrained:
            self.backbone.load_weights(
                encoder_weights="unet50_fmow_rgb", decoder_weights=None
            )

        # Run a per-pixel classifier on top of the output vectors.
        self.classifier = torch.nn.Conv2d(32, 2, (1, 1))

    def forward(self, x):
        x = self.backbone(x)
        return self.classifier(x)

In [8]:
class BuildingTrainer(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.jaccard = JaccardIndex(task="multiclass", num_classes=2)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)

        loss = torch.nn.functional.cross_entropy(y_hat, y)
        iou = self.jaccard(y_hat, y[:, 1, :, :])
        self.log("train/loss", loss, on_epoch=True, prog_bar=True)
        self.log("train/iou", iou, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)

        loss = torch.nn.functional.cross_entropy(y_hat, y)
        iou = self.jaccard(y_hat, y[:, 1, :, :])
        self.log("val/loss", loss, on_epoch=True, prog_bar=True)
        self.log("val/iou", iou, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [None]:
DATA_PATH = './data/Turkey/'

files = glob.glob(os.path.join(DATA_PATH, "images/*.tif"))

# Create a name for Tensorboard
exp_name = f"building_model_tuning"

# Create our datasets.
train_files, test_files, val_files = train_test_split(files)
train_dataset = get_dataset(train_files, split="train")
test_dataset = get_dataset(test_files, split="train")
val_dataset = get_dataset(val_files, split="val")

logger = pl.loggers.TensorBoardLogger(log_dir, name=exp_name)
trainer = pl.Trainer(
    accelerator="auto",
    max_epochs=50,
    enable_progress_bar=True,
    logger=logger,
)

model1 = BuildingClassifier(pretrained=False)
model1.load_state_dict(torch.load('./models/trainedc512b2.pt'))

pytrain = BuildingTrainer(model1)

# Train!
trainer.fit(
    model=pytrain,
    train_dataloaders=train_dataset,
    val_dataloaders=test_dataset,
)

In [None]:
torch.save(model1.state_dict(), './models/trainedc512b2_turkey.pt')