In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from utils import mAP_f1_p_fix_r
from utils import evaluate_scenes, predictions_to_scenes
from utils import get_frames, get_batches, scenes2zero_one_representation, visualize_predictions
import ffmpeg

import os
import pickle
from tqdm import tqdm

import numpy as np
import torch

In [10]:
fnm_path_dict = {}
dir_list = ["custom/"] ### Add directories names here for inference

### Loads videos from dir_list adds them with their path to fnm_path_dict
for cur_dir in dir_list:
    for fnm in os.listdir(cur_dir):
        if fnm.endswith(".mp4"):
            fnm_path_dict[fnm[:-len(".mp4")]] = cur_dir + fnm
        elif fnm.endswith(".webm"):
            fnm_path_dict[fnm[:-len(".webm")]] = cur_dir + fnm

# print(fnm_path_dict)

In [11]:
### max F1 - A
### Chosen AutoShot model
from supernet_flattransf_3_8_8_8_13_12_0_16_60 import TransNetV2Supernet
supernet_best_f1 = TransNetV2Supernet().eval() ### Eval mode

In [12]:
### Cuda / CPU
if torch.cuda.is_available() is True:
    device = "cuda"
else:
    device = "cpu"

### Pretrained model weights loading
pretrained_path = os.path.join("./ckpt_0_200_0.pth")
if os.path.exists(pretrained_path):
    print('Loading pretrained_path from %s' % pretrained_path)

    model_dict = supernet_best_f1.state_dict() ### dictionary of current weights in model as an ordered_dict
    pretrained_dict = torch.load(pretrained_path, map_location=device)

    # for k, v in pretrained_dict.items():
    #     print(k, type(v)) ### Ordered dict with layer_names:weights and name: 'net'
    
    ###! Only updates keys that are present in model_dict.. Important if finetuning
    pretrained_dict = {k: v for k, v in pretrained_dict['net'].items() if k in model_dict}
    print("Current model has %d params, Updating %d params from checkpoint" % (len(model_dict), len(pretrained_dict)))

    ### Update weights present in model
    model_dict.update(pretrained_dict)
    supernet_best_f1.load_state_dict(model_dict) ### load all relevant weights to best supernet

else:
    raise Exception("Error: Can NOT find pretrained best model!!")

### Switch to cuda if available
if device == "cuda":
    print("Cuda available. Switching to cuda.")
    supernet_best_f1 = supernet_best_f1.cuda(0)

Loading pretrained_path from ./ckpt_0_200_0.pth
Current model has 90 params, Updating 90 params from checkpoint
Cuda available. Switching to cuda.


In [13]:
supernet_best_f1.eval()

# Evaluation
def predict(batch):
    batch = torch.from_numpy(batch.transpose((3, 0, 1, 2))[np.newaxis, ...]) * 1.0
    batch = batch.to(device)
    one_hot = supernet_best_f1(batch)
    if isinstance(one_hot, tuple):
        one_hot = one_hot[0]

    return torch.sigmoid(one_hot[0])

In [14]:
supernet_best_f1_one_hot_pred_dict_custom = {}
i = 0

for fnm in tqdm(fnm_path_dict):
    i += 1
    # print(i, fnm)

    predictions = []

    frames = get_frames(fnm_path_dict[fnm])

    for batch in get_batches(frames):

        one_hot = predict(batch)
        one_hot = one_hot.detach().cpu().numpy()
        predictions.append(one_hot[25:75])

    predictions = np.concatenate(predictions, 0)[:len(frames)]
    supernet_best_f1_one_hot_pred_dict_custom[fnm] = predictions

100%|██████████| 6/6 [00:29<00:00,  4.94s/it]


In [15]:
# ### Saving best model predictions - Use if saving is needed
# with open('./custom_pickles/supernet_best_custom.pickle', 'wb') as handle:
#     pickle.dump(supernet_best_f1_one_hot_pred_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
# handle.close()

In [16]:
# ### Load saved model
# with open('./custom_pickles/supernet_best_custom.pickle', 'rb') as handle:
#     supernet_best_f1_one_hot_pred_dict_custom = pickle.load(handle)
# handle.close()

In [17]:
### Visualize predictions

for fnm in tqdm(fnm_path_dict):
    frames = get_frames(fnm_path_dict[fnm])
    c_preds = (supernet_best_f1_one_hot_pred_dict_custom[fnm]>0.296).astype(np.uint8).flatten()

    img = visualize_predictions(
        frames,
        predictions=c_preds,
        show_frame_num=True
    )

    if not os.path.exists("./custom_preds/"):
        os.mkdir("./custom_preds/")
    im = img.save("./custom_preds/" + fnm + ".png")

100%|██████████| 6/6 [00:19<00:00,  3.25s/it]
