In [None]:
!./setup.sh

In [None]:
!python -W ignore main.py   --device_ids 0 \
                            --pretrained_model ../checkpoint/50_wp.pt \
                            --saved_path ../checkpoint/0616/01/

In [None]:
!python -W ignore inference.py

In [None]:
!python -W ignore end2end_main.py

In [None]:
# ensemble 模型在验证集上测试
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
import yaml
import json
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import utils.train_util as train_util
from dataloader.dataloader import TestingDataset
from src.loss.loss_compute import SimpleLossCompute
from src.model.baseline_model import Baseline
from src.loop.run_epoch import training_loop,validating_loop
from dataloader.dataloader import MultimodaFeaturesDataset,Datasetfortextcnn
from torch.utils.data import DataLoader
batch_size = 8
modal_name_list = ['video','audio','text']
config_path = './config/config.yaml'
config = yaml.load(open(config_path))
dataset = MultimodaFeaturesDataset(config['DatasetConfig'],job='valdation')
loader = DataLoader(dataset,num_workers=8,
                    batch_size=batch_size,
                    pin_memory=False,
                    collate_fn=dataset.collate_fn)

model_path_1 = '../checkpoint/0609/01/epoch_48 0.7888.pt'
model_path_2 = '../checkpoint/0609/01/epoch_30 0.7886.pt'
model_path_3 = '../checkpoint/0609/01/epoch_24 0.7871.pt'

model_path_4 = '../checkpoint/0609/02/epoch_46 0.7868.pt'
model_path_5 = '../checkpoint/0609/02/epoch_28 0.7856.pt'

model_path_6 = '../checkpoint/0608/01/epoch_86 0.7877.pt'
model_path_7 = '../checkpoint/0608/01/epoch_50 0.7873.pt'
model_path_8 = '../checkpoint/0608/01/epoch_28 0.7869.pt'

model_path_9 = '../checkpoint/0608/03/epoch_28 0.7877.pt'
model_path_10 = '../checkpoint/0608/03/epoch_52 0.7890.pt'
model_path_11 = '../checkpoint/0608/03/epoch_74 0.7896.pt'
models_path = [model_path_1,
               model_path_4,
               model_path_6,
               model_path_10,model_path_11]
model_weights = [0.2,0.1,0.2,0.25,0.25] #0.791
# model_weights = [0.1,0.1,0.1,0.35,0.35] # 0.789
# model_weights = [0.2,0.2,0.2,0.2,0.2] # 0.7911
#model_weights = np.array(np.random.random(11))
#model_weights = model_weights/sum(model_weights)
device = 'cuda'
top_k=20
# output_json = './0604_resnet_ensemble.json'
models = []
for path in models_path:
    if(path.split('/')[2]+path.split('/')[3]=='060801'):
        config['ModelConfig']['fusion_head_params']['concat_feat_dim']['fusion'] = 30720
        config['ModelConfig']['audio_head_params']['max_frames'] = 300
    else:
        config['ModelConfig']['fusion_head_params']['concat_feat_dim']['fusion'] = 29696
        config['ModelConfig']['audio_head_params']['max_frames'] = 200
    model = Baseline(config['ModelConfig'])
    model.load_state_dict(torch.load(path))
    model.to(device)
    model.eval()
    models.append(model)

tagging_class_num = 82
evl_metrics = [train_util.EvaluationMetrics(tagging_class_num, top_k=20)
                           for i in range(len(modal_name_list)+1)] #+1 for fusion
for i in range(len(evl_metrics)):
    evl_metrics[i].clear()
metric_dict = {}
gap_dict = {}
with torch.no_grad():
    for i,batch in tqdm(enumerate(loader)):
        if(len(batch)==5):
            video,audio,text,text_mask,label = batch
            video = video.to(device)
            audio = audio.to(device)
            text = text.to(device)
            text_mask = text_mask.to(device)
            label = label.to(device)
        else:
            video,audio,text,label = batch
            video = video.to(device)
            audio = audio.to(device)
            text = text.to(device)
            label = label.to(device)

        inputs_dict={}
        inputs_dict['video'] = video
        inputs_dict['audio'] = audio
        inputs_dict['text'] = text 
        if(len(batch)==5):
            inputs_dict['attention_mask'] = text_mask
        else:
            inputs_dict['attention_mask'] = None

        B = video.shape[0]
        pred_dict_ensemble = {}
        for modal_name in (modal_name_list+['fusion']):
            pred_dict_ensemble['tagging_output_'+modal_name] = {}
            pred_dict_ensemble['tagging_output_'+modal_name]['predictions'] = torch.zeros(B,82).cuda()

        for i,model in enumerate(models):
            pred_dict = model(inputs_dict)
            for modal_name in (modal_name_list+['fusion']):
                pred_dict_ensemble['tagging_output_'+modal_name]['predictions'] += model_weights[i]*pred_dict['tagging_output_'+modal_name]['predictions']
        '''
        for modal_name in (modal_name_list+['fusion']):
            pred_dict_ensemble['tagging_output_'+modal_name]['predictions'] = pred_dict_ensemble['tagging_output_'+modal_name]['predictions']/len(models)
        '''
        for index,modal_name in enumerate(modal_name_list+['fusion']):
            pred = pred_dict_ensemble['tagging_output_'+modal_name]
            pred = pred['predictions'].detach().cpu().numpy()
            val_label = label.cpu().numpy()
            gap = train_util.calculate_gap(pred, val_label)
            evl_metrics[index].accumulate(pred, val_label, loss=0)
    for index,modal_name in enumerate(modal_name_list+['fusion']):
        metric_dict[modal_name] = evl_metrics[index].get()
        gap_dict[modal_name] = metric_dict[modal_name]['gap']
    print(gap_dict)



In [None]:
import os

dataset_root = '../dataset/videos/video_5k/train_5k/'

# ########## get train_5k_A video file lists
videos_train_5k_A_dir = os.path.join(dataset_root, 'videos/train_5k_A')
videos_train_5k_A_files = [os.path.join(videos_train_5k_A_dir, f) for f in os.listdir(videos_train_5k_A_dir) if os.path.isfile(os.path.join(videos_train_5k_A_dir, f))]

print("videos_train_5k_A_dir= {}".format(videos_train_5k_A_dir))
print("len(videos/train_5k_A)= {}".format(len(videos_train_5k_A_files)))

# ########## display
from IPython.display import display, HTML

# video
test_video_path = videos_train_5k_A_files[3000]
print(test_video_path)
print(os.path.exists(test_video_path))
html_str = '''
<video controls width=\"500\" height=\"500\" src=\"{}\">animation</video>
'''.format(test_video_path)
print(html_str)
display(HTML(html_str))