In [None]:
!nvidia-smi

# Init

In [None]:
!git clone --depth=1 https://github.com/AdeelH/potsdam-batch-exp.git

In [27]:
BATCH_SIZE = 16
PRETRAINED = True
LAST_CROSS = False

MODEL_ARCH = 'unet'
MODEL_BASE = 'resnet18'
CHANNEL_VARIATION = 'rgb_e'
MODEL_VARIATION = f'lc_{LAST_CROSS}'
MODEL_VARIATION += f'_ensemble_bn'


In [28]:
EXPERIMENT_NAME = f'ss_{CHANNEL_VARIATION}_{MODEL_ARCH}_{MODEL_BASE}{"p" if PRETRAINED else ""}_{MODEL_VARIATION}_bsz_{BATCH_SIZE}'
print(EXPERIMENT_NAME)

S3_BUCKET = 'raster-vision-ahassan'
S3_ROOT = f'potsdam/experiments/output/{EXPERIMENT_NAME}'

ss_rgb_e_unet_resnet18p_lc_False_ensemble_bn_bsz_16


In [29]:
EPOCHS = 40
LR_START = 1e-2
LR_END = 1e-4

MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4

CHECKPOINT_INTERVAL = 10
BATCH_CB_INTERVAL = 10

In [30]:
import os
from pathlib import Path
import glob
from datetime import datetime

import numpy as np
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
import torch.utils
import torchvision as tv
from torchvision import transforms as tf

from fastai.vision import *
from fastai.metrics import error_rate

import matplotlib.pyplot as plt
%reload_ext autoreload
%autoreload 2
# %matplotlib inline

In [31]:
import sys
# sys.path.append('potsdam-batch-exp/')
sys.path.append('../')

In [32]:
from utils_ import *
from io_ import *
from data_ import *
from transforms import *
from models import *
from training import *
from visualizations import *
from monitoring import *
%reload_ext autoreload
%autoreload 2

In [33]:
CLASS_NAMES = ['building', 'tree', 'low-vegetation', 'clutter', 'car', 'pavement']
NCLASSES = len(CLASS_NAMES)

In [34]:
io_handler = S3IoHandler(
    local_root=EXPERIMENT_NAME, 
    s3_bucket=S3_BUCKET, 
    s3_root=S3_ROOT
)

# Data

In [36]:
# potsdam_dict = io_handler.load_pickled_file('potsdam/data/potsdam.pkl', 'data/potsdam.pkl')
with open('../../potsdam/data/potsdam.pkl', 'rb') as f:
    potsdam_dict = pickle.load(f)


## Prepare datasets

In [37]:
CHANNELS = [ch_R, ch_G, ch_B, ch_E]
CHIP_SIZE = 400
STRIDE = 200
DOWNSAMPLING = 2

TRAIN_SPLIT = 0.85

In [38]:
train_transform, val_transform, x_transform, y_transform = tfs_potsdam(channels=CHANNELS, downsampling=DOWNSAMPLING)

In [39]:
original_ds = Potsdam(potsdam_dict, chip_size=CHIP_SIZE, stride=STRIDE, tf=val_transform)
train_ds    = Potsdam(potsdam_dict, chip_size=CHIP_SIZE, stride=STRIDE, tf=train_transform, x_tf=x_transform, y_tf=y_transform)
val_ds      = Potsdam(potsdam_dict, chip_size=CHIP_SIZE, stride=STRIDE, tf=val_transform  , x_tf=x_transform, y_tf=y_transform)

### Train/val split

