In [1]:
"""
UNet configs comparison: from batchflow, from batchflow with fix, mine
"""

'\nUNet configs comparison: from batchflow, from batchflow with fix, mine\n'

In [2]:
# f1_score is incorrect now
# add weights in ce loss ?

In [3]:
import sys
sys.path.append("..")

In [4]:
import os
import shutil

import warnings
warnings.filterwarnings('ignore')

import numpy as np
import batchflow
from batchflow import Pipeline, B, C, V, D
from batchflow.opensets import PascalSegmentation
from batchflow.research import Research, Option, Domain, Results, PrintLogger, RP, REP, KV
from batchflow.models.metrics import Loss
from batchflow.models.torch import UNet, EncoderDecoder

In [5]:
BATCH_SIZE = 64
ITERATIONS = 2000
# N_EPOCHS = 100
N_REPS = 10 
IMAGE_SHAPE = (160, 160)
GPU_ids = [3, 4, 6]
dataset = PascalSegmentation(bar='n')

In [6]:
def clear_previous_results(res_name):
    if os.path.exists(res_name):
        shutil.rmtree(res_name)

# Models configs

In [7]:
num_classes = 2

task_config = {
    'inputs/targets/classes': num_classes,  
    'head/layout': 'c',
    'head/filters': num_classes,
    'head/kernel_size': 1,
    'loss': 'ce',
    'optimizer': 'Adam'
}

# UNet from batchflow
UNet_bf = {'inputs': {'targets': {'classes': num_classes}},
 'placeholder_batch_size': 2,
 'device': None,
 'benchmark': True,
 'profile': False,
 'microbatch': None,
 'sync_frequency': 1,
 'optimizer': 'Adam',
 'decay': None,
 'amp': True,
 'sam_rho': 0.0,
 'sam_individual_norm': True,
 'order': ['initial_block', 'body', 'head'],
 'initial_block': {},
 'body': {'encoder': {'downsample': {'layout': 'p',
    'pool_size': 2,
    'pool_strides': 2},
   'num_stages': 4,
   'order': ['block', 'skip', 'downsampling'],
   'blocks': {'base': batchflow.models.torch.blocks.DefaultBlock,
    'layout': 'cna cna',
    'kernel_size': 3,
    'filters': [64, 128, 256, 512]}},
  'decoder': {'skip': True,
   'num_stages': None,
   'factor': None,
   'upsample': {'layout': 'tna'},
   'combine': {'op': 'concat', 'leading_index': 1},
   'order': ['upsampling', 'combine', 'block'],
   'blocks': {'base': batchflow.models.torch.blocks.DefaultBlock,
    'layout': 'cna cna',
    'kernel_size': 3,
    'filters': [512, 256, 128, 64]}},
  'embedding': {'base': batchflow.models.torch.blocks.DefaultBlock,
   'layout': 'cna cna',
   'kernel_size': 3,
   'filters': 1024}},
 'head': {'layout': 'c',
  'filters': num_classes,
  'kernel_size': 1,
  'target_shape': None,
  'classes': num_classes,
  'units': num_classes},
 'common': {'data_format': 'channels_first'},
 'predictions': None,
 'output': None,
 'loss': 'ce'}

# fixed UNet from batchflow
config_bf_with_fix = UNet_bf.copy()
config_bf_with_fix['body/decoder/upsample'] = dict(layout='tna', filters=[512, 256, 128, 64])

In [8]:
# my_UNet
downsample_depth = 4

my_config = {
    'body/encoder': {
        'num_stages': downsample_depth,
        'order': ['block', 'skip', 'downsampling']
    },    
    'body/encoder/blocks': {
        'layout': 'cna cna',
        'filters': [32*pow(2, i) for i in range(1, downsample_depth+1)]
    },
    'body/encoder/downsample': {
        'layout': 'p'
    },    
    
    'body/embedding': {
        'layout': 'cna cna', 
        'filters': [64*pow(2, downsample_depth), 64*pow(2, downsample_depth)]
    },   

    'body/decoder': {
        'num_stages': downsample_depth,
        'order': ['upsampling', 'combine', 'block']
    },
    'body/decoder/upsample': {
        'layout': 'tna',
        'filters': [32*pow(2, i) for i in range(downsample_depth, -1, -1)]
    },
    'body/decoder/combine': {
        'op': 'concat'
    },
    'body/decoder/blocks': {
        'layout': 'cna cna',
        'filters': [64*pow(2, i-1) for i in range(downsample_depth, -1, -1)]
    }
}
my_config.update(task_config)

# train

In [9]:
def process_mask(x):
    x = np.squeeze(x)
    np.place(x, x != 0, 1)
    return x

