In [None]:
!nvidia-smi

# Init

In [35]:
!git clone --depth=1 https://github.com/AdeelH/potsdam-batch-exp.git

Cloning into 'potsdam-batch-exp'...
remote: Enumerating objects: 16, done.[K
remote: Counting objects: 100% (16/16), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 16 (delta 2), reused 14 (delta 1), pack-reused 0[K
Unpacking objects: 100% (16/16), done.


In [1]:
BATCH_SIZE = 16
PRETRAINED = True

MODEL_ARCH = 'deeplab'
MODEL_BASE = 'resnet101'
CHANNEL_VARIATION = 'rgbp_irep'
MODEL_VARIATION = f''

CUT = 7

In [2]:
EXPERIMENT_NAME = f'ss_{CHANNEL_VARIATION}_{MODEL_ARCH}_{MODEL_BASE}{"p" if PRETRAINED else ""}_merge_after_backbone_partial_{CUT}'
print(EXPERIMENT_NAME)

S3_BUCKET = 'raster-vision-ahassan'
S3_ROOT = f'potsdam/experiments/output/{EXPERIMENT_NAME}'

ss_rgbp_irep_deeplab_resnet101p_merge_after_backbone_partial_7


In [3]:
EPOCHS = 50
LR_START = 1e-2
LR_END = 1e-4

MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4

CHECKPOINT_INTERVAL = 5
BATCH_CB_INTERVAL = 10

In [4]:
import os
from pathlib import Path
import glob
from datetime import datetime

import numpy as np
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
import torch.utils
import torchvision as tv
from torchvision import transforms as tf

from fastai.vision import *
from fastai.metrics import error_rate

import matplotlib.pyplot as plt
%reload_ext autoreload
%autoreload 2
# %matplotlib inline

In [5]:
import sys
sys.path.append('potsdam-batch-exp/')

from utils_ import *
from io_ import *
from data_ import *
from transforms import *
from models import *
from training import *
from visualizations import *
from monitoring import *

In [6]:
CHANNELS = [ch_R, ch_G, ch_B, ch_IR, ch_E]
CHIP_SIZE = 400
STRIDE = 200
DOWNSAMPLING = 2

TRAIN_SPLIT = 0.85

In [7]:
CLASS_NAMES = ['building', 'tree', 'low-vegetation', 'clutter', 'car', 'pavement']
NCLASSES = len(CLASS_NAMES)

In [9]:
io_handler = S3IoHandler(
    local_root=EXPERIMENT_NAME, 
    s3_bucket=S3_BUCKET, 
    s3_root=S3_ROOT
)

# Data

In [None]:
potsdam_dict = io_handler.load_pickled_file('potsdam/data/potsdam.pkl', 'data/potsdam.pkl')

## Prepare datasets

In [None]:
train_transform, val_transform, x_transform, y_transform = tfs_potsdam(channels=CHANNELS, downsampling=DOWNSAMPLING)

In [None]:
original_ds = Potsdam(potsdam_dict, chip_size=CHIP_SIZE, stride=STRIDE, tf=val_transform)
train_ds    = Potsdam(potsdam_dict, chip_size=CHIP_SIZE, stride=STRIDE, tf=train_transform, x_tf=x_transform, y_tf=y_transform)
val_ds      = Potsdam(potsdam_dict, chip_size=CHIP_SIZE, stride=STRIDE, tf=val_transform  , x_tf=x_transform, y_tf=y_transform)

### Train/val split

In [None]:
train_split_size = int((len(train_ds) * TRAIN_SPLIT) // 1)
val_split_size = len(train_ds) - train_split_size
train_split_size, val_split_size

print('train_split_size', train_split_size)
print('val_split_size', val_split_size)

inds = np.arange(len(train_ds))

### Samplers

In [None]:
train_sampler = torch.utils.data.SubsetRandomSampler(inds[:train_split_size])
val_sampler = torch.utils.data.SubsetRandomSampler(inds[train_split_size:])

assert len(set(train_sampler.indices) & set(val_sampler.indices)) == 0

# Model

In [14]:
model = get_deeplab_custom(NCLASSES, in_channels=3, pretrained=True)
model_ire = get_deeplab_custom(NCLASSES, in_channels=2, pretrained=True)

In [32]:
name = 'ss_ire_deeplab_resnet101p'
io_handler.load_model_weights(
    model_ire, 
    s3_path=f'potsdam/experiments/output/{name}/best_model/best_acc', 
    tgt_path=f'models/{name}'
)

In [33]:
name = 'ss_rgb_deeplab_resnet101p'
io_handler.load_model_weights(
    model, 
    s3_path=f'potsdam/experiments/output/{name}/best_model/best_acc', 
    tgt_path=f'models/{name}'
)

In [15]:
model.m.backbone = DeeplabDoublePartialBackbone(model.m.backbone, model_ire.m.backbone, CUT)
model = model.cuda()

In [20]:
freeze(model.m.backbone.head)

# Train

## Training monitoring callbacks

In [None]:

def get_epoch_monitor(io_handler, chkpt_interval=1, viz_root='visualizations/per_epoch'):
    assert chkpt_interval > 0

    def _monitor(model, optimizer, sched, logs):
        epoch = logs['epoch'][-1]

        track_best_model(io_handler, model, logs)

        if epoch % chkpt_interval == 0:
            io_handler.save_checkpoint(model, optimizer, sched, f'checkpoints/epoch_%04d' % (epoch), info=logs)

        log_str = logs_to_str(logs)
        print(log_str)

        io_handler.save_log('logs.pkl', logs)
        io_handler.save_log_str(f'logs.txt', log_str)
        
        make_plots(io_handler, logs)

        if epoch >= 9:
            unfreeze(model)

    return _monitor

def get_batch_monitor(io_handler, viz_root='visualizations/per_batch', interval=4):

    def _monitor(model, epoch, batch_idx, batch, labels):
        pass

    return _monitor


In [None]:
epoch_callback = get_epoch_monitor(io_handler, chkpt_interval=CHECKPOINT_INTERVAL)
batch_callback = get_batch_monitor(io_handler, interval=BATCH_CB_INTERVAL)

In [None]:
train_params = {}
train_params['batch_size'] = BATCH_SIZE
train_params['val_batch_size'] = BATCH_SIZE

In [None]:
train_dl = torch.utils.data.DataLoader(train_ds, sampler=train_sampler, batch_size=train_params['batch_size']    , pin_memory=False)
val_dl   = torch.utils.data.DataLoader(val_ds  , sampler=val_sampler  , batch_size=train_params['val_batch_size'], pin_memory=False)

In [None]:
train_params = {}
train_params['epochs'] = EPOCHS
train_params['learning_rate'] = LR_START
train_params['learning_rate_min'] = LR_END

optimizer = optim.SGD(model.parameters(), lr=train_params['learning_rate'], momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
sched = optim.lr_scheduler.CosineAnnealingLR(optimizer, train_params['epochs'], eta_min=train_params['learning_rate_min'])


In [None]:
from collections import defaultdict

if io_handler.checkpoint_exists():
    logs = restore_training_state(io_handler, model, optimizer, sched, train_params)
else:
    logs = defaultdict(list)

print(logs)

In [None]:
plt.ioff()
train_seg(model, train_dl, val_dl, optimizer, sched, train_params, 
          epoch_callback=epoch_callback, batch_callback=batch_callback, logs=logs)