In [3]:
import os
import torch
import torchvision
from tqdm import tqdm
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from pycocotools.coco import COCO
from glob import glob
from datetime import datetime

from utils import get_temp_model, tempPredictDataset, IoU, clean_box, get_birds, plot_box, getDF
from utils import getListImg, cropImg, StreamArgs, getFrames, tempPredictVideo

In [4]:
torch.cuda.empty_cache()

def get_transform():
    custom_transforms = []
    custom_transforms.append(torchvision.transforms.ToTensor())
    return torchvision.transforms.Compose(custom_transforms)

# collate_fn needs for batch
def collate_fn(batch):
    return tuple(zip(*batch))

n_batch = 4
num_classes = 2

# select device (whether GPU or CPU)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [5]:
model = get_temp_model(num_classes)
model.load_state_dict(torch.load(os.getcwd() + '/../models/bird_detection/outputs/models/output_model_temp_full_v0.pt')['model']) #if only CPU available 
model = model.to(device)
model.eval()
print('... model loaded.')

... model loaded.


In [8]:
root_path = os.getcwd() + '/../data/birds/images'
folders = glob(root_path + '/*')
list_files = [glob(f + '/*')[0] for f in folders]

dict_list = {i+1:[] for i in range(8)}
for f in list_files:
    dict_list[int(f.split('/')[-1][1])].append(f.split('/')[-2][:6])

In [9]:
dict_list

{1: ['220429', '220507'],
 2: ['220430', '220508'],
 3: ['220501', '220503', '220509'],
 4: ['220502', '220510'],
 5: ['220501', '220503', '220511'],
 6: ['220504', '220511', '220512'],
 7: ['220505', '220509', '220513'],
 8: ['220506', '220507']}

In [24]:
output_df = pd.DataFrame(columns = ['date', 'heure', 'methode', 'abondance', 'sous-semis', 'J_T', 'bois',
       'arrosage', 'n°cam'])

for j in range(8):
    
    c_num = j + 1
    
    print('*************************')
    print('Process camera', c_num)
    print('*************************')
    
    
    # ITERATE over the different dates
    for l, date in enumerate(dict_list[c_num]):
        
        paths = glob(root_path + '/' + date +'*/*.png')
        print('Video ', l+1, '/', len(dict_list[c_num]))
        
        out_fns = [paths[0].split('/')[-1][:9] + date , paths[930].split('/')[-1][:9] + date]

        print('Load images...')
        
        img_list = [[cv2.imread(p1) for p1 in paths[:930]], [cv2.imread(p2) for p2 in paths[930:]]]
        
        print('... done.')

        # ITERATE for left and right side: i
        ds = tempPredictVideo(img_list[0], out_fns[0], get_transform())
        dataloader = torch.utils.data.DataLoader(ds,
                        batch_size = n_batch,
                        shuffle = False,
                        num_workers = 6,
                        collate_fn = collate_fn)

        for data in tqdm(dataloader):
        
            imgs, fns = data
            imgs = list(img.float().to(device) for img in imgs)
            fns = list(fns)
                
            pred = model(imgs)
                
            # keep predictions of scores higher than 0.6
            boxes = []
            for j, p in enumerate(pred):
                cond = p['scores'] > 0.6
                boxes.append(p['boxes'][cond])
                    
            # remove boxes prediction that overlap more than 75%
            clean_boxes = clean_box(boxes)
                
            # save all information in dataframe
            df = getDF(clean_boxes, fns)
            output_df = output_df.append(df)
                
        output_df.to_csv('metrics_cameras_test.csv', index = False)

*************************
Process camera 1
*************************
Video  0 / 2
Load images...
... done.


 33%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               | 78/233 [01:30<03:00,  1.16s/it]


KeyboardInterrupt: 

# Visualize

In [None]:
for i in range(len(boxes)):
    np_img = imgs[i][3:6].cpu().numpy().transpose(1,2,0)
    np_box = boxes[i].cpu().detach().numpy()
    
    # plot images where birds were detected
    if len(np_box) != 0:
        plot_box(np_img, np_box)