train_ppl = (dataset.train.p
    .init_variable('train_loss', [])
    .init_model('dynamic', C('model'), 'model', config=C('config'))
    .resize(size=IMAGE_SHAPE, src=['images', 'labels'], dst=['images', 'labels'])
    .to_array(channels='first', src=['images', 'labels'], dst=['images', 'labels'])
    .process_mask(B('labels'), save_to=B('labels'))
    .train_model('model', B('images'), B('labels'),
                fetches='loss', save_to=V('train_loss', mode='a'))
    .run_later(BATCH_SIZE, shuffle=True, n_epochs=None)
)

In [10]:
configs = [KV(task_config, "config_bf"), KV(config_bf_with_fix, "config_bf_with_fix"), KV(my_config, "my_config")]
domain = Option('model', [UNet, UNet, EncoderDecoder]) @ Option('config', configs)
list(domain.iterator)

# configs = [KV(my_config, "my_config")]
# domain = Option('model', [EncoderDecoder]) @ Option('config', configs)
# list(domain.iterator)

[ConfigAlias({'model': 'UNet', 'config': 'config_bf', 'repetition': '0'}),
 ConfigAlias({'model': 'UNet', 'config': 'config_bf_with_fix', 'repetition': '0'}),
 ConfigAlias({'model': 'EncoderDecoder', 'config': 'my_config', 'repetition': '0'})]

# performance

In [11]:
test_ppl = (dataset.test.p
                .import_model('model', C('import_from'))
                .init_variable('metrics')
                .init_variable('predictions')
                .resize(size=IMAGE_SHAPE, src=['images', 'labels'], dst=['images', 'labels'])
                .to_array(channels='first', src=['images', 'labels'], dst=['images', 'labels'])
                .process_mask(B('labels'), save_to=B('labels'))                
                .predict_model('model', B('images'), fetches='predictions',
                               save_to=V('predictions'))
                .gather_metrics('classification', B('labels'), V('predictions'),
                                          axis=1, fmt='logits', num_classes=num_classes,
                                          save_to=V('metrics', mode='update'))
                .run_later(BATCH_SIZE, shuffle=False, n_epochs=1)
            )

In [12]:
metrics = ['acc', 'precision', 'recall', 'f1_score', 'iou']

In [13]:
TEST_EXECUTE_FREQ = ['#0', 100, 'last']

res_name = 'UNet_pascal_train_test_research'
clear_previous_results(res_name)

research = (Research()
#             .add_logger('print')
            .init_domain(domain, n_reps=N_REPS)
            .add_pipeline(train_ppl, variables='train_loss', name='train_ppl')
            .add_pipeline(test_ppl, name='test_ppl',
                         execute=TEST_EXECUTE_FREQ, run=True, import_from=RP('train_ppl'))
            .get_metrics(pipeline='test_ppl', metrics_var='metrics', metrics_name=metrics,
                         returns=metrics, execute=TEST_EXECUTE_FREQ)
           )



research.run(n_iters=ITERATIONS, name=res_name, bar=True, workers=len(GPU_ids), devices=GPU_ids)

Research UNet_pascal_train_test_research is starting...


Domain updated: 0: 100%|██████████| 30000/30000.0 [1:42:42<00:00,  4.87it/s]


<batchflow.research.research.Research at 0x7fc6d90ee908>

In [14]:
df = research.load_results().df

In [15]:
df.loc[df.name=='test_ppl_metrics'].sort_values('acc', ascending=False).head()

Unnamed: 0,name,train_loss,acc,precision,recall,f1_score,iou,iteration,sample_index,model,config,repetition,update
19018,test_ppl_metrics,,0.721753,0.847173,0.901243,0.870581,0.774733,999,68842025,UNet,config_bf_with_fix,7,0
18017,test_ppl_metrics,,0.719331,0.842359,0.901026,0.86813,0.770513,999,2317255566,UNet,config_bf,8,0
5004,test_ppl_metrics,,0.7182,0.816544,0.901593,0.854246,0.749405,999,3621896855,UNet,config_bf_with_fix,6,0
2001,test_ppl_metrics,,0.715894,0.86353,0.899177,0.878625,0.786879,999,3829628415,EncoderDecoder,my_config,0,0
20019,test_ppl_metrics,,0.715492,0.852308,0.899025,0.871872,0.777309,999,1637505446,UNet,config_bf,7,0


In [16]:
tmp = df.loc[df.name=='test_ppl_metrics']

In [17]:
tmp.groupby('config')[metrics].describe().T

Unnamed: 0,config,config_bf,config_bf_with_fix,my_config
acc,count,10.0,10.0,10.0
acc,mean,0.7083,0.704883,0.700597
acc,std,0.01038,0.011878,0.012718
acc,min,0.68856,0.684256,0.679182
acc,25%,0.703773,0.696979,0.694473
acc,50%,0.714114,0.706729,0.69974
acc,75%,0.715136,0.711273,0.712705
acc,max,0.719331,0.721753,0.715894
precision,count,10.0,10.0,10.0
precision,mean,0.835276,0.820412,0.813224
