# Install requirements

In [None]:
!pip install flask_ngrok
!pip install pyngrok

In [None]:
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.10.0/index.html

In [None]:
!git clone https://github.com/Slava-git/mmsegmentation_swin
%cd mmsegmentation_swin
!pip install -e .

# Import dependencies

In [46]:
import torch
import torchvision
import mmcv
import cv2
import matplotlib.pyplot as plt
import os.path as osp
import numpy as np
import flask
import time

from flask import Flask, render_template, request
from flask_ngrok import run_with_ngrok
from mmcv import Config

import mmseg
from mmseg.apis import set_random_seed
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset

# Register dataset

In [5]:
classes = ('Background', 'Person')
palette = [[0, 0, 0], [0, 128, 0]]

In [6]:
@DATASETS.register_module()
class FullBodyDataset(CustomDataset):
  CLASSES = classes
  PALETTE = palette
  def __init__(self, **kwargs):
    super().__init__(img_suffix='.png', seg_map_suffix='.png',
                    **kwargs)
    assert osp.exists(self.img_dir)

# Connect to google drive

In [None]:
%cd ../

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd mmsegmentation_swin

# Config

In [11]:
cfg = Config.fromfile('configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py')

In [12]:
cfg.checkpoint_config.meta = dict(
    CLASSES= classes,
    PALETTE= palette)

cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg

cfg.model.decode_head.num_classes = 2
cfg.model.auxiliary_head.num_classes = 2
dataset_type = 'FullBodyDataset'

cfg.dataset_type = dataset_type
cfg.data_root = '/content/drive/MyDrive/data/full_body_tik_tok'

cfg.data.samples_per_gpu = 8
cfg.data.workers_per_gpu = 8

cfg.img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
cfg.crop_size = (256, 256)
cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(540, 960), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **cfg.img_norm_cfg),
    dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]

cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(540, 960),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **cfg.img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]

cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.img_dir = 'images/training'
cfg.data.train.ann_dir = 'annotations/training_1D'
cfg.data.train.pipeline = cfg.train_pipeline

cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.img_dir = 'images/validation'
cfg.data.val.ann_dir = 'annotations/validation_1D'
cfg.data.val.pipeline = cfg.test_pipeline

cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
cfg.data.test.img_dir = 'images/validation'
cfg.data.test.ann_dir = 'annotations/validation_1D'
cfg.data.test.pipeline = cfg.test_pipeline

cfg.load_from = '/content/drive/MyDrive/data/work_dirs/iter_20000.pth'
cfg.work_dir = '/content/drive/MyDrive/data/work_dirs'

cfg.runner.max_iters = 400
cfg.log_config.interval = 50
cfg.evaluation.interval = 400
cfg.checkpoint_config.interval = 400

cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

# Flask 

In [13]:
checkpoint_file = '/content/drive/MyDrive/data/work_dirs/iter_400.pth'

In [14]:
def make_predict(checkpoint_file, config_file, image):
  '''
  Get segmentation on input image

  params:
    checkpoint_file - file with weights
    config_file - config file
    image - path to image
  
  Returns:
    predicted image
  '''

  model = init_segmentor(config_file, checkpoint_file, device='cuda:0')
  result = inference_segmentor(model, image)
  img = model.show_result(
        image, result, palette=palette, show=False, opacity=0.5)
  return img

In [47]:
image_folder = osp.join('static', 'images')
app = Flask(__name__)
run_with_ngrok(app)
app.config["UPLOAD_FOLDER"] = image_folder

In [48]:
@app.route('/', methods=['GET'])
def home():
  return render_template('index.html')

@app.route('/', methods=['POST'])
def predict():

  imagefile = request.files['imagefile']
  image_path = osp.join(image_folder, imagefile.filename)

  imagefile.save(image_path)

  input_im = osp.join(app.config['UPLOAD_FOLDER'], imagefile.filename)
  predicted_im = make_predict(checkpoint_file, cfg, input_im)

  path_pred_im = osp.join(app.config['UPLOAD_FOLDER'], 'out_'+imagefile.filename)
  cv2.imwrite(path_pred_im, predicted_im)

  time.sleep(1)
  
  return render_template('index.html', input_image=input_im, 
                         output_image= path_pred_im)

In [None]:
if __name__=='__main__':
  app.run()

In [17]:
# put your token
!ngrok authtoken 25OVPfANFlrrKi0P95hTvzDxUKG_4Gv7J1btSjGuCTRLTAZoN

Authtoken saved to configuration file: /root/.ngrok2/ngrok.yml