In [40]:
train_split_size = int((len(train_ds) * TRAIN_SPLIT) // 1)
val_split_size = len(train_ds) - train_split_size
train_split_size, val_split_size

print('train_split_size', train_split_size)
print('val_split_size', val_split_size)

inds = np.arange(len(train_ds))

train_split_size 12867
val_split_size 2271


### Samplers

In [41]:
train_sampler = torch.utils.data.SubsetRandomSampler(inds[:train_split_size])
val_sampler = torch.utils.data.SubsetRandomSampler(inds[train_split_size:])

assert len(set(train_sampler.indices) & set(val_sampler.indices)) == 0

# Model

Use FastAI to create a UNet from a Resnet18

In [10]:
def _base_model(pretrained=False):
    m = tv.models.resnet18(pretrained=pretrained)
    return m

In [11]:
body = create_body(_base_model, pretrained=PRETRAINED)
rgb_model = models.unet.DynamicUnet(body, n_classes=NCLASSES, last_cross=LAST_CROSS).cuda()

In [12]:
def _base_model(pretrained=False):
    m = tv.models.resnet18(pretrained=pretrained)
    m._modules['conv1'] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    return m

In [13]:
body = create_body(_base_model, pretrained=PRETRAINED)
e_model = models.unet.DynamicUnet(body, n_classes=NCLASSES, last_cross=LAST_CROSS).cuda()

In [16]:
io_handler.load_model_weights(
    rgb_model, 
    s3_path='potsdam/experiments/output/ss_rgb_unet_resnet18p_lc_False_bsz_16/best_model/best_acc', 
    tgt_path='rgb_model'
)
io_handler.load_model_weights(
    e_model, 
    s3_path='potsdam/experiments/output/ss_e_unet_resnet18p_lc_False_bsz_16/best_model/ss_e_unet_resnet18p_lc_False_bsz_16', 
    tgt_path='e_model'
)

In [60]:
model = RGB_E_ensemble(rgb_model, e_model, nclasses=6).cuda()

In [61]:
freeze(model.rgb[0])
freeze(model.e[0])

In [26]:
# model.eval()
# with torch.no_grad():
#     print(model.cpu()(torch.rand(16, 4, 200, 200)).shape)

torch.Size([16, 6, 200, 200])


# Train

## Training monitoring callbacks

In [53]:

def get_epoch_monitor(io_handler, chkpt_interval=1, viz_root='visualizations/per_epoch'):
    assert chkpt_interval > 0

    def _monitor(model, logs):
        epoch = len(logs['epoch']) # epoch is now 1-indexed
        val_acc = logs['val_acc'][-1]
        val_loss = logs['val_loss'][-1]
        last_best_acc = logs['best_acc'][-1] if epoch > 1 else -1
        last_best_loss = logs['best_loss'][-1] if epoch > 1 else 1e8

        if epoch % chkpt_interval == 0:
            io_handler.save_model(model, f'checkpoints/epoch_%04d' % (epoch), info=logs)

        if val_acc > last_best_acc:
            logs['best_acc'].append(val_acc)
            io_handler.save_model(model, f'best_model/best_acc', info=logs)
        else:
            logs['best_acc'].append(last_best_acc)

        if val_loss < last_best_loss:
            logs['best_loss'].append(val_loss)
            io_handler.save_model(model, f'best_model/best_loss', info=logs)
        else:
            logs['best_loss'].append(last_best_loss)

        log_str = logs_to_str(logs)
        print(log_str)
        
        print('rgb bn weight', model.rgb[1].weight.data)
        print('rgb bn bias  ', model.rgb[1].bias.data)
        print('---')
        print('e bn weight', model.e[1].weight.data)
        print('e bn bias  ', model.e[1].bias.data)
        print('------')

    return _monitor

def get_batch_monitor(io_handler, viz_root='visualizations/per_batch', interval=4):

    def _monitor(model, epoch, batch_idx, batch, labels):
        pass

    return _monitor


In [54]:
epoch_callback = get_epoch_monitor(io_handler, chkpt_interval=CHECKPOINT_INTERVAL)
batch_callback = get_batch_monitor(io_handler, interval=BATCH_CB_INTERVAL)

In [55]:
train_params = {}
train_params['batch_size'] = BATCH_SIZE
train_params['val_batch_size'] = BATCH_SIZE

In [56]:
train_dl = torch.utils.data.DataLoader(train_ds, sampler=train_sampler, batch_size=train_params['batch_size']    , pin_memory=False)
val_dl   = torch.utils.data.DataLoader(val_ds  , sampler=val_sampler  , batch_size=train_params['val_batch_size'], pin_memory=False)

In [63]:
train_params = {}
train_params['epochs'] = 2
train_params['learning_rate'] = 1e-2
train_params['learning_rate_min'] = 1e-3

optimizer = optim.SGD(model.parameters(), lr=train_params['learning_rate'], momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
sched = optim.lr_scheduler.CosineAnnealingLR(optimizer, train_params['epochs'], eta_min=train_params['learning_rate_min'])

plt.ioff()
train_seg(model, train_dl, val_dl, optimizer, sched, train_params, 
          epoch_callback=epoch_callback, batch_callback=batch_callback)

epoch               : 0
lr                  : 0.01
train_loss          : 2.039399300139156e-07
val_loss            : 6.035407977833529e-07
train_acc           : 0.9506579637527466
val_acc             : 0.8818700909614563
train_time          : 305.422819852829
val_time            : 52.639190912246704
class_0_precision   : 0.9531056880950928
class_0_recall      : 0.9650421142578125
class_0_fscore      : 0.9626309871673584
class_1_precision   : 0.8701796531677246
class_1_recall      : 0.7945263981819153
class_1_fscore      : 0.8085860013961792
class_2_precision   : 0.7986161708831787
class_2_recall      : 0.84298175573349
class_2_fscore      : 0.8337185382843018
class_3_precision   : 0.7848424315452576
class_3_recall      : 0.5006275773048401
class_3_fscore      : 0.5397171378135681
class_4_precision   : 0.898605465888977
class_4_recall      : 0.893841028213501
class_4_fscore      : 0.8947898149490356
class_5_precision   : 0.8651732206344604
class_5_recall      : 0.9079312682151794
class_

In [65]:
model.e[0][-1]

Sequential(
  (0): Conv2d(96, 6, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)