# Install requirements

In [None]:
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.10.0/index.html

In [None]:
!git clone https://github.com/Slava-git/mmsegmentation_swin
%cd mmsegmentation_swin
!pip install -e .

# Import dependencies

In [None]:
import torch
import torchvision
import mmcv
import cv2
import matplotlib.pyplot as plt
import os.path as osp
import numpy as np

from PIL import Image
from mmcv import Config
from os import listdir,makedirs

import mmseg
from mmseg.apis import set_random_seed, train_segmentor
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from mmseg.datasets import build_dataset
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset
from mmseg.models import build_segmentor

# Connect to google drive 

In [None]:
%cd ../

/content


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Download pretrained weights

In [None]:
!mkdir checkpoints
!wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth -P checkpoints

# Convert masks into label format

In [None]:
path = '/content/drive/MyDrive/data/full_body_tik_tok/annotations/validation'
dstpath = '/content/drive/MyDrive/data/full_body_tik_tok/annotations/validation_1D'


In [None]:
def convert_to_grayscale(path, dstpath):
  '''Convert 3 channel images to grayscale

  params:
    path (str) - source folder
    dstpath (str) - destination folder
  '''
  
  try:
    makedirs(dstpath)
  except:
      print ("Directory already exist, images will be written in asme folder")

  files = [f for f in listdir(path) if osp.isfile(osp.join(path,f))] 

  for image in files:
      try:
          img = cv2.imread(osp.join(path,image))
          gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
          dstPath = osp.join(dstpath,image)
          cv2.imwrite(dstPath,gray)
      except:
          print ("{} is not converted".format(image))

In [None]:
path = '/content/drive/MyDrive/data/full_body_tik_tok/annotations/training_1D' 

In [None]:
def graysk_to_label(path):
  '''Convert grayscale images to labels

  params:
    path (str) - source folder
  '''
  
  files = [f for f in listdir(path) if osp.isfile(osp.join(path,f))] 

  for image in files:
    img = cv2.imread(osp.join(path,image), cv2.IMREAD_UNCHANGED)
    arr = np.array(img)
    arr[arr == 255] = 1
    im = Image.fromarray(arr)
    im.save(osp.join(path, image))

# Training

## Register dataset

In [None]:
classes = ('Background', 'Person')
pallete = [[0, 0, 0], [0, 128, 0]]

In [None]:
@DATASETS.register_module()
class FullBodyDataset(CustomDataset):
  CLASSES = classes
  PALETTE = pallete
  def __init__(self, **kwargs):
    super().__init__(img_suffix='.png', seg_map_suffix='.png',
                    **kwargs)
    assert osp.exists(self.img_dir)

## Create config file

In [None]:
cfg = Config.fromfile('mmsegmentation_swin/configs/swin/'\
                      'upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py')

In [None]:
cfg.checkpoint_config.meta = dict(
    CLASSES= classes,
    PALETTE= pallete)

cfg.norm_cfg = dict(type='BN', requires_grad=True)
#cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg

cfg.model.decode_head.num_classes = 2
cfg.model.auxiliary_head.num_classes = 2
dataset_type = 'FullBodyDataset'

cfg.dataset_type = dataset_type
cfg.data_root = '/content/drive/MyDrive/data/full_body_tik_tok'

cfg.data.samples_per_gpu = 4
cfg.data.workers_per_gpu = 4

cfg.img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
cfg.crop_size = (256, 256)
cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(540, 960), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **cfg.img_norm_cfg),
    dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]

cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(540, 960),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **cfg.img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]

cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.img_dir = 'images/training'
cfg.data.train.ann_dir = 'annotations/training_1D'
cfg.data.train.pipeline = cfg.train_pipeline

cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.img_dir = 'images/validation'
cfg.data.val.ann_dir = 'annotations/validation_1D'
cfg.data.val.pipeline = cfg.test_pipeline

cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
cfg.data.test.img_dir = 'images/validation'
cfg.data.test.ann_dir = 'annotations/validation_1D'
cfg.data.test.pipeline = cfg.test_pipeline

cfg.load_from = '/content/checkpoints/swin_tiny_patch4_window7_224.pth'
cfg.work_dir = '/content/drive/MyDrive/data/swin_dirs'

cfg.runner.max_iters = 30000
cfg.log_config.interval = 200
cfg.evaluation.interval = 3000
cfg.checkpoint_config.interval = 3000

cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

## Train 


In [None]:
def run_training(cfg):
  '''Start training based on config file

  params:
    cfg - config file
  '''  
  
  datasets = [build_dataset(cfg.data.train)]

  model = build_segmentor(
      cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))

  model.CLASSES = datasets[0].CLASSES

  mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
  train_segmentor(model, datasets, cfg, distributed=False, validate=True, 
                  meta=dict())

In [None]:
run_training(cfg)

# Inference

In [None]:
checkpoint_file = '/content/drive/MyDrive/data/work_dirs/iter_200.pth'

In [None]:
def run_inference(checkpoint_file, config_file, image):
  '''Predict segmentation and draw it

  params:
    checkpoint_file (str) - file with weights
    config_file - config file
    image (str) - path to image
  '''
  
  model = init_segmentor(config_file, checkpoint_file, device='cuda:0')
  result = inference_segmentor(model, image)
  show_result_pyplot(model, image, result, pallete)

In [None]:
run_inference(checkpoint_file, cfg, 
              '/content/drive/MyDrive/data/full_body_tik_tok/images/validation/0_00030.png')