In [89]:
%load_ext autoreload
%autoreload 2

import gc
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import torchvision

from torch.utils.tensorboard import SummaryWriter
from functools import partial
from evolver import CrossoverType, MutationType, MatrixEvolver
from models.unet import UNet
from dataset_utils import PartitionType
from landcover_dataloader import LandCoverDataset, create_land_cover_dataset_from_config, get_land_cover_dataloader

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [90]:
# Run Networks on GPU if available.
if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu"  

device = torch.device(dev)  

In [106]:
# Create all data partitions from config.
dataset_dir = os.path.join(os.getcwd(), "data/landcover_small")
create_land_cover_dataset_from_config(dataset_dir)

# Create data loaders.
dataloader_params = {
    'batch_size': 8,
    'shuffle': False,
    'num_workers': 6}

train_loader = get_land_cover_dataloader(dataset_dir, PartitionType.TRAIN, dataloader_params)
validation_loader = get_land_cover_dataloader(dataset_dir, PartitionType.VALIDATION, dataloader_params)
finetuning_loader = get_land_cover_dataloader(dataset_dir, PartitionType.FINETUNING, dataloader_params)
test_loader = get_land_cover_dataloader(dataset_dir, PartitionType.TEST, dataloader_params)

In [210]:
def train(model, optimizer, loss_fn, train_loader, metrics, log_steps):
    model.train()
    metric_reports = []
    loss_avg = 0.0
    for i, data in enumerate(train_loader):
        batch_x, batch_y = data
        batch_x, batch_y = batch_x.to(device), batch_y.to(device) 
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = loss_fn(outputs, batch_y)
        loss.backward()
        optimizer.step()
        loss_avg += loss.item()
        
        if i % log_steps == 0:
            outputs = outputs.data.cpu().numpy()
            batch_y = batch_y.data.cpu().numpy()
            report = {m: metrics[m](outputs, batch_y) for m in metrics}
            report['loss'] = loss.item()
            metric_reports.append(report)
            print("Train: ", report['loss'])
    
    loss_avg /= float(i)
    return loss_avg, metric_reports

def evaluate(model, loss_fn, validation_loader, metrics):
    model.eval()
    metric_reports = []
    loss_avg = 0.0
    for i, data in enumerate(validation_loader):
        batch_x, batch_y = data
        batch_x, batch_y = batch_x.to(device), batch_y.to(device) 
        outputs = model(batch_x)
        loss = loss_fn(outputs, batch_y)
        outputs = outputs.data.cpu().numpy()
        batch_y = batch_y.data.cpu().numpy()
        report = {m: metrics[m](outputs, batch_y) for m in metrics}
        report['loss'] = loss.item()
        metric_reports.append(report)
    
    loss_avg /= float(i)
    return loss_avg, metric_reports

def accuracy_fn(output, labels, target=None):
    predicts = nn.functional.softmax(outputs, dim=1).argmax(dim=1)

    if target is None:
        # Calculate accuracy over all classes.
        return np.sum(predicts == labels) / np.prod(labels.shape)

In [218]:
# Clear cache and create model.
torch.cuda.empty_cache()
gc.collect()

params = {
    'max_epochs': 5,
    'n_classes': 4,
    'in_channels': 4,
    'depth': 4,
    'learning_rate': 0.01,
    'momentum': 0.8,
    'log_steps': 25
}

metrics = {
    "accuracy": accuracy_fn
}
    
model = UNet(in_channels = params['in_channels'],
           n_classes = params['n_classes'],
           depth = params['depth'])
model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), 
                            lr=params['learning_rate'],
                            momentum=params['momentum'])

In [None]:
# Fine tune with back prop
for epoch in range(params['max_epochs']):
    train_loss, train_reports = train(model, optimizer, loss_fn,
                                      train_loader, metrics, params['log_steps'])
    validate_loss, validate_reports = evaluate(model, loss_fn, validation_loader, metrics)

for epoch in range(params['max_epochs']):
    finetuning_loss, finetuning_reports = train(model, optimizer, loss_fn,
                                                finetuning_loader, metrics, params['log_steps'])
    test_loss, test_reports = evaluate(model, loss_fn, test_loader, metrics)

In [None]:
# Fine tune with learned dropout
dropout_masks = {
    'start': [256, 256],
#     'down_0': None,
#     'down_1': None,
#     'down_2': None,
#     'down_3': None,
#     'down_4': None,
#     'up_0': None,
#     'up_1': None,
#     'up_2': None,
#     'up_3': None,
#     'end': None,   
}

finetuning_params = {
    "n_generations": 100
    "n_children": 10
}

evolver = MatrixEvolver([m for k, m in dropout_masks.items() if m is not None],
                        CrossoverType.UNIFORM, MutationType.FLIP_BIT)

model.eval()
for generation in range(finetuning_params['n_generations']):
    for child in range(finetuning_params['n_children'])
        child_masks = evolver.spawn_child()
        model.set_dropout_masks({k: torch.tensor(child_masks[i], 
                                                 device=device,
                                                 dtype=torch.float) for i, k in enumerate(dropout_masks.keys())})
        total_loss = 0
        for i, data in enumerate(finetuning_loader):
            batch_x, batch_y = data
            batch_x, batch_y = batch_x.to(device), batch_y.to(device) 
            optimizer.zero_grad()
            outputs = model(batch_x)
            loss = loss_fn(outputs, batch_y)
            total_loss += loss.item()
        
        evolver.add_child(child_masks, total_loss)
    
    evolver.update_parents()

    
child_masks = evolver.spawn_child()
model.set_dropout_masks({k: torch.tensor(child_masks[i], 
                                         device=device,
                                         dtype=torch.float) for i, k in enumerate(dropout_masks.keys())})
test_loss, test_reports = evaluate(model, loss_fn, test_loader, metrics)