In [None]:
from tqdm import tqdm
import sys
import numpy as np
sys.path.append('/home/shenqi/Master_thesis')
import torch
from collections import defaultdict

from Neuronbench.utils.models import Vanilla
from Neuronbench.utils.dataset import seq_dataloader

from metavision_ml.detection.anchors import Anchors
from metavision_ml.detection.rpn import BoxHead
from metavision_ml.data import box_processing as box_api
from metavision_ml.detection.losses import DetectionLoss
from metavision_ml.metrics.coco_eval import CocoEvaluator
from metavision_sdk_core import EventBbox

In [None]:
dataloader = seq_dataloader(num_tbins=1, batch_size = 1)
model = Vanilla(cin = dataloader.in_channels, cout = 256, base = 16)
box_coder = Anchors(num_levels=model.levels, anchor_list="PSEE_ANCHORS", variances=[0.1, 0.2])
head = BoxHead(model.cout, box_coder.num_anchors, len(dataloader.wanted_keys) + 1, 0)
model = model.to('cuda')
head = head.to('cuda')
model.load_state_dict(torch.load('../train_RED/save_models/25_model.pth',map_location=torch.device('cuda')))
head.load_state_dict(torch.load('../train_RED/save_models/25_pd.pth',map_location=torch.device('cuda')))

In [None]:
def accumulate_predictions(preds, targets, video_infos, frame_is_labeled):
       
        dt_detections = {}
        gt_detections = {}
        for t in range(len(targets)):
            for i in range(len(targets[t])):
                gt_boxes = targets[t][i]
                pred = preds[t][i]

                video_info, tbin_start, _ = video_infos[i]

                if video_info.padding or frame_is_labeled[t, i] == False:
                    continue

                name = video_info.path
                if name not in dt_detections:
                    dt_detections[name] = [np.zeros((0), dtype=box_api.EventBbox)]
                if name not in gt_detections:
                    gt_detections[name] = [np.zeros((0), dtype=box_api.EventBbox)]
                assert video_info.start_ts == 0
                ts = tbin_start + t * video_info.delta_t

                if isinstance(gt_boxes, torch.Tensor):
                    gt_boxes = gt_boxes.cpu().numpy()
                if gt_boxes.dtype == np.float32:
                    gt_boxes = box_api.box_vectors_to_bboxes(gt_boxes[:, :4], gt_boxes[:, 4], ts=ts)

                if pred['boxes'] is not None and len(pred['boxes']) > 0:
                    boxes = pred['boxes'].cpu().data.numpy()
                    labels = pred['labels'].cpu().data.numpy()
                    scores = pred['scores'].cpu().data.numpy()
                    dt_boxes = box_api.box_vectors_to_bboxes(boxes, labels, scores, ts=ts)
                    dt_detections[name].append(dt_boxes)
                else:
                    dt_detections[name].append(np.zeros((0), dtype=EventBbox))

                if len(gt_boxes):
                    gt_boxes["t"] = ts
                    gt_detections[name].append(gt_boxes)
                else:
                    gt_detections[name].append(np.zeros((0), dtype=EventBbox))

        return dt_detections, gt_detections

In [8]:
from functools import partial
from metavision_ml.detection_tracking.display_frame import draw_box_events
from skvideo.io import FFmpegWriter
import cv2

viz_labels = partial(draw_box_events, label_map=['background'] + dataloader.wanted_keys, thickness = 2)
video_writer = FFmpegWriter('vis.mp4', outputdict={
                                            '-vcodec': 'libx264', '-crf': '20', '-preset': 'veryslow','-r': '20'})
size_x = 2
size_y = 1
height_scaled = 360
width_scaled = 640
frame = np.zeros((size_y * height_scaled, width_scaled * size_x, 3), dtype=np.uint8)

model.eval()
head.eval()

loader = dataloader.seq_dataloader_val
with tqdm(total=len(loader), desc=f'Testing',ncols=120) as pbar:
    for ind, data in enumerate(loader):
        pbar.update(1)
        inputs = data['inputs'].to(device='cuda')
        with torch.no_grad():
            batch = data['inputs']
            im =batch[0][0].cpu().numpy()
            img = loader.get_vis_func()(im)
            img_RED = img.copy()
            labels = data["labels"][0][0]
            img = viz_labels(img, labels)
            
            feature = model(inputs)
            loc_preds_val, cls_preds_val = head(feature)
            scores = head.get_scores(cls_preds_val)
            scores = scores.to('cpu')
            for i, feat in enumerate(feature):
                feature[i] = feature[i].to('cpu')
            inputs = data['inputs'].to('cpu')
            loc_preds_val = loc_preds_val.to('cpu')
            preds = box_coder.decode(feature, inputs, loc_preds_val, scores, batch_size=inputs.shape[1], score_thresh=0.5,
                        nms_thresh=0.5, max_boxes_per_input=500)
            # print(preds)
            dt_dic, gt_dic = accumulate_predictions(preds, data["labels"], data["video_infos"], data["frame_is_labeled"])
            output_dt_RED = list(dt_dic.values())
            
            if(len(output_dt_RED) < 1):
                img_RED = viz_labels(img_RED, [])
            else:
                img_RED = viz_labels(img_RED, output_dt_RED[0][1])
                
            metadata = loader.dataset.get_batch_metadata(ind)
            name = metadata[0][0].path.split('/')[-1]
            cv2.putText(img, name, (int(0.05 * (width_scaled)), int(0.94 * (height_scaled))),
                            cv2.FONT_HERSHEY_PLAIN, 1.2, (50, 240, 12))
            cv2.putText(img_RED, 'RED', (int(0.05 * (width_scaled)), int(0.94 * (height_scaled))),
                            cv2.FONT_HERSHEY_PLAIN, 1.2, (50, 240, 12))
            
            y, x = divmod(0, size_x)
            frame[y * (height_scaled):(y + 1) * (height_scaled),
                x * (width_scaled): (x + 1) * (width_scaled)] = img
            frame[y * (height_scaled):(y + 1) * (height_scaled),
                (x+1) * (width_scaled): (x + 2) * (width_scaled)] = img_RED
            video_writer.writeFrame(frame)
video_writer.close()

Testing: 100%|██████████████████████████████████████████████████████████████████████| 1200/1200 [02:08<00:00,  9.37it/s]
