In [1]:
# Check Pytorch installation
import mmseg, math
from PIL import Image
import numpy as np
import os.path as osp
import matplotlib.pyplot as plt
import mmcv
import torch, torchvision

print(torch.__version__, torch.cuda.is_available())
print(mmseg.__version__)

%cd /home/smlm-workstation/segmentation/mmsegmentation/

# split train/val set randomly
img_dir = 'images'
ann_dir = 'bit_masks'
classes = ('Background', 'Microtubule', 'Vesicle')
palette = [[128, 255, 0], [0, 255, 255]]

data_root = '/home/smlm-workstation/segmentation/data/full_combined_mt_cl/'

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

1.12.1 True
0.29.1
/home/smlm-workstation/segmentation/mmsegmentation


In [2]:
# split_dir = 'splits'
# mmcv.mkdir_or_exist(osp.join(data_root, split_dir))
# filename_list = [osp.splitext(filename)[0] for filename in mmcv.scandir(
#     osp.join(data_root, img_dir), suffix='.png')]
# with open(osp.join(data_root, split_dir, 'train.txt'), 'w') as f:
#   # select first 4/5 as train set
#   train_length = int(len(filename_list)*99/100)
#   f.writelines(line + '\n' for line in filename_list[:train_length])
# with open(osp.join(data_root, split_dir, 'val.txt'), 'w') as f:
#   # select last 1/5 as test set
#   f.writelines(line + '\n' for line in filename_list[train_length:])

In [3]:
# from mmseg.datasets.builder import DATASETS
# from mmseg.datasets.custom import CustomDataset

# @DATASETS.register_module()
# class SMLM_mt_ves(CustomDataset):
#   CLASSES = ('Background','Microtubule', 'Vesicle')
#   PALETTE = [[40,40,40], [128, 255, 0], [0, 255, 255]]
#   def __init__(self, split, **kwargs):
#     super().__init__(img_suffix='.png', seg_map_suffix='.png', 
#                      split=split,
#                      reduce_zero_label=False,
#                      **kwargs)
#     assert osp.exists(self.img_dir) and self.split is not None

In [4]:
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset


@DATASETS.register_module()
class SMLM_mt_ves(CustomDataset):
  CLASSES = ('Microtubule', 'Vesicle')
  PALETTE = [[128, 255, 0], [0, 255, 255]]

  def __init__(self, split, **kwargs):
    super().__init__(img_suffix='.png', seg_map_suffix='.png',
                     split=split,
                     reduce_zero_label=True,
                     **kwargs)
    assert osp.exists(self.img_dir) and self.split is not None


## SegFormer

In [5]:
# !wget https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b2_512x512_160k_ade20k/segformer_mit-b2_512x512_160k_ade20k_20220620_114047-64e4feca.pth -P /home/smlm-workstation/segmentation/mmsegmentation/checkpoints

In [6]:
from distutils.fancy_getopt import FancyGetopt
from mmseg.apis import set_random_seed
from mmcv import Config
cfg = Config.fromfile(
    'configs/segformer/segformer_mit-b1_512x512_160k_ade20k.py')

# Since we use only one GPU, BN is used instead of SyncBN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
# cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
# cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
# modify num classes of the model in decode/auxiliary head
cfg.model.decode_head.num_classes = 2

cfg.model.test_cfg = dotdict(
    mode='slide', crop_size=(256, 256), stride=(200, 200))

# cfg.model.test_cfg = dotdict(mode='slide', crop_size=(128, 128), stride=(127, 127))
# cfg.model.auxiliary_head.num_classes = 3

# cfg.model.test_cfg = dotdict(mode='slide', crop_size=(256, 256), stride=(1, 1))
# cfg.model.test_cfg = dotdict(
#     mode='whole')

# cfg.model.auxiliary_head.loss_decode = dict(
#     type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True, class_weight=[0.05, 0.55, 0.45])
cfg.model.decode_head.loss_decode = [dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
                                     dict(type='TverskyLoss', loss_name='TverskyLoss', loss_weight=3.0)]

# cfg.model.decode_head.loss_decode = dict(type='CrossEntropyLoss', use_sigmoid=False, class_weight=[0.3504681, 0.6460288])
# cfg.model.decode_head.loss_decode = dict(type='DiceLoss', use_sigmoid=True)
# cfg.model.decode_head.loss_decode = dict(type='DiceLoss', use_sigmoid=True, class_weight=[0.35, 0.64])
# cfg.model.decode_head.loss_decode = [dict(type='FocalLoss', use_sigmoid=True, alpha=.25)]

# cfg.model.decode_head.loss_decode = [dict(type='FocalLoss', use_sigmoid=True, alpha=.25, loss_weight=4., class_weight=[0.0035031103, 0.3504681, 0.6460288]),
#                                      dict(type='DiceLoss', loss_name='dice', loss_weight=1., class_weight=[0.0035031103, 0.3504681, 0.6460288])]

# cfg.model.decode_head.loss_decode = dict(
#     type='PhiLoss', loss_weight=1.0, gamma=0.5)

# cfg.model.decode_head.loss_decode = dict(
#     type='TverskyLoss', class_weight=[0.2, 0.3, 0.5])

