In [1]:
# Install PyTorch
!pip install torch==1.12.0 torchvision --extra-index-url https://download.pytorch.org/whl/cu113
!pip install -U openmim
# Install mmengine
!mim install mmengine
# Install MMCV
!mim install 'mmcv >= 2.0.0'

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu113
Collecting torch==1.12.0
  Downloading https://download.pytorch.org/whl/cu113/torch-1.12.0%2Bcu113-cp310-cp310-linux_x86_64.whl (1837.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 GB[0m [31m506.2 kB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch
  Attempting uninstall: torch
    Found existing installation: torch 2.0.0
    Uninstalling torch-2.0.0:
      Successfully uninstalled torch-2.0.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchdata 0.6.0 requires torch==2.0.0, but you have torch 1.12.0+cu113 which is incompatible.[0m[31m
[0mSuccessfully installed torch-1.12.0+cu113
[0mCollecting openmim
  Downloading openmim-0.3.7-py2.py3-none-any.whl (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
!git clone https://github.com/open-mmlab/mmsegmentation.git
%cd mmsegmentation
!pip install -v -e .

Cloning into 'mmsegmentation'...
remote: Enumerating objects: 14947, done.[K
remote: Counting objects: 100% (525/525), done.[K
remote: Compressing objects: 100% (348/348), done.[K
remote: Total 14947 (delta 192), reused 350 (delta 160), pack-reused 14422[K
Receiving objects: 100% (14947/14947), 20.63 MiB | 31.81 MiB/s, done.
Resolving deltas: 100% (10491/10491), done.
/kaggle/working/mmsegmentation
Using pip 23.0.1 from /opt/conda/lib/python3.10/site-packages/pip (python 3.10)
Obtaining file:///kaggle/working/mmsegmentation
  Running command python setup.py egg_info
  running egg_info
  creating /tmp/pip-pip-egg-info-3w41_p11/mmsegmentation.egg-info
  writing /tmp/pip-pip-egg-info-3w41_p11/mmsegmentation.egg-info/PKG-INFO
  writing dependency_links to /tmp/pip-pip-egg-info-3w41_p11/mmsegmentation.egg-info/dependency_links.txt
  writing requirements to /tmp/pip-pip-egg-info-3w41_p11/mmsegmentation.egg-info/requires.txt
  writing top-level names to /tmp/pip-pip-egg-inf

In [3]:
import mmseg
print(mmseg.__version__)

1.0.0


In [4]:
from mmcv.transforms import BaseTransform, TRANSFORMS

@TRANSFORMS.register_module()
class ConvertPixel(BaseTransform):
    def __init__(self):
        super().__init__()

    def transform(self, results: dict) -> dict:
        img_seg = results['gt_seg_map']
        img_seg[img_seg == 0] = 0
        img_seg[img_seg == 255] = 1
        return results

In [5]:
%%writefile /kaggle/working/mmsegmentation/mmseg/datasets/isicdataset.py
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset

@DATASETS.register_module()
class ISICDataset(BaseSegDataset):

    METAINFO = dict(
        classes=('background', 'mask'),
        palette=[[0, 0, 0], [255, 255, 255]])

    def __init__(self,
                 img_suffix='.jpg',
                 seg_map_suffix='_segmentation.png',
                 **kwargs) -> None:
        super().__init__(
            img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)

Writing /kaggle/working/mmsegmentation/mmseg/datasets/isicdataset.py


In [6]:
with open('/kaggle/working/mmsegmentation/mmseg/datasets/__init__.py', 'r+') as f:
    content = f.read()
    f.seek(0, 0)
    f.write("from .isicdataset import ISICDataset\n" + content)

In [7]:
%%writefile /kaggle/working/mmsegmentation/configs/_base_/datasets/isic_dataset.py
# dataset settings
dataset_type = 'ISICDataset'
data_root = '/kaggle/input/isic-2018/data'
img_scale = (2048,1024)
crop_size = (512,512)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='ConvertPixel'),
    dict(
        type='RandomResize',
        scale=img_scale,
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='GenerateEdge', edge_width=4),
    dict(type='PackSegInputs')
]

val_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=img_scale, keep_ratio=True),
    # add loading annotation after ``Resize`` because ground truth
    # does not need to do resize data transform
    dict(type='LoadAnnotations'),
    dict(type='ConvertPixel'),
    dict(type='PackSegInputs')
]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=img_scale, keep_ratio=True),
    # add loading annotation after ``Resize`` because ground truth
    # does not need to do resize data transform
    dict(type='LoadAnnotations'),
    dict(type='ConvertPixel'),
    dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]

tta_pipeline = [
    dict(type='LoadImageFromFile', backend_args=None),
    dict(
        type='TestTimeAug',
        transforms=[
            [
                dict(type='Resize', scale_factor=r, keep_ratio=True)
                for r in img_ratios
            ],
            [
                dict(type='RandomFlip', prob=0., direction='horizontal'),
                dict(type='RandomFlip', prob=1., direction='horizontal')
            ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
        ])
]

train_dataloader = dict(
    batch_size=4,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type='RepeatDataset',
        times=40000,
        dataset=dict(
            type=dataset_type,
            data_root=data_root,
            data_prefix=dict(
                img_path='images/train',
                seg_map_path='annotations/train'),
            pipeline=train_pipeline)))

val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='images/val',
            seg_map_path='annotations/val'),
        pipeline=val_pipeline))

test_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='images/test',
            seg_map_path='annotations/test'),
        pipeline=test_pipeline))

val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'], format_only=True, keep_results=True, output_dir='/kaggle/working/test')

