In [1]:
import numpy as np
import pandas as pd
import os
import rasterio
import imageio
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import mxnet as mx
from mxnet import gluon
from mxnet import autograd
from mxnet import image
from sklearn.model_selection import train_test_split

import sys
sys.path.append('../../resuneta/src')
sys.path.append('../../resuneta/nn/loss')
sys.path.append('../../resuneta/models')
sys.path.append('../../')

from bound_dist import get_distance, get_boundary
from loss import Tanimoto_wth_dual
from resunet_d6_causal_mtskcolor_ddist import ResUNet_d6
from resunet_d7_causal_mtskcolor_ddist import *

# Dataset

In [5]:
class PlanetDataset(gluon.data.Dataset):
    
    def __init__(self, image_directory, label_directory, image_names=None,
                 image_suffix='.jpeg', label_suffix='.png'):
        self.image_directory = image_directory
        self.label_directory = label_directory
        
        self.image_suffix = image_suffix
        self.label_suffix = label_suffix
        
        if image_names is None:
            image_names = os.listdir(image_directory)
            self.image_names = [x.split('.')[0] for x in image_names]
        else:
            self.image_names = image_names
        
    def __getitem__(self, item):
        image_path = os.path.join(self.image_directory, 
                                  str(self.image_names[item]) + self.image_suffix)
        image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
        # TODO: change this
        image = image[:256, :256]
        
        extent_path = os.path.join(self.label_directory, 
                                   str(self.image_names[item]) + self.label_suffix)
        extent_mask = imageio.imread(extent_path)
        # TODO: change this
        extent_mask = extent_mask[:256, :256] / 255
        boundary_mask = get_boundary(extent_mask)
        distance_mask = get_distance(extent_mask)
        image_hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
        
        image = mx.nd.array(np.moveaxis(image, -1, 0))
        image_hsv = mx.nd.array(np.moveaxis(image_hsv, -1, 0))
        
        extent_mask = mx.nd.array(np.expand_dims(extent_mask, 0))
        boundary_mask = mx.nd.array(np.expand_dims(boundary_mask, 0))
        distance_mask = mx.nd.array(np.expand_dims(distance_mask, 0))
        
        return image, extent_mask, boundary_mask, distance_mask, image_hsv
    
    def __len__(self):
        return len(self.image_names)

# Dataloader

In [6]:
image_directory = '../data/planet/france/april/'
label_directory = '../data/planet/france/labels/'

# all_names = os.listdir(image_directory)
# all_names = [x.split('.')[0] for x in all_names]
# trainval_names, test_names = train_test_split(all_names, test_size=0.2)
# train_names, val_names = train_test_split(trainval_names, test_size=0.2)

# Load train/val/test splits
splits_df = pd.read_csv('../data/splits/hanAndBurak_planetImagery_splits.csv')
train_names = splits_df[splits_df['fold'] == 'train']['image_id'].values
val_names = splits_df[splits_df['fold'] == 'val']['image_id'].values
test_names = splits_df[splits_df['fold'] == 'test']['image_id'].values

train_dataset = PlanetDataset(image_directory, label_directory, image_names=train_names)
val_dataset = PlanetDataset(image_directory, label_directory, image_names=val_names)
test_dataset = PlanetDataset(image_directory, label_directory, image_names=test_names)

In [7]:
batch_size = 4
train_dataloader = gluon.data.DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = gluon.data.DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = gluon.data.DataLoader(test_dataset, batch_size=batch_size)

for img_batch, extent_batch, boundary_batch, distance_batch, hsv_batch in train_dataloader:
    print("img_batch has shape {}".format(img_batch.shape))
    print("extent_batch has shape {}".format(extent_batch.shape))
    
    break

img_batch has shape (4, 3, 256, 256)
extent_batch has shape (4, 1, 256, 256)


In [8]:
img_batch.size

786432

# Training loop

