In [None]:
from tqdm import tqdm
import sys
import numpy as np
sys.path.append('/home/shenqi/Master_thesis')
import torch
from collections import defaultdict

from Neuronbench.utils.models import Vanilla
from Neuronbench.utils.dataset import seq_dataloader

from metavision_ml.detection.anchors import Anchors
from metavision_ml.detection.rpn import BoxHead
from metavision_ml.data import box_processing as box_api
from metavision_ml.detection.losses import DetectionLoss
from metavision_ml.metrics.coco_eval import CocoEvaluator
from metavision_sdk_core import EventBbox

In [None]:
dataloader = seq_dataloader()
model = Vanilla(cin = dataloader.in_channels, cout = 256, base = 16)
box_coder = Anchors(num_levels=model.levels, anchor_list="PSEE_ANCHORS", variances=[0.1, 0.2])
head = BoxHead(model.cout, box_coder.num_anchors, len(dataloader.wanted_keys) + 1, 0)
model = model.to('cuda')
head = head.to('cuda')
model.load_state_dict(torch.load('../train_RED/save_models/25_model.pth',map_location=torch.device('cuda')))
head.load_state_dict(torch.load('../train_RED/save_models/25_pd.pth',map_location=torch.device('cuda')))


In [None]:
def accumulate_predictions(preds, targets, video_infos, frame_is_labeled,skip_us):
       
        dt_detections = {}
        gt_detections = {}
        for t in range(len(targets)):
            for i in range(len(targets[t])):
                gt_boxes = targets[t][i]
                pred = preds[t][i]

                video_info, tbin_start, _ = video_infos[i]

                if video_info.padding or frame_is_labeled[t, i] == False:
                    continue

                name = video_info.path
                if name not in dt_detections:
                    dt_detections[name] = [np.zeros((0), dtype=box_api.EventBbox)]
                if name not in gt_detections:
                    gt_detections[name] = [np.zeros((0), dtype=box_api.EventBbox)]
                assert video_info.start_ts == 0
                ts = tbin_start + t * video_info.delta_t

                if ts < skip_us:
                    continue

                if isinstance(gt_boxes, torch.Tensor):
                    gt_boxes = gt_boxes.cpu().numpy()
                if gt_boxes.dtype == np.float32:
                    gt_boxes = box_api.box_vectors_to_bboxes(gt_boxes[:, :4], gt_boxes[:, 4], ts=ts)

                if pred['boxes'] is not None and len(pred['boxes']) > 0:
                    boxes = pred['boxes'].cpu().data.numpy()
                    labels = pred['labels'].cpu().data.numpy()
                    scores = pred['scores'].cpu().data.numpy()
                    dt_boxes = box_api.box_vectors_to_bboxes(boxes, labels, scores, ts=ts)
                    dt_detections[name].append(dt_boxes)
                else:
                    dt_detections[name].append(np.zeros((0), dtype=EventBbox))

                if len(gt_boxes):
                    gt_boxes["t"] = ts
                    gt_detections[name].append(gt_boxes)
                else:
                    gt_detections[name].append(np.zeros((0), dtype=EventBbox))

        return dt_detections, gt_detections

def inference_epoch_end(outputs):
    
    print('==> Start evaluation')
    dt_detections = defaultdict(list)
    gt_detections = defaultdict(list)

    for item in outputs:
        for k, v in item['gt'].items():
            gt_detections[k].extend(v)
        for k, v in item['dt'].items():
            dt_detections[k].extend(v)

    evaluator = CocoEvaluator(classes=['background'] + dataloader.wanted_keys, height=dataloader.height, width=dataloader.width)
    for key in gt_detections:
        evaluator.partial_eval([np.concatenate(gt_detections[key])], [np.concatenate(dt_detections[key])])
    coco_kpi = evaluator.accumulate()
    return coco_kpi

In [None]:
output_val_list = []
with tqdm(total=len(dataloader.seq_dataloader_test), desc=f'Testing',ncols=120) as pbar:
    for data in dataloader.seq_dataloader_test:
        pbar.update(1)
        inputs = data['inputs'].to(device='cuda')
        with torch.no_grad():
            feature = model(inputs)
            loc_preds_val, cls_preds_val = head(feature)
            scores = head.get_scores(cls_preds_val)
            scores = scores.to('cpu')
            for i, feat in enumerate(feature):
                feature[i] = feature[i].to('cpu')
            inputs = data['inputs'].to('cpu')
            loc_preds_val = loc_preds_val.to('cpu')
            preds = box_coder.decode(feature, inputs, loc_preds_val, scores, batch_size=inputs.shape[1], score_thresh=0.05,
                        nms_thresh=0.5, max_boxes_per_input=500)
            # print(preds)
            dt_dic, gt_dic = accumulate_predictions(preds, data["labels"], data["video_infos"], data["frame_is_labeled"], 500000)
            output_val_list.append({'dt': dt_dic, 'gt': gt_dic})
    coco_val_result = inference_epoch_end(output_val_list)
    print(coco_val_result)