In [1]:
import torch
from batchflow.opensets import MNIST
from batchflow.models.torch import UNet
from batchflow import Pipeline, P, R, V, B

from batch import MyBatch

In [None]:
dataset = MNIST(batch_class=MyBatch)

In [3]:
bg_shape=(128, 128)
augmentation = (Pipeline()
                 .mask()
                 .custom_rotate(angle=P(R('uniform', -35, 35)))
                 .background_and_mask(bg_shape)
                 .invert(channels=P(R('randint', 0, 3)), p=0.5)
                 .noise(n=1)
                 .custom_to_array(src=['images', 'masks'], dst=['images', 'masks'])
               )

In [4]:
inputs_config = {
    'images': {'shape': (3, *bg_shape)},
    'masks': {'shape': bg_shape,
              'classes': 2,
              'data_format': 'f',
              'name': 'targets'}
    }

In [8]:
%env CUDA_VISIBLE_DEVICES=1
w = torch.Tensor([1., 20.])
w = w.to('cuda')
config = {
    'build': 'first',
    'load': dict(path='model'),
    'inputs': inputs_config,
    'initial_block/inputs': 'images',
    'loss': {'name':'ce', 'weight': w},
    'optimizer': ('Adam', {'lr': 0.0001}),
    'device': 'cpu',
    'head/num_classes': 2, 
    'body/num_blocks': 4,
     
}

In [9]:
train_pipeline = (augmentation
                  + Pipeline()
                      .init_model('dynamic', UNet, 'my_model', config)
                      .init_variable('loss', init_on_each_run=list)
                      .train_model('my_model', B('images'), B('masks'),
                                   fetches='loss', save_to=V('loss'), mode='a')
                 ) << dataset.train 

In [None]:
train_pipeline.run(32, shuffle=True, n_epochs=25)

In [None]:
train_pipeline.save_model('my_model', path='model')