# Inference

In [1]:
# Importation et modification des limites de taille d'images ouvrables avec opencv et PIL
import ast
import os
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = str(pow(2,40))
import pathlib

import neptune
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.models.detection.transform import GeneralizedRCNNTransform

from scripts.datasets import ObjectDetectionDatasetSingle, ObjectDetectionDataSet
from scripts.faster_RCNN import get_fasterRCNN_resnet
from scripts.transformations import ComposeDouble
from scripts.transformations import ComposeSingle
from scripts.transformations import FunctionWrapperDouble
from scripts.transformations import FunctionWrapperSingle
from scripts.transformations import apply_nms, apply_score_threshold
from scripts.transformations import normalize_01
from scripts.utils import get_filenames_of_path, collate_single, save_json
from scripts.visual import DatasetViewer
from scripts.visual import DatasetViewerSingle
from scripts.balayage import *
from skimage.io import imread
import PIL
PIL.Image.MAX_IMAGE_PIXELS = 500000000

In [2]:
# Paramètres à modifier en fonction du modèle
params = {'INPUT_DIR': 'Inférence/Input',  # files to predict
          'PREDICTIONS_PATH': 'Inférence/Output',  # where to save the predictions
          'MODEL_DIR': 'Inférence/model/model.ckpt',  # load model from checkpoint
          }
parameters = {'BACKBONE': 'resnet34', # modèle choisir idéalement entre resnet18, resnet34, resnet50
          'FPN': 'False',
          'ANCHOR_SIZE': '((32, 64, 128, 256, 512),)',
          'ASPECT_RATIOS': '((0.5, 1.0, 2.0),)',
          'MIN_SIZE': 1024,
          'MAX_SIZE': 1024,
          'IMG_MEAN': '[0.485, 0.456, 0.406]',
          'IMG_STD': '[0.229, 0.224, 0.225]',
          'IOU_THRESHOLD': 0.5
          }

In [5]:
# Récupération des images
inputs = get_filenames_of_path(pathlib.Path(params['INPUT_DIR']))
inputs.sort()

In [5]:
# transformations
transforms = ComposeSingle([
    FunctionWrapperSingle(np.moveaxis, source=-1, destination=0),
    FunctionWrapperSingle(normalize_01)
])

In [24]:
# Récupération du modèle
checkpoint = torch.load(params['MODEL_DIR'])
model_state_dict = checkpoint['hyper_parameters']['model'].state_dict()

In [25]:
# initialisation du modèle
model = get_fasterRCNN_resnet(num_classes=int(parameters['CLASSES']),
                              backbone_name=parameters['BACKBONE'],
                              anchor_size=ast.literal_eval(parameters['ANCHOR_SIZE']),
                              aspect_ratios=ast.literal_eval(parameters['ASPECT_RATIOS']),
                              fpn=ast.literal_eval(parameters['FPN']),
                              min_size=int(parameters['MIN_SIZE']),
                              max_size=int(parameters['MAX_SIZE'])
                              )

In [26]:
# affectation des poids
model.load_state_dict(model_state_dict)

<All keys matched successfully>

### Inférence

In [None]:
# inference (cpu)
model.eval()
for filename in inputs : 
    x = imread(filename)
    x=np.array(x)
    dic = balayage_inference_single(x,(1024,1024), 512)
    preds = inference_on_balayage(dic,model,transforms)    
    save_json(preds, path=pathlib.Path('Inférence/Output/'+ os.path.basename(filename)).with_suffix('.json'))
    print('saved file for image: '+str(filename))

### Affichage des résulats

In [6]:
# get prediction files
predictions = get_filenames_of_path(pathlib.Path('Inférence/Output'))
predictions.sort()

In [11]:
# create prediction dataset
iou_threshold = 0.25 #IntersectionOverUnion
score_threshold = 0.95 

transforms_prediction = ComposeDouble([
    FunctionWrapperDouble(np.moveaxis, source=-1, destination=0),
    FunctionWrapperDouble(normalize_01),
    FunctionWrapperDouble(apply_nms, input=False, target=True, iou_threshold=iou_threshold),
    FunctionWrapperDouble(apply_score_threshold, input=False, target=True, score_threshold=score_threshold)
])

dataset_prediction = ObjectDetectionDataSet(inputs=inputs,
                                            targets=predictions,
                                            transform=transforms_prediction,
                                            use_cache=False)

In [12]:
# mapping
color_mapping = {
    1: 'red',
}

In [13]:
# visualize predictions
datasetviewer_prediction = DatasetViewer(dataset_prediction, color_mapping)
datasetviewer_prediction.napari()
# add text properties gui
datasetviewer_prediction.gui_text_properties(datasetviewer_prediction.shape_layer)