Writing /kaggle/working/mmsegmentation/configs/_base_/datasets/isic_dataset.py


In [8]:
%%writefile /kaggle/working/mmsegmentation/mmseg/utils/class_names.py
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils import is_str


def cityscapes_classes():
    """Cityscapes class names for external use."""
    return [
        'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
        'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
        'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
        'bicycle'
    ]


def ade_classes():
    """ADE20K class names for external use."""
    return [
        'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
        'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
        'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
        'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
        'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
        'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
        'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
        'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
        'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
        'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
        'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
        'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
        'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
        'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
        'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
        'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
        'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
        'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
        'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
        'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
        'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
        'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
        'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
        'clock', 'flag'
    ]


def voc_classes():
    """Pascal VOC class names for external use."""
    return [
        'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
        'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
        'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
        'tvmonitor'
    ]


def cocostuff_classes():
    """CocoStuff class names for external use."""
    return [
        'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
        'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
        'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
        'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
        'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
        'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
        'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
        'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
        'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
        'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
        'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
        'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
        'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
        'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
        'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
        'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
        'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
        'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',
        'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',
        'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',
        'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper',
        'pavement', 'pillow', 'plant-other', 'plastic', 'platform',
        'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',
        'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',
        'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',
        'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',
        'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',
        'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
        'window-blind', 'window-other', 'wood'
    ]


def loveda_classes():
    """LoveDA class names for external use."""
    return [
        'background', 'building', 'road', 'water', 'barren', 'forest',
        'agricultural'
    ]


def potsdam_classes():
    """Potsdam class names for external use."""
    return [
        'impervious_surface', 'building', 'low_vegetation', 'tree', 'car',
        'clutter'
    ]


def vaihingen_classes():
    """Vaihingen class names for external use."""
    return [
        'impervious_surface', 'building', 'low_vegetation', 'tree', 'car',
        'clutter'
    ]


def isaid_classes():
    """iSAID class names for external use."""
    return [
        'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court',
        'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle',
        'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout',
        'Soccer_ball_field', 'plane', 'Harbor'
    ]


def stare_classes():
    """stare class names for external use."""
    return ['background', 'vessel']


def mapillary_v1_classes():
    """mapillary_v1 class names for external use."""
    return [
        'Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier',
        'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking',
        'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane', 'Sidewalk',
        'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
        'Other Rider', 'Lane Marking - Crosswalk', 'Lane Marking - General',
        'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water',
        'Banner', 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin',
        'CCTV Camera', 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
        'Phone Booth', 'Pothole', 'Street Light', 'Pole', 'Traffic Sign Frame',
        'Utility Pole', 'Traffic Light', 'Traffic Sign (Back)',
        'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car',
        'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer',
        'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'
    ]


