# Pretrained Segmentation Model predict Video

## Bash

In [1]:
# 不删除逐帧的预测
# !python3 C_pretrained_models_for_video.py --temp_dir_delete False

## Process

1. 本项目输出的文件在根目录下

In [2]:
import os
import numpy as np
import time
import shutil

import torch

from PIL import Image
import cv2

import mmcv
import mmengine
from mmseg.apis import init_model, inference_model
from mmseg.utils import register_all_modules
register_all_modules()

from mmseg.datasets import CityscapesDataset

# os.chdir('mmsegmentation')

[2023-09-17 22:46:18,701] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


  warn(f"Failed to load image Python extension: {e}")


In [3]:
def predict_single_frame(model, img, palette, opacity=0.2):
    
    result = inference_model(model, img)
    
    # 将分割图按调色板染色
    seg_map = np.array(result.pred_sem_seg.data[0].detach().cpu().numpy()).astype('uint8')
    seg_img = Image.fromarray(seg_map).convert('P')
    seg_img.putpalette(np.array(palette, dtype=np.uint8))
    
    show_img = (np.array(seg_img.convert('RGB')))*(1-opacity) + img*opacity
    
    return show_img

In [4]:
def main(model, dataset, input_video):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    if dataset == 'cityscapes':
        from mmseg.datasets import cityscapes
        classes = cityscapes.CityscapesDataset.METAINFO['classes']
        palette = cityscapes.CityscapesDataset.METAINFO['palette']
        
        if model == 'segformer':
            # 模型 config 配置文件
            config_file = './mmsegmentation/configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py'
            # 模型 checkpoint 权重文件
            checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_8x1_1024x1024_160k_cityscapes/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth'
        elif model == 'mask2former':
            config_file = './mmsegmentation/configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'
            checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221203_045030-9a86a225.pth'
    
    elif dataset == 'ADE20K':
        from mmseg.datasets import ade
        classes = ade.ADE20KDataset.METAINFO['classes']
        palette = ade.ADE20KDataset.METAINFO['palette']
        
        if model == 'mask2former':
            config_file = './mmsegmentation/configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py'
            checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640_20221203_235230-7ec0f569.pth'
    
    model = init_model(config_file, checkpoint_file, device=device)
    
    temp_out_dir = time.strftime('%Y%m%d%H%M%S')
    os.mkdir(temp_out_dir)
    print('创建临时文件夹 {} 用于存放每帧预测结果'.format(temp_out_dir))
    
    # 读入待预测视频
    imgs = mmcv.VideoReader(input_video)

    prog_bar = mmengine.ProgressBar(len(imgs))

    # 对视频逐帧处理
    for frame_id, img in enumerate(imgs):
        
        ## 处理单帧画面
        show_img = predict_single_frame(model, img, palette, opacity=0.15)
        # 保存语义分割预测结果图像至临时文件夹
        temp_path = f'{temp_out_dir}/{frame_id:06d}.jpg' 
        cv2.imwrite(temp_path, show_img)

        prog_bar.update() # 更新进度条

    # 把每一帧串成视频文件
    mmcv.frames2video(temp_out_dir, './outputs/'+f'out_{model}_{dataset}'+'.mp4', fps=imgs.fps, fourcc='mp4v')

    shutil.rmtree(temp_out_dir) # 删除存放每帧画面的临时文件夹
    print('删除临时文件夹', temp_out_dir)

In [5]:
dataset = 'cityscapes'
model = 'segformer'

if dataset == 'cityscapes':
    # input_video = 'data/traffic.mp4'
    # input_video = 'data/street_20220330_174028.mp4'
    input_video = 'data/street_5s.mp4'  # mydata-street
elif dataset == 'ADE20K':
    # input_video = 'data/Library_8s.mp4'
    input_video = 'data/Library_5s.mp4'
    
main(model=model, dataset=dataset, input_video=input_video)



Loads checkpoint by http backend from path: https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_8x1_1024x1024_160k_cityscapes/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth
创建临时文件夹 20230917224622 用于存放每帧预测结果
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 168/168, 68.4 task/s, elapsed: 2s, ETA:     0s[                                                  ] 0/168, elapsed: 0s, ETA:
删除临时文件夹 20230917224622