# cfg.model.auxiliary_head.ignore_index = 0
# cfg.model.decode_head.ignore_index = 0

# Modify dataset type and path
cfg.dataset_type = 'SMLM_mt_ves'
cfg.data_root = data_root
cfg.reduce_zero_label = True

cfg.data.samples_per_gpu = 2
cfg.data.workers_per_gpu = 12

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)

cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', reduce_zero_label=True),
    # dict(type='Resize', img_scale=(1024, 1024), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=(256, 256), cat_max_ratio=0.95),
    # dict(type='RandomRotate', prob=0.5, degree=35),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **img_norm_cfg),
    # dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug', #RandomMosaic
        # img_scale=(1584, 1584),
        img_scale=None,
        img_ratios=[1.0],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]

cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
# cfg.data.train.reduce_zero_label = cfg.reduce_zero_label
cfg.data.train.img_dir = img_dir
cfg.data.train.ann_dir = ann_dir
cfg.data.train.pipeline = cfg.train_pipeline
cfg.data.train.split = 'splits/train.txt'

cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
# cfg.data.val.reduce_zero_label = cfg.reduce_zero_label
cfg.data.val.img_dir = img_dir
cfg.data.val.ann_dir = ann_dir
cfg.data.val.pipeline = cfg.test_pipeline
cfg.data.val.split = 'splits/val.txt'

cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
# cfg.data.test.reduce_zero_label = cfg.reduce_zero_label
cfg.data.test.img_dir = img_dir
cfg.data.test.ann_dir = ann_dir
cfg.data.test.pipeline = cfg.test_pipeline
cfg.data.test.split = 'splits/val.txt'

cfg.log_config = dict(  
    interval=2,  
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),
        # dict(type='TensorboardLoggerHook', by_epoch=False),
        # dict(type='NeptuneLoggerHook', by_epoch=False) 
        # MMSegWandbHook is mmseg implementation of WandbLoggerHook. ClearMLLoggerHook, DvcliveLoggerHook, MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook, SegmindLoggerHook are also supported based on MMCV implementation.
    ])

cfg.runner.max_iters = 17000
cfg.evaluation.interval = 100
cfg.checkpoint_config.interval = 4000

cfg.resume_from = 'work_dirs/segformer_b1_adamW_16k/iter_16000_palette.pth'
# cfg.laod_from = 'checkpoints/segformer_mit-b1_512x512_160k_ade20k_20220620_112037-c3f39e00.pth'
# cfg.load_from = 'checkpoints/deeplabv3plus_r18-d8_512x512_80k_potsdam_20211219_020601-75fd5bc3.pth'
# cfg.load_from = 'checkpoints/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen_20211231_230816-5040938d.pth'
# cfg.load_from = 'checkpoints/deeplabv3plus_r101-d8_512x512_80k_potsdam_20211219_031508-8b112708.pth'
# cfg.load_from = 'checkpoints/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen_20211231_230816-8a095afa.pth'

# Set up working dir to save files and logs.
cfg.work_dir = './work_dirs/segformer_b1_adamW_16k_testing2'

optimizer = dict(
    _delete_=True,
    type='AdamW',
    lr=0.000005,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    paramwise_cfg=dict(
        custom_keys={
            'pos_block': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.),
            'head': dict(lr_mult=10.)
        }))

lr_config = dict(
    _delete_=True,
    policy='poly',
    warmup='linear',
    warmup_iters=1000,
    warmup_ratio=1e-6,
    power=1.0,
    min_lr=0.0,
    by_epoch=False)

# cfg.optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0005,
#                      paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10.)}))

# cfg.optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))

# cfg.lr_config=dict(
#     policy='cyclic',
#     target_ratio=(10, 1e-4),
#     cyclic_times=1,
#     step_ratio_up=0.4,
# )
# cfg.momentum_config=dict(
#     policy='cyclic',
#     target_ratio=(0.85 / 0.95, 1),
#     cyclic_times=1,
#     step_ratio_up=0.4,
# )

# Set seed to facitate reproducing the result
cfg.seed = 42
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
cfg.device = 'cuda'
cfg.cudnn_benchmark = True
# cfg.model.pretrained

# Let's have a look at the final config used for training
# print(f'Config:\n{cfg.pretty_text}')

In [7]:
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor


# Build the dataset
datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_segmentor(cfg.model)
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
model.PALETTE = datasets[0].PALETTE

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_segmentor(model, datasets, cfg, distributed=False, validate=True, 
                meta=dict())

2022-11-24 10:38:18,163 - mmseg - INFO - Loaded 2076 images
2022-11-24 10:38:18,837 - mmseg - INFO - Loaded 21 images
2022-11-24 10:38:18,839 - mmseg - INFO - load checkpoint from local path: work_dirs/segformer_b1_adamW_16k/iter_16000_palette.pth
2022-11-24 10:38:18,963 - mmseg - INFO - resumed from epoch: 500, iter 15999
2022-11-24 10:38:18,964 - mmseg - INFO - Start running, host: smlm-workstation@smlmworkstation, work_dir: /home/smlm-workstation/segmentation/mmsegmentation/work_dirs/segformer_b1_adamW_16k_testing2
2022-11-24 10:38:18,965 - mmseg - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) PolyLrUpdaterHook                  
(NORMAL      ) CheckpointHook                     
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) PolyLrUpdaterHook                  
(LOW         ) IterTimerHook                      
(LOW         ) EvalHook    

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 21/21, 0.1 task/s, elapsed: 157s, ETA:     0s

