# Model inference setup

In [1]:
def label2idx(file):
    w = file.split("_")[-1]
    return "0" if w[0]=='A' else "3" if w[0]=='G' else w[1]

from mmaction.models import build_model
from mmcv import Config
# from mmcv.runner import set_random_seed
import shutil
import os

def val_xdv(filename):
    cfg_file = f'configs/recognition/swin/swin_small_patch244_window877_xdviolence_k400_1k.py'
    checkpoint_file = f'work_dirs/xdv/best_top1_acc_epoch_15.pth'

    cfg = Config.fromfile(cfg_file)
    cfg.model.cls_head.num_classes = 7

    # model = init_recognizer(cfg_file, checkpoint_file, device=torch.device('cuda:0'))  # or 'cpu'
    # Build the recognizer
    from mmaction.models import build_model
    model = build_model(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
    

    print('load_checkpoint: ', checkpoint_file)
    from mmcv.runner import load_checkpoint
    load_checkpoint(model, checkpoint_file)
        
    from mmaction.apis import single_gpu_test
    from mmaction.datasets import build_dataloader, build_dataset
    from mmcv.parallel import MMDataParallel

    # Build a test dataloader
    temp = '/home/ubuntu/swin-data/Video-Swin-Transformer/data/temp'
    print(cfg.data.test.data_prefix)
    cmd = f"cp \"/home/ubuntu/swin-data/Video-Swin-Transformer/{cfg.data.test['data_prefix']}{filename}\" \"{temp}\""
    shutil.copy(f"/home/ubuntu/swin-data/Video-Swin-Transformer/{cfg.data.test['data_prefix']}{filename}", temp)
    cfg.data.test['data_prefix']=temp+'/'
    cfg.data.test['ann_file']=temp+'/temp.txt'
    with open(cfg.data.test.ann_file, 'w') as f:
        f.write(f"{filename} {label2idx(filename)}")
    
    dataset = build_dataset(cfg.data.test, dict(test_mode=True))
    print(dataset)
    data_loader = build_dataloader(
            dataset,
            videos_per_gpu=1,
            workers_per_gpu=cfg.data.workers_per_gpu,
            dist=False,
            shuffle=False)
    model = MMDataParallel(model, device_ids=[0])
    outputs = single_gpu_test(model, data_loader)
    os.remove(temp+'/'+filename)
    os.remove(cfg.data.test.ann_file)
    
    return outputs

def eval_xdv(outputs, cfg):
    eval_config = cfg.evaluation
    eval_config.pop('interval')
    eval_res = dataset.evaluate(outputs, **eval_config)
    for name, val in eval_res.items():
        print(f'{name}: {val:.04f}')            

# Demo (gradio)

In [2]:
# !pip install gradio
import gradio as gr
import os
filepath = '/home/ubuntu/swin-data/Video-Swin-Transformer/data/xd-violence/test12/'
x = os.listdir(filepath)

def get_inference(filename):
    filename=filename.split('/')[-1]
    filename=filename[:-44].replace('__', '__#')+filename[-4:]
    if filename[0]=='v':  filename = 'v='+filename[1:]
    print(filename)
    outputs = val_xdv(filename)
    import matplotlib.pyplot as plt
    fig = plt.figure()
    plt.title(filename)
    plt.bar(range(7), outputs[0], tick_label=['A:\nNormal', 'B1:\nFighting', 'B2:\nShooting', 'G:\nExploration', 
                                               'B4:\nRiot', 'B5:\nAbuse', 'B6:\nCar accident'])
    plt.grid()
    return fig  # filepath+filename, fig


demo = gr.Interface(
    fn=get_inference, 
    inputs=gr.Video(), # gr.Radio(x, label="Input Video"),
    outputs=gr.Plot()) #["video", gr.Plot()])
    # examples=x, cache_examples=True)

demo.launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




Deadpool_2_2018__#0-04-46_0-05-01_label_B2-0-0.mp4
load_checkpoint:  work_dirs/xdv/best_top1_acc_epoch_15.pth
load checkpoint from local path: work_dirs/xdv/best_top1_acc_epoch_15.pth
data/xd-violence/test12/
<mmaction.datasets.video_dataset.VideoDataset object at 0x7fd76a151210>
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 1/1, 0.4 task/s, elapsed: 2s, ETA:     0sSin_City_2005__#0-22-04_0-22-18_label_B5-0-0.mp4
load_checkpoint:  work_dirs/xdv/best_top1_acc_epoch_15.pth
load checkpoint from local path: work_dirs/xdv/best_top1_acc_epoch_15.pth
data/xd-violence/test12/
<mmaction.datasets.video_dataset.VideoDataset object at 0x7fd768240690>
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 1/1, 0.6 task/s, elapsed: 2s, ETA:     0sv=yDqThVpu1AM__#1_label_B4-0-0.mp4
load_checkpoint:  work_dirs/xdv/best_top1_acc_epoch_15.pth
load checkpoint from local path: work_dirs/xdv/best_top1_acc_epoch_15.pth
data/xd-violence/test12/
<mmaction.datasets.video_dataset.VideoDataset object at 0x7fd768226c50>
[>>>>>>>>>>>>>>>>