In [14]:
def train_model(train_dataloader, model, tanimoto_dual, trainer, epoch):
    
    # initialize metrics
    cumulative_loss = 0
    accuracy = mx.metric.Accuracy()
    f1 = mx.metric.F1()
    mcc = mx.metric.MCC()
    
    # training set
    for batch_i, (img, extent, boundary, distance, hsv) in enumerate(
        tqdm(train_dataloader, desc='Training epoch {}'.format(epoch))):
        
        with autograd.record():

            img = img.as_in_context(mx.gpu())
            extent = extent.as_in_context(mx.gpu())
            boundary = boundary.as_in_context(mx.gpu())
            distance = distance.as_in_context(mx.gpu())
            hsv = hsv.as_in_context(mx.gpu())
            
            logits, bound, dist, convc = model(img)
            
            # multi-task loss
            # TODO: wrap this in a custom loss function / class
            loss_extent = sum(1 - tanimoto_dual(logits, extent))
            loss_boundary = sum(1 - tanimoto_dual(bound, boundary))
            loss_distance = sum(1 - tanimoto_dual(dist, distance))
            loss_hsv = sum(1 - tanimoto_dual(convc, hsv))
            
            loss = 0.25 * (loss_extent + loss_boundary + loss_distance + loss_hsv)
            
        loss.backward()
        trainer.step(batch_size)
        
        # update metrics based on every batch
        cumulative_loss += mx.nd.sum(loss).asscalar()
        # accuracy
        extent_predicted_classes = mx.nd.ceil(logits[:,[0],:,:] - 0.5)
        accuracy.update(extent, extent_predicted_classes)
        # f1 score
        prediction = logits[:,0,:,:].reshape(-1)
        probabilities = mx.nd.stack(1 - prediction, prediction, axis=1)
        f1.update(extent.reshape(-1), probabilities)
        # MCC metric
        mcc.update(extent.reshape(-1), probabilities)
        # TODO: eccentricity
        # TODO: ...
        
    return cumulative_loss, accuracy, f1, mcc

In [15]:
# hyperparameters
lr = 0.001
epochs = 2
batch_size = 4

# define model
model = ResUNet_d6(_nfilters_init=8, _NClasses=2)
model.initialize()
model.hybridize()
model.collect_params().reset_ctx(mx.gpu())

# define loss function
tanimoto_dual = Tanimoto_wth_dual()
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(model.collect_params(),
                        'adam', {'learning_rate': lr})

# containers for metrics to log
train_metrics = {'train_loss': [], 'train_acc': [], 'train_f1': [], 
                 'train_mcc': []}
val_metrics = {'val_loss': [], 'val_acc': [], 'val_f1': [], 
               'val_mcc': []}

# training loop
for epoch in range(1, epochs+1):
    
    # training set
    train_loss, train_accuracy, train_f1, train_mcc = train_model(
        train_dataloader, model, tanimoto_dual, trainer, epoch)
        
    # training set metrics
    train_loss_avg = train_loss / len(train_dataset)
    train_metrics['train_loss'].append(train_loss_avg)
    train_metrics['train_acc'].append(train_accuracy.get()[1])
    train_metrics['train_f1'].append(train_f1.get()[1])
    train_metrics['train_mcc'].append(train_mcc.get()[1])
    
    # validation set
    #TODO
    
    # validation set metrics
    #TODO
    
    print("Epoch {}:".format(epoch))
    print("    Train loss {:0.3f}, accuracy {:0.3f}, F1-score {:0.3f}, MCC: {:0.3f}".format(
        train_loss_avg, train_accuracy.get()[1], train_f1.get()[1], train_mcc.get()[1]))
    print("    Val loss {:0.3f}, accuracy {:0.3f}, F1-score {:0.3f}, MCC: {:0.3f}".format(
        val_loss_avg, val_accuracy.get()[1], val_f1.get()[1], val_mcc.get()[1]))
    
    # save model
    # TODO

Training epoch 1:   0%|          | 0/310 [00:00<?, ?it/s]

depth:= 0, nfilters: 8
depth:= 1, nfilters: 16
depth:= 2, nfilters: 32
depth:= 3, nfilters: 64
depth:= 4, nfilters: 128
depth:= 5, nfilters: 256
depth:= 6, nfilters: 128
depth:= 7, nfilters: 64
depth:= 8, nfilters: 32
depth:= 9, nfilters: 16
depth:= 10, nfilters: 8


Training epoch 1:   4%|▎         | 11/310 [00:14<06:39,  1.33s/it]


KeyboardInterrupt: 

In [None]:
#             img = img.astype('float32')
#             extent = extent.astype('float32')
#             boundary = boundary.astype('float32')
#             distance = distance.astype('float32')
#             hsv = hsv.astype('float32')