In [None]:
import sys,os
sys.path.append('/home/shenqi/Master_thesis/SEED')
from tqdm import tqdm
from functools import partial
import numpy as np
import cv2
import torch
import torch.nn as nn
from metavision_ml.detection.anchors import Anchors
from metavision_ml.detection_tracking.display_frame import draw_box_events

from Model.feature_extractor import SEED_EventGRU,EMinGRU_ReLUFuseDownsampleConv_ConditionalConv
from Model.ssd_head import BoxHead
from Model.detection import inference_step,evaluate
from utils.dataloader import seq_dataloader
import utils.data_augmentation as data_aug
from skvideo.io import FFmpegWriter

dataset_path = '/media/shenqi/data/Gen4_multi_timesurface_FromDat_super_small'
dataset_type = 'gen4'
dataloader = seq_dataloader(dataset_path = dataset_path, dataset_type = dataset_type, num_tbins = 1, batch_size = 1, channels = 6)


In [None]:

cout = 256
net = EMinGRU_ReLUFuseDownsampleConv_ConditionalConv(dataloader.channels, base=int(cout/16), cout=cout, dataset = dataset_type, pruning = False)
box_coder = Anchors(num_levels=net.levels, anchor_list='PSEE_ANCHORS', variances=[0.1, 0.2])
ssd_head = BoxHead(net.cout, box_coder.num_anchors, len(dataloader.wanted_keys)+1, n_layers=0)

net.load_state_dict(torch.load('/home/shenqi/Master_thesis/SEED/Saved_Model/new_gen4/EMinGRU_ReLUFuseDownsampleConv_ConditionalConv_256_norelu_n15b8/48_model.pth',map_location=torch.device('cuda')))
ssd_head.load_state_dict(torch.load('/home/shenqi/Master_thesis/SEED/Saved_Model/new_gen4/EMinGRU_ReLUFuseDownsampleConv_ConditionalConv_256_norelu_n15b8/48_pd.pth',map_location=torch.device('cuda')))

net.eval().to('cuda')
ssd_head.eval().to('cuda')

augment = data_aug.data_augmentation(dataset_type= dataset_type)

viz_labels = partial(draw_box_events, label_map=['background']+dataloader.wanted_keys, thickness = 2)
video_writer = FFmpegWriter('EMGU_condition.mp4', outputdict={'-vcodec': 'libx264', '-crf': '20', '-preset': 'veryslow','-r': '20'})
size_x = 2
size_y = 1
height_scaled = 360
width_scaled = 640
frame = np.zeros((size_y * height_scaled, width_scaled * size_x, 3), dtype=np.uint8)


In [None]:
with tqdm(total=len(dataloader.seq_dataloader_test)) as pbar:
    with torch.no_grad():     
        for ind, data in enumerate(dataloader.seq_dataloader_test):
            pbar.update(1)
            
            mask = data["mask_keep_memory"]
            metadata = dataloader.seq_dataloader_test.dataset.get_batch_metadata(ind)
            
            data['inputs'] = data['inputs'].to(device='cuda')
            if data['frame_is_labeled'].sum().item() != 0:
                data = augment(data, only_vertical_move=True)  
            
            batch = data["inputs"]
            
            output_val_emgu,*_ = inference_step(data,net,ssd_head,box_coder)
            
            value_output_val_dt_emguf = list(output_val_emgu['dt'].values())
            
            
            index = 0
            t = 0   
            im = batch[t][index]
            
            im = im.cpu().numpy()
            
            
            y, x = divmod(index, size_x)
            img = dataloader.seq_dataloader_test.get_vis_func()(im)
            
            img_emgu = img.copy()
        
            if viz_labels is not None:
                labels = data["labels"][t][index]
                img = viz_labels(img, labels)
                
                if(len(value_output_val_dt_emguf)<1):
                    img_emgu = viz_labels(img_emgu, [])
                else:
                    img_emgu = viz_labels(img_emgu, value_output_val_dt_emguf[0][1])
                
                
            if t <= 1 and not mask[index]:
                # mark the beginning of a sequence with a red square
                img[:10, :10, 0] = 222
            name = metadata[index][0].path.split('/')[-1]
            cv2.putText(img, name, (int(0.05 * (width_scaled)), int(0.94 * (height_scaled))),
                            cv2.FONT_HERSHEY_PLAIN, 1.2, (50, 240, 12))
            

            frame[y * (height_scaled):(y + 1) * (height_scaled),
                    x * (width_scaled): (x + 1) * (width_scaled)] = img
            frame[y * (height_scaled):(y + 1) * (height_scaled),
                (x+1) * (width_scaled): (x + 2) * (width_scaled)] = img_emgu
            
            
            video_writer.writeFrame(frame)
            
                
            
    video_writer.close()