In [1]:
%load_ext autoreload
%autoreload 2

import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import rasterio
import seaborn as sns
import torch
import torch.nn as nn

from evolver import CrossoverType, MutationType, MatrixEvolver
from unet import UNet

In [2]:
def read_tif_to_np(tif_path):
    """Reads a tif file and converts it into a numpy.ndarray.
    
    Arg:
        tif_path: The full path to the tif file to read.
    
    Returns:
        A numpy.ndarray containing the tif file data. The returned tif has a rolled
        dimension and so the input is in the shape (channels, height width).
    
    """
    with rasterio.open(tif_path) as f:
        return f.read()

def apply_remap_values(labels, label_map):
    """Reassigns values inplace in an numpy array given a provided mapping.
    
    Args:
        labels: An ndarray of labels.
        label_map: A dict[int, int] mapping label classes [original, new].
        
    """
    for l1, l2 in label_map.items():
        labels[labels == l1] = l2

def sample_patch_coordinates(data, labels, patch_size, n_samples):
    """Generates image patches from a tile containing features and corresponding labels.

    Args:
        data: The x features of the image.
        labels: The y labels of the image.
        patch_size: An Iterable[int, int] size of the image patch to be extracted.
        n_samples: The number of samples to extract per tile.

    Returns:
        A list of x_patches and y_patches containg features and labels respectively.

    """
    height, width = patch_size
    channels = data.shape[0]
    xs = np.random.randint(0, data.shape[2] - width, n_samples)
    ys = np.random.randint(0, data.shape[1] - height, n_samples)
    return np.dstack((xs, ys)).reshape((n_samples, 2))
    
class LandCoverDataset(torch.utils.data.Dataset):
    """Land Cover Dataset Containing patches."""

    def __init__(self, features_path, labels_path, patch_size, n_samples, patch_coordinates=None):
        """
        Args:
            features_path: Path to the features of a tile.
            labels_path: Path to the labels of a tile.
            patch_size: An Iterable[int, int] size of the image patch to be extracted.
            n_samples: The number of samples to extract per tile.
            patch_coordinates: A list of coordinates used to identify the top left hand corners of
                the patches to extract from the tile. If None they are randomly generated.

        """
        self.data = read_tif_to_np(features_path)
        self.labels = read_tif_to_np(labels_path)
        # Coalesces labels into 4 groups instead of 6.
        # TODO(ameade): Consider allowing for transformation function arguments to modify data upon
        # reading it in.
        water_forest_land_impervious_remap = {5: 4, 6: 4}
        apply_remap_values(self.labels, water_forest_land_impervious_remap)

        self.n_classes = len(np.unique(self.labels))
        
        self.patch_size = patch_size
        self.n_samples = n_samples    
        self.patch_coordinates = patch_coordinates
        
        if self.patch_coordinates is None:
            self.patch_coordinates = sample_patch_coordinates(self.data, 
                                                              self.labels,
                                                              self.patch_size,
                                                              self.n_samples)
        
    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        height, width = self.patch_size
        x, y = self.patch_coordinates[idx]
        img = self.data[:, y : y + height, x : x + width].astype(np.float32)
        label = self.labels[0, y : y + height, x : x + width]
        return img, label

In [3]:
# Create Data
params = {'batch_size': 32,
          'shuffle': False,
          'num_workers': 6}
max_epochs = 10
patch_size = (256, 256)
samples_per_tile = 5000

# Data Generators
train_x_path = "/mnt/blobfuse/esri-naip/v002/md/2015/md_100cm_2015/39076/m_3907639_sw_18_1_20150815.tif"
train_y_path = "/mnt/blobfuse/resampled-lc/data/v1/2015/states/md/md_1m_2015/39076/m_3907639_sw_18_1_20150815_lc.tif"

test_x_path = "/mnt/blobfuse/esri-naip/v002/md/2015/md_100cm_2015/39076/m_3907639_ne_18_1_20150815.tif"
test_y_path = "/mnt/blobfuse/resampled-lc/data/v1/2015/states/md/md_1m_2015/39076/m_3907639_ne_18_1_20150815_lc.tif"

train_set = LandCoverDataset(train_x_path, train_y_path, patch_size, samples_per_tile)
train_loader = torch.utils.data.DataLoader(train_set, **params)

test_set = LandCoverDataset(test_x_path, test_y_path, patch_size, samples_per_tile)
test_loader = torch.utils.data.DataLoader(test_set, **params)

In [71]:
# Define Model Loss and Optimizers
net = UNet(in_channels = 4, n_classes = train_set.n_classes, depth = 4)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(max_epochs):
    running_loss = 0.0
    
    for i, data in enumerate(train_loader):
        batch_x, batch_y = data
        optimizer.zero_grad()
        outputs = net(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

RuntimeError: expected scalar type Long but found Byte

In [None]:
dropout_mask_evolver = MatrixEvolver([[3, 3]], CrossoverType.UNIFORM, MutationType.FLIP_BIT)

In [77]:
outputs.shape

torch.Size([32, 4, 256, 256])

In [78]:
train_set.n_classes

4

In [79]:
batch_y

tensor([[[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[3, 3, 3,  ..., 2, 2, 2],
         [3, 3, 3,  ..., 2, 2, 2],
         [3, 3, 3,  ..., 2, 2, 2],
         ...,
         [3, 3, 3,  ..., 2, 2, 2],
         [3, 3, 3,  ..., 2, 2, 2],
         [3, 3, 3,  ..., 2, 2, 2]],

        [[1, 1, 1,  ..., 3, 3, 3],
         [1, 1, 1,  ..., 3, 3, 3],
         [1, 1, 1,  ..., 3, 3, 3],
         ...,
         [3, 3, 3,  ..., 3, 3, 3],
         [3, 3, 3,  ..., 3, 3, 3],
         [3, 3, 3,  ..., 3, 3, 3]],

        ...,

        [[3, 3, 3,  ..., 2, 2, 2],
         [3, 3, 3,  ..., 2, 2, 2],
         [3, 3, 3,  ..., 2, 2, 2],
         ...,
         [3, 3, 3,  ..., 3, 3, 3],
         [2, 3, 3,  ..., 3, 3, 3],
         [2, 2, 3,  ..., 3, 3, 3]],

        [[1, 1, 1,  ..., 2, 2, 2],
         [1, 1, 1,  ..., 2, 2, 2],
         [1,