def mapillary_v1_palette():
    """mapillary_v1_ palette for external use."""
    return [[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
            [180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255],
            [140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96],
            [230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232],
            [150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60],
            [255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128],
            [255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180],
            [190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30],
            [255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220],
            [220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40],
            [33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150],
            [210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80],
            [250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20],
            [119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142],
            [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
            [0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]]


def mapillary_v2_classes():
    """mapillary_v2 class names for external use."""
    return [
        'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block', 'Curb',
        'Fence', 'Guard Rail', 'Barrier', 'Road Median', 'Road Side',
        'Lane Separator', 'Temporary Barrier', 'Wall', 'Bike Lane',
        'Crosswalk - Plain', 'Curb Cut', 'Driveway', 'Parking',
        'Parking Aisle', 'Pedestrian Area', 'Rail Track', 'Road',
        'Road Shoulder', 'Service Lane', 'Sidewalk', 'Traffic Island',
        'Bridge', 'Building', 'Garage', 'Tunnel', 'Person', 'Person Group',
        'Bicyclist', 'Motorcyclist', 'Other Rider',
        'Lane Marking - Dashed Line', 'Lane Marking - Straight Line',
        'Lane Marking - Zigzag Line', 'Lane Marking - Ambiguous',
        'Lane Marking - Arrow (Left)', 'Lane Marking - Arrow (Other)',
        'Lane Marking - Arrow (Right)',
        'Lane Marking - Arrow (Split Left or Straight)',
        'Lane Marking - Arrow (Split Right or Straight)',
        'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk',
        'Lane Marking - Give Way (Row)', 'Lane Marking - Give Way (Single)',
        'Lane Marking - Hatched (Chevron)',
        'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other',
        'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)',
        'Lane Marking - Symbol (Other)', 'Lane Marking - Text',
        'Lane Marking (only) - Dashed Line', 'Lane Marking (only) - Crosswalk',
        'Lane Marking (only) - Other', 'Lane Marking (only) - Test',
        'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water',
        'Banner', 'Bench', 'Bike Rack', 'Catch Basin', 'CCTV Camera',
        'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Parking Meter',
        'Phone Booth', 'Pothole', 'Signage - Advertisement',
        'Signage - Ambiguous', 'Signage - Back', 'Signage - Information',
        'Signage - Other', 'Signage - Store', 'Street Light', 'Pole',
        'Pole Group', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Cone',
        'Traffic Light - General (Single)', 'Traffic Light - Pedestrians',
        'Traffic Light - General (Upright)',
        'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists',
        'Traffic Light - Other', 'Traffic Sign - Ambiguous',
        'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)',
        'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)',
        'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)',
        'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat',
        'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle',
        'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve',
        'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static', 'Unlabeled'
    ]


def mapillary_v2_palette():
    """mapillary_v2_ palette for external use."""
    return [[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32],
            [196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150],
            [250, 170, 33], [250, 170, 34], [128, 128, 128], [250, 170, 35],
            [102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170],
            [250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96],
            [230, 150, 140], [128, 64, 128], [110, 110, 110], [110, 110, 110],
            [244, 35, 232], [128, 196, 128], [150, 100, 100], [70, 70, 70],
            [150, 150, 150], [150, 120, 90], [220, 20, 60], [220, 20, 60],
            [255, 0, 0], [255, 0, 100], [255, 0, 200], [255, 255, 255],
            [255, 255, 255], [250, 170, 29], [250, 170, 28], [250, 170, 26],
            [250, 170, 25], [250, 170, 24], [250, 170, 22], [250, 170, 21],
            [250, 170, 20], [255, 255, 255], [250, 170, 19], [250, 170, 18],
            [250, 170, 12], [250, 170, 11], [255, 255, 255], [255, 255, 255],
            [250, 170, 16], [250, 170, 15], [250, 170, 15], [255, 255, 255],
            [255, 255, 255], [255, 255, 255], [255, 255, 255], [64, 170, 64],
            [230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152],
            [107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
            [100, 140, 180], [220, 128, 128], [222, 40, 40], [100, 170, 30],
            [40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255],
            [142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30],
            [250, 173, 30], [250, 174, 30], [250, 175, 30], [250, 176, 30],
            [210, 170, 100], [153, 153, 153], [153, 153, 153], [128, 128, 128],
            [0, 0, 80], [210, 60, 60], [250, 170, 30], [250, 170, 30],
            [250, 170, 30], [250, 170, 30], [250, 170, 30], [250, 170, 30],
            [192, 192, 192], [192, 192, 192], [192, 192, 192], [220, 220, 0],
            [220, 220, 0], [0, 0, 196], [192, 192, 192], [220, 220, 0],
            [140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100],
            [0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64],
            [0, 0, 110], [0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170],
            [32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81],
            [111, 111, 0], [0, 0, 0]]


def cityscapes_palette():
    """Cityscapes palette for external use."""
    return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
            [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
            [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
            [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
            [0, 0, 230], [119, 11, 32]]


def ade_palette():
    """ADE20K palette for external use."""
    return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
            [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
            [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
            [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
            [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
            [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
            [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
            [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
            [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
            [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
            [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
            [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
            [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
            [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
            [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
            [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
            [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
            [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
            [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
            [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
            [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
            [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
            [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
            [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
            [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
            [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
            [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
            [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
            [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
            [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
            [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
            [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
            [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
            [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
            [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
            [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
            [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
            [102, 255, 0], [92, 0, 255]]


def voc_palette():
    """Pascal VOC palette for external use."""
    return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
            [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
            [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
            [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
            [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]


def cocostuff_palette():
    """CocoStuff palette for external use."""
    return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
            [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
            [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
            [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
            [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
            [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
            [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], [0, 32, 0],
            [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0],
            [192, 128, 32], [128, 96, 128], [0, 0, 128], [64, 0, 32],
            [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],
            [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64],
            [192, 0, 32], [128, 96, 0], [128, 0, 192], [0, 128, 32],
            [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],
            [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64],
            [128, 128, 32], [192, 32, 128], [0, 64, 192], [0, 0, 32],
            [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128],
            [128, 192, 192], [0, 0, 160], [192, 160, 128], [128, 192, 0],
            [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96],
            [64, 160, 0], [0, 64, 0], [192, 128, 224], [64, 32, 0],
            [0, 192, 128], [64, 128, 224], [192, 160, 0], [0, 192, 0],
            [192, 128, 96], [192, 96, 128], [0, 64, 128], [64, 0, 96],
            [64, 224, 128], [128, 64, 0], [192, 0, 224], [64, 96, 128],
            [128, 192, 128], [64, 0, 224], [192, 224, 128], [128, 192, 64],
            [192, 0, 96], [192, 96, 0], [128, 64, 192], [0, 128, 96],
            [0, 224, 0], [64, 64, 64], [128, 128, 224], [0, 96, 0],
            [64, 192, 192], [0, 128, 224], [128, 224, 0], [64, 192, 64],
            [128, 128, 96], [128, 32, 128], [64, 0, 192], [0, 64, 96],
            [0, 160, 128], [192, 0, 64], [128, 64, 224], [0, 32, 128],
            [192, 128, 192], [0, 64, 224], [128, 160, 128], [192, 128, 0],
            [128, 64, 32], [128, 32, 64], [192, 0, 128], [64, 192, 32],
            [0, 160, 64], [64, 0, 0], [192, 192, 160], [0, 32, 64],
            [64, 128, 128], [64, 192, 160], [128, 160, 64], [64, 128, 0],
            [192, 192, 32], [128, 96, 192], [64, 0, 128], [64, 64, 32],
            [0, 224, 192], [192, 0, 0], [192, 64, 160], [0, 96, 192],
            [192, 128, 128], [64, 64, 160], [128, 224, 192], [192, 128, 64],
            [192, 64, 32], [128, 96, 64], [192, 0, 192], [0, 192, 32],
            [64, 224, 64], [64, 0, 64], [128, 192, 160], [64, 96, 64],
            [64, 128, 192], [0, 192, 160], [192, 224, 64], [64, 128, 64],
            [128, 192, 32], [192, 32, 192], [64, 64, 192], [0, 64, 32],
            [64, 160, 192], [192, 64, 64], [128, 64, 160], [64, 32, 192],
            [192, 192, 192], [0, 64, 160], [192, 160, 192], [192, 192, 0],
            [128, 64, 96], [192, 32, 64], [192, 64, 128], [64, 192, 96],
            [64, 160, 64], [64, 64, 0]]


def loveda_palette():
    """LoveDA palette for external use."""
    return [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
            [159, 129, 183], [0, 255, 0], [255, 195, 128]]


def potsdam_palette():
    """Potsdam palette for external use."""
    return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
            [255, 255, 0], [255, 0, 0]]


def vaihingen_palette():
    """Vaihingen palette for external use."""
    return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
            [255, 255, 0], [255, 0, 0]]


def isaid_palette():
    """iSAID palette for external use."""
    return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
            [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127,
                                                       127], [0, 0, 127],
            [0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191],
            [0, 127, 255], [0, 100, 155]]


def stare_palette():
    """STARE palette for external use."""
    return [[120, 120, 120], [6, 230, 230]]


def synapse_palette():
    """Synapse palette for external use."""
    return [[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], [0, 255, 255],
            [255, 0, 255], [255, 255, 0], [60, 255, 255], [240, 240, 240]]


def synapse_classes():
    """Synapse class names for external use."""
    return [
        'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney',
        'liver', 'pancreas', 'spleen', 'stomach'
    ]


def lip_classes():
    """LIP class names for external use."""
    return [
        'background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes',
        'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt',
        'face', 'leftArm', 'rightArm', 'leftLeg', 'rightLeg', 'leftShoe',
        'rightShoe'
    ]


def lip_palette():
    """LIP palette for external use."""
    return [
        'Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'UpperClothes',
        'Dress', 'Coat', 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt',
        'Face', 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe',
        'Right-shoe'
    ]

def isic_class():
    return ['background', 'mask']

def isic_palette():
    return [[0, 0, 0], [255, 255, 255]]

dataset_aliases = {
    'cityscapes': ['cityscapes'],
    'ade': ['ade', 'ade20k'],
    'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'],
    'loveda': ['loveda'],
    'potsdam': ['potsdam'],
    'vaihingen': ['vaihingen'],
    'cocostuff': [
        'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff',
        'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k',
        'coco_stuff164k'
    ],
    'isaid': ['isaid', 'iSAID'],
    'stare': ['stare', 'STARE'],
    'lip': ['LIP', 'lip'],
    'mapillary_v1': ['mapillary_v1'],
    'mapillary_v2': ['mapillary_v2'],
    'isic':['isic','isic_2018','ISIC'],
}


def get_classes(dataset):
    """Get class names of a dataset."""
    alias2name = {}
    for name, aliases in dataset_aliases.items():
        for alias in aliases:
            alias2name[alias] = name

    if is_str(dataset):
        if dataset in alias2name:
            labels = eval(alias2name[dataset] + '_classes()')
        else:
            raise ValueError(f'Unrecognized dataset: {dataset}')
    else:
        raise TypeError(f'dataset must a str, but got {type(dataset)}')
    return labels


def get_palette(dataset):
    """Get class palette (RGB) of a dataset."""
    alias2name = {}
    for name, aliases in dataset_aliases.items():
        for alias in aliases:
            alias2name[alias] = name

    if is_str(dataset):
        if dataset in alias2name:
            labels = eval(alias2name[dataset] + '_palette()')
        else:
            raise ValueError(f'Unrecognized dataset: {dataset}')
    else:
        raise TypeError(f'dataset must a str, but got {type(dataset)}')
    return labels

Overwriting /kaggle/working/mmsegmentation/mmseg/utils/class_names.py


In [9]:
%%writefile /kaggle/working/mmsegmentation/mmseg/models/backbones/pidnet.py
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmengine.runner import CheckpointLoader
from torch import Tensor

from mmseg.registry import MODELS
from mmseg.utils import OptConfigType
from ..utils import DAPPM, PAPPM, BasicBlock, Bottleneck


class PagFM(BaseModule):
    """Pixel-attention-guided fusion module.

    Args:
        in_channels (int): The number of input channels.
        channels (int): The number of channels.
        after_relu (bool): Whether to use ReLU before attention.
            Default: False.
        with_channel (bool): Whether to use channel attention.
            Default: False.
        upsample_mode (str): The mode of upsample. Default: 'bilinear'.
        norm_cfg (dict): Config dict for normalization layer.
            Default: dict(type='BN').
        act_cfg (dict): Config dict for activation layer.
            Default: dict(typ='ReLU', inplace=True).
        init_cfg (dict): Config dict for initialization. Default: None.
    """

    def __init__(self,
                 in_channels: int,
                 channels: int,
                 after_relu: bool = False,
                 with_channel: bool = False,
                 upsample_mode: str = 'bilinear',
                 norm_cfg: OptConfigType = dict(type='BN'),
                 act_cfg: OptConfigType = dict(typ='ReLU', inplace=True),
                 init_cfg: OptConfigType = None):
        super().__init__(init_cfg)
        self.after_relu = after_relu
        self.with_channel = with_channel
        self.upsample_mode = upsample_mode
        self.f_i = ConvModule(
            in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
        self.f_p = ConvModule(
            in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
        if with_channel:
            self.up = ConvModule(
                channels, in_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
        if after_relu:
            self.relu = MODELS.build(act_cfg)

    def forward(self, x_p: Tensor, x_i: Tensor) -> Tensor:
        """Forward function.

        Args:
            x_p (Tensor): The featrue map from P branch.
            x_i (Tensor): The featrue map from I branch.

        Returns:
            Tensor: The feature map with pixel-attention-guided fusion.
        """
        if self.after_relu:
            x_p = self.relu(x_p)
            x_i = self.relu(x_i)

        f_i = self.f_i(x_i)
        f_i = F.interpolate(
            f_i,
            size=x_p.shape[2:],
            mode=self.upsample_mode,
            align_corners=False)

        f_p = self.f_p(x_p)

        if self.with_channel:
            sigma = torch.sigmoid(self.up(f_p * f_i))
        else:
            sigma = torch.sigmoid(torch.sum(f_p * f_i, dim=1).unsqueeze(1))

        x_i = F.interpolate(
            x_i,
            size=x_p.shape[2:],
            mode=self.upsample_mode,
            align_corners=False)

        out = sigma * x_i + (1 - sigma) * x_p
        return out


class Bag(BaseModule):
    """Boundary-attention-guided fusion module.

    Args:
        in_channels (int): The number of input channels.
        out_channels (int): The number of output channels.
        kernel_size (int): The kernel size of the convolution. Default: 3.
        padding (int): The padding of the convolution. Default: 1.
        norm_cfg (dict): Config dict for normalization layer.
            Default: dict(type='BN').
        act_cfg (dict): Config dict for activation layer.
            Default: dict(type='ReLU', inplace=True).
        conv_cfg (dict): Config dict for convolution layer.
            Default: dict(order=('norm', 'act', 'conv')).
        init_cfg (dict): Config dict for initialization. Default: None.
    """

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int = 3,
                 padding: int = 1,
                 norm_cfg: OptConfigType = dict(type='BN'),
                 act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
                 conv_cfg: OptConfigType = dict(order=('norm', 'act', 'conv')),
                 init_cfg: OptConfigType = None):
        super().__init__(init_cfg)

        self.conv = ConvModule(
            in_channels,
            out_channels,
            kernel_size,
            padding=padding,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg,
            **conv_cfg)

    def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
        """Forward function.

        Args:
            x_p (Tensor): The featrue map from P branch.
            x_i (Tensor): The featrue map from I branch.
            x_d (Tensor): The featrue map from D branch.

        Returns:
            Tensor: The feature map with boundary-attention-guided fusion.
        """
        sigma = torch.sigmoid(x_d)
        return self.conv(sigma * x_p + (1 - sigma) * x_i)


class LightBag(BaseModule):
    """Light Boundary-attention-guided fusion module.

    Args:
        in_channels (int): The number of input channels.
        out_channels (int): The number of output channels.
        norm_cfg (dict): Config dict for normalization layer.
            Default: dict(type='BN').
        act_cfg (dict): Config dict for activation layer. Default: None.
        init_cfg (dict): Config dict for initialization. Default: None.
    """

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 norm_cfg: OptConfigType = dict(type='BN'),
                 act_cfg: OptConfigType = None,
                 init_cfg: OptConfigType = None):
        super().__init__(init_cfg)
        self.f_p = ConvModule(
            in_channels,
            out_channels,
            kernel_size=1,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg)
        self.f_i = ConvModule(
            in_channels,
            out_channels,
            kernel_size=1,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg)

    def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
        """Forward function.
        Args:
            x_p (Tensor): The featrue map from P branch.
            x_i (Tensor): The featrue map from I branch.
            x_d (Tensor): The featrue map from D branch.

        Returns:
            Tensor: The feature map with light boundary-attention-guided
                fusion.
        """
        sigma = torch.sigmoid(x_d)

        f_p = self.f_p((1 - sigma) * x_i + x_p)
        f_i = self.f_i(x_i + sigma * x_p)

        return f_p + f_i


@MODELS.register_module()
class PIDNet(BaseModule):
    """PIDNet backbone.

    This backbone is the implementation of `PIDNet: A Real-time Semantic
    Segmentation Network Inspired from PID Controller
    <https://arxiv.org/abs/2206.02066>`_.
    Modified from https://github.com/XuJiacong/PIDNet.

    Licensed under the MIT License.

    Args:
        in_channels (int): The number of input channels. Default: 3.
        channels (int): The number of channels in the stem layer. Default: 64.
        ppm_channels (int): The number of channels in the PPM layer.
            Default: 96.
        num_stem_blocks (int): The number of blocks in the stem layer.
            Default: 2.
        num_branch_blocks (int): The number of blocks in the branch layer.
            Default: 3.
        align_corners (bool): The align_corners argument of F.interpolate.
            Default: False.
        norm_cfg (dict): Config dict for normalization layer.
            Default: dict(type='BN').
        act_cfg (dict): Config dict for activation layer.
            Default: dict(type='ReLU', inplace=True).
        init_cfg (dict): Config dict for initialization. Default: None.
    """

    def __init__(self,
                 in_channels: int = 3,
                 channels: int = 64,
                 ppm_channels: int = 96,
                 num_stem_blocks: int = 2,
                 num_branch_blocks: int = 3,
                 align_corners: bool = False,
                 norm_cfg: OptConfigType = dict(type='BN'),
                 act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
                 init_cfg: OptConfigType = None,
                 **kwargs):
        super().__init__(init_cfg)
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        self.align_corners = align_corners

        # stem layer
        self.stem = self._make_stem_layer(in_channels, channels,
                                          num_stem_blocks)
        self.relu = nn.ReLU()

        # I Branch
        self.i_branch_layers = nn.ModuleList()
        for i in range(3):
            self.i_branch_layers.append(
                self._make_layer(
                    block=BasicBlock if i < 2 else Bottleneck,
                    in_channels=channels * 2**(i + 1),
                    channels=channels * 8 if i > 0 else channels * 4,
                    num_blocks=num_branch_blocks if i < 2 else 2,
                    stride=2))

        # P Branch
        self.p_branch_layers = nn.ModuleList()
        for i in range(3):
            self.p_branch_layers.append(
                self._make_layer(
                    block=BasicBlock if i < 2 else Bottleneck,
                    in_channels=channels * 2,
                    channels=channels * 2,
                    num_blocks=num_stem_blocks if i < 2 else 1))
        self.compression_1 = ConvModule(
            channels * 4,
            channels * 2,
            kernel_size=1,
            bias=False,
            norm_cfg=norm_cfg,
            act_cfg=None)
        self.compression_2 = ConvModule(
            channels * 8,
            channels * 2,
            kernel_size=1,
            bias=False,
            norm_cfg=norm_cfg,
            act_cfg=None)
        self.pag_1 = PagFM(channels * 2, channels)
        self.pag_2 = PagFM(channels * 2, channels)

        # D Branch
        if num_stem_blocks == 2:
            self.d_branch_layers = nn.ModuleList([
                self._make_single_layer(BasicBlock, channels * 2, channels),
                self._make_layer(Bottleneck, channels, channels, 1)
            ])
            channel_expand = 1
            spp_module = PAPPM
            dfm_module = LightBag
            act_cfg_dfm = None
        else:
            self.d_branch_layers = nn.ModuleList([
                self._make_single_layer(BasicBlock, channels * 2,
                                        channels * 2),
                self._make_single_layer(BasicBlock, channels * 2, channels * 2)
            ])
            channel_expand = 2
            spp_module = DAPPM
            dfm_module = Bag
            act_cfg_dfm = act_cfg

        self.diff_1 = ConvModule(
            channels * 4,
            channels * channel_expand,
            kernel_size=3,
            padding=1,
            bias=False,
            norm_cfg=norm_cfg,
            act_cfg=None)
        self.diff_2 = ConvModule(
            channels * 8,
            channels * 2,
            kernel_size=3,
            padding=1,
            bias=False,
            norm_cfg=norm_cfg,
            act_cfg=None)

        self.spp = spp_module(
            channels * 16, ppm_channels, channels * 4, num_scales=5)
        self.dfm = dfm_module(
            channels * 4, channels * 4, norm_cfg=norm_cfg, act_cfg=act_cfg_dfm)

        self.d_branch_layers.append(
            self._make_layer(Bottleneck, channels * 2, channels * 2, 1))

    def _make_stem_layer(self, in_channels: int, channels: int,
                         num_blocks: int) -> nn.Sequential:
        """Make stem layer.

        Args:
            in_channels (int): Number of input channels.
            channels (int): Number of output channels.
            num_blocks (int): Number of blocks.

        Returns:
            nn.Sequential: The stem layer.
        """

        layers = [
            ConvModule(
                in_channels,
                channels,
                kernel_size=3,
                stride=2,
                padding=1,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg),
            ConvModule(
                channels,
                channels,
                kernel_size=3,
                stride=2,
                padding=1,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg)
        ]

        layers.append(
            self._make_layer(BasicBlock, channels, channels, num_blocks))
        layers.append(nn.ReLU())
        layers.append(
            self._make_layer(
                BasicBlock, channels, channels * 2, num_blocks, stride=2))
        layers.append(nn.ReLU())

        return nn.Sequential(*layers)

    def _make_layer(self,
                    block: BasicBlock,
                    in_channels: int,
                    channels: int,
                    num_blocks: int,
                    stride: int = 1) -> nn.Sequential:
        """Make layer for PIDNet backbone.
        Args:
            block (BasicBlock): Basic block.
            in_channels (int): Number of input channels.
            channels (int): Number of output channels.
            num_blocks (int): Number of blocks.
            stride (int): Stride of the first block. Default: 1.

        Returns:
            nn.Sequential: The Branch Layer.
        """
        downsample = None
        if stride != 1 or in_channels != channels * block.expansion:
            downsample = ConvModule(
                in_channels,
                channels * block.expansion,
                kernel_size=1,
                stride=stride,
                norm_cfg=self.norm_cfg,
                act_cfg=None)

        layers = [block(in_channels, channels, stride, downsample)]
        in_channels = channels * block.expansion
        for i in range(1, num_blocks):
            layers.append(
                block(
                    in_channels,
                    channels,
                    stride=1,
                    act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
        return nn.Sequential(*layers)

    def _make_single_layer(self,
                           block: Union[BasicBlock, Bottleneck],
                           in_channels: int,
                           channels: int,
                           stride: int = 1) -> nn.Module:
        """Make single layer for PIDNet backbone.
        Args:
            block (BasicBlock or Bottleneck): Basic block or Bottleneck.
            in_channels (int): Number of input channels.
            channels (int): Number of output channels.
            stride (int): Stride of the first block. Default: 1.

        Returns:
            nn.Module
        """

        downsample = None
        if stride != 1 or in_channels != channels * block.expansion:
            downsample = ConvModule(
                in_channels,
                channels * block.expansion,
                kernel_size=1,
                stride=stride,
                norm_cfg=self.norm_cfg,
                act_cfg=None)
        return block(
            in_channels, channels, stride, downsample, act_cfg_out=None)

    def init_weights(self):
        """Initialize the weights in backbone.

        Since the D branch is not initialized by the pre-trained model, we
        initialize it with the same method as the ResNet.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        if self.init_cfg is not None:
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            ckpt = CheckpointLoader.load_checkpoint(
                self.init_cfg['checkpoint'], map_location='cpu')
            self.load_state_dict(ckpt, strict=False)

    def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]:
        """Forward function.

        Args:
            x (Tensor): Input tensor with shape (B, C, H, W).

        Returns:
            Tensor or tuple[Tensor]: If self.training is True, return
                tuple[Tensor], else return Tensor.
        """
        w_out = x.shape[-1] // 8
        h_out = x.shape[-2] // 8

        # stage 0-2
        x = self.stem(x)

        # stage 3
        x_i = self.relu(self.i_branch_layers[0](x))
        x_p = self.p_branch_layers[0](x)
        x_d = self.d_branch_layers[0](x)

        comp_i = self.compression_1(x_i)
        x_p = self.pag_1(x_p, comp_i)
        diff_i = self.diff_1(x_i)
        x_d += F.interpolate(
            diff_i,
            size=x_d.shape[2:],
            mode='bilinear',
            align_corners=self.align_corners)
        if self.training:
            temp_p = x_p.clone()

        # stage 4
        x_i = self.relu(self.i_branch_layers[1](x_i))
        x_p = self.p_branch_layers[1](self.relu(x_p))
        x_d = self.d_branch_layers[1](self.relu(x_d))

        comp_i = self.compression_2(x_i)
        x_p = self.pag_2(x_p, comp_i)
        diff_i = self.diff_2(x_i)
        x_d += F.interpolate(
            diff_i,
            size=x_d.shape[2:],
            mode='bilinear',
            align_corners=self.align_corners)
        if self.training:
            temp_d = x_d.clone()

        # stage 5
        x_i = self.i_branch_layers[2](x_i)
        x_p = self.p_branch_layers[2](self.relu(x_p))
        x_d = self.d_branch_layers[2](self.relu(x_d))

        x_i = self.spp(x_i)
        x_i = F.interpolate(
            x_i,
            size=x_d.shape[2:],
            mode='bilinear',
            align_corners=self.align_corners)
        out = self.dfm(x_p, x_i, x_d)
        return (temp_p, out, temp_d) if self.training else out

Overwriting /kaggle/working/mmsegmentation/mmseg/models/backbones/pidnet.py


In [10]:
%%writefile /kaggle/working/config.py
_base_ = [
    '/kaggle/working/mmsegmentation/configs/_base_/datasets/isic_dataset.py',
    '/kaggle/working/mmsegmentation/configs/_base_/default_runtime.py'
]

# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa
# Licensed under the MIT License
class_weight =[1.0, 1.0]
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-s_imagenet1k_20230306-715e6273.pth'  # noqa
crop_size = (1024, 1024)
data_preprocessor = dict(
    type='SegDataPreProcessor',
    mean=[123.675, 116.28, 103.53],
    std=[58.395, 57.12, 57.375],
    bgr_to_rgb=True,
    pad_val=0,
    seg_pad_val=255,
    size=crop_size)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    data_preprocessor=data_preprocessor,
    backbone=dict(
        type='PIDNet',
        in_channels=3,
        channels=32,
        ppm_channels=96,
        num_stem_blocks=2,
        num_branch_blocks=3,
        align_corners=False,
        norm_cfg=norm_cfg,
        act_cfg=dict(type='ReLU', inplace=True),
        init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)),
    decode_head=dict(
        type='PIDHead',
        in_channels=128,
        channels=128,
        num_classes=2,
        norm_cfg=norm_cfg,
        act_cfg=dict(type='ReLU', inplace=True),
        align_corners=False,
        loss_decode=[
            dict(
                type='CrossEntropyLoss',
                use_sigmoid=False,
                class_weight=class_weight,
                loss_weight=0.4),
            dict(
                type='OhemCrossEntropy',
                thres=0.9,
                min_kept=131072,
                class_weight=class_weight,
                loss_weight=1.0),
            dict(type='BoundaryLoss', loss_weight=20.0),
            dict(
                type='OhemCrossEntropy',
                thres=0.9,
                min_kept=131072,
                class_weight=class_weight,
                loss_weight=1.0)
        ]),
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))

iters = 1500
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
    dict(
        type='PolyLR',
        eta_min=0,
        power=0.9,
        begin=0,
        end=iters,
        by_epoch=False)
]
# training schedule for 10000
train_cfg = dict(
    type='IterBasedTrainLoop', max_iters=iters, val_interval=iters) 
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(
        type='CheckpointHook', by_epoch=False, interval=iters), 
    sampler_seed=dict(type='DistSamplerSeedHook'),
    visualization=dict(type='SegVisualizationHook'))

randomness = dict(seed=304)
work_dir = "/kaggle/working/"

Writing /kaggle/working/config.py


In [11]:
from mmengine.config import Config, DictAction
cfg = Config.fromfile('/kaggle/working/config.py')

In [12]:
from mmengine.runner import Runner
runner = Runner.from_cfg(cfg)

05/11 10:19:33 - mmengine - [4m[97mINFO[0m - 
------------------------------------------------------------
System environment:
    sys.platform: linux
    Python: 3.10.10 | packaged by conda-forge | (main, Mar 24 2023, 20:08:06) [GCC 11.3.0]
    CUDA available: True
    numpy_random_seed: 304
    GPU 0: Tesla P100-PCIE-16GB
    CUDA_HOME: /opt/conda
    NVCC: Cuda compilation tools, release 12.1, V12.1.105
    GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
    PyTorch: 1.12.0+cu113
    PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.3
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-g

  warn(


05/11 10:19:34 - mmengine - [4m[97mINFO[0m - Config:
dataset_type = 'ISICDataset'
data_root = '/kaggle/input/isic-2018/data'
img_scale = (2048, 1024)
crop_size = (1024, 1024)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='ConvertPixel'),
    dict(
        type='RandomResize',
        scale=(2048, 1024),
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='GenerateEdge', edge_width=4),
    dict(type='PackSegInputs')
]
val_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
    dict(type='LoadAnnotations'),
    dict(type='ConvertPixel'),
    dict(type='PackSegInputs')
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
    dict(type='LoadA



05/11 10:19:39 - mmengine - [4m[97mINFO[0m - Distributed training is not used, all SyncBatchNorm (SyncBN) layers in the model will be automatically reverted to BatchNormXd layers if they are used.
05/11 10:19:39 - mmengine - [4m[97mINFO[0m - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) RuntimeInfoHook                    
(BELOW_NORMAL) LoggerHook                         
 -------------------- 
before_train:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
(VERY_LOW    ) CheckpointHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
(NORMAL      ) DistSamplerSeedHook                
 -------------------- 
before_train_iter:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
 -------------------- 
after_train_iter:
(VERY_HIGH   ) Runti



In [13]:
runner.train()



05/11 10:20:08 - mmengine - [4m[97mINFO[0m - Loads checkpoint by http backend from path: https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-s_imagenet1k_20230306-715e6273.pth


Downloading: "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-s_imagenet1k_20230306-715e6273.pth" to /root/.cache/torch/hub/checkpoints/pidnet-s_imagenet1k_20230306-715e6273.pth


05/11 10:20:11 - mmengine - [4m[97mINFO[0m - Checkpoints will be saved to /kaggle/working.
05/11 10:21:08 - mmengine - [4m[97mINFO[0m - Iter(train) [  50/1500]  lr: 9.7053e-03  eta: 0:27:40  time: 1.5389  data_time: 1.2797  memory: 10793  loss: 1.9240  decode.loss_sem_p: 0.1300  decode.loss_sem_i: 0.8174  decode.loss_bd: 0.1334  decode.loss_sem_bd: 0.8432  decode.acc_seg: 58.4813
05/11 10:21:54 - mmengine - [4m[97mINFO[0m - Iter(train) [ 100/1500]  lr: 9.4036e-03  eta: 0:24:02  time: 1.1438  data_time: 0.8682  memory: 2887  loss: 1.3212  decode.loss_sem_p: 0.0474  decode.loss_sem_i: 0.4989  decode.loss_bd: 0.1116  decode.loss_sem_bd: 0.6633  decode.acc_seg: 88.0873
05/11 10:22:35 - mmengine - [4m[97mINFO[0m - Iter(train) [ 150/1500]  lr: 9.1008e-03  eta: 0:21:40  time: 0.8204  data_time: 0.5407  memory: 2889  loss: 1.4577  decode.loss_sem_p: 0.0569  decode.loss_sem_i: 0.6560  decode.loss_bd: 0.1266  decode.loss_sem_bd: 0.6183  decode.acc_seg: 79.4621
05/11 10:23:18 - mmengi



05/11 10:43:17 - mmengine - [4m[97mINFO[0m - Iter(val) [ 50/100]    eta: 0:00:17  time: 0.0999  data_time: 0.0213  memory: 8735  
05/11 10:43:30 - mmengine - [4m[97mINFO[0m - Iter(val) [100/100]    eta: 0:00:00  time: 0.1672  data_time: 0.0136  memory: 8526  
05/11 10:43:30 - mmengine - [4m[97mINFO[0m - per class results:
05/11 10:43:30 - mmengine - [4m[97mINFO[0m - 
+------------+-------+-------+
|   Class    |  IoU  |  Acc  |
+------------+-------+-------+
| background | 82.39 | 89.87 |
|    mask    | 55.26 | 72.31 |
+------------+-------+-------+
05/11 10:43:30 - mmengine - [4m[97mINFO[0m - Iter(val) [100/100]    aAcc: 85.5300  mIoU: 68.8200  mAcc: 81.0900  data_time: 0.0481  time: 0.2948


EncoderDecoder(
  (data_preprocessor): SegDataPreProcessor()
  (backbone): PIDNet(
    (stem): Sequential(
      (0): ConvModule(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): _BatchNormXd(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
      (1): ConvModule(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): _BatchNormXd(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
      (2): Sequential(
        (0): BasicBlock(
          (conv1): ConvModule(
            (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (activate): ReLU(inplace=True)
          )
          (conv2): ConvModule(
            (co

In [14]:
runner.test()



05/11 10:44:08 - mmengine - [4m[97mINFO[0m - Iter(test) [  50/1000]    eta: 0:10:44  time: 1.4842  data_time: 0.0767  memory: 801  
05/11 10:45:21 - mmengine - [4m[97mINFO[0m - Iter(test) [ 100/1000]    eta: 0:16:01  time: 1.3933  data_time: 0.0640  memory: 8742  
05/11 10:45:54 - mmengine - [4m[97mINFO[0m - Iter(test) [ 150/1000]    eta: 0:13:15  time: 0.0344  data_time: 0.0038  memory: 805  
05/11 10:45:56 - mmengine - [4m[97mINFO[0m - Iter(test) [ 200/1000]    eta: 0:09:28  time: 0.0371  data_time: 0.0043  memory: 200  
05/11 10:45:58 - mmengine - [4m[97mINFO[0m - Iter(test) [ 250/1000]    eta: 0:07:12  time: 0.0369  data_time: 0.0037  memory: 200  
05/11 10:46:00 - mmengine - [4m[97mINFO[0m - Iter(test) [ 300/1000]    eta: 0:05:40  time: 0.0370  data_time: 0.0037  memory: 200  
05/11 10:46:02 - mmengine - [4m[97mINFO[0m - Iter(test) [ 350/1000]    eta: 0:04:35  time: 0.0668  data_time: 0.0089  memory: 200  
05/11 10:46:04 - mmengine - [4m[97mINFO[0m - Iter(t

{}