2022-11-24 10:41:02,510 - mmseg - INFO - per class results:
2022-11-24 10:41:02,511 - mmseg - INFO - 
+-------------+-------+-------+
|    Class    |  IoU  |  Acc  |
+-------------+-------+-------+
| Microtubule | 94.04 | 97.52 |
|   Vesicle   | 89.39 | 93.36 |
+-------------+-------+-------+
2022-11-24 10:41:02,511 - mmseg - INFO - Summary:
2022-11-24 10:41:02,511 - mmseg - INFO - 
+-------+-------+-------+
|  aAcc |  mIoU |  mAcc |
+-------+-------+-------+
| 96.03 | 91.72 | 95.44 |
+-------+-------+-------+
2022-11-24 10:41:02,512 - mmseg - INFO - Iter(val) [21]	aAcc: 0.9603, mIoU: 0.9172, mAcc: 0.9544, IoU.Microtubule: 0.9404, IoU.Vesicle: 0.8939, Acc.Microtubule: 0.9752, Acc.Vesicle: 0.9336
2022-11-24 10:41:02,614 - mmseg - INFO - Iter [16002/17000]	lr: 3.526e-06, eta: 15:40:14, time: 78.636, data_time: 78.587, memory: 1559, decode.loss_ce: 0.0624, decode.TverskyLoss: 0.3626, decode.acc_seg: 92.7455, loss: 0.0624
2022-11-24 10:41:02,701 - mmseg - INFO - Iter [16004/17000]	lr: 3.51

[                                                  ] 0/21, elapsed: 0s, ETA:

In [None]:
# model = torch.load(
#     '/home/smlm-workstation/segmentation/mmsegmentation/work_dirs/segformer_b1_adamW_16k/iter_16000.pth')
# model['meta']['PALETTE'] = [[128, 255, 0], [0, 255, 255]]
# torch.save(
#         model, '/home/smlm-workstation/segmentation/mmsegmentation/work_dirs/segformer_b1_adamW_16k/iter_16000_palette.pth')

In [None]:
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from skimage.io import imread, imshow, imsave
# model.cfg = cfg

model = init_segmentor(
    cfg, checkpoint='/home/smlm-workstation/segmentation/mmsegmentation/work_dirs/segformer_b1_adamW_16k/iter_16000_palette.pth')

im_list = [
           '/home/smlm-workstation/segmentation/data/archive/mt_cl/Wue_MT_clathrin_647_mixed_2_crop.png',
           '/home/smlm-workstation/segmentation/data/archive/mt_cl/Wue_MT_clathrin_647_mixed_2.png',
           '/home/smlm-workstation/segmentation/data/archive/mt_cl/Wue_MT_clathrin_647_mixed_3_crop.png',
           '/home/smlm-workstation/segmentation/data/archive/mt_cl/Wue_MT_clathrin_647_mixed_3.png',
           '/home/smlm-workstation/segmentation/data/archive/mt_cl/Wue_MT_clathrin_647_mixed_7_crop.png',
           '/home/smlm-workstation/segmentation/data/archive/mt_cl/Wue_MT_clathrin_647_mixed_7.png',
           '/home/smlm-workstation/segmentation/data/archive/mt_cl/Wue_MT_clathrin_647_mixed_6_crop.png',
           '/home/smlm-workstation/segmentation/data/archive/mt_cl/Wue_MT_clathrin_647_mixed_6.png'
          ]
i = 0
for im in im_list:
    img = mmcv.imread(im)
    img2 = mmcv.imread(im, flag='grayscale')
   #  img2 = imread(im, as_gray = True)
    result = inference_segmentor(model, img)
    mt, cl = np.zeros(shape=(img.shape[:2])), np.zeros(shape=(img.shape[:2]))
    mmt, mcl = np.array(result[0] == 0), np.array(result[0] == 1)
    mt[mmt] = img2[mmt].astype(np.uint8)
    cl[mcl] = img2[mcl].astype(np.uint8)
    imsave(
        f'/home/smlm-workstation/segmentation/data/results/segformerb1_reduce0_tversk_CE_256px_mIoU_57_16k_res2/{i}_MT.png', mt, check_contrast=False)
    imsave(
        f'/home/smlm-workstation/segmentation/data/results/segformerb1_reduce0_tversk_CE_256px_mIoU_57_16k_res2/{i}_CL.png', cl, check_contrast=False)  
     
    # plt.figure(figsize=(12, 8))
    show_result_pyplot(model, img, result, palette, opacity=0.3, 
                    out_file=f'/home/smlm-workstation/segmentation/data/results/segformerb1_reduce0_tversk_CE_256px_mIoU_57_16k_res2/{i}_segmap.png')
    i += 1
